From 3a66602eb2d7c1806cf787b36bec66cf2cef5b1d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 26 Jun 2025 00:47:05 -0700 Subject: [PATCH] [et] generate debug handle before opeartor decomposition This diff update the debug handle generation, from each node in the edge program having a individual debug handle, to all nodes having a same ancestor in export graph sharing a same debug handle, which update the start point of tracing our node transformation from edge graph to exported graph. Differential Revision: [D76860368](https://our.internmc.facebook.com/intern/diff/D76860368/) [ghstack-poisoned] --- devtools/inspector/_inspector.py | 25 +++++---- devtools/inspector/_inspector_utils.py | 23 ++++---- devtools/inspector/tests/inspector_test.py | 6 ++- .../inspector/tests/inspector_test_utils.py | 30 ++++------- .../inspector/tests/inspector_utils_test.py | 16 ++++-- exir/passes/debug_handle_generator_pass.py | 53 ++++++++++++++++--- exir/tests/test_passes.py | 24 ++++++--- 7 files changed, 121 insertions(+), 56 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index dfff3d0818e..df98da8bf3d 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -654,7 +654,7 @@ def _populate_debugging_related_fields( def _associate_with_op_graph_nodes( self, - debug_handle_to_op_node_map: Dict[int, OperatorNode], + debug_handle_to_op_node_map: Dict[int, List[OperatorNode]], ) -> None: """ Helper function to populate the stack_traces, module_hierarchy and op_types attributes @@ -672,14 +672,21 @@ def _associate_with_op_graph_nodes( debug_handles = [debug_handles] for handle in debug_handles: - node = debug_handle_to_op_node_map.get(handle) - # Attach node metadata including stack traces, module hierarchy and op_types to this event - if node is not None and (metadata := node.metadata) is not None: - self.stack_traces[node.name] = metadata.get("stack_trace") - self.module_hierarchy[node.name] = metadata.get("nn_module_stack") - if node.op: - # TODO: consider having this as a dict from node.name -> node.op - self.op_types += [node.op] + nodes = debug_handle_to_op_node_map.get(handle, None) + if nodes is None: + continue + + for node in nodes: + # Attach node metadata including stack traces, module hierarchy and op_types to this event + if node is not None and (metadata := node.metadata) is not None: + if node.name not in self.stack_traces: + self.stack_traces[node.name] = metadata.get("stack_trace") + self.module_hierarchy[node.name] = metadata.get( + "nn_module_stack" + ) + if node.op: + # TODO: consider having this as a dict from node.name -> node.op + self.op_types += [node.op] @dataclass diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 61e2ea4d031..08c2e2b5c91 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -279,14 +279,18 @@ def gen_graphs_from_etrecord( return op_graph_map +# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation +# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge +# graph. After fully migration, we should bring the bring type as well as the #node check back. +# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check. def create_debug_handle_to_op_node_mapping( op_graph: OperatorGraph, -) -> Dict[int, OperatorNode]: +) -> Dict[int, List[OperatorNode]]: """ Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping from each debug handle to the operator node that contains the debug handle in its metadata. """ - debug_handle_to_op_node_map: Dict[int, OperatorNode] = {} + debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {} # Recursively searches through the metadata of nodes def _extract_debug_handles(graph: OperatorGraph): @@ -296,14 +300,13 @@ def _extract_debug_handles(graph: OperatorGraph): if isinstance(element, OperatorNode) and element.metadata is not None: metadata = element.metadata debug_handle = metadata.get("debug_handle") - if debug_handle is not None: - existing_entry = debug_handle_to_op_node_map.get(debug_handle) - if existing_entry is not None: - raise ValueError( - f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. " - "No two op nodes of the same graph should have the same debug handle." - ) - debug_handle_to_op_node_map[debug_handle] = element + if debug_handle is None: + continue + + if debug_handle not in debug_handle_to_op_node_map: + debug_handle_to_op_node_map[debug_handle] = [] + + debug_handle_to_op_node_map[debug_handle].append(element) # Start traversing _extract_debug_handles(op_graph) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 1460dbd46a2..896228d5334 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -182,7 +182,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self): # Call the method that's under testing and verify event_with_single_debug_handle._associate_with_op_graph_nodes( - {debug_handle: node_0} + { + debug_handle: [ + node_0, + ] + } ) expected_stack_traces = {"node_0": "stack_trace_relu"} diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index b9d4b1882b8..86f828d3dc3 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -62,25 +62,17 @@ def get_expected_intermediate_outputs(): Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input. """ return { - (10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), - (11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), - (12,): torch.tensor( - [ - [0.1000, 0.5000], - [0.2000, 0.6000], - [0.3000, 0.7000], - [0.4000, 0.8000], - ] - ), - (13,): torch.tensor([[5.0000, 14.1200]]), - (14,): torch.tensor([[5.5000, 13.6200]]), - (15,): torch.tensor([[5.4000, 13.5200]]), - (16,): torch.tensor([[10.8000, 6.7600]]), - (17,): torch.tensor([3.0000, 1.5000]), - (18,): torch.tensor([[3.6000, 4.5067]]), - (19,): torch.tensor([[3.6000, 4.5067]]), - (20,): torch.tensor([[0.9734, 0.9891]]), - (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], + (1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), + (2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), + (3,): torch.tensor([[5.0000, 14.1200]]), + (4,): torch.tensor([[5.5000, 13.6200]]), + (5,): torch.tensor([[5.4000, 13.5200]]), + (6,): torch.tensor([[10.8000, 6.7600]]), + (7,): torch.tensor([3.0000, 1.5000]), + (8,): torch.tensor([[3.6000, 4.5067]]), + (9,): torch.tensor([[3.6000, 4.5067]]), + (10,): torch.tensor([[0.9734, 0.9891]]), + (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8148d2c36f0..74ddf8ae3f6 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -381,7 +381,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_relu", }, ) - mapping[111] = node_fused_conv_relu + mapping[111] = [ + node_fused_conv_relu, + ] node_sin = OperatorNode( "sin", [node_fused_conv_relu], @@ -392,7 +394,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_sin", }, ) - mapping[222] = node_sin + mapping[222] = [ + node_sin, + ] node_cos = OperatorNode( "cos", [node_sin], @@ -403,7 +407,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_cos", }, ) - mapping[333] = node_cos + mapping[333] = [ + node_cos, + ] node_div = OperatorNode( "div", [node_cos], @@ -414,7 +420,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_div", }, ) - mapping[444] = node_div + mapping[444] = [ + node_div, + ] node_output = ValueNode("output", [node_div]) return ( OperatorGraph( diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 7de8676084b..f42d7d646c2 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict + from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram -from torch.fx import GraphModule +from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassResult @@ -17,18 +19,57 @@ def call(self, graph_module: GraphModule) -> PassResult: to executorch backend, that has a canonical set of quantized operators """ - index = 1 + FROM_NODE_KEY = "from_node" + DEBUG_HANDLE_KEY = "debug_handle" + + source_node_to_debug_handle: Dict[str, int] = {} + + def _get_greatest_ancestor_source_node(node: Node) -> str: + """Get the source of the greatest ancestor node of the given node. The source + here means the name of the node concated with the id the graph it belongs to. + For example, if the node transformation is node a -> b -> c, then the greatest + ancestor node of c is a. + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + + return node_source.name + str(node_source.graph_id) + + def _extract_debug_handles_from_node(node: Node) -> None: + """ + Generate a debug handle based on node's oldest ancestor node's name + and graph id, or return None if the node does not need to be traced. + """ + + if node.op == "placeholder" or node.op == "output": + # placeholder and output nodes don't have debug handle + return + + assert ( + FROM_NODE_KEY in node.meta + ), f"Node {node} does not have meta key {FROM_NODE_KEY}" + + source_node = _get_greatest_ancestor_source_node(node) + + debug_handle = ( + len(source_node_to_debug_handle) + 1 + if source_node not in source_node_to_debug_handle + else source_node_to_debug_handle[source_node] + ) + source_node_to_debug_handle[source_node] = debug_handle - def _extract_debug_handles_from_node(node): - nonlocal index - node.meta["debug_handle"] = index - index += 1 + node.meta[DEBUG_HANDLE_KEY] = debug_handle bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node) return PassResult(graph_module, True) +# TODO(gasoonjia): generate missing debug handles using `from_node` info def generate_missing_debug_handles(ep: ExportedProgram): """ This pass is used to generate missing debug handles for the graph module and its submodules. diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index dd4037b64c0..3077447bdd5 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -859,11 +859,16 @@ def test_debug_handle_generator_pass(self) -> None: .exported_program() .graph_module ) + + # Every node except input and output should have debug handle for node in graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) ScalarToTensorPass()(graph_module) + for node in graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) def test_generate_missing_debug_handles(self) -> None: eager_model = MLP(2, output_size=4) @@ -871,10 +876,15 @@ def test_generate_missing_debug_handles(self) -> None: ep = to_edge(export(eager_model, inputs, strict=True)).exported_program() - list(ep.graph.nodes)[0].meta.pop("debug_handle") - self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None) + # get the first non-placeholder node + first_non_placeholder_node = [ + n for n in ep.graph.nodes if n.op != "placeholder" + ][0] + + first_non_placeholder_node.meta.pop("debug_handle") + self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is None) generate_missing_debug_handles(ep) - self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None) + self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is not None) def test_debug_handle_generator_pass_with_control_flow(self) -> None: def true_nested(y: torch.Tensor) -> torch.Tensor: @@ -928,7 +938,8 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: while queue: current_graph_module = queue.pop(0) for node in current_graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) control_flow_submodules = [ submodule for _, submodule, _ in get_control_flow_submodules( @@ -939,7 +950,6 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: DebugHandleGeneratorPass()(graph_module) check_debug_handle_metadata(graph_module) - generate_missing_debug_handles(ep) # Check debug handle still preserved after ScalarToTensorPass ScalarToTensorPass()(graph_module)