Skip to content

Commit 7d365ec

Browse files
authored
Arm backend: Migrate pass manager to exported program (#20025)
Summary: - Use ExportedProgramPassManager as the Arm pass manager base - Keep GraphModule-only transforms on the non-deprecated FX manager Test: - PYTHONPATH=src:. /Users/usazah01/src/executorch/env/bin/python -m pytest -q -p no:rerunfailures backends/arm/test/misc/test_call_operator_submodule.py backends/arm/test/passes/test_arm_pass_manager_insertions.py backends/arm/test/misc/test_pass_pipeline_config.py backends/arm/test/misc/test_pass_required_order.py cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Usamah Zaheer <usamah.zaheer@arm.com>
1 parent 14f2017 commit 7d365ec

4 files changed

Lines changed: 114 additions & 48 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import logging
99
from collections import defaultdict
10-
from collections.abc import Sequence
10+
from collections.abc import Callable, Sequence
1111
from dataclasses import dataclass, field
12+
from typing import Any, cast
1213

1314
from executorch.backends.arm._passes import (
1415
AccumulateIndexPutPass,
@@ -167,12 +168,17 @@
167168
)
168169

169170
from executorch.exir import ExportedProgram
170-
from executorch.exir.pass_base import ExportPass
171-
from executorch.exir.pass_manager import PassManager
171+
from executorch.exir._program_utils import _get_updated_graph_signature
172+
from executorch.exir.pass_base import (
173+
ExportedProgramPassBase,
174+
ExportedProgramPassResult,
175+
ExportPass,
176+
)
177+
from executorch.exir.pass_manager import ExportedProgramPassManager
172178
from torch._export.utils import _get_shape_env_from_gm
173179
from torch.fx import GraphModule
174180
from torch.fx.passes.infra.pass_base import PassResult
175-
from torch.nn.modules import Module
181+
from torch.fx.passes.infra.pass_manager import PassManager as GraphModulePassManager
176182

177183
logger = logging.getLogger(__name__)
178184

@@ -188,6 +194,50 @@ class PassInsertions:
188194
_registered_pass_insertions: dict[type, PassInsertions] = {}
189195

190196

197+
def _graph_pass_name(graph_pass: Callable[[GraphModule], PassResult | None]) -> str:
198+
if isinstance(graph_pass, ExportPass):
199+
return ArmPass.get_name(graph_pass)
200+
if hasattr(graph_pass, "__name__"):
201+
return graph_pass.__name__
202+
return type(graph_pass).__name__
203+
204+
205+
class _ExportedProgramGraphPassAdapter(ExportedProgramPassBase):
206+
def __init__(self, graph_pass: Callable[[GraphModule], PassResult | None]) -> None:
207+
self.graph_pass = graph_pass
208+
209+
def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult:
210+
graph_pass = cast(Any, self.graph_pass)
211+
pass_exported_program = getattr(graph_pass, "exported_program", None)
212+
if pass_exported_program is not None:
213+
# ExportedProgramPassManager works on a shallow copy; Arm graph
214+
# passes that store an ExportedProgram must update that copy.
215+
graph_pass.exported_program = exported_program
216+
217+
try:
218+
result = self.graph_pass(exported_program.graph_module)
219+
finally:
220+
if pass_exported_program is not None:
221+
graph_pass.exported_program = pass_exported_program
222+
223+
if result is None:
224+
raise TypeError(
225+
f"The result of pass {_graph_pass_name(self.graph_pass)} should be type PassResult."
226+
)
227+
228+
if result.modified:
229+
result.graph_module.recompile()
230+
exported_program._graph_module = result.graph_module
231+
exported_program._graph_signature = _get_updated_graph_signature(
232+
exported_program.graph_signature,
233+
result.graph_module,
234+
)
235+
# Arm graph passes do not change symbolic shape constraints, and
236+
# metadata-only fake modes may differ after propagation.
237+
238+
return ExportedProgramPassResult(exported_program, result.modified)
239+
240+
191241
def register_pass_insertions_before(
192242
target_pass_type: type, passes: list[ExportPass]
193243
) -> None:
@@ -211,7 +261,7 @@ def clear_registered_pass_insertions() -> None:
211261
_registered_pass_insertions.clear()
212262

213263

214-
class ArmPassManager(PassManager):
264+
class ArmPassManager(ExportedProgramPassManager):
215265
def __init__(self, compile_spec: ArmCompileSpec) -> None:
216266
self.compile_spec = compile_spec
217267
self.tosa_spec = compile_spec.tosa_spec
@@ -374,8 +424,39 @@ def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
374424
shape_env = _get_shape_env_from_gm(graph_module)
375425
return TosaLoweringContext(self.tosa_spec, shape_env)
376426

377-
def _transform(self, graph_module: GraphModule):
378-
return self(graph_module).graph_module
427+
def _transform_graph_module(self, graph_module: GraphModule):
428+
# TFA and control-flow submodule paths operate on bare GraphModules
429+
# without a standalone ExportedProgram to keep in sync.
430+
return GraphModulePassManager(self.passes)(graph_module).graph_module
431+
432+
def __call__( # type: ignore[override]
433+
self,
434+
module: ExportedProgram | GraphModule,
435+
override_verifiers: Any | None = None,
436+
) -> ExportedProgramPassResult | PassResult:
437+
if isinstance(module, GraphModule):
438+
if override_verifiers is not None:
439+
raise ValueError("override_verifiers is only valid for ExportedProgram")
440+
return GraphModulePassManager(self.passes)(module)
441+
return super().__call__(module, override_verifiers)
442+
443+
def _transform(
444+
self,
445+
exported_program: ExportedProgram,
446+
graph_module: GraphModule,
447+
) -> GraphModule:
448+
if graph_module is exported_program.graph_module:
449+
passes: list[
450+
ExportedProgramPassBase | Callable[[GraphModule], PassResult | None]
451+
] = [_ExportedProgramGraphPassAdapter(p) for p in self.passes]
452+
transformed_program = ExportedProgramPassManager(passes)(
453+
exported_program
454+
).exported_program
455+
exported_program._graph_module = transformed_program.graph_module
456+
exported_program._graph_signature = transformed_program.graph_signature
457+
exported_program._range_constraints = transformed_program.range_constraints
458+
return exported_program.graph_module
459+
return self._transform_graph_module(graph_module)
379460

380461
def add_pass(self, pipeline_pass):
381462
if type(pipeline_pass) in self._skip_pass_types:
@@ -558,7 +639,7 @@ def _tosa_pipeline(
558639
self._apply_pass_insertions()
559640

560641
self.validate_constraints_mandatory()
561-
return self._transform(graph_module)
642+
return self._transform(exported_program, graph_module)
562643

563644
def transform_to_backend_pipeline(
564645
self, exported_program: ExportedProgram, graph_module: GraphModule
@@ -663,21 +744,4 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
663744
]
664745
)
665746

666-
return self._transform(graph_module)
667-
668-
def __call__(self, module: Module) -> PassResult:
669-
try:
670-
return super().__call__(module)
671-
except Exception as e:
672-
first_exception = e.__cause__ or e.__context__ or e
673-
import re
674-
675-
message = e.args[0]
676-
m = re.search(r"An error occurred when running the '([^']+)' pass", message)
677-
if m:
678-
pass_name = m.group(1)
679-
first_exception.args = (
680-
f"{pass_name}: {first_exception.args[0]}",
681-
*first_exception.args[1:],
682-
)
683-
raise first_exception
747+
return self._transform_graph_module(graph_module)

backends/arm/test/misc/test_call_operator_submodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -60,7 +60,7 @@ def test_call_operator_runs_once_for_cond_submodules_tosa_FP() -> None:
6060
recording_pass = _DepthRecordingPass(graph_module)
6161
pass_manager = ArmPassManager(TosaCompileSpec("TOSA-1.00+FP"))
6262
pass_manager.add_pass(recording_pass)
63-
pass_manager._transform(graph_module)
63+
pass_manager._transform_graph_module(graph_module)
6464

6565
assert recording_pass.num_submodules_called == 3
6666
assert recording_pass.depths, "call_operator was never invoked"

backends/arm/test/passes/test_arm_op_targeted_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import operator
7-
from typing import Set, Type
7+
from typing import cast, Set, Type
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass
@@ -45,7 +45,7 @@ def create_test_pass_manager() -> ArmPassManager:
4545
def run_single_pass(graph_module: GraphModule, test_pass: ExportPass) -> PassResult:
4646
pass_manager = create_test_pass_manager()
4747
pass_manager.add_pass(test_pass)
48-
return pass_manager(graph_module)
48+
return cast(PassResult, pass_manager(graph_module))
4949

5050

5151
class DummyTargetedPass(ArmOpTargetedPass):

exir/passes/__init__.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
to_scratch_op,
3333
)
3434
from executorch.exir.pass_base import ExportPass
35-
from executorch.exir.pass_manager import PassManager, PassType
35+
from executorch.exir.pass_manager import ExportedProgramPassManager, PassType
3636
from executorch.exir.passes.const_prop_pass import ConstPropPass
3737
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
3838

