Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self):
# Verify other ETRecord components are preserved
self.assertIsNotNone(parsed_etrecord._debug_handle_map)
self.assertIsNotNone(parsed_etrecord._delegate_map)

def test_edge_after_transform_graph_capture(self):
"""Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph.

When generate_etrecord=True and transform_passes are applied, the ETRecord should
contain the after-transform graph under the key 'edge_after_transform' in graph_map.
This enables backends like Qualcomm to use the post-custom-transform graph as the
golden reference for numeric gap calculation.
"""
from torch.fx.passes.infra.pass_base import PassBase, PassResult

# Create a simple custom pass that modifies the graph
class SimpleCustomPass(PassBase):
"""A simple pass that adds a marker attribute to each node."""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Mark each node to indicate this pass ran
for node in graph_module.graph.nodes:
node.meta["custom_pass_applied"] = True
return PassResult(graph_module=graph_module, modified=True)

f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs(), strict=True)

# Create edge program with custom transform pass and generate_etrecord=True
transform_passes = [SimpleCustomPass()]

edge_manager = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
transform_passes=transform_passes,
generate_etrecord=True,
)

# Verify that ETRecord was generated
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord

# Verify graph_map exists and contains the 'edge_after_transform' key
self.assertIsNotNone(etrecord.graph_map)
self.assertIn(
"edge_after_transform/forward",
etrecord.graph_map,
"graph_map should contain 'edge_after_transform/forward' when transform_passes are applied",
)

# Verify the captured graph has the custom pass marker
after_transform_graph = etrecord.graph_map["edge_after_transform/forward"]
self.assertIsNotNone(after_transform_graph)

# Check that at least one node has the custom_pass_applied marker
has_marker = False
for node in after_transform_graph.graph.nodes:
if node.meta.get("custom_pass_applied", False):
has_marker = True
break

self.assertTrue(
has_marker,
"The edge_after_transform graph should have the custom pass marker applied",
)

# Verify edge_dialect_program is still the pre-transform graph (original behavior preserved)
self.assertIsNotNone(etrecord.edge_dialect_program)

# Save and parse the ETRecord to verify persistence
et_output = edge_manager.to_executorch()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_custom_pass.bin"

# Get ETRecord and save
complete_etrecord = et_output.get_etrecord()
complete_etrecord.save(etrecord_path)

# Parse ETRecord back
parsed_etrecord = parse_etrecord(etrecord_path)

# Verify the after-transform graph is preserved after save/parse
self.assertIsNotNone(parsed_etrecord.graph_map)
self.assertIn(
"edge_after_transform/forward",
parsed_etrecord.graph_map,
"Parsed ETRecord should still contain 'edge_after_transform/forward'",
)

# Verify the parsed graph still has the marker
parsed_after_transform_graph = parsed_etrecord.graph_map[
"edge_after_transform/forward"
]
self.assertIsNotNone(parsed_after_transform_graph)

def test_no_edge_after_transform_without_transform_passes(self):
"""Test that 'edge_after_transform' is NOT added when no transform_passes are provided.

This ensures backward compatibility - when generate_etrecord=True but no transform_passes
are applied, the ETRecord should NOT have an 'edge_after_transform' entry.
"""
f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs(), strict=True)

# Create edge program WITHOUT transform_passes
edge_manager = to_edge_transform_and_lower(
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
generate_etrecord=True,
)

# Verify that ETRecord was generated
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord

# Verify that 'edge_after_transform' is NOT in graph_map
if etrecord.graph_map is not None:
self.assertNotIn(
"edge_after_transform/forward",
etrecord.graph_map,
"graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied",
)
116 changes: 96 additions & 20 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,31 +1166,94 @@ def _consume_etrecord(self) -> None:

def _get_aot_intermediate_outputs_and_op_names(
self,
reference_graph: Optional[str] = None,
disable_debug_handle_valdiation: bool = False,
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
"""
Capture intermediate outputs only if _representative_inputs are provided
when using bundled program to create the etrecord
"""
if self._etrecord._representative_inputs is None:
return {}, {}
when using bundled program to create the etrecord.

export_program = None
Args:
reference_graph_name: Name of the graph to use as the reference for intermediate
output capture. Must be one of:
- "exported_program": Uses the ATen dialect exported program. Requires
successful debug handle backpropagation, otherwise raises an error.
- "edge_dialect_exported_program": Uses the Edge dialect program directly.
- Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward"
for post-custom-transform graph when transform_passes are applied in
to_edge_transform_and_lower with generate_etrecord=True).
disable_debug_handle_valdiation: If True, skip debug handle validation.

# Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program
if self._etrecord.exported_program and propagate_back_debug_handle(
self._etrecord.exported_program,
self._etrecord.export_graph_id,
self._etrecord.edge_dialect_program,
disable_debug_handle_valdiation,
):
Returns:
Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries.

Raises:
ValueError: If the specified reference_graph_name is not available or if
debug handle backpropagation fails for "exported_program".
"""

# Determine the reference graph to use
if reference_graph is None or reference_graph == "exported_program":
# Auto-select: try exported_program first, fall back to edge_dialect_exported_program
if self._etrecord.exported_program and propagate_back_debug_handle(
self._etrecord.exported_program,
self._etrecord.export_graph_id,
self._etrecord.edge_dialect_program,
disable_debug_handle_valdiation,
):
reference_graph = "exported_program"
elif reference_graph is None:
log.warning(
"Either ATen dialect exported program is not in ETRecord, or debug handle "
"backpropagation failed. Falling back to 'edge_dialect_exported_program'."
)
reference_graph = "edge_dialect_exported_program"
else:
raise ValueError(
"Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. "
"Please check if the exported program is available in ETRecord, or try to disable debug handle validation."
)
if reference_graph == "edge_dialect_exported_program":
# Explicitly requested edge_dialect_exported_program
export_program = self._etrecord.edge_dialect_program
log.info(
"Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture"
)
elif reference_graph == "exported_program":
export_program = self._etrecord.exported_program
else:
log.warning(
"Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program."
"Will fall back to use edge dialect program to extract intermediate output",
log.info(
"Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture"
)
export_program = self._etrecord.edge_dialect_program
else:
# Try to fetch from graph_map
# If no method name is provided (no "/" in the name), try adding "/forward" as default
lookup_name = reference_graph
if "/" not in reference_graph:
lookup_name = f"{reference_graph}/forward"
log.info(
f"No method name specified in '{reference_graph}', "
f"using '{lookup_name}' as default"
)

if (
self._etrecord.graph_map is not None
and lookup_name in self._etrecord.graph_map
):
export_program = self._etrecord.graph_map[lookup_name]
log.info(
f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture"
)
else:
available_graphs = (
list(self._etrecord.graph_map.keys())
if self._etrecord.graph_map
else []
)
raise ValueError(
f"Reference graph '{lookup_name}' not found. "
f"Available options: 'exported_program', 'edge_dialect_exported_program', "
f"or one of the graphs in graph_map: {available_graphs}"
)
graph_module = export_program.module()
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
graph_module
Expand Down Expand Up @@ -1406,11 +1469,11 @@ def calculate_numeric_gap(
self,
distance: Union[str, NumericalComparatorBase],
disable_debug_handle_valdiation: bool = False,
reference_graph: Optional[str] = None,
):
"""
Compares logged intermediate outputs from the exported graph (in ETRecord)
with runtime outputs (in ETDump) using a user-specific numerical comparator.
If the exported graph is not supported, the function will fall back to use edge dialect graph.

To use this function, you must first generate the ETRecord with representative inputs,
and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then
Expand All @@ -1423,18 +1486,31 @@ def calculate_numeric_gap(
logic by subclassing NumericalComparatorBase and implementing the element_compare()
method. Custom comparators can also override the preprocessing() method to apply
transformations (e.g., layout conversion, dequantization) before comparison.
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge
IR as reference graph. This flag allows one to override such behavior and make best effort comparison.
reference_graph: Name of the graph to use as the golden reference for intermediate output capture.
Must be one of:
- "exported_program": Uses the ATen dialect exported program. Requires successful debug
handle backpropagation, otherwise raises an error.
- "edge_dialect_exported_program": Uses the Edge dialect program directly.
- Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for
post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower
with generate_etrecord=True).

If None (default), automatically selects the best available graph:
- Uses "exported_program" if available and debug handle backpropagation succeeds.
- Falls back to "edge_dialect_exported_program" otherwise.

Returns:
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
"""
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
self._get_aot_intermediate_outputs_and_op_names(
disable_debug_handle_valdiation
reference_graph,
disable_debug_handle_valdiation,
)
)
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
Expand Down
Loading
Loading