diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index 1b4fc677d18..5c210f13f2e 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -23,6 +23,25 @@ class ArmPass(ExportPass): """Base class for Arm passes.""" + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + if getattr(cls, "targeted_ops", None) is not None: + return + # Only auto-discover targeted_ops for passes that use the standard + # call_operator() pattern. Passes that override call() use _TARGET_OPS + # for their own graph manipulation logic, not as a fast-copy declaration. + if "call" in cls.__dict__: + return + for attr in ("_TARGET_OPS", "_supported_ops"): + ops = getattr(cls, attr, None) + if ops: + cls.targeted_ops = set(ops) # type: ignore[attr-defined] + return + edge = getattr(cls, "_EDGE_OPS", None) + aten = getattr(cls, "_ATEN_OPS", None) + if edge or aten: + cls.targeted_ops = {*(edge or ()), *(aten or ())} # type: ignore[attr-defined] + def __init__(self, tfa_pass: bool = False, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.submodule_depth = 0 @@ -81,6 +100,34 @@ def get_name(pass_) -> str: f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute." ) + def should_run(self, graph_module: GraphModule) -> bool: + """Skip this pass if the graph contains none of its targeted ops. + + Subclasses that define a ``targeted_ops`` class attribute (a set of + op overloads) get this check for free via inheritance. Passes + without ``targeted_ops`` always run (the default). + + Recursively checks control flow submodules (cond/while_loop) so + passes are not incorrectly skipped when targeted ops are nested. + + """ + targeted = getattr(self, "targeted_ops", None) + if targeted is None: + return True + + from executorch.exir.graph_module import get_control_flow_submodules + + def _has_targeted_op(gm: GraphModule) -> bool: + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in targeted: + return True + for _, submod, _ in get_control_flow_submodules(gm): + if _has_targeted_op(submod): + return True + return False + + return _has_targeted_op(graph_module) + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): if ( op == exir_ops.edge.aten.bmm.default diff --git a/backends/arm/_passes/cast_to_int32_pass.py b/backends/arm/_passes/cast_to_int32_pass.py index 609526b9ecc..6b117da3fb1 100644 --- a/backends/arm/_passes/cast_to_int32_pass.py +++ b/backends/arm/_passes/cast_to_int32_pass.py @@ -6,9 +6,7 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes.arm_pass import ArmPass - from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index f81ef33e2d1..6ba2aa70ab6 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -9,10 +9,8 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmOpTargetedPass - from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -36,6 +34,8 @@ class Conv1dUnsqueezePass(ArmOpTargetedPass): } target_ops = (exir_ops.edge.aten.convolution.default,) + targeted_ops = {exir_ops.edge.aten.convolution.default} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index 430dc70bd0c..673239be3e6 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -8,7 +8,6 @@ from typing import cast, Set, Type import torch - from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import ( UnsqueezeBeforeRepeatPass, @@ -58,6 +57,8 @@ class ConvertExpandCopyToRepeatPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass} + targeted_ops = {exir_ops.edge.aten.expand_copy.default} + expand_copy = exir_ops.edge.aten.expand_copy.default repeat = exir_ops.edge.aten.repeat.default target_ops = (expand_copy,) diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index f7a94424228..299d1aa3fc6 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -9,7 +9,6 @@ from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, ) - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -37,6 +36,8 @@ class ConvertFullLikeToFullPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} target_ops = (exir_ops.edge.aten.full_like.default,) + targeted_ops = {exir_ops.edge.aten.full_like.default} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py index 0ed5f92f91d..1fdbf56763e 100644 --- a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -7,10 +7,8 @@ from typing import Sequence, Set, Tuple, Type from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass - from torch._ops import OpOverload @@ -36,6 +34,8 @@ class ConvertPermuteSingletonToViewPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = _PERMUTE_TARGETS + targeted_ops = set(_PERMUTE_TARGETS) + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 425c1dafdac..f6e5672459a 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -21,6 +21,11 @@ class ConvertSplitToSlicePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = { + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split_copy.Tensor, + } + split_ops = ( exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index b79e38cdf10..c75fcd401da 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -28,6 +28,11 @@ class ConvertSqueezesToViewPass(ArmOpTargetedPass): exir_ops.edge.aten.unsqueeze_copy.default, ) + targeted_ops = { + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + } + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/convert_to_clamp_pass.py b/backends/arm/_passes/convert_to_clamp_pass.py index 6273759aa55..31ede8a84e0 100644 --- a/backends/arm/_passes/convert_to_clamp_pass.py +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -6,11 +6,9 @@ from typing import Set, Tuple, Type from executorch.backends.arm._passes import ArmOpTargetedPass - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( QuantizeClampArgumentsPass, ) - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -34,6 +32,8 @@ class ConvertToClampPass(ArmOpTargetedPass): target_ops = edge_operators check_allowed_to_transform = True + targeted_ops = edge_operators + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 3c2cac45e75..018097bfe42 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -38,6 +38,8 @@ class DecomposeAcoshPass(ArmOpTargetedPass): } target_ops = (edge_acosh_op,) + targeted_ops = {edge_acosh_op} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in self.target_ops: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index 07fd5c9e358..c4b6492674b 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -7,7 +7,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( DecomposeAvgPool2dPass, @@ -55,6 +54,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmOpTargetedPass): target_ops = edge_ops + aten_ops check_allowed_to_transform = True + targeted_ops = {*edge_ops, *aten_ops} + @staticmethod def _is_static_dim(dim) -> bool: return not isinstance(dim, torch.SymInt) diff --git a/backends/arm/_passes/decompose_add_sub_alpha_pass.py b/backends/arm/_passes/decompose_add_sub_alpha_pass.py index 30903fbd3d8..c94251df82c 100644 --- a/backends/arm/_passes/decompose_add_sub_alpha_pass.py +++ b/backends/arm/_passes/decompose_add_sub_alpha_pass.py @@ -61,6 +61,8 @@ class DecomposeAddSubAlphaPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = _ADD_OPS + _SUB_OPS + targeted_ops = {*_ADD_OPS, *_SUB_OPS} + def call_operator(self, op, args, kwargs, meta, updated: bool | None = False): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py index d198e1a3b64..8ecad73331d 100644 --- a/backends/arm/_passes/decompose_addmm_pass.py +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -51,6 +50,8 @@ class DecomposeAddmmPass(ArmOpTargetedPass): } target_ops = (edge_addmm, aten_addmm) + targeted_ops = {edge_addmm, aten_addmm} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_as_strided_copy_pass.py b/backends/arm/_passes/decompose_as_strided_copy_pass.py index c8c2a200bd8..fec9a234b93 100644 --- a/backends/arm/_passes/decompose_as_strided_copy_pass.py +++ b/backends/arm/_passes/decompose_as_strided_copy_pass.py @@ -6,7 +6,6 @@ from typing import Dict, Optional, Set, Tuple, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm.common.as_strided_utils import ( contiguous_strides, diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 5e0cfd66c32..a4184083971 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -9,7 +9,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, @@ -73,6 +72,8 @@ class DecomposeAsinAndAcosPass(ArmOpTargetedPass): } target_ops = edge_asin_op + edge_acos_op + targeted_ops = {*edge_asin_op, *edge_acos_op} + def _build_polynomial( self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str] ) -> torch.Tensor: diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 5f31c5efedc..b1916b37da6 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -38,6 +38,8 @@ class DecomposeAsinhPass(ArmOpTargetedPass): } target_ops = edge_asinh_op + targeted_ops = {*edge_asinh_op} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index cd33504c972..345c4ad8754 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -51,6 +51,8 @@ class DecomposeAtanPass(ArmOpTargetedPass): } target_ops = (edge_atan,) + targeted_ops = {edge_atan} + def _rational_approximation(self, z, ops, meta): """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index c542b94f30d..6e7b7122d4b 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -48,6 +48,8 @@ class DecomposeAtanhPass(ArmOpTargetedPass): } target_ops = (edge_atanh,) + targeted_ops = {edge_atanh} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index 51f2afe8351..1ffa6ecc02b 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -101,6 +101,8 @@ class DecomposeAvgPool2dPass(ArmOpTargetedPass): target_ops = edge_avg_pool2d + aten_avg_pool2d check_allowed_to_transform = True + targeted_ops = {*edge_avg_pool2d, *aten_avg_pool2d} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index 96c73b6cdf2..370560b998c 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -36,6 +36,8 @@ class DecomposeCoshPass(ArmOpTargetedPass): } target_ops = (edge_cosh,) + targeted_ops = {edge_cosh} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_cosine_similarity_pass.py b/backends/arm/_passes/decompose_cosine_similarity_pass.py index b9e11a68174..810e5a103ca 100644 --- a/backends/arm/_passes/decompose_cosine_similarity_pass.py +++ b/backends/arm/_passes/decompose_cosine_similarity_pass.py @@ -10,7 +10,6 @@ from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, ) - from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -45,6 +44,8 @@ class DecomposeCosineSimilarityPass(ArmOpTargetedPass): target_ops = torch_cosine_similarity check_allowed_to_transform = True + targeted_ops = {*torch_cosine_similarity} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_div_pass.py b/backends/arm/_passes/decompose_div_pass.py index be4d91cd30c..da45bd4ce72 100644 --- a/backends/arm/_passes/decompose_div_pass.py +++ b/backends/arm/_passes/decompose_div_pass.py @@ -42,6 +42,8 @@ class DecomposeDivPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} target_ops = edge_div_ops + aten_div_ops + targeted_ops = {*edge_div_ops, *aten_div_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index cc5440b4e5b..56ee697e446 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -60,6 +60,8 @@ class DecomposeDivTensorModePass(ArmOpTargetedPass): target_ops = edge_div_mode_ops + aten_div_mode_ops check_allowed_to_transform = True + targeted_ops = {*edge_div_mode_ops, *aten_div_mode_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_elu_pass.py b/backends/arm/_passes/decompose_elu_pass.py index 5f94968ad79..d0863026c8e 100644 --- a/backends/arm/_passes/decompose_elu_pass.py +++ b/backends/arm/_passes/decompose_elu_pass.py @@ -123,6 +123,8 @@ class DecomposeEluPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = edge_elu_family_ops + targeted_ops = {*edge_elu_family_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 6898b9fafb2..1fb3377ea02 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -89,6 +89,8 @@ class DecomposeExpm1Pass(ArmOpTargetedPass): } target_ops = edge_expm1_ops + targeted_ops = {*edge_expm1_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_floor_divide_pass.py b/backends/arm/_passes/decompose_floor_divide_pass.py index d8f451f8af6..2106ea472b2 100644 --- a/backends/arm/_passes/decompose_floor_divide_pass.py +++ b/backends/arm/_passes/decompose_floor_divide_pass.py @@ -55,6 +55,8 @@ class DecomposeFloorDividePass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} target_ops = edge_floor_divide_ops + aten_floor_divide_ops + targeted_ops = {*edge_floor_divide_ops, *aten_floor_divide_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 85f0b77df21..af24dfe0f4a 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -90,6 +90,8 @@ class DecomposeGeluPass(ArmOpTargetedPass): } target_ops = torch_gelu + edge_gelu + targeted_ops = {*torch_gelu, *edge_gelu} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py index 5927174a776..2d8a6e20f90 100644 --- a/backends/arm/_passes/decompose_glu_pass.py +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -45,6 +45,8 @@ class DecomposeGluPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} target_ops = (edge_glu, aten_glu) + targeted_ops = {edge_glu, aten_glu} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_grouped_conv_pass.py b/backends/arm/_passes/decompose_grouped_conv_pass.py index 7a8b744d9e3..6158428ada5 100644 --- a/backends/arm/_passes/decompose_grouped_conv_pass.py +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -212,7 +212,6 @@ def _get_meta_copy( # Get quantization params of the weights and slice them. w_qarg = new_qparams[1] if DecomposeGroupedConvPass._is_per_channel_qparams(w_qarg): - # For transpose conv, axis=1 corresponds to output channels and # does not align with grouped slicing. # Per-channel quantization on axis=0 on the other hand could align here but @@ -295,7 +294,6 @@ def call_operator(self, op, args, kwargs, meta): for i, (input_slice, filter_slice, bias_slice) in enumerate( zip(input_slices, weight_slices, bias_slices) ): - meta_copy = DecomposeGroupedConvPass._get_meta_copy( meta, i, diff --git a/backends/arm/_passes/decompose_index_select_to_gather_pass.py b/backends/arm/_passes/decompose_index_select_to_gather_pass.py index be0d4dbb07c..e13267fc350 100644 --- a/backends/arm/_passes/decompose_index_select_to_gather_pass.py +++ b/backends/arm/_passes/decompose_index_select_to_gather_pass.py @@ -7,7 +7,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, diff --git a/backends/arm/_passes/decompose_int16_activation_conv_pass.py b/backends/arm/_passes/decompose_int16_activation_conv_pass.py new file mode 100644 index 00000000000..0aee7487945 --- /dev/null +++ b/backends/arm/_passes/decompose_int16_activation_conv_pass.py @@ -0,0 +1,147 @@ +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import cast, Sequence, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.quant_args import QuantArgs +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeConvWithInt16ActivationPass(ArmPass): + """This pass decomposes a convolution with input dtype int16 and bias into a + convolution without bias followed by an addition of the bias. + + We also reshape the 1D bias to [1, C, 1, …] so it broadcasts along the + channel dimension. Since the TOSA op requires the bias to be int48 which is + hard to represent in torch. Instead rescale the int48 output to int16 and + add the bias in int16. + + """ + + def __init__(self) -> None: + super().__init__() + + _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {exir_ops.edge.aten.convolution.default} + + def bias_view_shape( + self, bias: torch.Tensor, activation_rank: int + ) -> Sequence[int]: + # reshape bias to match convolution output rank so addition broadcasts over channels + return [1, bias.shape[0], *([1] * (activation_rank - 2))] + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.convolution.default: + return super().call_operator(op, args, kwargs, meta) + + tosa_spec = get_context_spec() + if not tosa_spec.support_integer(): + return super().call_operator(op, args, kwargs, meta) + + # return if no bias + if args[2] is None: + return super().call_operator(op, args, kwargs, meta) + + activation_tensor = args[0].data + activation_rank = activation_tensor.dim() + + # Check input qparams dtype instead of raw tensor dtype, since the tensor + # may have been rescaled to int32 by earlier passes while the quantization + # parameters still indicate the original int16 dtype. + input_qparams = meta.data.get("input_qparams", {}) + if 0 not in input_qparams: + return super().call_operator(op, args, kwargs, meta) + activation_dtype = input_qparams[0].dtype + + if activation_rank not in (4, 5) or activation_dtype != torch.int16: + return super().call_operator(op, args, kwargs, meta) + + if not tosa_spec.support_extension("int16"): + raise ValueError( + "int16 activation for convolution requires TOSA int16 extension" + ) + + # convolution with bias and activation is int16 (expected activation rank enforced above) + # The bias is assumed to be quantized with the same quantization parameters as + # the output of the convolution + bias_arg = args[2] + bias_data = bias_arg.data + + no_bias_args = list(args) + no_bias_args[2] = None + # split up to convolution + bias + convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta) + + # create a copy of the meta without the qparams, to be used with the new nodes + new_meta = meta.copy() + new_meta.data.pop("output_qparams", None) + new_meta.data.pop("input_qparams", None) + + # reshape the tensor to the same rank as the convolution output to add the bias to the channels + channel_bias = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (bias_arg, self.bias_view_shape(bias_data, activation_rank)), + {}, + new_meta, + ) + + output_dtype = meta.data["output_qparams"][0].dtype + + if output_dtype == torch.int16: + # The conv will get the output int48 scaled to int32 in serialization step. + # To be able to add the bias we need to first scale (cast?) the output to int32. + # The resulting i32 sum will then need to be scaled back to the output dtype. + output_qparams = cast(QuantArgs, meta.data["output_qparams"][0]) + conv_output_scale = output_qparams.scale + + bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) + per_channel_quant = bias_qparams.per_channel + + if per_channel_quant: + bias_scale = bias_qparams.get_scale_per_channel() + else: + bias_scale = [bias_qparams.get_scale_per_tensor()] + + conv_rescale_factors = [1.0] * len(bias_scale) + final_output_scale = [b / conv_output_scale for b in bias_scale] + + conv_output = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (convolution, torch.int32, conv_rescale_factors, 0, 0), + {}, + new_meta, + ) + + add = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (conv_output, channel_bias), + {}, + new_meta, + ) + + res_rescale = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + ( + add, + output_dtype, + final_output_scale, + 0, + 0, + ), + {}, + new_meta, + ) + + else: + raise NotImplementedError( + f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}" + ) + + return res_rescale diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index 5147d23b68c..0df15a3c567 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -22,6 +22,8 @@ class DecomposeIntPowPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = (exir_ops.edge.aten.pow.Tensor_Scalar,) + targeted_ops = {exir_ops.edge.aten.pow.Tensor_Scalar} + @staticmethod def _get_decomposable_integer_exponent(exp) -> Optional[int]: if isinstance(exp, int): diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index e2f9852d7f9..84eedd1f101 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -50,6 +50,8 @@ class DecomposeLeakyReLUPass(ArmOpTargetedPass): target_ops = edge_ops + torch_ops check_allowed_to_transform = True + targeted_ops = {*edge_ops, *torch_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 1604d861030..4d13a083a18 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -43,6 +43,8 @@ class DecomposeLinalgVectorNormPass(ArmOpTargetedPass): target_ops = torch_linalg_vector_norm check_allowed_to_transform = True + targeted_ops = torch_linalg_vector_norm + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index b11c6ac6ab3..fdf1c5e8a5e 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -32,6 +32,8 @@ class DecomposeLinearPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass} + targeted_ops = {exir_ops.edge.aten.linear.default} + def call(self, graph_module): modified = False for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 9f9f4744fd0..fd1402e341b 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass diff --git a/backends/arm/_passes/decompose_masked_fill_pass.py b/backends/arm/_passes/decompose_masked_fill_pass.py index dfb85da7742..f44141c7fbd 100644 --- a/backends/arm/_passes/decompose_masked_fill_pass.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -7,7 +7,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( ConvertFullLikeToFullPass, @@ -45,6 +44,8 @@ class DecomposeMaskedFillPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} target_ops = aten_ops + edge_ops + targeted_ops = {*edge_ops, *aten_ops} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py index 7729b755113..82b31e90768 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -8,7 +8,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.exir.dialects._ops import ops as exir_ops @@ -57,6 +56,8 @@ class DecomposeMaxPool2dPass(ArmOpTargetedPass): } target_ops = EDGE_MAXPOOL2D + targeted_ops = set(EDGE_MAXPOOL2D) + def call_operator(self, op, args, kwargs, meta): # Only intercept EXIR edge max_pool2d ops if op not in self.target_ops: diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index e1175d5ba1b..792b63af7b7 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -102,6 +102,13 @@ class DecomposeMeanDimPass(ArmOpTargetedPass): ) check_allowed_to_transform = True + targeted_ops = { + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + exir_ops.edge.aten.mean.default, + torch.ops.aten.mean.default, + } + def __init__(self, graph_module, tosa_spec, *args, **kwargs): super().__init__(*args, **kwargs) self._graph_module = graph_module diff --git a/backends/arm/_passes/decompose_ne_pass.py b/backends/arm/_passes/decompose_ne_pass.py index 4dfcf6ad934..c6528cf6004 100644 --- a/backends/arm/_passes/decompose_ne_pass.py +++ b/backends/arm/_passes/decompose_ne_pass.py @@ -59,6 +59,8 @@ class DecomposeNotEqualPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = edge_ne_ops + aten_ne_ops + targeted_ops = {*edge_ne_ops, *aten_ne_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py index af22cad1624..154367bfb78 100644 --- a/backends/arm/_passes/decompose_remainder_pass.py +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -51,6 +51,8 @@ class DecomposeRemainderPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} target_ops = tuple(_decomposition_ops) + targeted_ops = set(_decomposition_ops.keys()) + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_select_scatter_pass.py b/backends/arm/_passes/decompose_select_scatter_pass.py index 129e9f05961..0b8bc76c20e 100644 --- a/backends/arm/_passes/decompose_select_scatter_pass.py +++ b/backends/arm/_passes/decompose_select_scatter_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( ConvertInt64ConstOpsToInt32Pass, @@ -67,6 +66,8 @@ class DecomposeSelectScatterPass(ArmOpTargetedPass): } target_ops = edge_scatter_ops + aten_scatter_ops + targeted_ops = {*edge_scatter_ops, *aten_scatter_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated=False) diff --git a/backends/arm/_passes/decompose_sign_pass.py b/backends/arm/_passes/decompose_sign_pass.py index 8f7fda8729b..299c570659b 100644 --- a/backends/arm/_passes/decompose_sign_pass.py +++ b/backends/arm/_passes/decompose_sign_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -52,6 +51,8 @@ class DecomposeSignPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() target_ops = (edge_sign, aten_sign) + targeted_ops = {edge_sign, aten_sign} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index 053b378af83..06462818c1b 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -41,6 +41,8 @@ class DecomposeSinhPass(ArmOpTargetedPass): } target_ops = (edge_sinh,) + targeted_ops = {edge_sinh} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_slice_scatter_pass.py b/backends/arm/_passes/decompose_slice_scatter_pass.py index edf030f9701..5852561df9b 100644 --- a/backends/arm/_passes/decompose_slice_scatter_pass.py +++ b/backends/arm/_passes/decompose_slice_scatter_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.accumulate_index_put_pass import ( AccumulateIndexPutPass, @@ -73,6 +72,8 @@ class DecomposeSliceScatterPass(ArmOpTargetedPass): } target_ops = edge_slice_scatter_ops + aten_slice_scatter_ops + targeted_ops = {*edge_slice_scatter_ops, *aten_slice_scatter_ops} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index d30137c0460..29ed7346cc4 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -79,6 +79,8 @@ class DecomposeSoftmaxPass(ArmOpTargetedPass): } target_ops = torch_softmax + edge_softmax + targeted_ops = {*torch_softmax, *edge_softmax} + def __init__(self, skip_safe_softmax: bool = False, **kwargs): super().__init__(**kwargs) self._skip_safe_softmax = skip_safe_softmax diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index ce5a5b6d2a4..7040ca5b488 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -31,6 +31,8 @@ class DecomposeSqrtPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass} target_ops = edge_sqrt_ops + aten_sqrt_ops + targeted_ops = {*edge_sqrt_ops, *aten_sqrt_ops} + def call_operator(self, op, args, kwargs, meta): """Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.""" diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index e134ea6abc7..a3b3e3f4f37 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -49,6 +49,11 @@ class DecomposeSumPass(ArmOpTargetedPass): torch.ops.aten.sum.dim_IntList, ) + targeted_ops = { + exir_ops.edge.aten.sum.dim_IntList, + torch.ops.aten.sum.dim_IntList, + } + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_tan_pass.py b/backends/arm/_passes/decompose_tan_pass.py index 2d655a9937d..8f212fa1940 100644 --- a/backends/arm/_passes/decompose_tan_pass.py +++ b/backends/arm/_passes/decompose_tan_pass.py @@ -19,6 +19,8 @@ class DecomposeTanPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} target_ops = (edge_tan_op,) + targeted_ops = {edge_tan_op} + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta, updated) diff --git a/backends/arm/_passes/decompose_tril_pass.py b/backends/arm/_passes/decompose_tril_pass.py index 9108208e73d..8bce88d32e8 100644 --- a/backends/arm/_passes/decompose_tril_pass.py +++ b/backends/arm/_passes/decompose_tril_pass.py @@ -56,6 +56,8 @@ class DecomposeTrilPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} target_ops = (torch.ops.aten.tril.default,) + targeted_ops = {torch.ops.aten.tril.default} + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/decompose_unfold_to_gather_pass.py b/backends/arm/_passes/decompose_unfold_to_gather_pass.py index 950290b3b83..26f30f9c264 100644 --- a/backends/arm/_passes/decompose_unfold_to_gather_pass.py +++ b/backends/arm/_passes/decompose_unfold_to_gather_pass.py @@ -8,7 +8,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( ReplaceScalarWithTensorByProfilePass, diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index 90ea80b6b47..264811b9ea7 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -63,6 +63,12 @@ class DecomposeVarPass(ArmOpTargetedPass): ) check_allowed_to_transform = True + targeted_ops = { + exir_ops.edge.aten.var.correction, + torch.ops.aten.var.correction, + torch.ops.aten.var.dim, + } + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 45374c12c3b..4575351e7cc 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -15,7 +15,6 @@ from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) - from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -35,6 +34,8 @@ class InsertRescalePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + targeted_ops = {*DQ_OPS} + def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None: """Ensure uint8 tensors only appear at IO boundaries. @@ -70,7 +71,6 @@ def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None: f"Found internal uint8 tensor at node {node.name} " f"({node.target}). Uint8 is only allowed at IO boundaries." ) - def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) @@ -141,7 +141,7 @@ class InsertRescaleInt32Pass(ArmPass): # decomposition. _passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass} - included_targets = [ + targeted_ops = { exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.eq.Tensor, @@ -154,7 +154,9 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, - ] + } + + included_targets = list(targeted_ops) def _int32_qargs(self, s): """Helper creator function for INT32-based QuantArgs.""" @@ -609,8 +611,12 @@ def _get_output_qparams_map(self, node: Node): def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> bool: modified = False - if_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore - else_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[2].target)) # type: ignore + if_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[1].target) # type: ignore[union-attr, arg-type] + ) + else_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[2].target) # type: ignore[union-attr, arg-type] + ) input_qparams_map = self._get_input_qparams_map(node, 3) if input_qparams_map: modified |= self._rescale_submodule_inputs(if_graph, input_qparams_map) @@ -624,8 +630,12 @@ def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> boo def _rescale_while_submodules(self, node: Node, graph_module: GraphModule): modified = False - cond_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[0].target)) # type: ignore - body_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + cond_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[0].target) # type: ignore[union-attr, arg-type] + ) + body_graph: GraphModule = cast( + GraphModule, graph_module.get_submodule(node.args[1].target) # type: ignore[union-attr, arg-type] + ) input_qparams_map = self._get_input_qparams_map(node, 2) if input_qparams_map: diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 10b85149dad..582a75e3418 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -12,12 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.transforms.utils import create_constant_placeholder - from executorch.exir import ExportedProgram - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload - from executorch.exir.pass_base import ExportPass, PassResult from torch.export.graph_signature import InputKind from torch.fx import GraphModule diff --git a/backends/arm/_passes/remove_getitem_pass.py b/backends/arm/_passes/remove_getitem_pass.py index 3ce157d3fd8..122a8330203 100644 --- a/backends/arm/_passes/remove_getitem_pass.py +++ b/backends/arm/_passes/remove_getitem_pass.py @@ -7,8 +7,14 @@ from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.transforms import remove_getitem_op +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass class RemoveGetItemPass(ArmPass, remove_getitem_op.RemoveGetItemPass): _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops = { + exir_ops.edge.aten.max_pool2d_with_indices.default, + exir_ops.edge.aten.max.dim, + } diff --git a/backends/arm/_passes/remove_noop_pass.py b/backends/arm/_passes/remove_noop_pass.py index 5fafc848003..505d34bc7f7 100644 --- a/backends/arm/_passes/remove_noop_pass.py +++ b/backends/arm/_passes/remove_noop_pass.py @@ -9,7 +9,6 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmOpTargetedPass - from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -28,6 +27,14 @@ class RemoveNoopPass(ArmOpTargetedPass): exir_ops.edge.aten.detach_copy.default, ) + targeted_ops = { + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.detach_copy.default, + } + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py index 962bdbbaf6e..372c6f73ab0 100644 --- a/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py +++ b/backends/arm/_passes/rewrite_bool_bitwise_to_logical_pass.py @@ -34,6 +34,8 @@ class RewriteBoolBitwiseToLogicalPass(ArmOpTargetedPass): } target_ops = tuple(_TARGET_TO_LOGICAL) + targeted_ops = set(_TARGET_TO_LOGICAL.keys()) + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py index c73279e65d0..890b72e03d4 100644 --- a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py +++ b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py @@ -6,7 +6,6 @@ from typing import Set, Type import torch - from executorch.backends.arm._passes import ArmOpTargetedPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -26,6 +25,8 @@ class RewriteLeLtToGeGtPass(ArmOpTargetedPass): target_ops = tuple(OP_MAP) check_allowed_to_transform = True + targeted_ops = {*OP_MAP} + def call_operator(self, op, args, kwargs, meta): if not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 1c331b9c329..36d8a2166bd 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -216,6 +216,8 @@ class SizeAdjustInputPass(ArmPass): RewriteMaxPool2dPass, } + targeted_ops = set(valid_operators) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph modified = False diff --git a/examples/models/BUCK b/examples/models/BUCK index a2b6789a95e..c6c8ac21f14 100644 --- a/examples/models/BUCK +++ b/examples/models/BUCK @@ -52,3 +52,12 @@ fbcode_target(_kind = python_library, "//caffe2:torch", ], ) + +fbcode_target(_kind = python_library, + name = "mlperf_tiny", + srcs = glob(["mlperf_tiny/*.py"]), + deps = [ + "//caffe2:torch", + "//executorch/examples/models:model_base", + ], +) diff --git a/exir/pass_base.py b/exir/pass_base.py index c657ac53a91..ab9263844e4 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -9,7 +9,7 @@ import operator import traceback from abc import ABC, abstractmethod -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -280,6 +280,117 @@ def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027 """ +# Namespaces of ops that are safe to cache in the FakeTensor dispatch cache. +# By default, FakeTensorMode only caches ops in {"aten", "prim", "prims"}. +# ExecuTorch passes commonly use quantization and TOSA ops that are +# deterministic and shape-preserving, so we extend caching to cover them +# during pass execution to avoid redundant FakeTensor dispatches. +_EXTRA_CACHEABLE_NAMESPACES: frozenset[str] = frozenset( + { + "quantized_decomposed", + "tosa", + "dim_order_ops", + "cortex_m", + } +) + + +@contextmanager +# pyre-ignore[3] +def _extend_faketensor_cache_builtins(): # noqa: C901 + """Temporarily extend FakeTensor dispatch cache to cover ExecuTorch ops. + + The FakeTensor dispatch cache (``FakeTensorMode``) only caches "builtin" + ops whose namespace is in ``{"aten", "prim", "prims"}``. ExecuTorch + passes operate on graphs that contain ``quantized_decomposed``, ``tosa``, + and other non-builtin ops that are nonetheless safe to cache -- they are + deterministic and their output metadata depends only on input metadata. + + Without caching these ops, every pass re-dispatches them through the full + PyTorch stack (~0.5 ms each), leading to tens of seconds of overhead + across 50+ passes on a ~1200-node graph. + + This context manager monkey-patches ``torch._library.utils.is_builtin`` + so that the cache also covers the extra namespaces, then restores the + original function on exit. + """ + import torch._library.utils as _library_utils + + _original_is_builtin = _library_utils.is_builtin + + def _extended_is_builtin(op: torch._ops.OpOverload) -> bool: + if not isinstance(op, torch._ops.OpOverload): + raise AssertionError(f"op must be OpOverload, got {type(op)}") + return op.namespace in {"aten", "prim", "prims"} or ( + op.namespace in _EXTRA_CACHEABLE_NAMESPACES + ) + + try: + _library_utils.is_builtin = _extended_is_builtin # pyre-ignore[8] + + # Evict negative cache entries ("non-builtin" bypass entries) that were + # stored before the extension was active. FakeTensorMode stores + # _DispatchCacheBypassEntry objects as negative cache hits — once stored, + # _validate_cache_key is never re-evaluated for that key. We must evict + # these so the first dispatch under the extension re-evaluates is_builtin + # and creates a proper positive cache entry instead. + # + # There are TWO caches that can hold negative entries: + # 1. FakeTensorMode.cache -- the global (class-level) cache, used when + # the dispatch has no SymInt inputs. + # 2. shape_env.fake_tensor_cache -- per-ShapeEnv cache, used when the + # dispatch involves SymInt/SymFloat inputs (cache_on_shape_env=True). + # We must evict from both. + try: + from torch._subclasses.fake_tensor import ( + _DispatchCacheBypassEntry, + FakeTensorMode, + ) + + def _is_nonbuiltin_bypass(v: object) -> bool: + return ( + isinstance(v, _DispatchCacheBypassEntry) + and v.reason == "non-builtin" + ) + + # 1. Evict from the global class-level cache. + FakeTensorMode.cache = { + k: v + for k, v in FakeTensorMode.cache.items() + if not _is_nonbuiltin_bypass(v) + } + + # 2. Evict from the per-ShapeEnv cache of the currently active + # FakeTensorMode (if any). When ExportPass enters _fx(), the + # FakeTensorMode is already on the dispatch stack before this CM + # is entered, so we can reach its shape_env cache. + try: + from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + ) + + for mode in _get_current_dispatch_mode_stack(): + if isinstance(mode, FakeTensorMode): + se = getattr(mode, "shape_env", None) + if se is not None: + se_cache = getattr(se, "fake_tensor_cache", None) + if se_cache: + se.fake_tensor_cache = { + k: v + for k, v in se_cache.items() + if not _is_nonbuiltin_bypass(v) + } + except (ImportError, AttributeError): + pass + except (ImportError, AttributeError): + pass # Graceful degradation if internals change + + yield + finally: + _library_utils.is_builtin = _original_is_builtin # pyre-ignore[8] + + + class _ExportPassBase(PassBase): """ Interpreter-based pass class to help users maintain the IR spec while writing @@ -413,12 +524,45 @@ def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + # Types whose nodes are eligible for the fast-copy optimisation in + # ``run_node``. Subclass interpreters (e.g. ``ExportPass``) extend + # this tuple to include dialect-specific overload types such as + # ``EdgeOpOverload``. + _OPERATOR_TARGET_TYPES: Tuple[type, ...] = ( + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + ) + class ExportInterpreter(fx.Interpreter): def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: super().__init__(gm) self.callback = callback self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + # --- fast-copy bookkeeping --------------------------------- + # When the owning pass declares ``targeted_ops``, cold nodes + # (those whose target is *not* in the set) can be copied into + # the new graph without an expensive FakeTensor dispatch. + targeted: Optional[Set[Any]] = getattr(callback, "targeted_ops", None) + self._targeted_ops: Optional[Set[Any]] = targeted if targeted else None + + # Fast-copy relies on the existing ``n.meta["val"]`` being + # correct for cold nodes. If the pass overrides ``call()`` + # it may modify the graph (e.g. insert nodes with metadata + # copied from unrelated ops) before calling ``super().call()``, + # which would make cold-node metadata unreliable. Disable the + # optimisation in that case. + call_overridden = type(callback).call is not _ExportPassBase.call + self._fast_copy_enabled: bool = ( + self._targeted_ops is not None and not call_overridden + ) + + # Maps old-graph nodes to their new-graph equivalents so that + # ``_fast_copy_node`` can remap arguments (including get_attr + # nodes that are stored in ``self.env`` as raw tensors rather + # than ProxyValues). + self._node_remap: Dict[torch.fx.Node, torch.fx.Node] = {} + def placeholder( # pyre-fixme[14] self, target: str, @@ -512,10 +656,113 @@ def call_method( # pyre-fixme[14] ) -> None: raise ExportPassBaseError("call_method is not supported.") + # -- fast-copy helpers ------------------------------------------ + + def _fast_copy_node(self, n: torch.fx.Node) -> "ProxyValue": + """Copy *n* into the new graph without FakeTensor dispatch. + + This is the fast path for "cold" nodes — nodes whose target is + not in the pass's ``targeted_ops``. Instead of running the + full ``_fx`` pipeline (unwrap → dispatch → create_proxy → + set_metadata), we use ``graph.node_copy`` to clone the node + directly and reuse the original ``val`` metadata. + + Typical savings: ~0.4 ms → ~0.02 ms per node. + """ + + tracer = self.callback.tracer + + def _arg_transform(old_node: torch.fx.Node) -> torch.fx.Node: + # 1. Check the remap dict (populated for processed nodes + # whose result is a ProxyValue). + new_node = self._node_remap.get(old_node) + if new_node is not None: + return new_node + # 2. Fallback: extract from ProxyValue in env. + pv = self.env.get(old_node) + if pv is not None and hasattr(pv, "proxy"): + mapped = pv.proxy.node + self._node_remap[old_node] = mapped + return mapped + # 3. For get_attr / placeholder nodes that were processed + # via the normal path but returned raw tensors (not + # ProxyValue), they won't be in _node_remap. Copy + # them into the new graph on demand. + if old_node.op in ("get_attr", "placeholder"): + copied = tracer.graph.node_copy( + old_node, lambda x: self._node_remap.get(x, x) + ) + self._node_remap[old_node] = copied + # For get_attr, also register the attribute on the + # new module so GraphModule.__init__ can find it. + if old_node.op == "get_attr": + val = self.fetch_attr(old_node.target) + target_atoms = old_node.target.split(".") + root = tracer.root + for atom in target_atoms[:-1]: + if not hasattr(root, atom): + setattr(root, atom, torch.nn.Module()) + root = getattr(root, atom) + setattr(root, target_atoms[-1], val) + return copied + return old_node + + new_node = tracer.graph.node_copy(n, _arg_transform) + # node_copy already does copy.copy(node.meta) + + val = n.meta.get("val") + proxy = torch.fx.Proxy(new_node, tracer) + result = ProxyValue(val, proxy) + self._node_remap[n] = new_node + return result + def run_node(self, n: torch.fx.Node) -> Argument: self.node = n self.callback.node_debug_str = n.format_node() - return super().run_node(n) + + # Fast-copy path: skip the full interpreter dispatch for cold + # call_function nodes whose operator is not targeted by this + # pass. This avoids the expensive FakeTensor re-dispatch and + # proxy reconstruction for nodes the pass will not modify. + if ( + self._fast_copy_enabled + and n.op == "call_function" + and isinstance(n.target, self.callback._OPERATOR_TARGET_TYPES) + and n.target not in self._targeted_ops # type: ignore[operator] + and n.meta.get("val") is not None + ): + return self._fast_copy_node(n) + + result = super().run_node(n) + + # Record old→new node mapping for fast-copy arg remapping. + if self._fast_copy_enabled and isinstance(result, ProxyValue): + self._node_remap[n] = result.proxy.node + + # After a hot node runs through full dispatch, verify that + # it did not change output shapes. If it did, downstream + # cold nodes' original ``val`` metadata would be stale, so + # we disable the fast-copy optimisation for the remainder + # of this interpreter walk. + if ( + self._fast_copy_enabled + and n.op == "call_function" + and self._targeted_ops is not None + and n.target in self._targeted_ops + and isinstance(result, ProxyValue) + ): + original_val = n.meta.get("val") + new_val = result.data + if isinstance(original_val, torch.Tensor) and isinstance( + new_val, torch.Tensor + ): + if ( + original_val.shape != new_val.shape + or original_val.dtype != new_val.dtype + ): + self._fast_copy_enabled = False + + return result def __init__(self) -> None: self.interpreter = torch.fx.Interpreter( @@ -768,13 +1015,17 @@ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: def call_submodule( self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] ) -> PassResult: - prev_tracer, self.tracer = self.tracer, self.ExportTracer( - self, graph_module.graph._codegen + prev_tracer, self.tracer = ( + self.tracer, + self.ExportTracer(self, graph_module.graph._codegen), ) self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode interpreter = self.ExportInterpreter(self, graph_module) - prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( - torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + prev_interpreter, self.interpreter = ( + self.interpreter, + torch.fx.Interpreter( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ), ) inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) with fx_traceback.preserve_node_meta(): @@ -793,12 +1044,33 @@ def call_submodule( True, ) + def should_run(self, graph_module: fx.GraphModule) -> bool: + """Override to declare when this pass can be skipped entirely. + + When this method returns False, the expensive FakeTensor graph + re-interpretation is bypassed and the original graph module is returned + unchanged. Subclasses should override this to inspect the graph cheaply + (e.g. checking whether any node targets an op this pass cares about). + + The default implementation returns True so existing passes are + unaffected. + """ + return True + def call(self, graph_module: fx.GraphModule) -> PassResult: if not getattr(self, "_initialized", False): raise ExportPassBaseError( "ExportPass is not initialized with __init__().", ) + if not getattr(self, "_skip_should_run", False) and not self.should_run( + graph_module + ): + return PassResult(graph_module, False) + + prev_skip = getattr(self, "_skip_should_run", False) + self._skip_should_run = True + inputs = self.inputs(graph_module) fake_tensor_mode = None @@ -817,13 +1089,23 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: self.tracer.fake_tensor_mode = fake_tensor_mode self.fake_tensor_mode = fake_tensor_mode - with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] - result = self.call_submodule(graph_module, tuple(inputs)) - - return result + try: + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + with _extend_faketensor_cache_builtins(): + return self.call_submodule(graph_module, tuple(inputs)) + finally: + self._skip_should_run = prev_skip class ExportPass(_ExportPassBase): + # Extend operator target types to include the Edge dialect overloads so + # that the fast-copy optimisation in ``run_node`` also covers Edge ops. + _OPERATOR_TARGET_TYPES: Tuple[type, ...] = ( + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + EdgeOpOverload, + ) + class ExportTracer(_ExportPassBase.ExportTracer): def create_arg(self, a: Argument) -> torch.fx.Node: if isinstance(a, torch.nn.Module): diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 20906fe92e9..c361d3b54d2 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -570,6 +570,79 @@ class NullPass(ExportPass): self.assertEqual(new_meta["custom"]["test_key"], "test_value") self.assertEqual(new_meta["custom"]["nested"]["a"], 1) + def test_export_pass_should_run_skip(self) -> None: + """Test that should_run=False skips FakeTensor re-interpretation.""" + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class AlwaysSkipPass(ExportPass): + def should_run(self, graph_module) -> bool: + return False + + def call_operator(self, op, args, kwargs, meta): + raise AssertionError("call_operator should never be reached") + + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + original_gm = prog.exported_program().graph_module + + result = AlwaysSkipPass()(original_gm) + self.assertIsNotNone(result) + self.assertFalse(result.modified) + self.assertIs(result.graph_module, original_gm) + + def test_export_pass_should_run_op_predicate(self) -> None: + """Test should_run with op-based predicate: skip when irrelevant ops.""" + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class MulOnlyPass(ExportPass): + """A pass that only cares about mul ops.""" + + def should_run(self, graph_module) -> bool: + return any( + node.target == torch.ops.aten.mul.Tensor + for node in graph_module.graph.nodes + if node.op == "call_function" + ) + + def call_operator(self, op, args, kwargs, meta): + raise AssertionError("call_operator should never be reached") + + # Foo only has add ops, so MulOnlyPass should be skipped + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + gm = prog.exported_program().graph_module + + result = MulOnlyPass()(gm) + self.assertIsNotNone(result) + self.assertFalse(result.modified) + self.assertIs(result.graph_module, gm) + + def test_export_pass_should_run_true_still_runs(self) -> None: + """Test that should_run=True (default) still runs the pass normally.""" + + call_count = 0 + + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + x + + class CountingPass(ExportPass): + def call_operator(self, op, args, kwargs, meta): + nonlocal call_count + call_count += 1 + return super().call_operator(op, args, kwargs, meta) + + prog = to_edge(export(Foo(), (torch.ones(3, 2),), strict=True)) + gm = prog.exported_program().graph_module + + result = CountingPass()(gm) + self.assertIsNotNone(result) + self.assertGreater(call_count, 0) + def test_export_scalar_to_tensor_pass(self) -> None: # Build a graph with a scalar argument where schema expects tensor graph = torch.fx.Graph()