@@ -498,25 +498,27 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
498498

499499
# Passes to convert a graph module from ATen to Edge IR
500500

501-
base_pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager(
502-
passes=[
503-
# ReplaceSymSizeOpPass need to be run before other passes which inherits
504-
# from ExportPass. ExportPass can not handle OpOverloadPacket in its
505-
# call_function method. The ReplaceSymSizeOpPass pass converts sym size
506-
# ops from OpOverloadPacket to OpOverload.
507-
ReplaceSymSizeOpPass(),
508-
NormalizeTransposePass(),
509-
ReplaceBrokenOpsWithFunctionalOpsPass(),
510-
ScalarToTensorPass(),
511-
SymToTensorPass(),
512-
RemoveNoopPass(),
513-
PruneEmptyTensorsPass(),
514-
RemoveToCopyPass(),
515-
]
516-
).passes
501+
base_pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
502+
ExportedProgramPassManager(
503+
passes=[
504+
# ReplaceSymSizeOpPass need to be run before other passes which inherits
505+
# from ExportPass. ExportPass can not handle OpOverloadPacket in its
506+
# call_function method. The ReplaceSymSizeOpPass pass converts sym size
507+
# ops from OpOverloadPacket to OpOverload.
508+
ReplaceSymSizeOpPass(),
509+
NormalizeTransposePass(),
510+
ReplaceBrokenOpsWithFunctionalOpsPass(),
511+
ScalarToTensorPass(),
512+
SymToTensorPass(),
513+
RemoveNoopPass(),
514+
PruneEmptyTensorsPass(),
515+
RemoveToCopyPass(),
516+
]
517+
).passes
518+
)
517519

518520
base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
519-
PassManager(
521+
ExportedProgramPassManager(
520522
passes=[
521523
dead_code_elimination_pass,
522524
DebugHandleGeneratorPass(),

0 commit comments

Comments
 (0)