From 8dadd36d296493674b67392776e60335714d1931 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 24 Feb 2026 17:38:30 -0800 Subject: [PATCH 1/2] [devtools] Add preprocessing support to NumericalComparatorBase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/17433 This change extends NumericalComparatorBase to support custom preprocessing before numeric gap calculation, enabling backends like Qualcomm to apply necessary tensor transformations (e.g., dequantization, layout conversion) before comparison. Key changes: - Extended NumericalComparatorBase with: - `__init__(inspector)` to store optional Inspector reference for accessing graph metadata - `preprocessing(mapping)` method that can be overridden for custom tensor transformations (default: identity) - `element_compare(a, b)` abstract method for element-level tensor comparison - `compare(mapping, ...)` method that orchestrates the full pipeline: preprocessing → element-wise compare → aggregate to DataFrame This is part of the operator-level numeric discrepancy detector project for ExecuTorch Qualcomm backend (https://github.com/pytorch/executorch/issues/16381). Design doc: https://docs.google.com/document/d/1GaCHiy9InytOsUrl2BKEgOiP1iKTfpCVdWg6QDh0N2E/edit?tab=t.0#heading=h.fcrpnrtb6cud ghstack-source-id: 344497031 Differential Revision: [D93169813](https://our.internmc.facebook.com/intern/diff/D93169813/) --- devtools/inspector/_inspector.py | 57 ++-- devtools/inspector/_inspector_utils.py | 34 -- .../numerical_comparator/__init__.py | 12 +- .../l1_numerical_comparator.py | 12 +- .../mse_numerical_comparator.py | 12 +- .../numerical_comparator_base.py | 250 ++++++++++++++- .../snr_numerical_comparator.py | 12 +- devtools/inspector/tests/inspector_test.py | 298 +++++++++++++++++- .../inspector/tests/inspector_utils_test.py | 21 +- .../inspector/tests/l1_comparator_test.py | 10 +- .../inspector/tests/mse_comparator_test.py | 10 +- .../inspector/tests/snr_comparator_test.py | 10 +- 12 files changed, 616 insertions(+), 122 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 6b6b4f583a6..368824f71a3 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -42,7 +42,6 @@ from executorch.devtools.etrecord import ETRecord, parse_etrecord from executorch.devtools.inspector._inspector_utils import ( calculate_time_scale_factor, - compare_intermediate_outputs, create_debug_handle_to_op_node_mapping, DebugHandle, display_or_print_df, @@ -50,7 +49,6 @@ EXCLUDED_COLUMNS_WHEN_PRINTING, EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT, EXCLUDED_EVENTS_WHEN_PRINTING, - find_op_names, find_populated_event, FORWARD, gen_etdump_object, @@ -1421,8 +1419,10 @@ def calculate_numeric_gap( Args: distance: The metrics the inspector will use for gap calculation. Can be either: - A string: one of "MSE", "L1", or "SNR" for built-in comparators. - - A custom NumericalComparatorBase instance: allows you to define custom comparison logic - by subclassing NumericalComparatorBase and implementing the compare() method. + - A custom NumericalComparatorBase instance: allows you to define custom comparison + logic by subclassing NumericalComparatorBase and implementing the element_compare() + method. Custom comparators can also override the preprocessing() method to apply + transformations (e.g., layout conversion, dequantization) before comparison. disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR @@ -1448,48 +1448,27 @@ def calculate_numeric_gap( mapping = map_runtime_aot_intermediate_outputs( aot_intermediate_outputs, runtime_intermediate_outputs ) + + # Get or create comparator if isinstance(distance, NumericalComparatorBase): comparator = distance + # Inject inspector if not already set + if comparator.inspector is None: + comparator.inspector = self else: metric = distance.strip().upper() if metric == "MSE": - comparator = MSEComparator() + comparator = MSEComparator(inspector=self) elif metric == "L1": - comparator = L1Comparator() + comparator = L1Comparator(inspector=self) elif metric == "SNR": - comparator = SNRComparator() + comparator = SNRComparator(inspector=self) else: raise ValueError(f"Unsupported distance metric {distance!r}") - rows = [] - for (aot_debug_handle, aot_intermediate_output), ( - runtime_debug_handle, - runtime_intermediate_output, - ) in mapping.items(): - if aot_intermediate_output is None or runtime_intermediate_output is None: - continue - # If aot outputs length is > 1 then comparison fails since we dont really have - # any instances where runtime intermediate output is a tuple or list - # This does not happen when edge dialect program is reference for comparison - # but happens in aten graph where ops like unbind remain undecomposed - if ( - isinstance(aot_intermediate_output, Sequence) - and len(aot_intermediate_output) > 1 - ): - continue - rows.append( - { - "aot_ops": find_op_names( - aot_debug_handle, aot_debug_handle_to_op_names - ), - "aot_intermediate_output": aot_intermediate_output, - "runtime_ops": find_op_names( - runtime_debug_handle, runtime_debug_handle_to_op_names - ), - "runtime_intermediate_output": runtime_intermediate_output, - "gap": compare_intermediate_outputs( - aot_intermediate_output, runtime_intermediate_output, comparator - ), - } - ) - return pd.DataFrame(rows) + # Delegate to comparator's compare method (includes preprocessing) + return comparator.compare( + mapping, + aot_debug_handle_to_op_names, + runtime_debug_handle_to_op_names, + ) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 878e0ddb7e0..556987e4bbf 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -1068,40 +1068,6 @@ def find_op_names( return result -def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: - """ - Compare two outputs, handling both sequence and non-sequence cases, - and return a list of comparison results. - Parameters: - a: The first intermediate output to compare. - b: The second intermediate output to compare. - comparator: A comparator object with a `compare` method. - Returns: - List[float]: A list of comparison results. - Raises: - ValueError: If one input is a sequence and the other is not, or if sequences have different lengths. - """ - is_a_sequence = isinstance(a, Sequence) - is_b_sequence = isinstance(b, Sequence) - if is_a_sequence and is_b_sequence: - # Ensure both sequences have the same length - if len(a) != len(b): - raise ValueError( - f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}." - ) - - # Compare each element in the sequences and return the list of results - return [comparator.compare(x, y) for x, y in zip(a, b)] - elif not is_a_sequence and not is_b_sequence: - # Compare non-sequence items and return the result in a list - return [comparator.compare(a, b)] - else: - # Raise an error if one is a sequence and the other is not - raise ValueError( - f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences." - ) - - def get_ancestor_node_identifiers(node: Node) -> List[str]: """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in. diff --git a/devtools/inspector/numerical_comparator/__init__.py b/devtools/inspector/numerical_comparator/__init__.py index 0090c50025f..68ccfabe02a 100644 --- a/devtools/inspector/numerical_comparator/__init__.py +++ b/devtools/inspector/numerical_comparator/__init__.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +# Re-export DebugHandle from _inspector_utils for convenience +from executorch.devtools.inspector._inspector_utils import DebugHandle from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import ( L1Comparator, ) @@ -14,6 +16,7 @@ ) from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + IntermediateOutputMapping, NumericalComparatorBase, ) @@ -22,4 +25,11 @@ ) -__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"] +__all__ = [ + "DebugHandle", + "IntermediateOutputMapping", + "L1Comparator", + "MSEComparator", + "NumericalComparatorBase", + "SNRComparator", +] diff --git a/devtools/inspector/numerical_comparator/l1_numerical_comparator.py b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py index 43f4f170c2f..ddc6233b769 100644 --- a/devtools/inspector/numerical_comparator/l1_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -12,9 +12,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class L1Comparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """L1 (sum of absolute differences) comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """Sum up all these element-wise absolute differences between two tensors.""" t_a = convert_to_float_tensor(a) diff --git a/devtools/inspector/numerical_comparator/mse_numerical_comparator.py b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py index c4693ff2ad4..7a6b323e81a 100644 --- a/devtools/inspector/numerical_comparator/mse_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -12,9 +12,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class MSEComparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """Mean Squared Error comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """Compare mean squared difference between two outputs.""" t_a = convert_to_float_tensor(a) diff --git a/devtools/inspector/numerical_comparator/numerical_comparator_base.py b/devtools/inspector/numerical_comparator/numerical_comparator_base.py index db498980e1f..c4f8a90f78f 100644 --- a/devtools/inspector/numerical_comparator/numerical_comparator_base.py +++ b/devtools/inspector/numerical_comparator/numerical_comparator_base.py @@ -6,21 +6,259 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING + +import pandas as pd + +from executorch.devtools.inspector._inspector_utils import DebugHandle + +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + +# Type alias for the mapping used in preprocessing +# Maps (aot_debug_handle, aot_output) -> (runtime_debug_handle, runtime_output) +IntermediateOutputMapping = Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] class NumericalComparatorBase(ABC): + """Base class for numerical comparison with optional preprocessing. + + This class provides a framework for comparing intermediate outputs between + AOT (Ahead-of-Time) and runtime execution. Subclasses can override the + `preprocessing` method to transform tensors before comparison (e.g., layout + conversion, dequantization) and must implement `element_compare` for + element-wise comparison logic. + + The `compare` method is the main entry point called by Inspector, which + orchestrates the full comparison pipeline: preprocess -> element-wise compare + -> aggregate results into a DataFrame. + + Attributes: + _inspector: Optional reference to the Inspector instance, which provides + access to the reference graph and other metadata needed for preprocessing. + """ + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + """Initialize the comparator. + + Args: + inspector: Optional Inspector instance that provides access to the + reference graph and other metadata. Can be set later via the + `inspector` property. + """ + self._inspector: Optional["Inspector"] = inspector + + @property + def inspector(self) -> Optional["Inspector"]: + """Get the Inspector instance.""" + return self._inspector + + @inspector.setter + def inspector(self, value: Optional["Inspector"]) -> None: + """Set the Inspector instance.""" + self._inspector = value + + def preprocessing( + self, mapping: IntermediateOutputMapping + ) -> IntermediateOutputMapping: + """Transform the mapping before comparison. + + Override this method to apply custom preprocessing to the intermediate + outputs before comparison. This is useful for backends like Qualcomm that + require tensor transformations (e.g., dequantization, layout conversion) + before accurate numeric discrepancy measurement. + + The default implementation returns the mapping unchanged. + + Args: + mapping: Dictionary mapping AOT (debug_handle, intermediate_output) pairs + to runtime (debug_handle, intermediate_output) pairs. + + - Key: Tuple[DebugHandle, Any] + - DebugHandle: Tuple[int, ...] - debug handle(s) from AOT graph + - Any: torch.Tensor or sequence - AOT intermediate output + + - Value: Tuple[DebugHandle, Any] + - DebugHandle: Tuple[int, ...] - debug handle(s) from runtime + - Any: torch.Tensor or sequence - runtime intermediate output + + Returns: + The transformed mapping, ready for element-wise comparison. + + Note: + When implementing custom preprocessing, you can access the reference + graph via `self._inspector.get_reference_graph()` to retrieve node + metadata such as quantization parameters or layout information. + """ + return mapping + @abstractmethod - def compare(self, a: Any, b: Any) -> float: - """Compare two intermediate output and return a result. + def element_compare(self, a: Any, b: Any) -> float: + """Compare two tensors and return a scalar distance. - This method should be overridden by subclasses to provide custom comparison logic. + This method should be overridden by subclasses to provide custom + element-wise comparison logic (e.g., MSE, L1, SNR). + + Args: + a: The first intermediate output to compare (typically AOT output). + b: The second intermediate output to compare (typically runtime output). + + Returns: + A numerical result indicating the comparison outcome (e.g., distance, + error metric). Lower values typically indicate better agreement. + """ + pass + + @staticmethod + def _validate_preprocessing_output( + processed_mapping: IntermediateOutputMapping, + ) -> None: + """Validate the output format of preprocessing(). + + Ensures the preprocessed mapping follows the expected format: + Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]] + + Args: + processed_mapping: The mapping returned by preprocessing(). + + Raises: + TypeError: If processed_mapping is not a dict. + ValueError: If any key or value in the mapping has an invalid format. + """ + if not isinstance(processed_mapping, dict): + raise TypeError( + f"preprocessing() must return a dict, got {type(processed_mapping).__name__}. " + "Expected format: Dict[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]" + ) + + for key, value in processed_mapping.items(): + # Validate key format: Tuple[DebugHandle, Any] + if not isinstance(key, tuple) or len(key) != 2: + raise ValueError( + f"Invalid key format in preprocessed mapping: {key}. " + "Expected Tuple[DebugHandle, Any] where DebugHandle is Tuple[int, ...]" + ) + aot_debug_handle, _ = key + if not isinstance(aot_debug_handle, tuple) or not all( + isinstance(x, int) for x in aot_debug_handle + ): + raise ValueError( + f"Invalid AOT debug handle in key: {aot_debug_handle}. " + "Expected Tuple[int, ...]" + ) + + # Validate value format: Tuple[DebugHandle, Any] + if not isinstance(value, tuple) or len(value) != 2: + raise ValueError( + f"Invalid value format in preprocessed mapping: {value}. " + "Expected Tuple[DebugHandle, Any] where DebugHandle is Tuple[int, ...]" + ) + runtime_debug_handle, _ = value + if not isinstance(runtime_debug_handle, tuple) or not all( + isinstance(x, int) for x in runtime_debug_handle + ): + raise ValueError( + f"Invalid runtime debug handle in value: {runtime_debug_handle}. " + "Expected Tuple[int, ...]" + ) + + def _compare_intermediate_outputs(self, a: Any, b: Any) -> List[float]: + """Compare two outputs, handling both sequence and non-sequence cases. Args: a: The first intermediate output to compare. b: The second intermediate output to compare. Returns: - A numerical result indicating the comparison outcome. + List[float]: A list of comparison results. + + Raises: + ValueError: If one input is a sequence and the other is not, + or if sequences have different lengths. """ - pass + is_a_sequence = isinstance(a, Sequence) + is_b_sequence = isinstance(b, Sequence) + if is_a_sequence and is_b_sequence: + if len(a) != len(b): + raise ValueError( + f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length " + f"for comparison. len(a): {len(a)} len(b): {len(b)}." + ) + return [self.element_compare(x, y) for x, y in zip(a, b)] + elif not is_a_sequence and not is_b_sequence: + return [self.element_compare(a, b)] + else: + raise ValueError( + f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences " + f"or both must be non-sequences." + ) + + def compare( + self, + mapping: IntermediateOutputMapping, + aot_debug_handle_to_op_names: Dict[DebugHandle, List[str]], + runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]], + ) -> pd.DataFrame: + """Full comparison pipeline: preprocess -> element-wise compare -> aggregate. + + This is the main entry point called by Inspector.calculate_numeric_gap(). + It orchestrates the full comparison pipeline and returns a DataFrame + with the results. + + Args: + mapping: Dictionary mapping AOT (debug_handle, intermediate_output) pairs + to runtime (debug_handle, intermediate_output) pairs. + aot_debug_handle_to_op_names: Mapping from AOT debug handles to operator names. + runtime_debug_handle_to_op_names: Mapping from runtime debug handles to operator names. + + Returns: + pd.DataFrame: A DataFrame with columns: + - aot_ops: List of AOT operator names + - aot_intermediate_output: AOT intermediate output tensor + - runtime_ops: List of runtime operator names + - runtime_intermediate_output: Runtime intermediate output tensor + - gap: List of numerical gap values + """ + from executorch.devtools.inspector._inspector_utils import find_op_names + + # Step 1: Apply preprocessing + processed_mapping = self.preprocessing(mapping) + + # Validate the preprocessed mapping format + self._validate_preprocessing_output(processed_mapping) + + # Step 2: Element-wise comparison and aggregation + rows = [] + for (aot_debug_handle, aot_intermediate_output), ( + runtime_debug_handle, + runtime_intermediate_output, + ) in processed_mapping.items(): + if aot_intermediate_output is None or runtime_intermediate_output is None: + continue + # If aot outputs length is > 1 then comparison fails since we don't really have + # any instances where runtime intermediate output is a tuple or list. + # This does not happen when edge dialect program is reference for comparison + # but happens in aten graph where ops like unbind remain undecomposed. + if ( + isinstance(aot_intermediate_output, Sequence) + and len(aot_intermediate_output) > 1 + ): + continue + rows.append( + { + "aot_ops": find_op_names( + aot_debug_handle, aot_debug_handle_to_op_names + ), + "aot_intermediate_output": aot_intermediate_output, + "runtime_ops": find_op_names( + runtime_debug_handle, runtime_debug_handle_to_op_names + ), + "runtime_intermediate_output": runtime_intermediate_output, + "gap": self._compare_intermediate_outputs( + aot_intermediate_output, runtime_intermediate_output + ), + } + ) + + # Step 3: Build and return DataFrame + return pd.DataFrame(rows) diff --git a/devtools/inspector/numerical_comparator/snr_numerical_comparator.py b/devtools/inspector/numerical_comparator/snr_numerical_comparator.py index efe881a2549..1e474a7eba3 100644 --- a/devtools/inspector/numerical_comparator/snr_numerical_comparator.py +++ b/devtools/inspector/numerical_comparator/snr_numerical_comparator.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. -from typing import Any +from typing import Any, Optional, TYPE_CHECKING import torch from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor @@ -13,9 +13,17 @@ NumericalComparatorBase, ) +if TYPE_CHECKING: + from executorch.devtools.inspector._inspector import Inspector + class SNRComparator(NumericalComparatorBase): - def compare(self, a: Any, b: Any) -> float: + """Signal-to-Noise Ratio comparator for numerical discrepancy detection.""" + + def __init__(self, inspector: Optional["Inspector"] = None) -> None: + super().__init__(inspector) + + def element_compare(self, a: Any, b: Any) -> float: """ Compare the Signal-to-Noise Ratio (SNR) between two inputs Formula: SNR = 10 * log10(original_power / error_power) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 422f5d5defe..ba199e470a8 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -730,7 +730,7 @@ def test_calculate_numeric_gap_with_custom_comparator(self): # Create a custom comparator that returns the max absolute difference class MaxAbsDiffComparator(NumericalComparatorBase): - def compare(self, a, b): + def element_compare(self, a, b): if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): return torch.max(torch.abs(a - b)).item() return abs(a - b) @@ -795,6 +795,302 @@ def compare(self, a, b): # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0 self.assertEqual(df.iloc[1]["gap"][0], 1.0) + def test_calculate_numeric_gap_with_custom_comparator_and_preprocessing(self): + """Test calculate_numeric_gap with multiple custom comparators sharing the same preprocessing.""" + from executorch.devtools.inspector.numerical_comparator import ( + IntermediateOutputMapping, + NumericalComparatorBase, + ) + from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import ( + SNRComparator, + ) + + # Shared preprocessing function that scales runtime tensors by 2x + def scale_runtime_tensors( + mapping: IntermediateOutputMapping, scale_factor: float = 2.0 + ) -> IntermediateOutputMapping: + """Scale runtime tensors by scale_factor before comparison.""" + transformed_mapping = {} + for (aot_handle, aot_output), ( + runtime_handle, + runtime_output, + ) in mapping.items(): + # Scale the runtime output + if isinstance(runtime_output, torch.Tensor): + scaled_runtime_output = runtime_output * scale_factor + else: + scaled_runtime_output = runtime_output + transformed_mapping[(aot_handle, aot_output)] = ( + runtime_handle, + scaled_runtime_output, + ) + return transformed_mapping + + # Create a custom MSE comparator with shared preprocessing + class MSEComparatorWithScaling(NumericalComparatorBase): + def __init__(self, scale_factor: float = 2.0): + super().__init__() + self.scale_factor = scale_factor + self.preprocessing_called = False + + def preprocessing( + self, mapping: IntermediateOutputMapping + ) -> IntermediateOutputMapping: + """Use the shared preprocessing function.""" + self.preprocessing_called = True + return scale_runtime_tensors(mapping, self.scale_factor) + + def element_compare(self, a, b) -> float: + """Compute MSE between two tensors.""" + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.mean(torch.square(a.float() - b.float())).item() + return (a - b) ** 2 + + # Create an SNR comparator with the same shared preprocessing + class SNRComparatorWithScaling(SNRComparator): + def __init__(self, scale_factor: float = 2.0): + super().__init__() + self.scale_factor = scale_factor + self.preprocessing_called = False + + def preprocessing( + self, mapping: IntermediateOutputMapping + ) -> IntermediateOutputMapping: + """Use the shared preprocessing function.""" + self.preprocessing_called = True + return scale_runtime_tensors(mapping, self.scale_factor) + + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # AOT outputs: [1.0, 2.0, 3.0] and [4.0, 5.0, 6.0] + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + # Runtime outputs: [1.0, 1.0, 1.0] and [2.0, 2.0, 2.0] + # After 2x scaling: [2.0, 2.0, 2.0] and [4.0, 4.0, 4.0] + runtime_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1), + (1,): ([torch.tensor([2.0, 2.0, 2.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # --- Test 1: MSE comparator with scaling preprocessing --- + mse_comparator = MSEComparatorWithScaling(scale_factor=2.0) + df_mse = inspector_instance.calculate_numeric_gap(distance=mse_comparator) + + # Verify preprocessing was called + self.assertTrue(mse_comparator.preprocessing_called) + + # Verify DataFrame structure + self.assertIsInstance(df_mse, pd.DataFrame) + self.assertEqual(len(df_mse), 2) + cols = set(df_mse.columns) + expected_cols = { + "aot_ops", + "aot_intermediate_output", + "runtime_ops", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + + # Verify the MSE comparison after preprocessing + # For (0,): AOT=[1.0, 2.0, 3.0], Runtime after scaling=[2.0, 2.0, 2.0] + # MSE = mean((1-2)^2 + (2-2)^2 + (3-2)^2) = mean(1 + 0 + 1) = 2/3 + expected_mse_gap_0 = (1.0 + 0.0 + 1.0) / 3.0 + self.assertAlmostEqual( + df_mse.iloc[0]["gap"][0], expected_mse_gap_0, places=5 + ) + + # For (1,): AOT=[4.0, 5.0, 6.0], Runtime after scaling=[4.0, 4.0, 4.0] + # MSE = mean((4-4)^2 + (5-4)^2 + (6-4)^2) = mean(0 + 1 + 4) = 5/3 + expected_mse_gap_1 = (0.0 + 1.0 + 4.0) / 3.0 + self.assertAlmostEqual( + df_mse.iloc[1]["gap"][0], expected_mse_gap_1, places=5 + ) + + # --- Test 2: SNR comparator with the same scaling preprocessing --- + snr_comparator = SNRComparatorWithScaling(scale_factor=2.0) + df_snr = inspector_instance.calculate_numeric_gap(distance=snr_comparator) + + # Verify preprocessing was called + self.assertTrue(snr_comparator.preprocessing_called) + + # Verify DataFrame structure + self.assertIsInstance(df_snr, pd.DataFrame) + self.assertEqual(len(df_snr), 2) + self.assertEqual(set(df_snr.columns), expected_cols) + + # Verify the SNR comparison after preprocessing + # For (0,): AOT=[1.0, 2.0, 3.0], Runtime after scaling=[2.0, 2.0, 2.0] + # signal_power = mean([1.0^2, 2.0^2, 3.0^2]) = mean([1, 4, 9]) = 14/3 + # error = [1.0-2.0, 2.0-2.0, 3.0-2.0] = [-1.0, 0.0, 1.0] + # error_power = mean([1.0, 0.0, 1.0]) = 2/3 + # SNR = 10 * log10(14/3 / (2/3)) = 10 * log10(7) ≈ 8.451 + signal_power_0 = (1.0 + 4.0 + 9.0) / 3.0 # 14/3 + error_power_0 = (1.0 + 0.0 + 1.0) / 3.0 # 2/3 + expected_snr_gap_0 = ( + 10 * torch.log10(torch.tensor(signal_power_0 / error_power_0)).item() + ) + self.assertAlmostEqual( + df_snr.iloc[0]["gap"][0], expected_snr_gap_0, places=5 + ) + + # For (1,): AOT=[4.0, 5.0, 6.0], Runtime after scaling=[4.0, 4.0, 4.0] + # signal_power = mean([4.0^2, 5.0^2, 6.0^2]) = mean([16, 25, 36]) = 77/3 + # error = [4.0-4.0, 5.0-4.0, 6.0-4.0] = [0.0, 1.0, 2.0] + # error_power = mean([0.0, 1.0, 4.0]) = 5/3 + # SNR = 10 * log10(77/3 / (5/3)) = 10 * log10(77/5) ≈ 11.875 + signal_power_1 = (16.0 + 25.0 + 36.0) / 3.0 # 77/3 + error_power_1 = (0.0 + 1.0 + 4.0) / 3.0 # 5/3 + expected_snr_gap_1 = ( + 10 * torch.log10(torch.tensor(signal_power_1 / error_power_1)).item() + ) + self.assertAlmostEqual( + df_snr.iloc[1]["gap"][0], expected_snr_gap_1, places=5 + ) + + def test_calculate_numeric_gap_with_invalid_preprocessing_output(self): + """Test that invalid preprocessing output raises appropriate errors.""" + from executorch.devtools.inspector.numerical_comparator import ( + NumericalComparatorBase, + ) + + # Test 1: preprocessing returns non-dict + class NonDictPreprocessingComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return "invalid" # Should return a dict + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 2: preprocessing returns dict with invalid key format + class InvalidKeyFormatComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return {"invalid_key": ((0,), torch.tensor([1.0]))} + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 3: preprocessing returns dict with invalid debug handle in key + class InvalidKeyDebugHandleComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return { + (("not_int",), torch.tensor([1.0])): ((0,), torch.tensor([1.0])) + } + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 4: preprocessing returns dict with invalid value format + class InvalidValueFormatComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return {((0,), torch.tensor([1.0])): "invalid_value"} + + def element_compare(self, a, b) -> float: + return 0.0 + + # Test 5: preprocessing returns dict with invalid debug handle in value + class InvalidValueDebugHandleComparator(NumericalComparatorBase): + def preprocessing(self, mapping): + return { + ((0,), torch.tensor([1.0])): (("not_int",), torch.tensor([1.0])) + } + + def element_compare(self, a, b) -> float: + return 0.0 + + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([1.0, 1.0, 1.0])], 1), + } + aot_debug_handle_to_op_name = {(0,): "op_0"} + runtime_debug_handle_to_op_name = {(0,): "op_0"} + + inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Test 1: Non-dict return type + with self.assertRaises(TypeError) as context: + inspector_instance.calculate_numeric_gap( + distance=NonDictPreprocessingComparator() + ) + self.assertIn("must return a dict", str(context.exception)) + + # Test 2: Invalid key format + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidKeyFormatComparator() + ) + self.assertIn("Invalid key format", str(context.exception)) + + # Test 3: Invalid debug handle in key + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidKeyDebugHandleComparator() + ) + self.assertIn("Invalid AOT debug handle", str(context.exception)) + + # Test 4: Invalid value format + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidValueFormatComparator() + ) + self.assertIn("Invalid value format", str(context.exception)) + + # Test 5: Invalid debug handle in value + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance=InvalidValueDebugHandleComparator() + ) + self.assertIn("Invalid runtime debug handle", str(context.exception)) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8c4bb4b38b9..cbdc557f405 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -30,7 +30,6 @@ calculate_mse, calculate_snr, calculate_time_scale_factor, - compare_intermediate_outputs, convert_to_float_tensor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, @@ -45,7 +44,7 @@ propagate_back_debug_handle, TimeScale, ) -from executorch.devtools.inspector.numerical_comparator import L1Comparator + from executorch.exir import to_edge from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE from torch.export import export @@ -606,24 +605,6 @@ def test_find_op_names_mixed_single_and_multiple_ops(self): ["op1", "op2", "op3", "op4", "op5", "op6", "op7"], ) - def test_compare_intermediate_outputs_sequences(self): - a = [1.0, 2.0, 3.0] - b = [1.0, 2.5, 3.5] - result = compare_intermediate_outputs(a, b, L1Comparator()) - self.assertEqual(result, [0.0, 0.5, 0.5]) - - def test_compare_intermediate_outputs_diff_len_sequences(self): - a = [1.0, 2.0] - b = [1.0, 2.0, 3.0] - with self.assertRaises(ValueError): - compare_intermediate_outputs(a, b, L1Comparator()) - - def test_compare_intermediate_outputs_sequence_and_non_sequence(self): - a = [1.0, 2.0] - b = 1.0 - with self.assertRaises(ValueError): - compare_intermediate_outputs(a, b, L1Comparator()) - def test_equip_debug_handle_to_export_program_success(self): """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" # Create a test model diff --git a/devtools/inspector/tests/l1_comparator_test.py b/devtools/inspector/tests/l1_comparator_test.py index 1e9f0be9c10..b2c1a86910e 100644 --- a/devtools/inspector/tests/l1_comparator_test.py +++ b/devtools/inspector/tests/l1_comparator_test.py @@ -18,32 +18,32 @@ def test_identical_tensors(self): a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[1, 2], [3, 4]]) expected = 0.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_scalar(self): a = 1 b = 2 expected = 1.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_with_nans_replaced_with_zero(self): a = torch.tensor([3, 2, -1, float("nan")]) b = torch.tensor([float("nan"), 0, -3, 1]) expected = 8.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_shape_mismatch_raises_exception(self): a = torch.tensor([0, 2, -1]) b = torch.tensor([1, 0, -3, 4]) with self.assertRaises(ValueError): - self.l1_comparator.compare(a, b) + self.l1_comparator.element_compare(a, b) def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 5]]) expected = 14.0 - result = self.l1_comparator.compare(a, b) + result = self.l1_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) diff --git a/devtools/inspector/tests/mse_comparator_test.py b/devtools/inspector/tests/mse_comparator_test.py index b24302e12e8..f9e61af4e88 100644 --- a/devtools/inspector/tests/mse_comparator_test.py +++ b/devtools/inspector/tests/mse_comparator_test.py @@ -18,32 +18,32 @@ def test_identical_tensors(self): a = torch.tensor([[10, 4], [3, 4]]) b = torch.tensor([[10, 4], [3, 4]]) expected = 0.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_scalar(self): a = 10 b = 2 expected = 64.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_with_nans_replaced_with_zero(self): a = torch.tensor([3, 1, -3, float("nan")]) b = torch.tensor([float("nan"), 0, -3, 2]) expected = (9.0 + 1.0 + 0.0 + 4.0) / 4.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) def test_shape_mismatch_raises_exception(self): a = torch.tensor([0, 2, -1]) b = torch.tensor([1, 1, -3, 4]) with self.assertRaises(ValueError): - self.mse_comparator.compare(a, b) + self.mse_comparator.element_compare(a, b) def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 10]]) expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0 - result = self.mse_comparator.compare(a, b) + result = self.mse_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) diff --git a/devtools/inspector/tests/snr_comparator_test.py b/devtools/inspector/tests/snr_comparator_test.py index b21e1f3d61a..93d0a2f5deb 100644 --- a/devtools/inspector/tests/snr_comparator_test.py +++ b/devtools/inspector/tests/snr_comparator_test.py @@ -19,27 +19,27 @@ def test_identical_tensors(self): # identical tensors --> error_power == 0 --> SNR is inf a = torch.tensor([[10, 4], [3, 4]]) b = torch.tensor([[10, 4], [3, 4]]) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertTrue(math.isinf(result) and result > 0) def test_scalar(self): # original_power == 1, error_power == 1 --> SNR = 10 * log10(1/1) = 0 a = 1 b = 2 - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertAlmostEqual(result, 0.0) def test_with_nans_replaced_with_zero(self): a = torch.tensor([float("nan"), 1.0]) b = torch.tensor([0.0, 1.0]) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertTrue(math.isinf(result) and result > 0) def test_shape_mismatch_raises_exception(self): a = torch.tensor([1, 2, -1]) b = torch.tensor([1, 1, -3, 4]) with self.assertRaises(ValueError): - self.snr_comparator.compare(a, b) + self.snr_comparator.element_compare(a, b) def test_2D_tensors(self): # original_power = mean([16, 81, 36, 16]) = 37.25 @@ -48,5 +48,5 @@ def test_2D_tensors(self): a = torch.tensor([[4, 9], [6, 4]]) b = torch.tensor([[1, 2], [3, 5]]) expected = 10 * math.log10(37.25 / 17.0) - result = self.snr_comparator.compare(a, b) + result = self.snr_comparator.element_compare(a, b) self.assertAlmostEqual(result, expected) From e57876a8d813b207822fa940139a151865d614b3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 25 Feb 2026 00:05:31 -0800 Subject: [PATCH 2/2] [devtools] Auto-record after-transform graph in ETRecord (#17696) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: https://github.com/pytorch/executorch/pull/17434 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/120/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/120/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/119/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/120/orig Differential Revision: [D93176563](https://our.internmc.facebook.com/intern/diff/D93176563/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia --- devtools/etrecord/tests/etrecord_test.py | 119 ++++++++++ devtools/inspector/_inspector.py | 116 +++++++-- devtools/inspector/tests/inspector_test.py | 261 +++++++++++++++++++-- exir/program/_program.py | 5 + 4 files changed, 466 insertions(+), 35 deletions(-) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 11463a976b4..a57515bffee 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self): # Verify other ETRecord components are preserved self.assertIsNotNone(parsed_etrecord._debug_handle_map) self.assertIsNotNone(parsed_etrecord._delegate_map) + + def test_edge_after_transform_graph_capture(self): + """Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph. + + When generate_etrecord=True and transform_passes are applied, the ETRecord should + contain the after-transform graph under the key 'edge_after_transform' in graph_map. + This enables backends like Qualcomm to use the post-custom-transform graph as the + golden reference for numeric gap calculation. + """ + from torch.fx.passes.infra.pass_base import PassBase, PassResult + + # Create a simple custom pass that modifies the graph + class SimpleCustomPass(PassBase): + """A simple pass that adds a marker attribute to each node.""" + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Mark each node to indicate this pass ran + for node in graph_module.graph.nodes: + node.meta["custom_pass_applied"] = True + return PassResult(graph_module=graph_module, modified=True) + + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program with custom transform pass and generate_etrecord=True + transform_passes = [SimpleCustomPass()] + + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + transform_passes=transform_passes, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify graph_map exists and contains the 'edge_after_transform' key + self.assertIsNotNone(etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should contain 'edge_after_transform/forward' when transform_passes are applied", + ) + + # Verify the captured graph has the custom pass marker + after_transform_graph = etrecord.graph_map["edge_after_transform/forward"] + self.assertIsNotNone(after_transform_graph) + + # Check that at least one node has the custom_pass_applied marker + has_marker = False + for node in after_transform_graph.graph.nodes: + if node.meta.get("custom_pass_applied", False): + has_marker = True + break + + self.assertTrue( + has_marker, + "The edge_after_transform graph should have the custom pass marker applied", + ) + + # Verify edge_dialect_program is still the pre-transform graph (original behavior preserved) + self.assertIsNotNone(etrecord.edge_dialect_program) + + # Save and parse the ETRecord to verify persistence + et_output = edge_manager.to_executorch() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_custom_pass.bin" + + # Get ETRecord and save + complete_etrecord = et_output.get_etrecord() + complete_etrecord.save(etrecord_path) + + # Parse ETRecord back + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify the after-transform graph is preserved after save/parse + self.assertIsNotNone(parsed_etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + parsed_etrecord.graph_map, + "Parsed ETRecord should still contain 'edge_after_transform/forward'", + ) + + # Verify the parsed graph still has the marker + parsed_after_transform_graph = parsed_etrecord.graph_map[ + "edge_after_transform/forward" + ] + self.assertIsNotNone(parsed_after_transform_graph) + + def test_no_edge_after_transform_without_transform_passes(self): + """Test that 'edge_after_transform' is NOT added when no transform_passes are provided. + + This ensures backward compatibility - when generate_etrecord=True but no transform_passes + are applied, the ETRecord should NOT have an 'edge_after_transform' entry. + """ + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program WITHOUT transform_passes + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify that 'edge_after_transform' is NOT in graph_map + if etrecord.graph_map is not None: + self.assertNotIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied", + ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 368824f71a3..236597f2bd6 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1166,31 +1166,94 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + reference_graph: Optional[str] = None, disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided - when using bundled program to create the etrecord - """ - if self._etrecord._representative_inputs is None: - return {}, {} + when using bundled program to create the etrecord. - export_program = None + Args: + reference_graph_name: Name of the graph to use as the reference for intermediate + output capture. Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires + successful debug handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" + for post-custom-transform graph when transform_passes are applied in + to_edge_transform_and_lower with generate_etrecord=True). + disable_debug_handle_valdiation: If True, skip debug handle validation. - # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program - if self._etrecord.exported_program and propagate_back_debug_handle( - self._etrecord.exported_program, - self._etrecord.export_graph_id, - self._etrecord.edge_dialect_program, - disable_debug_handle_valdiation, - ): + Returns: + Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries. + + Raises: + ValueError: If the specified reference_graph_name is not available or if + debug handle backpropagation fails for "exported_program". + """ + + # Determine the reference graph to use + if reference_graph is None or reference_graph == "exported_program": + # Auto-select: try exported_program first, fall back to edge_dialect_exported_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, + ): + reference_graph = "exported_program" + elif reference_graph is None: + log.warning( + "Either ATen dialect exported program is not in ETRecord, or debug handle " + "backpropagation failed. Falling back to 'edge_dialect_exported_program'." + ) + reference_graph = "edge_dialect_exported_program" + else: + raise ValueError( + "Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. " + "Please check if the exported program is available in ETRecord, or try to disable debug handle validation." + ) + if reference_graph == "edge_dialect_exported_program": + # Explicitly requested edge_dialect_exported_program + export_program = self._etrecord.edge_dialect_program + log.info( + "Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture" + ) + elif reference_graph == "exported_program": export_program = self._etrecord.exported_program - else: - log.warning( - "Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program." - "Will fall back to use edge dialect program to extract intermediate output", + log.info( + "Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture" ) - export_program = self._etrecord.edge_dialect_program + else: + # Try to fetch from graph_map + # If no method name is provided (no "/" in the name), try adding "/forward" as default + lookup_name = reference_graph + if "/" not in reference_graph: + lookup_name = f"{reference_graph}/forward" + log.info( + f"No method name specified in '{reference_graph}', " + f"using '{lookup_name}' as default" + ) + + if ( + self._etrecord.graph_map is not None + and lookup_name in self._etrecord.graph_map + ): + export_program = self._etrecord.graph_map[lookup_name] + log.info( + f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture" + ) + else: + available_graphs = ( + list(self._etrecord.graph_map.keys()) + if self._etrecord.graph_map + else [] + ) + raise ValueError( + f"Reference graph '{lookup_name}' not found. " + f"Available options: 'exported_program', 'edge_dialect_exported_program', " + f"or one of the graphs in graph_map: {available_graphs}" + ) graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module @@ -1406,11 +1469,11 @@ def calculate_numeric_gap( self, distance: Union[str, NumericalComparatorBase], disable_debug_handle_valdiation: bool = False, + reference_graph: Optional[str] = None, ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. - If the exported graph is not supported, the function will fall back to use edge dialect graph. To use this function, you must first generate the ETRecord with representative inputs, and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then @@ -1423,18 +1486,31 @@ def calculate_numeric_gap( logic by subclassing NumericalComparatorBase and implementing the element_compare() method. Custom comparators can also override the preprocessing() method to apply transformations (e.g., layout conversion, dequantization) before comparison. - disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., + disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. + reference_graph: Name of the graph to use as the golden reference for intermediate output capture. + Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires successful debug + handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for + post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower + with generate_etrecord=True). + + If None (default), automatically selects the best available graph: + - Uses "exported_program" if available and debug handle backpropagation succeeds. + - Falls back to "edge_dialect_exported_program" otherwise. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( self._get_aot_intermediate_outputs_and_op_names( - disable_debug_handle_valdiation + reference_graph, + disable_debug_handle_valdiation, ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index ba199e470a8..f7050593cc2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,7 +17,7 @@ from typing import Callable, List, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pandas as pd @@ -682,9 +682,11 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -764,9 +766,11 @@ def element_compare(self, a, b): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -891,9 +895,11 @@ def preprocessing( aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -901,7 +907,9 @@ def preprocessing( # --- Test 1: MSE comparator with scaling preprocessing --- mse_comparator = MSEComparatorWithScaling(scale_factor=2.0) - df_mse = inspector_instance.calculate_numeric_gap(distance=mse_comparator) + df_mse = inspector_instance.calculate_numeric_gap( + distance=mse_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(mse_comparator.preprocessing_called) @@ -936,7 +944,9 @@ def preprocessing( # --- Test 2: SNR comparator with the same scaling preprocessing --- snr_comparator = SNRComparatorWithScaling(scale_factor=2.0) - df_snr = inspector_instance.calculate_numeric_gap(distance=snr_comparator) + df_snr = inspector_instance.calculate_numeric_gap( + distance=snr_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(snr_comparator.preprocessing_called) @@ -1048,9 +1058,11 @@ def element_compare(self, a, b) -> float: aot_debug_handle_to_op_name = {(0,): "op_0"} runtime_debug_handle_to_op_name = {(0,): "op_0"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -1091,6 +1103,225 @@ def element_compare(self, a, b) -> float: ) self.assertIn("Invalid runtime debug handle", str(context.exception)) + def test_calculate_numeric_gap_with_reference_graph_name(self): + """Test calculate_numeric_gap with the reference_graph parameter using a custom graph from graph_map.""" + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + (1,): ([torch.tensor([5.0, 6.0, 7.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + # Create a mock graph module for the reference graph + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + self.graph.nodes = [] + + def module(self): + return self + + mock_graph_module = MockGraphModule() + + # Create a real ETRecord and set up the graph_map with edge_after_transform + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = None + mock_etrecord.edge_dialect_program = mock_graph_module + + # The code adds "/forward" suffix when looking up, so we need "edge_after_transform/forward" + mock_etrecord.graph_map = { + "edge_after_transform/forward": mock_graph_module + } + + inspector_instance._etrecord = mock_etrecord + + # Mock the runtime intermediate outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Mock IntermediateOutputCapturer and get_aot_debug_handle_to_op_name_mapping + # These are called inside _get_aot_intermediate_outputs_and_op_names when using a custom graph + with patch( + "executorch.devtools.inspector._inspector.IntermediateOutputCapturer" + ) as mock_capturer_class, patch( + "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping" + ) as mock_get_mapping: + mock_capturer = MagicMock() + mock_capturer.run_and_capture.return_value = aot_intermediate_outputs + mock_capturer_class.return_value = mock_capturer + mock_get_mapping.return_value = aot_debug_handle_to_op_name + + # Test with reference_graph parameter (without /forward suffix) + # The code should automatically add "/forward" when looking up in graph_map + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_after_transform", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + + def test_calculate_numeric_gap_with_invalid_reference_graph_name(self): + """Test that calculate_numeric_gap raises ValueError for invalid reference_graph.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create a real ETRecord with empty graph_map + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Test with non-existent reference_graph + # Since "non_existent_graph" has no "/", it will be looked up as "non_existent_graph/forward" + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="non_existent_graph", + ) + self.assertIn("not found", str(context.exception)) + self.assertIn("non_existent_graph/forward", str(context.exception)) + + def test_calculate_numeric_gap_with_exported_program_name_backprop_failure(self): + """Test that calculate_numeric_gap raises ValueError when exported_program backpropagation fails.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock graph modules + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + + def module(self): + return self + + mock_exported_program = MockGraphModule() + mock_edge_dialect_program = MockGraphModule() + + # Create a real ETRecord with exported_program + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = mock_exported_program + mock_etrecord.edge_dialect_program = mock_edge_dialect_program + mock_etrecord.export_graph_id = "graph_id" + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Mock propagate_back_debug_handle to return False (backpropagation failure) + with patch( + "executorch.devtools.inspector._inspector.propagate_back_debug_handle" + ) as mock_propagate: + mock_propagate.return_value = False + + # Test with "exported_program" should raise error when backpropagation fails + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="exported_program", + ) + self.assertIn("Cannot use 'exported_program'", str(context.exception)) + self.assertIn("backpropagation failed", str(context.exception)) + + def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): + """Test calculate_numeric_gap with edge_dialect_exported_program reference_graph parameter.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs (same structure as test_calculate_numeric_gap) + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0"} + runtime_debug_handle_to_op_name = {(0,): "op_0"} + + # Mock the internal methods like test_calculate_numeric_gap does + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Test with edge_dialect_exported_program parameter + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_dialect_exported_program", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 1) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ diff --git a/exir/program/_program.py b/exir/program/_program.py index abf413918e5..baacd5eaec4 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1388,6 +1388,11 @@ def to_edge_transform_and_lower( # noqa: C901 if transform_passes is not None: edge_manager = edge_manager.transform(transform_passes) + if generate_etrecord: + edge_manager._etrecord.add_extra_export_modules( + {"edge_after_transform": copy.deepcopy(edge_manager)} + ) + max_num_partitioners = 0 for partitioner_list in partitioner.values(): max_num_partitioners = max(max_num_partitioners, len(partitioner_list))