diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 11463a976b4..a57515bffee 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self): # Verify other ETRecord components are preserved self.assertIsNotNone(parsed_etrecord._debug_handle_map) self.assertIsNotNone(parsed_etrecord._delegate_map) + + def test_edge_after_transform_graph_capture(self): + """Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph. + + When generate_etrecord=True and transform_passes are applied, the ETRecord should + contain the after-transform graph under the key 'edge_after_transform' in graph_map. + This enables backends like Qualcomm to use the post-custom-transform graph as the + golden reference for numeric gap calculation. + """ + from torch.fx.passes.infra.pass_base import PassBase, PassResult + + # Create a simple custom pass that modifies the graph + class SimpleCustomPass(PassBase): + """A simple pass that adds a marker attribute to each node.""" + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Mark each node to indicate this pass ran + for node in graph_module.graph.nodes: + node.meta["custom_pass_applied"] = True + return PassResult(graph_module=graph_module, modified=True) + + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program with custom transform pass and generate_etrecord=True + transform_passes = [SimpleCustomPass()] + + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + transform_passes=transform_passes, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify graph_map exists and contains the 'edge_after_transform' key + self.assertIsNotNone(etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should contain 'edge_after_transform/forward' when transform_passes are applied", + ) + + # Verify the captured graph has the custom pass marker + after_transform_graph = etrecord.graph_map["edge_after_transform/forward"] + self.assertIsNotNone(after_transform_graph) + + # Check that at least one node has the custom_pass_applied marker + has_marker = False + for node in after_transform_graph.graph.nodes: + if node.meta.get("custom_pass_applied", False): + has_marker = True + break + + self.assertTrue( + has_marker, + "The edge_after_transform graph should have the custom pass marker applied", + ) + + # Verify edge_dialect_program is still the pre-transform graph (original behavior preserved) + self.assertIsNotNone(etrecord.edge_dialect_program) + + # Save and parse the ETRecord to verify persistence + et_output = edge_manager.to_executorch() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_custom_pass.bin" + + # Get ETRecord and save + complete_etrecord = et_output.get_etrecord() + complete_etrecord.save(etrecord_path) + + # Parse ETRecord back + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify the after-transform graph is preserved after save/parse + self.assertIsNotNone(parsed_etrecord.graph_map) + self.assertIn( + "edge_after_transform/forward", + parsed_etrecord.graph_map, + "Parsed ETRecord should still contain 'edge_after_transform/forward'", + ) + + # Verify the parsed graph still has the marker + parsed_after_transform_graph = parsed_etrecord.graph_map[ + "edge_after_transform/forward" + ] + self.assertIsNotNone(parsed_after_transform_graph) + + def test_no_edge_after_transform_without_transform_passes(self): + """Test that 'edge_after_transform' is NOT added when no transform_passes are provided. + + This ensures backward compatibility - when generate_etrecord=True but no transform_passes + are applied, the ETRecord should NOT have an 'edge_after_transform' entry. + """ + f = models.BasicSinMax() + aten_dialect = export(f, f.get_random_inputs(), strict=True) + + # Create edge program WITHOUT transform_passes + edge_manager = to_edge_transform_and_lower( + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify that 'edge_after_transform' is NOT in graph_map + if etrecord.graph_map is not None: + self.assertNotIn( + "edge_after_transform/forward", + etrecord.graph_map, + "graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied", + ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 368824f71a3..236597f2bd6 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1166,31 +1166,94 @@ def _consume_etrecord(self) -> None: def _get_aot_intermediate_outputs_and_op_names( self, + reference_graph: Optional[str] = None, disable_debug_handle_valdiation: bool = False, ) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]: """ Capture intermediate outputs only if _representative_inputs are provided - when using bundled program to create the etrecord - """ - if self._etrecord._representative_inputs is None: - return {}, {} + when using bundled program to create the etrecord. - export_program = None + Args: + reference_graph_name: Name of the graph to use as the reference for intermediate + output capture. Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires + successful debug handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" + for post-custom-transform graph when transform_passes are applied in + to_edge_transform_and_lower with generate_etrecord=True). + disable_debug_handle_valdiation: If True, skip debug handle validation. - # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program - if self._etrecord.exported_program and propagate_back_debug_handle( - self._etrecord.exported_program, - self._etrecord.export_graph_id, - self._etrecord.edge_dialect_program, - disable_debug_handle_valdiation, - ): + Returns: + Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries. + + Raises: + ValueError: If the specified reference_graph_name is not available or if + debug handle backpropagation fails for "exported_program". + """ + + # Determine the reference graph to use + if reference_graph is None or reference_graph == "exported_program": + # Auto-select: try exported_program first, fall back to edge_dialect_exported_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + disable_debug_handle_valdiation, + ): + reference_graph = "exported_program" + elif reference_graph is None: + log.warning( + "Either ATen dialect exported program is not in ETRecord, or debug handle " + "backpropagation failed. Falling back to 'edge_dialect_exported_program'." + ) + reference_graph = "edge_dialect_exported_program" + else: + raise ValueError( + "Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. " + "Please check if the exported program is available in ETRecord, or try to disable debug handle validation." + ) + if reference_graph == "edge_dialect_exported_program": + # Explicitly requested edge_dialect_exported_program + export_program = self._etrecord.edge_dialect_program + log.info( + "Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture" + ) + elif reference_graph == "exported_program": export_program = self._etrecord.exported_program - else: - log.warning( - "Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program." - "Will fall back to use edge dialect program to extract intermediate output", + log.info( + "Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture" ) - export_program = self._etrecord.edge_dialect_program + else: + # Try to fetch from graph_map + # If no method name is provided (no "/" in the name), try adding "/forward" as default + lookup_name = reference_graph + if "/" not in reference_graph: + lookup_name = f"{reference_graph}/forward" + log.info( + f"No method name specified in '{reference_graph}', " + f"using '{lookup_name}' as default" + ) + + if ( + self._etrecord.graph_map is not None + and lookup_name in self._etrecord.graph_map + ): + export_program = self._etrecord.graph_map[lookup_name] + log.info( + f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture" + ) + else: + available_graphs = ( + list(self._etrecord.graph_map.keys()) + if self._etrecord.graph_map + else [] + ) + raise ValueError( + f"Reference graph '{lookup_name}' not found. " + f"Available options: 'exported_program', 'edge_dialect_exported_program', " + f"or one of the graphs in graph_map: {available_graphs}" + ) graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module @@ -1406,11 +1469,11 @@ def calculate_numeric_gap( self, distance: Union[str, NumericalComparatorBase], disable_debug_handle_valdiation: bool = False, + reference_graph: Optional[str] = None, ): """ Compares logged intermediate outputs from the exported graph (in ETRecord) with runtime outputs (in ETDump) using a user-specific numerical comparator. - If the exported graph is not supported, the function will fall back to use edge dialect graph. To use this function, you must first generate the ETRecord with representative inputs, and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then @@ -1423,18 +1486,31 @@ def calculate_numeric_gap( logic by subclassing NumericalComparatorBase and implementing the element_compare() method. Custom comparators can also override the preprocessing() method to apply transformations (e.g., layout conversion, dequantization) before comparison. - disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., + disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc., during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge IR as reference graph. This flag allows one to override such behavior and make best effort comparison. + reference_graph: Name of the graph to use as the golden reference for intermediate output capture. + Must be one of: + - "exported_program": Uses the ATen dialect exported program. Requires successful debug + handle backpropagation, otherwise raises an error. + - "edge_dialect_exported_program": Uses the Edge dialect program directly. + - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for + post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower + with generate_etrecord=True). + + If None (default), automatically selects the best available graph: + - Uses "exported_program" if available and debug handle backpropagation succeeds. + - Falls back to "edge_dialect_exported_program" otherwise. Returns: pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps. """ aot_intermediate_outputs, aot_debug_handle_to_op_names = ( self._get_aot_intermediate_outputs_and_op_names( - disable_debug_handle_valdiation + reference_graph, + disable_debug_handle_valdiation, ) ) if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0: diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index ba199e470a8..f7050593cc2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,7 +17,7 @@ from typing import Callable, List, Union -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pandas as pd @@ -682,9 +682,11 @@ def test_calculate_numeric_gap(self): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -764,9 +766,11 @@ def element_compare(self, a, b): aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -891,9 +895,11 @@ def preprocessing( aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -901,7 +907,9 @@ def preprocessing( # --- Test 1: MSE comparator with scaling preprocessing --- mse_comparator = MSEComparatorWithScaling(scale_factor=2.0) - df_mse = inspector_instance.calculate_numeric_gap(distance=mse_comparator) + df_mse = inspector_instance.calculate_numeric_gap( + distance=mse_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(mse_comparator.preprocessing_called) @@ -936,7 +944,9 @@ def preprocessing( # --- Test 2: SNR comparator with the same scaling preprocessing --- snr_comparator = SNRComparatorWithScaling(scale_factor=2.0) - df_snr = inspector_instance.calculate_numeric_gap(distance=snr_comparator) + df_snr = inspector_instance.calculate_numeric_gap( + distance=snr_comparator, reference_graph="NOT_USED_NAME" + ) # Verify preprocessing was called self.assertTrue(snr_comparator.preprocessing_called) @@ -1048,9 +1058,11 @@ def element_compare(self, a, b) -> float: aot_debug_handle_to_op_name = {(0,): "op_0"} runtime_debug_handle_to_op_name = {(0,): "op_0"} - inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: ( - aot_intermediate_outputs, - aot_debug_handle_to_op_name, + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) ) inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) @@ -1091,6 +1103,225 @@ def element_compare(self, a, b) -> float: ) self.assertIn("Invalid runtime debug handle", str(context.exception)) + def test_calculate_numeric_gap_with_reference_graph_name(self): + """Test calculate_numeric_gap with the reference_graph parameter using a custom graph from graph_map.""" + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + (1,): ([torch.tensor([5.0, 6.0, 7.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"} + + # Create a mock graph module for the reference graph + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + self.graph.nodes = [] + + def module(self): + return self + + mock_graph_module = MockGraphModule() + + # Create a real ETRecord and set up the graph_map with edge_after_transform + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = None + mock_etrecord.edge_dialect_program = mock_graph_module + + # The code adds "/forward" suffix when looking up, so we need "edge_after_transform/forward" + mock_etrecord.graph_map = { + "edge_after_transform/forward": mock_graph_module + } + + inspector_instance._etrecord = mock_etrecord + + # Mock the runtime intermediate outputs + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Mock IntermediateOutputCapturer and get_aot_debug_handle_to_op_name_mapping + # These are called inside _get_aot_intermediate_outputs_and_op_names when using a custom graph + with patch( + "executorch.devtools.inspector._inspector.IntermediateOutputCapturer" + ) as mock_capturer_class, patch( + "executorch.devtools.inspector._inspector.get_aot_debug_handle_to_op_name_mapping" + ) as mock_get_mapping: + mock_capturer = MagicMock() + mock_capturer.run_and_capture.return_value = aot_intermediate_outputs + mock_capturer_class.return_value = mock_capturer + mock_get_mapping.return_value = aot_debug_handle_to_op_name + + # Test with reference_graph parameter (without /forward suffix) + # The code should automatically add "/forward" when looking up in graph_map + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_after_transform", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + + def test_calculate_numeric_gap_with_invalid_reference_graph_name(self): + """Test that calculate_numeric_gap raises ValueError for invalid reference_graph.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create a real ETRecord with empty graph_map + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Test with non-existent reference_graph + # Since "non_existent_graph" has no "/", it will be looked up as "non_existent_graph/forward" + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="non_existent_graph", + ) + self.assertIn("not found", str(context.exception)) + self.assertIn("non_existent_graph/forward", str(context.exception)) + + def test_calculate_numeric_gap_with_exported_program_name_backprop_failure(self): + """Test that calculate_numeric_gap raises ValueError when exported_program backpropagation fails.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock graph modules + class MockGraphModule: + def __init__(self): + self.graph = MagicMock() + + def module(self): + return self + + mock_exported_program = MockGraphModule() + mock_edge_dialect_program = MockGraphModule() + + # Create a real ETRecord with exported_program + from executorch.devtools.etrecord import ETRecord + + mock_etrecord = ETRecord() + mock_etrecord._representative_inputs = torch.tensor([1.0]) + mock_etrecord.exported_program = mock_exported_program + mock_etrecord.edge_dialect_program = mock_edge_dialect_program + mock_etrecord.export_graph_id = "graph_id" + mock_etrecord.graph_map = {} + + inspector_instance._etrecord = mock_etrecord + + # Mock propagate_back_debug_handle to return False (backpropagation failure) + with patch( + "executorch.devtools.inspector._inspector.propagate_back_debug_handle" + ) as mock_propagate: + mock_propagate.return_value = False + + # Test with "exported_program" should raise error when backpropagation fails + with self.assertRaises(ValueError) as context: + inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="exported_program", + ) + self.assertIn("Cannot use 'exported_program'", str(context.exception)) + self.assertIn("backpropagation failed", str(context.exception)) + + def test_calculate_numeric_gap_with_edge_dialect_exported_program_name(self): + """Test calculate_numeric_gap with edge_dialect_exported_program reference_graph parameter.""" + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + # Create mock intermediate outputs (same structure as test_calculate_numeric_gap) + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + } + runtime_intermediate_outputs = { + (0,): ([torch.tensor([2.0, 3.0, 4.0])], 1), + } + + aot_debug_handle_to_op_name = {(0,): "op_0"} + runtime_debug_handle_to_op_name = {(0,): "op_0"} + + # Mock the internal methods like test_calculate_numeric_gap does + inspector_instance._get_aot_intermediate_outputs_and_op_names = ( + lambda x, y: ( + aot_intermediate_outputs, + aot_debug_handle_to_op_name, + ) + ) + inspector_instance._get_runtime_intermediate_outputs_and_op_names = ( + lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name) + ) + + # Test with edge_dialect_exported_program parameter + df = inspector_instance.calculate_numeric_gap( + distance="L1", + reference_graph="edge_dialect_exported_program", + ) + + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 1) + @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): """ diff --git a/exir/program/_program.py b/exir/program/_program.py index abf413918e5..baacd5eaec4 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1388,6 +1388,11 @@ def to_edge_transform_and_lower( # noqa: C901 if transform_passes is not None: edge_manager = edge_manager.transform(transform_passes) + if generate_etrecord: + edge_manager._etrecord.add_extra_export_modules( + {"edge_after_transform": copy.deepcopy(edge_manager)} + ) + max_num_partitioners = 0 for partitioner_list in partitioner.values(): max_num_partitioners = max(max_num_partitioners, len(partitioner_list))