Skip to content
Merged
57 changes: 18 additions & 39 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@
from executorch.devtools.etrecord import ETRecord, parse_etrecord
from executorch.devtools.inspector._inspector_utils import (
calculate_time_scale_factor,
compare_intermediate_outputs,
create_debug_handle_to_op_node_mapping,
DebugHandle,
display_or_print_df,
EDGE_DIALECT_GRAPH_KEY,
EXCLUDED_COLUMNS_WHEN_PRINTING,
EXCLUDED_EVENTS_FOR_INTERMEDIATE_OUTPUT,
EXCLUDED_EVENTS_WHEN_PRINTING,
find_op_names,
find_populated_event,
FORWARD,
gen_etdump_object,
Expand Down Expand Up @@ -1421,8 +1419,10 @@ def calculate_numeric_gap(
Args:
distance: The metrics the inspector will use for gap calculation. Can be either:
- A string: one of "MSE", "L1", or "SNR" for built-in comparators.
- A custom NumericalComparatorBase instance: allows you to define custom comparison logic
by subclassing NumericalComparatorBase and implementing the compare() method.
- A custom NumericalComparatorBase instance: allows you to define custom comparison
logic by subclassing NumericalComparatorBase and implementing the element_compare()
method. Custom comparators can also override the preprocessing() method to apply
transformations (e.g., layout conversion, dequantization) before comparison.
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
Expand All @@ -1448,48 +1448,27 @@ def calculate_numeric_gap(
mapping = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)

# Get or create comparator
if isinstance(distance, NumericalComparatorBase):
comparator = distance
# Inject inspector if not already set
if comparator.inspector is None:
comparator.inspector = self
else:
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
comparator = MSEComparator(inspector=self)
elif metric == "L1":
comparator = L1Comparator()
comparator = L1Comparator(inspector=self)
elif metric == "SNR":
comparator = SNRComparator()
comparator = SNRComparator(inspector=self)
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

rows = []
for (aot_debug_handle, aot_intermediate_output), (
runtime_debug_handle,
runtime_intermediate_output,
) in mapping.items():
if aot_intermediate_output is None or runtime_intermediate_output is None:
continue
# If aot outputs length is > 1 then comparison fails since we dont really have
# any instances where runtime intermediate output is a tuple or list
# This does not happen when edge dialect program is reference for comparison
# but happens in aten graph where ops like unbind remain undecomposed
if (
isinstance(aot_intermediate_output, Sequence)
and len(aot_intermediate_output) > 1
):
continue
rows.append(
{
"aot_ops": find_op_names(
aot_debug_handle, aot_debug_handle_to_op_names
),
"aot_intermediate_output": aot_intermediate_output,
"runtime_ops": find_op_names(
runtime_debug_handle, runtime_debug_handle_to_op_names
),
"runtime_intermediate_output": runtime_intermediate_output,
"gap": compare_intermediate_outputs(
aot_intermediate_output, runtime_intermediate_output, comparator
),
}
)
return pd.DataFrame(rows)
# Delegate to comparator's compare method (includes preprocessing)
return comparator.compare(
mapping,
aot_debug_handle_to_op_names,
runtime_debug_handle_to_op_names,
)
34 changes: 0 additions & 34 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,40 +1068,6 @@ def find_op_names(
return result


def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
"""
Compare two outputs, handling both sequence and non-sequence cases,
and return a list of comparison results.
Parameters:
a: The first intermediate output to compare.
b: The second intermediate output to compare.
comparator: A comparator object with a `compare` method.
Returns:
List[float]: A list of comparison results.
Raises:
ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
"""
is_a_sequence = isinstance(a, Sequence)
is_b_sequence = isinstance(b, Sequence)
if is_a_sequence and is_b_sequence:
# Ensure both sequences have the same length
if len(a) != len(b):
raise ValueError(
f"Sequences 'a' ({a}) and 'b' ({b}) must have the same length for comparison. len(a): {len(a)} len(b): {len(b)}."
)

# Compare each element in the sequences and return the list of results
return [comparator.compare(x, y) for x, y in zip(a, b)]
elif not is_a_sequence and not is_b_sequence:
# Compare non-sequence items and return the result in a list
return [comparator.compare(a, b)]
else:
# Raise an error if one is a sequence and the other is not
raise ValueError(
f"Both inputs 'a' ({a}) and 'b' ({b}) must be sequences or both must be non-sequences."
)


def get_ancestor_node_identifiers(node: Node) -> List[str]:
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.

Expand Down
12 changes: 11 additions & 1 deletion devtools/inspector/numerical_comparator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.


# Re-export DebugHandle from _inspector_utils for convenience
from executorch.devtools.inspector._inspector_utils import DebugHandle
from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import (
L1Comparator,
)
Expand All @@ -14,6 +16,7 @@
)

from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
IntermediateOutputMapping,
NumericalComparatorBase,
)

Expand All @@ -22,4 +25,11 @@
)


__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"]
__all__ = [
"DebugHandle",
"IntermediateOutputMapping",
"L1Comparator",
"MSEComparator",
"NumericalComparatorBase",
"SNRComparator",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any
from typing import Any, Optional, TYPE_CHECKING

import torch
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
NumericalComparatorBase,
)

if TYPE_CHECKING:
from executorch.devtools.inspector._inspector import Inspector


class L1Comparator(NumericalComparatorBase):
def compare(self, a: Any, b: Any) -> float:
"""L1 (sum of absolute differences) comparator for numerical discrepancy detection."""

def __init__(self, inspector: Optional["Inspector"] = None) -> None:
super().__init__(inspector)

def element_compare(self, a: Any, b: Any) -> float:
"""Sum up all these element-wise absolute differences between two tensors."""

t_a = convert_to_float_tensor(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any
from typing import Any, Optional, TYPE_CHECKING

import torch
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
NumericalComparatorBase,
)

if TYPE_CHECKING:
from executorch.devtools.inspector._inspector import Inspector


class MSEComparator(NumericalComparatorBase):
def compare(self, a: Any, b: Any) -> float:
"""Mean Squared Error comparator for numerical discrepancy detection."""

def __init__(self, inspector: Optional["Inspector"] = None) -> None:
super().__init__(inspector)

def element_compare(self, a: Any, b: Any) -> float:
"""Compare mean squared difference between two outputs."""

t_a = convert_to_float_tensor(a)
Expand Down
Loading
Loading