Skip to content

Commit 6817eda

Browse files
Andrew Pullinmeta-codesync[bot]
authored andcommitted
Minor speedup for model lowering: Skip redundant run_decompositions when no ops match decomp table (#18496)
Summary: Pull Request resolved: #18496 Adds an early-exit check to _gen_edge_manager_for_partitioners: before calling program.run_decompositions(table), scan the graph for ops that appear in the decomposition table. If none are found, skip the call entirely. Each run_decompositions call performs a full re-export of the program via make_fx(), re-tracing every node through FakeTensor dispatch. On the EDGE_DO_NOT_DECOMP path this function is called up to 3 times; the early-exit eliminates at least one redundant call where the previous pass already decomposed all matching ops. The check recursively walks control flow submodules (cond/map/scan) to avoid incorrectly skipping when decomposable ops are nested. ## Benchmark Model: small CNN feature extractor (~50K params, 9 conv layers with LayerNorm, targeting Ethos-U55 via the ARM/TOSA lowering pipeline). Graph: ~1200 nodes. lower() before: 82 s lower() after: 71 s Delta: -11 s (-13 %) Differential Revision: D96489903
1 parent a94c7c3 commit 6817eda

1 file changed

Lines changed: 36 additions & 5 deletions

File tree

exir/program/_program.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,33 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
11001100
return check_op_support is None
11011101

11021102

1103-
def _gen_edge_manager_for_partitioners(
1103+
def _has_decomposable_ops(
1104+
program: "ExportedProgram",
1105+
decomp_table: dict,
1106+
) -> bool:
1107+
"""Check if any ops in the graph match the decomposition table.
1108+
1109+
Returns True if the graph contains at least one op that appears in the
1110+
decomposition table, meaning run_decompositions would actually decompose
1111+
something. Returns True for empty tables (functionalization-only path)
1112+
since we can't cheaply determine if the graph needs functionalization.
1113+
"""
1114+
if not decomp_table:
1115+
return True # empty table = functionalize, can't skip cheaply
1116+
1117+
def _graph_has_match(gm: torch.fx.GraphModule) -> bool:
1118+
for node in gm.graph.nodes:
1119+
if node.op == "call_function" and node.target in decomp_table:
1120+
return True
1121+
for _, submod, _ in get_control_flow_submodules(gm):
1122+
if _graph_has_match(submod):
1123+
return True
1124+
return False
1125+
1126+
return _graph_has_match(program.graph_module)
1127+
1128+
1129+
def _gen_edge_manager_for_partitioners( # noqa: C901
11041130
partitioner: Dict[str, List[Partitioner]],
11051131
aten_programs: Dict[str, ExportedProgram],
11061132
config: EdgeCompileConfig,
@@ -1135,7 +1161,8 @@ def _gen_edge_manager_for_partitioners(
11351161
table = _default_decomposition_table()
11361162
for op in config.preserve_ops:
11371163
table.pop(op, None)
1138-
program = program.run_decompositions(table)
1164+
if _has_decomposable_ops(program, table):
1165+
program = program.run_decompositions(table)
11391166

11401167
# Process each partitioner individually using their specific requirements
11411168
for curr_partitioner in partitioners_for_program:
@@ -1155,7 +1182,8 @@ def _gen_edge_manager_for_partitioners(
11551182
if table.pop(op, None) is not None:
11561183
ops_needing_preservation.append(op)
11571184

1158-
program = program.run_decompositions(table)
1185+
if _has_decomposable_ops(program, table):
1186+
program = program.run_decompositions(table)
11591187
final_ops_to_preserve.update(ops_needing_preservation)
11601188
else:
11611189
# EDGE_DO_NOT_DECOMP path for the partitioner
@@ -1169,7 +1197,8 @@ def _gen_edge_manager_for_partitioners(
11691197
table.pop(op, None)
11701198

11711199
# First pass of decompositions with this partitioner's preserved ops
1172-
program = program.run_decompositions(table)
1200+
if _has_decomposable_ops(program, table):
1201+
program = program.run_decompositions(table)
11731202

11741203
# Filter ops using EDGE_DO_NOT_DECOMP
11751204
temp_partitioner_dict = {name: [curr_partitioner]}
@@ -1182,7 +1211,9 @@ def _gen_edge_manager_for_partitioners(
11821211
final_ops_to_preserve.update(preserved_ops)
11831212

11841213
# Second pass of decompositions with this partitioner's preserved ops after filtering
1185-
program = program.run_decompositions(_default_decomposition_table())
1214+
full_table = _default_decomposition_table()
1215+
if _has_decomposable_ops(program, full_table):
1216+
program = program.run_decompositions(full_table)
11861217

11871218
# Restore ops from edge_no_decomp_namespace to aten ops
11881219
_restore_transformed_ops_to_aten_ops(program)

0 commit comments

Comments
 (0)