4242from executorch .devtools .etrecord import ETRecord , parse_etrecord
4343from executorch .devtools .inspector ._inspector_utils import (
4444 calculate_time_scale_factor ,
45- compare_intermediate_outputs ,
4645 create_debug_handle_to_op_node_mapping ,
4746 DebugHandle ,
4847 display_or_print_df ,
4948 EDGE_DIALECT_GRAPH_KEY ,
5049 EXCLUDED_COLUMNS_WHEN_PRINTING ,
5150 EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT ,
5251 EXCLUDED_EVENTS_WHEN_PRINTING ,
53- find_op_names ,
5452 find_populated_event ,
5553 FORWARD ,
5654 gen_etdump_object ,
@@ -1168,31 +1166,94 @@ def _consume_etrecord(self) -> None:
11681166
11691167 def _get_aot_intermediate_outputs_and_op_names (
11701168 self ,
1169+ reference_graph : Optional [str ] = None ,
11711170 disable_debug_handle_valdiation : bool = False ,
11721171 ) -> Tuple [Dict [DebugHandle , Any ], Dict [DebugHandle , List [str ]]]:
11731172 """
11741173 Capture intermediate outputs only if _representative_inputs are provided
1175- when using bundled program to create the etrecord
1176- """
1177- if self ._etrecord ._representative_inputs is None :
1178- return {}, {}
1174+ when using bundled program to create the etrecord.
11791175
1180- export_program = None
1176+ Args:
1177+ reference_graph_name: Name of the graph to use as the reference for intermediate
1178+ output capture. Must be one of:
1179+ - "exported_program": Uses the ATen dialect exported program. Requires
1180+ successful debug handle backpropagation, otherwise raises an error.
1181+ - "edge_dialect_exported_program": Uses the Edge dialect program directly.
1182+ - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward"
1183+ for post-custom-transform graph when transform_passes are applied in
1184+ to_edge_transform_and_lower with generate_etrecord=True).
1185+ disable_debug_handle_valdiation: If True, skip debug handle validation.
11811186
1182- # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is one of the ancestors of the edge_dialect_program
1183- if self ._etrecord .exported_program and propagate_back_debug_handle (
1184- self ._etrecord .exported_program ,
1185- self ._etrecord .export_graph_id ,
1186- self ._etrecord .edge_dialect_program ,
1187- disable_debug_handle_valdiation ,
1188- ):
1187+ Returns:
1188+ Tuple of (intermediate_outputs, debug_handle_to_op_names) dictionaries.
1189+
1190+ Raises:
1191+ ValueError: If the specified reference_graph_name is not available or if
1192+ debug handle backpropagation fails for "exported_program".
1193+ """
1194+
1195+ # Determine the reference graph to use
1196+ if reference_graph is None or reference_graph == "exported_program" :
1197+ # Auto-select: try exported_program first, fall back to edge_dialect_exported_program
1198+ if self ._etrecord .exported_program and propagate_back_debug_handle (
1199+ self ._etrecord .exported_program ,
1200+ self ._etrecord .export_graph_id ,
1201+ self ._etrecord .edge_dialect_program ,
1202+ disable_debug_handle_valdiation ,
1203+ ):
1204+ reference_graph = "exported_program"
1205+ elif reference_graph is None :
1206+ log .warning (
1207+ "Either ATen dialect exported program is not in ETRecord, or debug handle "
1208+ "backpropagation failed. Falling back to 'edge_dialect_exported_program'."
1209+ )
1210+ reference_graph = "edge_dialect_exported_program"
1211+ else :
1212+ raise ValueError (
1213+ "Cannot use 'exported_program': Debug handle backpropagation failed or exported program is unavailable. "
1214+ "Please check if the exported program is available in ETRecord, or try to disable debug handle validation."
1215+ )
1216+ if reference_graph == "edge_dialect_exported_program" :
1217+ # Explicitly requested edge_dialect_exported_program
1218+ export_program = self ._etrecord .edge_dialect_program
1219+ log .info (
1220+ "Using 'edge_dialect_exported_program' (Edge dialect) as reference graph for intermediate output capture"
1221+ )
1222+ elif reference_graph == "exported_program" :
11891223 export_program = self ._etrecord .exported_program
1190- else :
1191- log .warning (
1192- "Either aten dialect exported program is not in ETRecord, or it is not one of the ancestors of current edge dialect program."
1193- "Will fall back to use edge dialect program to extract intermediate output" ,
1224+ log .info (
1225+ "Using 'exported_program' (ATen dialect) as reference graph for intermediate output capture"
11941226 )
1195- export_program = self ._etrecord .edge_dialect_program
1227+ else :
1228+ # Try to fetch from graph_map
1229+ # If no method name is provided (no "/" in the name), try adding "/forward" as default
1230+ lookup_name = reference_graph
1231+ if "/" not in reference_graph :
1232+ lookup_name = f"{ reference_graph } /forward"
1233+ log .info (
1234+ f"No method name specified in '{ reference_graph } ', "
1235+ f"using '{ lookup_name } ' as default"
1236+ )
1237+
1238+ if (
1239+ self ._etrecord .graph_map is not None
1240+ and lookup_name in self ._etrecord .graph_map
1241+ ):
1242+ export_program = self ._etrecord .graph_map [lookup_name ]
1243+ log .info (
1244+ f"Using '{ lookup_name } ' from graph_map as reference graph for intermediate output capture"
1245+ )
1246+ else :
1247+ available_graphs = (
1248+ list (self ._etrecord .graph_map .keys ())
1249+ if self ._etrecord .graph_map
1250+ else []
1251+ )
1252+ raise ValueError (
1253+ f"Reference graph '{ lookup_name } ' not found. "
1254+ f"Available options: 'exported_program', 'edge_dialect_exported_program', "
1255+ f"or one of the graphs in graph_map: { available_graphs } "
1256+ )
11961257 graph_module = export_program .module ()
11971258 aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping (
11981259 graph_module
@@ -1408,11 +1469,11 @@ def calculate_numeric_gap(
14081469 self ,
14091470 distance : Union [str , NumericalComparatorBase ],
14101471 disable_debug_handle_valdiation : bool = False ,
1472+ reference_graph : Optional [str ] = None ,
14111473 ):
14121474 """
14131475 Compares logged intermediate outputs from the exported graph (in ETRecord)
14141476 with runtime outputs (in ETDump) using a user-specific numerical comparator.
1415- If the exported graph is not supported, the function will fall back to use edge dialect graph.
14161477
14171478 To use this function, you must first generate the ETRecord with representative inputs,
14181479 and then create the Inspector instance with the ETRecord and ETDump. The Inspector can then
@@ -1421,20 +1482,35 @@ def calculate_numeric_gap(
14211482 Args:
14221483 distance: The metrics the inspector will use for gap calculation. Can be either:
14231484 - A string: one of "MSE", "L1", or "SNR" for built-in comparators.
1424- - A custom NumericalComparatorBase instance: allows you to define custom comparison logic
1425- by subclassing NumericalComparatorBase and implementing the compare() method.
1426- disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
1485+ - A custom NumericalComparatorBase instance: allows you to define custom comparison
1486+ logic by subclassing NumericalComparatorBase and implementing the element_compare()
1487+ method. Custom comparators can also override the preprocessing() method to apply
1488+ transformations (e.g., layout conversion, dequantization) before comparison.
1489+ disable_debug_handle_valdiation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
14271490 during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
14281491 connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
14291492 node has corresponding node in aten IR, and when such validation fails numeric debugger falls back to edge
14301493 IR as reference graph. This flag allows one to override such behavior and make best effort comparison.
1494+ reference_graph: Name of the graph to use as the golden reference for intermediate output capture.
1495+ Must be one of:
1496+ - "exported_program": Uses the ATen dialect exported program. Requires successful debug
1497+ handle backpropagation, otherwise raises an error.
1498+ - "edge_dialect_exported_program": Uses the Edge dialect program directly.
1499+ - Any other string: Fetches from graph_map (e.g., "edge_after_transform/forward" for
1500+ post-custom-transform graph when transform_passes are applied in to_edge_transform_and_lower
1501+ with generate_etrecord=True).
1502+
1503+ If None (default), automatically selects the best available graph:
1504+ - Uses "exported_program" if available and debug handle backpropagation succeeds.
1505+ - Falls back to "edge_dialect_exported_program" otherwise.
14311506
14321507 Returns:
14331508 pd.DataFrame: A DataFrame listing corresponding operator intermediate outputs from both stages and their computed numerical gaps.
14341509 """
14351510 aot_intermediate_outputs , aot_debug_handle_to_op_names = (
14361511 self ._get_aot_intermediate_outputs_and_op_names (
1437- disable_debug_handle_valdiation
1512+ reference_graph ,
1513+ disable_debug_handle_valdiation ,
14381514 )
14391515 )
14401516 if len (aot_intermediate_outputs ) == 0 or len (aot_debug_handle_to_op_names ) == 0 :
@@ -1448,48 +1524,27 @@ def calculate_numeric_gap(
14481524 mapping = map_runtime_aot_intermediate_outputs (
14491525 aot_intermediate_outputs , runtime_intermediate_outputs
14501526 )
1527+
1528+ # Get or create comparator
14511529 if isinstance (distance , NumericalComparatorBase ):
14521530 comparator = distance
1531+ # Inject inspector if not already set
1532+ if comparator .inspector is None :
1533+ comparator .inspector = self
14531534 else :
14541535 metric = distance .strip ().upper ()
14551536 if metric == "MSE" :
1456- comparator = MSEComparator ()
1537+ comparator = MSEComparator (inspector = self )
14571538 elif metric == "L1" :
1458- comparator = L1Comparator ()
1539+ comparator = L1Comparator (inspector = self )
14591540 elif metric == "SNR" :
1460- comparator = SNRComparator ()
1541+ comparator = SNRComparator (inspector = self )
14611542 else :
14621543 raise ValueError (f"Unsupported distance metric { distance !r} " )
14631544
1464- rows = []
1465- for (aot_debug_handle , aot_intermediate_output ), (
1466- runtime_debug_handle ,
1467- runtime_intermediate_output ,
1468- ) in mapping .items ():
1469- if aot_intermediate_output is None or runtime_intermediate_output is None :
1470- continue
1471- # If aot outputs length is > 1 then comparison fails since we dont really have
1472- # any instances where runtime intermediate output is a tuple or list
1473- # This does not happen when edge dialect program is reference for comparison
1474- # but happens in aten graph where ops like unbind remain undecomposed
1475- if (
1476- isinstance (aot_intermediate_output , Sequence )
1477- and len (aot_intermediate_output ) > 1
1478- ):
1479- continue
1480- rows .append (
1481- {
1482- "aot_ops" : find_op_names (
1483- aot_debug_handle , aot_debug_handle_to_op_names
1484- ),
1485- "aot_intermediate_output" : aot_intermediate_output ,
1486- "runtime_ops" : find_op_names (
1487- runtime_debug_handle , runtime_debug_handle_to_op_names
1488- ),
1489- "runtime_intermediate_output" : runtime_intermediate_output ,
1490- "gap" : compare_intermediate_outputs (
1491- aot_intermediate_output , runtime_intermediate_output , comparator
1492- ),
1493- }
1494- )
1495- return pd .DataFrame (rows )
1545+ # Delegate to comparator's compare method (includes preprocessing)
1546+ return comparator .compare (
1547+ mapping ,
1548+ aot_debug_handle_to_op_names ,
1549+ runtime_debug_handle_to_op_names ,
1550+ )
0 commit comments