Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@
class ArmPass(ExportPass):
"""Base class for Arm passes."""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
if getattr(cls, "targeted_ops", None) is not None:
return
# Only auto-discover targeted_ops for passes that use the standard
# call_operator() pattern. Passes that override call() use _TARGET_OPS
# for their own graph manipulation logic, not as a fast-copy declaration.
if "call" in cls.__dict__:
return
for attr in ("_TARGET_OPS", "_supported_ops"):
ops = getattr(cls, attr, None)
if ops:
cls.targeted_ops = set(ops) # type: ignore[attr-defined]
return
edge = getattr(cls, "_EDGE_OPS", None)
aten = getattr(cls, "_ATEN_OPS", None)
if edge or aten:
cls.targeted_ops = {*(edge or ()), *(aten or ())} # type: ignore[attr-defined]

def __init__(self, tfa_pass: bool = False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.submodule_depth = 0
Expand Down Expand Up @@ -81,6 +100,34 @@ def get_name(pass_) -> str:
f"Cannot get name for pass: {pass_}. It must be an instance of ExportPass or have a __name__ attribute."
)

def should_run(self, graph_module: GraphModule) -> bool:
"""Skip this pass if the graph contains none of its targeted ops.

Subclasses that define a ``targeted_ops`` class attribute (a set of
op overloads) get this check for free via inheritance. Passes
without ``targeted_ops`` always run (the default).

Recursively checks control flow submodules (cond/while_loop) so
passes are not incorrectly skipped when targeted ops are nested.

"""
targeted = getattr(self, "targeted_ops", None)
if targeted is None:
return True

from executorch.exir.graph_module import get_control_flow_submodules

def _has_targeted_op(gm: GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in targeted:
return True
for _, submod, _ in get_control_flow_submodules(gm):
if _has_targeted_op(submod):
return True
return False

return _has_targeted_op(graph_module)

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if (
op == exir_ops.edge.aten.bmm.default
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/_passes/cast_to_int32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from typing import Set, Type

from executorch.backends.arm._passes import ArmOpTargetedPass

from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -36,6 +34,8 @@ class Conv1dUnsqueezePass(ArmOpTargetedPass):
}
target_ops = (exir_ops.edge.aten.convolution.default,)

targeted_ops = {exir_ops.edge.aten.convolution.default}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import cast, Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
from executorch.backends.arm._passes.unsqueeze_before_repeat_pass import (
UnsqueezeBeforeRepeatPass,
Expand Down Expand Up @@ -58,6 +57,8 @@ class ConvertExpandCopyToRepeatPass(ArmOpTargetedPass):

_passes_required_after: Set[Type[ExportPass]] = {UnsqueezeBeforeRepeatPass}

targeted_ops = {exir_ops.edge.aten.expand_copy.default}

expand_copy = exir_ops.edge.aten.expand_copy.default
repeat = exir_ops.edge.aten.repeat.default
target_ops = (expand_copy,)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/convert_full_like_to_full_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
ComputeConstantOpsAOTPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -37,6 +36,8 @@ class ConvertFullLikeToFullPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
target_ops = (exir_ops.edge.aten.full_like.default,)

targeted_ops = {exir_ops.edge.aten.full_like.default}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Sequence, Set, Tuple, Type

from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

from torch._ops import OpOverload


Expand All @@ -36,6 +34,8 @@ class ConvertPermuteSingletonToViewPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = set()
target_ops = _PERMUTE_TARGETS

targeted_ops = set(_PERMUTE_TARGETS)

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class ConvertSplitToSlicePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
}

split_ops = (
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/convert_squeezes_to_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ConvertSqueezesToViewPass(ArmOpTargetedPass):
exir_ops.edge.aten.unsqueeze_copy.default,
)

targeted_ops = {
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/convert_to_clamp_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from typing import Set, Tuple, Type

from executorch.backends.arm._passes import ArmOpTargetedPass

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
QuantizeClampArgumentsPass,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -34,6 +32,8 @@ class ConvertToClampPass(ArmOpTargetedPass):
target_ops = edge_operators
check_allowed_to_transform = True

targeted_ops = edge_operators

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DecomposeAcoshPass(ArmOpTargetedPass):
}
target_ops = (edge_acosh_op,)

targeted_ops = {edge_acosh_op}

def call_operator(self, op, args, kwargs, meta, updated=False):

if op not in self.target_ops:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmOpTargetedPass
from executorch.backends.arm._passes.decompose_avg_pool2d_pass import (
DecomposeAvgPool2dPass,
Expand Down Expand Up @@ -55,6 +54,8 @@ class DecomposeAdaptiveAvgPool2dPass(ArmOpTargetedPass):
target_ops = edge_ops + aten_ops
check_allowed_to_transform = True

targeted_ops = {*edge_ops, *aten_ops}

@staticmethod
def _is_static_dim(dim) -> bool:
return not isinstance(dim, torch.SymInt)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_add_sub_alpha_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class DecomposeAddSubAlphaPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = set()
target_ops = _ADD_OPS + _SUB_OPS

targeted_ops = {*_ADD_OPS, *_SUB_OPS}

def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_addmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmOpTargetedPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
Expand Down Expand Up @@ -51,6 +50,8 @@ class DecomposeAddmmPass(ArmOpTargetedPass):
}
target_ops = (edge_addmm, aten_addmm)

targeted_ops = {edge_addmm, aten_addmm}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
1 change: 0 additions & 1 deletion backends/arm/_passes/decompose_as_strided_copy_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, Optional, Set, Tuple, Type

import torch

from executorch.backends.arm._passes import ArmOpTargetedPass
from executorch.backends.arm.common.as_strided_utils import (
contiguous_strides,
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes import ArmOpTargetedPass
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
Expand Down Expand Up @@ -73,6 +72,8 @@ class DecomposeAsinAndAcosPass(ArmOpTargetedPass):
}
target_ops = edge_asin_op + edge_acos_op

targeted_ops = {*edge_asin_op, *edge_acos_op}

def _build_polynomial(
self, coefficients: list[float], variable: torch.Tensor, meta: dict[str, str]
) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_asinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DecomposeAsinhPass(ArmOpTargetedPass):
}
target_ops = edge_asinh_op

