77
88import logging
99from collections import defaultdict
10- from collections .abc import Sequence
10+ from collections .abc import Callable , Sequence
1111from dataclasses import dataclass , field
12+ from typing import Any , cast
1213
1314from executorch .backends .arm ._passes import (
1415 AccumulateIndexPutPass ,
167168)
168169
169170from executorch .exir import ExportedProgram
170- from executorch .exir .pass_base import ExportPass
171- from executorch .exir .pass_manager import PassManager
171+ from executorch .exir ._program_utils import _get_updated_graph_signature
172+ from executorch .exir .pass_base import (
173+ ExportedProgramPassBase ,
174+ ExportedProgramPassResult ,
175+ ExportPass ,
176+ )
177+ from executorch .exir .pass_manager import ExportedProgramPassManager
172178from torch ._export .utils import _get_shape_env_from_gm
173179from torch .fx import GraphModule
174180from torch .fx .passes .infra .pass_base import PassResult
175- from torch .nn . modules import Module
181+ from torch .fx . passes . infra . pass_manager import PassManager as GraphModulePassManager
176182
177183logger = logging .getLogger (__name__ )
178184
@@ -188,6 +194,50 @@ class PassInsertions:
188194_registered_pass_insertions : dict [type , PassInsertions ] = {}
189195
190196
197+ def _graph_pass_name (graph_pass : Callable [[GraphModule ], PassResult | None ]) -> str :
198+ if isinstance (graph_pass , ExportPass ):
199+ return ArmPass .get_name (graph_pass )
200+ if hasattr (graph_pass , "__name__" ):
201+ return graph_pass .__name__
202+ return type (graph_pass ).__name__
203+
204+
205+ class _ExportedProgramGraphPassAdapter (ExportedProgramPassBase ):
206+ def __init__ (self , graph_pass : Callable [[GraphModule ], PassResult | None ]) -> None :
207+ self .graph_pass = graph_pass
208+
209+ def call (self , exported_program : ExportedProgram ) -> ExportedProgramPassResult :
210+ graph_pass = cast (Any , self .graph_pass )
211+ pass_exported_program = getattr (graph_pass , "exported_program" , None )
212+ if pass_exported_program is not None :
213+ # ExportedProgramPassManager works on a shallow copy; Arm graph
214+ # passes that store an ExportedProgram must update that copy.
215+ graph_pass .exported_program = exported_program
216+
217+ try :
218+ result = self .graph_pass (exported_program .graph_module )
219+ finally :
220+ if pass_exported_program is not None :
221+ graph_pass .exported_program = pass_exported_program
222+
223+ if result is None :
224+ raise TypeError (
225+ f"The result of pass { _graph_pass_name (self .graph_pass )} should be type PassResult."
226+ )
227+
228+ if result .modified :
229+ result .graph_module .recompile ()
230+ exported_program ._graph_module = result .graph_module
231+ exported_program ._graph_signature = _get_updated_graph_signature (
232+ exported_program .graph_signature ,
233+ result .graph_module ,
234+ )
235+ # Arm graph passes do not change symbolic shape constraints, and
236+ # metadata-only fake modes may differ after propagation.
237+
238+ return ExportedProgramPassResult (exported_program , result .modified )
239+
240+
191241def register_pass_insertions_before (
192242 target_pass_type : type , passes : list [ExportPass ]
193243) -> None :
@@ -211,7 +261,7 @@ def clear_registered_pass_insertions() -> None:
211261 _registered_pass_insertions .clear ()
212262
213263
214- class ArmPassManager (PassManager ):
264+ class ArmPassManager (ExportedProgramPassManager ):
215265 def __init__ (self , compile_spec : ArmCompileSpec ) -> None :
216266 self .compile_spec = compile_spec
217267 self .tosa_spec = compile_spec .tosa_spec
@@ -374,8 +424,39 @@ def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
374424 shape_env = _get_shape_env_from_gm (graph_module )
375425 return TosaLoweringContext (self .tosa_spec , shape_env )
376426
377- def _transform (self , graph_module : GraphModule ):
378- return self (graph_module ).graph_module
427+ def _transform_graph_module (self , graph_module : GraphModule ):
428+ # TFA and control-flow submodule paths operate on bare GraphModules
429+ # without a standalone ExportedProgram to keep in sync.
430+ return GraphModulePassManager (self .passes )(graph_module ).graph_module
431+
432+ def __call__ ( # type: ignore[override]
433+ self ,
434+ module : ExportedProgram | GraphModule ,
435+ override_verifiers : Any | None = None ,
436+ ) -> ExportedProgramPassResult | PassResult :
437+ if isinstance (module , GraphModule ):
438+ if override_verifiers is not None :
439+ raise ValueError ("override_verifiers is only valid for ExportedProgram" )
440+ return GraphModulePassManager (self .passes )(module )
441+ return super ().__call__ (module , override_verifiers )
442+
443+ def _transform (
444+ self ,
445+ exported_program : ExportedProgram ,
446+ graph_module : GraphModule ,
447+ ) -> GraphModule :
448+ if graph_module is exported_program .graph_module :
449+ passes : list [
450+ ExportedProgramPassBase | Callable [[GraphModule ], PassResult | None ]
451+ ] = [_ExportedProgramGraphPassAdapter (p ) for p in self .passes ]
452+ transformed_program = ExportedProgramPassManager (passes )(
453+ exported_program
454+ ).exported_program
455+ exported_program ._graph_module = transformed_program .graph_module
456+ exported_program ._graph_signature = transformed_program .graph_signature
457+ exported_program ._range_constraints = transformed_program .range_constraints
458+ return exported_program .graph_module
459+ return self ._transform_graph_module (graph_module )
379460
380461 def add_pass (self , pipeline_pass ):
381462 if type (pipeline_pass ) in self ._skip_pass_types :
@@ -558,7 +639,7 @@ def _tosa_pipeline(
558639 self ._apply_pass_insertions ()
559640
560641 self .validate_constraints_mandatory ()
561- return self ._transform (graph_module )
642+ return self ._transform (exported_program , graph_module )
562643
563644 def transform_to_backend_pipeline (
564645 self , exported_program : ExportedProgram , graph_module : GraphModule
@@ -663,21 +744,4 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
663744 ]
664745 )
665746
666- return self ._transform (graph_module )
667-
668- def __call__ (self , module : Module ) -> PassResult :
669- try :
670- return super ().__call__ (module )
671- except Exception as e :
672- first_exception = e .__cause__ or e .__context__ or e
673- import re
674-
675- message = e .args [0 ]
676- m = re .search (r"An error occurred when running the '([^']+)' pass" , message )
677- if m :
678- pass_name = m .group (1 )
679- first_exception .args = (
680- f"{ pass_name } : { first_exception .args [0 ]} " ,
681- * first_exception .args [1 :],
682- )
683- raise first_exception
747+ return self ._transform_graph_module (graph_module )
0 commit comments