Skip to content

Commit f78535d

Browse files
[devtools] Add preprocessing support to NumericalComparatorBase (#17695)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #17433 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/119/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/119/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/119/orig Differential Revision: [D93169813](https://our.internmc.facebook.com/intern/diff/D93169813/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <gasoonjia@icloud.com>
1 parent 7edb46d commit f78535d

14 files changed

Lines changed: 1074 additions & 149 deletions

devtools/etrecord/tests/etrecord_test.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,3 +1819,122 @@ def test_multi_method_etrecord_generation(self):
18191819
# Verify other ETRecord components are preserved
18201820
self.assertIsNotNone(parsed_etrecord._debug_handle_map)
18211821
self.assertIsNotNone(parsed_etrecord._delegate_map)
1822+
1823+
def test_edge_after_transform_graph_capture(self):
1824+
"""Test that to_edge_transform_and_lower with transform_passes captures the after-transform graph.
1825+
1826+
When generate_etrecord=True and transform_passes are applied, the ETRecord should
1827+
contain the after-transform graph under the key 'edge_after_transform' in graph_map.
1828+
This enables backends like Qualcomm to use the post-custom-transform graph as the
1829+
golden reference for numeric gap calculation.
1830+
"""
1831+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1832+
1833+
# Create a simple custom pass that modifies the graph
1834+
class SimpleCustomPass(PassBase):
1835+
"""A simple pass that adds a marker attribute to each node."""
1836+
1837+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1838+
# Mark each node to indicate this pass ran
1839+
for node in graph_module.graph.nodes:
1840+
node.meta["custom_pass_applied"] = True
1841+
return PassResult(graph_module=graph_module, modified=True)
1842+
1843+
f = models.BasicSinMax()
1844+
aten_dialect = export(f, f.get_random_inputs(), strict=True)
1845+
1846+
# Create edge program with custom transform pass and generate_etrecord=True
1847+
transform_passes = [SimpleCustomPass()]
1848+
1849+
edge_manager = to_edge_transform_and_lower(
1850+
aten_dialect,
1851+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
1852+
transform_passes=transform_passes,
1853+
generate_etrecord=True,
1854+
)
1855+
1856+
# Verify that ETRecord was generated
1857+
self.assertIsNotNone(edge_manager._etrecord)
1858+
etrecord = edge_manager._etrecord
1859+
1860+
# Verify graph_map exists and contains the 'edge_after_transform' key
1861+
self.assertIsNotNone(etrecord.graph_map)
1862+
self.assertIn(
1863+
"edge_after_transform/forward",
1864+
etrecord.graph_map,
1865+
"graph_map should contain 'edge_after_transform/forward' when transform_passes are applied",
1866+
)
1867+
1868+
# Verify the captured graph has the custom pass marker
1869+
after_transform_graph = etrecord.graph_map["edge_after_transform/forward"]
1870+
self.assertIsNotNone(after_transform_graph)
1871+
1872+
# Check that at least one node has the custom_pass_applied marker
1873+
has_marker = False
1874+
for node in after_transform_graph.graph.nodes:
1875+
if node.meta.get("custom_pass_applied", False):
1876+
has_marker = True
1877+
break
1878+
1879+
self.assertTrue(
1880+
has_marker,
1881+
"The edge_after_transform graph should have the custom pass marker applied",
1882+
)
1883+
1884+
# Verify edge_dialect_program is still the pre-transform graph (original behavior preserved)
1885+
self.assertIsNotNone(etrecord.edge_dialect_program)
1886+
1887+
# Save and parse the ETRecord to verify persistence
1888+
et_output = edge_manager.to_executorch()
1889+
1890+
with tempfile.TemporaryDirectory() as tmpdirname:
1891+
etrecord_path = tmpdirname + "/etrecord_custom_pass.bin"
1892+
1893+
# Get ETRecord and save
1894+
complete_etrecord = et_output.get_etrecord()
1895+
complete_etrecord.save(etrecord_path)
1896+
1897+
# Parse ETRecord back
1898+
parsed_etrecord = parse_etrecord(etrecord_path)
1899+
1900+
# Verify the after-transform graph is preserved after save/parse
1901+
self.assertIsNotNone(parsed_etrecord.graph_map)
1902+
self.assertIn(
1903+
"edge_after_transform/forward",
1904+
parsed_etrecord.graph_map,
1905+
"Parsed ETRecord should still contain 'edge_after_transform/forward'",
1906+
)
1907+
1908+
# Verify the parsed graph still has the marker
1909+
parsed_after_transform_graph = parsed_etrecord.graph_map[
1910+
"edge_after_transform/forward"
1911+
]
1912+
self.assertIsNotNone(parsed_after_transform_graph)
1913+
1914+
def test_no_edge_after_transform_without_transform_passes(self):
1915+
"""Test that 'edge_after_transform' is NOT added when no transform_passes are provided.
1916+
1917+
This ensures backward compatibility - when generate_etrecord=True but no transform_passes
1918+
are applied, the ETRecord should NOT have an 'edge_after_transform' entry.
1919+
"""
1920+
f = models.BasicSinMax()
1921+
aten_dialect = export(f, f.get_random_inputs(), strict=True)
1922+
1923+
# Create edge program WITHOUT transform_passes
1924+
edge_manager = to_edge_transform_and_lower(
1925+
aten_dialect,
1926+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
1927+
generate_etrecord=True,
1928+
)
1929+
1930+
# Verify that ETRecord was generated
1931+
self.assertIsNotNone(edge_manager._etrecord)
1932+
etrecord = edge_manager._etrecord
1933+
1934+
# Verify that 'edge_after_transform' is NOT in graph_map
1935+
if etrecord.graph_map is not None:
1936+
self.assertNotIn(
1937+
"edge_after_transform/forward",
1938+
etrecord.graph_map,
1939+
"graph_map should NOT contain 'edge_after_transform/forward' when no transform_passes are applied",
1940+
)

devtools/inspector/_inspector.py

Lines changed: 114 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@
4242
from executorch.devtools.etrecord import ETRecord, parse_etrecord
4343
from executorch.devtools.inspector._inspector_utils import (
4444
calculate_time_scale_factor,
45-
compare_intermediate_outputs,
4645
create_debug_handle_to_op_node_mapping,
4746
DebugHandle,
4847
display_or_print_df,
4948
EDGE_DIALECT_GRAPH_KEY,
5049
EXCLUDED_COLUMNS_WHEN_PRINTING,
5150
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
5251
EXCLUDED_EVENTS_WHEN_PRINTING,
53-
find_op_names,
5452
find_populated_event,
5553
FORWARD,
5654
gen_etdump_object,
@@ -1168,31 +1166,94 @@ def _consume_etrecord(self) -> None:
11681166

11691167
def _get_aot_intermediate_outputs_and_op_names(
11701168
self,
1169+
reference_graph: Optional[str] = None,
11711170
disable_debug_handle_valdiation: bool = False,
11721171
) -> Tuple[Dict[DebugHandle, Any], Dict[DebugHandle, List[str]]]:
11731172
"""
11741173
Capture intermediate outputs only if _representative_inputs are provided
1175-
when using bundled program to create the etrecord
1176-
"""
1177-
if self._etrecord._representative_inputs is None:
1178-
return {}, {}
1174+
when using bundled program to create the etrecord.
11791175
1180-
export_program = None
1176+
Args:
1177+
reference_graph_name: Name of the graph to use as the reference for intermediate
1178+
output capture. Must be one of:
1179+
- "exported_program": Uses the ATen dialect exported program. Requires
1180+
successful debug handle backpropagation, otherwise raises an error.
1181+
- "edge_dialect_exported_program": Uses the Edge dialect program directly.
1182+
- Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward"
1183+
for post-custom-transform graph when transform_passes are applied in
1184+
to_edge_transform_and_lower with generate_etrecord=True).
1185+
disable_debug_handle_valdiation: If True, skip debug handle validation.
11811186
1182-
# Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program
1183-
if self._etrecord.exported_program and propagate_back_debug_handle(
1184-
self._etrecord.exported_program,
1185-
self._etrecord.export_graph_id,
1186-
self._etrecord.edge_dialect_program,
1187-
disable_debug_handle_valdiation,
1188-
):
1187+
Returns:
1188+
Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries.
1189+
1190+
Raises:
1191+
ValueError: If the specified reference_graph_name is not available or if
1192+
debug handle backpropagation fails for "exported_program".
1193+
"""
1194+
1195+
# Determine the reference graph to use
1196+
if reference_graph is None or reference_graph == "exported_program":
1197+
# Auto-select: try exported_program first, fall back to edge_dialect_exported_program
1198+
if self._etrecord.exported_program and propagate_back_debug_handle(
1199+
self._etrecord.exported_program,
1200+
self._etrecord.export_graph_id,
1201+
self._etrecord.edge_dialect_program,
1202+
disable_debug_handle_valdiation,
1203+
):
1204+
reference_graph = "exported_program"
1205+
elif reference_graph is None:
1206+
log.warning(
1207+
"Either ATen dialect exported program is not in ETRecord, or debug handle "
1208+
"backpropagation failed. Falling back to 'edge_dialect_exported_program'."
1209+
)
1210+
reference_graph = "edge_dialect_exported_program"
1211+
else:
1212+
raise ValueError(
1213+
"Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. "
1214+
"Please check if the exported program is available in ETRecord, or try to disable debug handle validation."
1215+
)
1216+
if reference_graph == "edge_dialect_exported_program":
1217+
# Explicitly requested edge_dialect_exported_program
1218+
export_program = self._etrecord.edge_dialect_program
1219+
log.info(
1220+
"Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture"
1221+
)
1222+
elif reference_graph == "exported_program":
11891223
export_program = self._etrecord.exported_program
1190-
else:
1191-
log.warning(
1192-
"Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program."
1193-
"Will fall back to use edge dialect program to extract intermediate output",
1224+
log.info(
1225+
"Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture"
11941226
)
1195-
export_program = self._etrecord.edge_dialect_program
1227+
else:
1228+
# Try to fetch from graph_map
1229+
# If no method name is provided (no "/" in the name), try adding "/forward" as default
1230+
lookup_name = reference_graph
1231+
if "/" not in reference_graph:
1232+
lookup_name = f"{reference_graph}/forward"
1233+
log.info(
1234+
f"No method name specified in '{reference_graph}', "
1235+
f"using '{lookup_name}' as default"
1236+
)
1237+
1238+
if (
1239+
self._etrecord.graph_map is not None
1240+
and lookup_name in self._etrecord.graph_map
1241+
):
1242+
export_program = self._etrecord.graph_map[lookup_name]
1243+
log.info(
1244+
f"Using '{lookup_name}' from graph_map as reference graph for intermediate output capture"
1245+
)
1246+
else:
1247+
available_graphs = (
1248+
list(self._etrecord.graph_map.keys())
1249+
if self._etrecord.graph_map
1250+
else []
1251+
)
1252+
raise ValueError(
1253+
f"Reference graph '{lookup_name}' not found. "
1254+
f"Available options: 'exported_program', 'edge_dialect_exported_program', "
1255+
f"or one of the graphs in graph_map: {available_graphs}"
1256+
)
11961257
graph_module = export_program.module()
11971258
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
11981259
graph_module
@@ -1408,11 +1469,11 @@ def calculate_numeric_gap(
14081469
self,
14091470
distance: Union[str, NumericalComparatorBase],
14101471
disable_debug_handle_valdiation: bool = False,
1472+
reference_graph: Optional[str] = None,
14111473
):
14121474
"""
14131475
Compares logged intermediate outputs from the exported graph (in ETRecord)
14141476
with runtime outputs (in ETDump) using a user-specific numerical comparator.
1415-
If the exported graph is not supported, the function will fall back to use edge dialect graph.
14161477
14171478
To use this function, you must first generate the ETRecord with representative inputs,
14181479
and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then
@@ -1421,20 +1482,35 @@ def calculate_numeric_gap(
14211482
Args:
14221483
distance: The metrics the inspector will use for gap calculation. Can be either:
14231484
- A string: one of "MSE", "L1", or "SNR" for built-in comparators.
1424-
- A custom NumericalComparatorBase instance: allows you to define custom comparison logic
1425-
by subclassing NumericalComparatorBase and implementing the compare() method.
1426-
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
1485+
- A custom NumericalComparatorBase instance: allows you to define custom comparison
1486+
logic by subclassing NumericalComparatorBase and implementing the element_compare()
1487+
method. Custom comparators can also override the preprocessing() method to apply
1488+
transformations (e.g., layout conversion, dequantization) before comparison.
1489+
disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
14271490
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
14281491
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
14291492
node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge
14301493
IR as reference graph. This flag allows one to override such behavior and make best effort comparison.
1494+
reference_graph: Name of the graph to use as the golden reference for intermediate output capture.
1495+
Must be one of:
1496+
- "exported_program": Uses the ATen dialect exported program. Requires successful debug
1497+
handle backpropagation, otherwise raises an error.
1498+
- "edge_dialect_exported_program": Uses the Edge dialect program directly.
1499+
- Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for
1500+
post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower
1501+
with generate_etrecord=True).
1502+
1503+
If None (default), automatically selects the best available graph:
1504+
- Uses "exported_program" if available and debug handle backpropagation succeeds.
1505+
- Falls back to "edge_dialect_exported_program" otherwise.
14311506
14321507
Returns:
14331508
pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14341509
"""
14351510
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
14361511
self._get_aot_intermediate_outputs_and_op_names(
1437-
disable_debug_handle_valdiation
1512+
reference_graph,
1513+
disable_debug_handle_valdiation,
14381514
)
14391515
)
14401516
if len(aot_intermediate_outputs) == 0 or len(aot_debug_handle_to_op_names) == 0:
@@ -1448,48 +1524,27 @@ def calculate_numeric_gap(
14481524
mapping = map_runtime_aot_intermediate_outputs(
14491525
aot_intermediate_outputs, runtime_intermediate_outputs
14501526
)
1527+
1528+
# Get or create comparator
14511529
if isinstance(distance, NumericalComparatorBase):
14521530
comparator = distance
1531+
# Inject inspector if not already set
1532+
if comparator.inspector is None:
1533+
comparator.inspector = self
14531534
else:
14541535
metric = distance.strip().upper()
14551536
if metric == "MSE":
1456-
comparator = MSEComparator()
1537+
comparator = MSEComparator(inspector=self)
14571538
elif metric == "L1":
1458-
comparator = L1Comparator()
1539+
comparator = L1Comparator(inspector=self)
14591540
elif metric == "SNR":
1460-
comparator = SNRComparator()
1541+
comparator = SNRComparator(inspector=self)
14611542
else:
14621543
raise ValueError(f"Unsupported distance metric {distance!r}")
14631544

1464-
rows = []
1465-
for (aot_debug_handle, aot_intermediate_output), (
1466-
runtime_debug_handle,
1467-
runtime_intermediate_output,
1468-
) in mapping.items():
1469-
if aot_intermediate_output is None or runtime_intermediate_output is None:
1470-
continue
1471-
# If aot outputs length is > 1 then comparison fails since we dont really have
1472-
# any instances where runtime intermediate output is a tuple or list
1473-
# This does not happen when edge dialect program is reference for comparison
1474-
# but happens in aten graph where ops like unbind remain undecomposed
1475-
if (
1476-
isinstance(aot_intermediate_output, Sequence)
1477-
and len(aot_intermediate_output) > 1
1478-
):
1479-
continue
1480-
rows.append(
1481-
{
1482-
"aot_ops": find_op_names(
1483-
aot_debug_handle, aot_debug_handle_to_op_names
1484-
),
1485-
"aot_intermediate_output": aot_intermediate_output,
1486-
"runtime_ops": find_op_names(
1487-
runtime_debug_handle, runtime_debug_handle_to_op_names
1488-
),
1489-
"runtime_intermediate_output": runtime_intermediate_output,
1490-
"gap": compare_intermediate_outputs(
1491-
aot_intermediate_output, runtime_intermediate_output, comparator
1492-
),
1493-
}
1494-
)
1495-
return pd.DataFrame(rows)
1545+
# Delegate to comparator's compare method (includes preprocessing)
1546+
return comparator.compare(
1547+
mapping,
1548+
aot_debug_handle_to_op_names,
1549+
runtime_debug_handle_to_op_names,
1550+
)

0 commit comments

Comments
 (0)