targeted_ops = {*edge_asinh_op}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class DecomposeAtanPass(ArmOpTargetedPass):
}
target_ops = (edge_atan,)

targeted_ops = {edge_atan}

def _rational_approximation(self, z, ops, meta):
"""Creates a (2,1) Padé approximation for atan(x) on [-1, 1]."""

Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_atanh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class DecomposeAtanhPass(ArmOpTargetedPass):
}
target_ops = (edge_atanh,)

targeted_ops = {edge_atanh}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_avg_pool2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class DecomposeAvgPool2dPass(ArmOpTargetedPass):
target_ops = edge_avg_pool2d + aten_avg_pool2d
check_allowed_to_transform = True

targeted_ops = {*edge_avg_pool2d, *aten_avg_pool2d}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_cosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DecomposeCoshPass(ArmOpTargetedPass):
}
target_ops = (edge_cosh,)

targeted_ops = {edge_cosh}

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/decompose_cosine_similarity_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)

from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
Expand Down Expand Up @@ -45,6 +44,8 @@ class DecomposeCosineSimilarityPass(ArmOpTargetedPass):
target_ops = torch_cosine_similarity
check_allowed_to_transform = True

targeted_ops = {*torch_cosine_similarity}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class DecomposeDivPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
target_ops = edge_div_ops + aten_div_ops

targeted_ops = {*edge_div_ops, *aten_div_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_div_tensor_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class DecomposeDivTensorModePass(ArmOpTargetedPass):
target_ops = edge_div_mode_ops + aten_div_mode_ops
check_allowed_to_transform = True

targeted_ops = {*edge_div_mode_ops, *aten_div_mode_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_elu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class DecomposeEluPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = set()
target_ops = edge_elu_family_ops

targeted_ops = {*edge_elu_family_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class DecomposeExpm1Pass(ArmOpTargetedPass):
}
target_ops = edge_expm1_ops

targeted_ops = {*edge_expm1_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_floor_divide_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class DecomposeFloorDividePass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}
target_ops = edge_floor_divide_ops + aten_floor_divide_ops

targeted_ops = {*edge_floor_divide_ops, *aten_floor_divide_ops}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_gelu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class DecomposeGeluPass(ArmOpTargetedPass):
}
target_ops = torch_gelu + edge_gelu

targeted_ops = {*torch_gelu, *edge_gelu}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/decompose_glu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class DecomposeGluPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = {InsertTableOpsPass}
target_ops = (edge_glu, aten_glu)

targeted_ops = {edge_glu, aten_glu}

def call_operator(self, op, args, kwargs, meta):
if op not in self.target_ops or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
Expand Down
Loading
Loading