diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 368824f71a3..236597f2bd6 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1166,31 +1166,94 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + reference_graph: Optional[str] = None, disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided - when using bundled program to create the etrecord - """ - if self._etrecord._representative_inputs is None: - return {}, {} + when using bundled program to create the etrecord. - export_program = None + Args: + reference_graph_name: Name of the graph to use as the reference for intermediate + output capture. Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires + successful debug handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" + for post-custom-transform graph when transform_passes are applied in + to_edge_transform_and_lower with generate_etrecord=True). + disable_debug_handle_valdiation: If True, skip debug handle validation. - # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program - if self._etrecord.exported_program and propagate_back_debug_handle( - self._etrecord.exported_program, - self._etrecord.export_graph_id, - self._etrecord.edge_dialect_program, - disable_debug_handle_valdiation, - ): + Returns: + Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries. + + Raises: + ValueError: If the specified reference_graph_name is not available or if + debug handle backpropagation fails for "exported_program". + """ + + # Determine the reference graph to use + if reference_graph is None or reference_graph == "exported_program": + # Auto-select: try exported_program first, fall back to edge_dialect_exported_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, + ): + reference_graph = "exported_program" + elif reference_graph is None: + log.warning( + "Either ATen dialect exported program is not in ETRecord, or debug handle " + "backpropagation failed. Falling back to 'edge_dialect_exported_program'." + ) + reference_graph = "edge_dialect_exported_program" + else: + raise ValueError( + "Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. " + "Please check if the exported program is available in ETRecord, or try to disable debug handle validation." + ) + if reference_graph == "edge_dialect_exported_program": + # Explicitly requested edge_dialect_exported_program + export_program = self._etrecord.edge_dialect_program + log.info( + "Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture" + ) + elif reference_graph == "exported_program": export_program = self._etrecord.exported_program - else: - log.warning( - "Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program." - "Will fall back to use edge dialect program to extract intermediate output", + log.info( + "Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture" ) - export_program = self._etrecord.edge_dialect_program + else: + # Try to fetch from graph_map + # If no method name is provided (no "/" in the name), try adding "/forward" as default + lookup_name = reference_graph + if "/" not in reference_graph: + lookup_name = f"{reference_graph}/forward" + log.info( + f"No method name specified in '{reference_graph}', " + f"using '{lookup_name}' as default" + ) + + if ( + self._etrecord.graph_map is not None + and lookup_name in self._etrecord.graph_map + ): + export_program = self._etrecord.graph_map[lookup_name] + log.info( + f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture" + ) + else: + available_graphs = ( + list(self._etrecord.graph_map.keys()) + if self._etrecord.graph_map + else [] + ) + raise ValueError( + f"Reference graph '{lookup_name}' not found. " + f"Available options: 'exported_program', 'edge_dialect_exported_program', " + f"or one of the graphs in graph_map: {available_graphs}" + ) graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module @@ -1406,11 +1469,11 @@ def calculate_numeric_gap( self, distance: Union[str, NumericalComparatorBase], disable_debug_handle_valdiation: bool = False, + reference_graph: Optional[str] = None, ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. - If the exported graph is not supported, the function will fall back to use edge dialect graph. To use this function, you must first generate the ETRecord with representative inputs, and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then @@ -1423,18 +1486,31 @@ def calculate_numeric_gap( logic by subclassing NumericalComparatorBase and implementing the element_compare() method. Custom comparators can also override the preprocessing() method to apply transformations (e.g., layout conversion, dequantization) before comparison. - disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., + disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. + reference_graph: Name of the graph to use as the golden reference for intermediate output capture. + Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires successful debug + handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for + post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower + with generate_etrecord=True). + + If None (default), automatically selects the best available graph: + - Uses "exported_program" if available and debug handle backpropagation succeeds. + - Falls back to "edge_dialect_exported_program" otherwise. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( self._get_aot_intermediate_outputs_and_op_names( - disable_debug_handle_valdiation + reference_graph, + disable_debug_handle_valdiation, ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index ba199e470a8..f7050593cc2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,7 +17,7 @@ from typing import Callable, List, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pandas as pd @@ -682,9 +682,11 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -764,9 +766,11 @@ def element_compare(self, a, b): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -891,9 +895,11 @@ def preprocessing( aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -901,7 +907,9 @@ def preprocessing( # --- Test 1: MSE comparator with scaling preprocessing --- mse_comparator = MSEComparatorWithScaling(scale_factor=2.0) - df_mse = inspector_instance.calculate_numeric_gap(distance=mse_comparator) + df_mse = inspector_instance.calculate_numeric_gap( + distance=mse_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(mse_comparator.preprocessing_called) @@ -936,7 +944,9 @@ def preprocessing( # --- Test 2: SNR comparator with the same scaling preprocessing --- snr_comparator = SNRComparatorWithScaling(scale_factor=2.0) - df_snr = inspector_instance.calculate_numeric_gap(distance=snr_comparator) + df_snr = inspector_instance.calculate_numeric_gap( + distance=snr_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(snr_comparator.preprocessing_called) @@ -1048,9 +1058,11 @@ def element_compare(self, a, b) -> float: aot_debug_handle_to_op_name = {(0,): "op_0"} runtime_debug_handle_to_op_name = {(0,): "op_0"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -1091,6 +1103,225 @@ def element_compare(self, a, b) -> float: ) self.assertIn("Invalid runtime debug handle", str(context.exception)) + def test_calculate_numeric_gap_with_reference_graph_name(self): + """Test calculate_numeric_gap with the reference_graph parameter using a custom graph from graph_map.""" + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + (1,): ([torch.tensor([5.0, 6.0, 7.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + # Create a mock graph module for the reference graph + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + self.graph.nodes = [] + + def module(self): + return self + + mock_graph_module = MockGraphModule() + + # Create a real ETRecord and set up the graph_map with edge_after_transform + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = None + mock_etrecord.edge_dialect_program = mock_graph_module + + # The code adds "/forward" suffix when looking up, so we need "edge_after_transform/forward" + mock_etrecord.graph_map = { + "edge_after_transform/forward": mock_graph_module + } + + inspector_instance._etrecord = mock_etrecord + + # Mock the runtime intermediate outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Mock IntermediateOutputCapturer and get_aot_debug_handle_to_op_name_mapping + # These are called inside _get_aot_intermediate_outputs_and_op_names when using a custom graph + with patch( + "executorch.devtools.inspector._inspector.IntermediateOutputCapturer" + ) as mock_capturer_class, patch( + "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping" + ) as mock_get_mapping: + mock_capturer = MagicMock() + mock_capturer.run_and_capture.return_value = aot_intermediate_outputs + mock_capturer_class.return_value = mock_capturer + mock_get_mapping.return_value = aot_debug_handle_to_op_name + + # Test with reference_graph parameter (without /forward suffix) + # The code should automatically add "/forward" when looking up in graph_map + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_after_transform", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + + def test_calculate_numeric_gap_with_invalid_reference_graph_name(self): + """Test that calculate_numeric_gap raises ValueError for invalid reference_graph.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create a real ETRecord with empty graph_map + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Test with non-existent reference_graph + # Since "non_existent_graph" has no "/", it will be looked up as "non_existent_graph/forward" + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="non_existent_graph", + ) + self.assertIn("not found", str(context.exception)) + self.assertIn("non_existent_graph/forward", str(context.exception)) + + def test_calculate_numeric_gap_with_exported_program_name_backprop_failure(self): + """Test that calculate_numeric_gap raises ValueError when exported_program backpropagation fails.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock graph modules + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + + def module(self): + return self + + mock_exported_program = MockGraphModule() + mock_edge_dialect_program = MockGraphModule() + + # Create a real ETRecord with exported_program + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = mock_exported_program + mock_etrecord.edge_dialect_program = mock_edge_dialect_program + mock_etrecord.export_graph_id = "graph_id" + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Mock propagate_back_debug_handle to return False (backpropagation failure) + with patch( + "executorch.devtools.inspector._inspector.propagate_back_debug_handle" + ) as mock_propagate: + mock_propagate.return_value = False + + # Test with "exported_program" should raise error when backpropagation fails + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="exported_program", + ) + self.assertIn("Cannot use 'exported_program'", str(context.exception)) + self.assertIn("backpropagation failed", str(context.exception)) + + def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): + """Test calculate_numeric_gap with edge_dialect_exported_program reference_graph parameter.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs (same structure as test_calculate_numeric_gap) + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0"} + runtime_debug_handle_to_op_name = {(0,): "op_0"} + + # Mock the internal methods like test_calculate_numeric_gap does + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Test with edge_dialect_exported_program parameter + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_dialect_exported_program", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 1) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """