Skip to content

Commit a67fd35

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Use CapabilityBasedPartitioner in AotiPartitioner (#20384)
Summary: AotiPartitioner (the base for the CUDA and Metal backends) groups the ops it delegates into one partition, by hand. Every other ExecuTorch backend (XNNPACK, Vulkan, CoreML) uses the shared CapabilityBasedPartitioner helper instead. This switches AotiPartitioner to that helper too. Why: 1. Consistency -- same partitioning path as the other backends, and a real OperatorSupport hook instead of a hand-rolled tagging loop. 2. It can break. A delegate has to be one connected chunk of the graph. If the ops being delegated aren't all next to each other (some other node sits in between), putting them all in one partition is invalid and lowering crashes with "AssertionError: Invalid partition, found dependency cycles". CapabilityBasedPartitioner returns several maximal convex partitions instead, each of which fuses cleanly. No change for the common case: if every op can be delegated, you still get exactly one partition (no extra delegate boundaries). When a non-delegated node splits the delegated ops, this emits one partition (and one delegate boundary) per island, which is the cost of producing a valid program. Control-flow ops (cond/map/while_loop/scan) keep their branch get_attr operands in the same partition, and constant/buffer tagging is unchanged. Reviewed By: Gasoonjia Differential Revision: D109040727
1 parent c9ef423 commit a67fd35

3 files changed

Lines changed: 221 additions & 69 deletions

File tree

backends/aoti/aoti_partitioner.py

Lines changed: 67 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Dict, List, Optional, Tuple
7+
from typing import Callable, Dict, List, Mapping, Optional, Tuple
88

99
import torch
1010
from executorch.exir._warnings import experimental
@@ -21,6 +21,8 @@
2121
)
2222
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2323
from torch.export.exported_program import ExportedProgram
24+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
25+
from torch.fx.passes.operator_support import OperatorSupportBase
2426

2527

2628
@experimental(
@@ -30,12 +32,10 @@ class AotiPartitioner(Partitioner):
3032
"""
3133
Base partitioner for AOTInductor-driven backend integration.
3234
33-
This partitioner creates a single partition containing all operators from the input graph.
34-
It skips core ATen decomposition, allowing the backend to handle decomposition using
35+
Delegates the non-lowered operators to AOTInductor as one or more convex
36+
partitions (a single partition when nothing else has claimed part of the
37+
graph). It skips core ATen decomposition, letting the backend decompose via
3538
AOTInductor's backend-specific decomposition table.
36-
37-
Only operators that cannot be handled by the aoti library will be excluded from
38-
the partition and fall back to ExecuTorch's default or custom handling.
3939
"""
4040

4141
def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None:
@@ -49,62 +49,76 @@ def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None:
4949
self.delegation_spec = DelegationSpec(backend_name, compile_spec)
5050

5151
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
52-
"""
53-
Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.
54-
"""
52+
"""Delegate the non-lowered ops to AOTInductor.
5553
56-
partition_tags: Dict[str, DelegationSpec] = {}
57-
tag = "tag0"
58-
59-
# Tag torch.cond and other control flow operations
60-
def is_control_flow(node: torch.fx.Node) -> bool:
61-
return node.op == "call_function" and node.target in [
62-
torch.ops.higher_order.cond,
63-
torch.ops.higher_order.map_impl,
64-
torch.ops.higher_order.while_loop,
65-
]
66-
67-
# Nodes already lowered by an earlier partitioner (e.g. a preceding
68-
# TensorRT partition) appear as executorch_call_delegate calls and their
69-
# output getitems; re-delegating them would nest a foreign delegate. Tag
70-
# only the remaining non-lowered ops so this partitioner composes after
71-
# others.
54+
Uses CapabilityBasedPartitioner rather than a single tag because a
55+
delegated submodule must be convex: if a node that is not delegated sits
56+
between the delegated ops, one tag would span a non-convex set and fusion
57+
would fail with a dependency cycle.
58+
"""
59+
# Only nodes not already lowered are candidates for this backend.
7260
non_lowered_nodes = set(get_non_lowered_nodes(exported_program.graph))
7361

74-
for node in exported_program.graph.nodes:
75-
if node.op == "call_function":
76-
if node not in non_lowered_nodes:
77-
continue
62+
control_flow_targets = [
63+
torch.ops.higher_order.cond,
64+
torch.ops.higher_order.map_impl,
65+
torch.ops.higher_order.while_loop,
66+
torch.ops.higher_order.scan,
67+
]
68+
69+
class AotiOperatorSupport(OperatorSupportBase):
70+
def is_node_supported(
71+
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
72+
) -> bool:
73+
return node.op == "call_function" and node in non_lowered_nodes
74+
75+
partitioner = CapabilityBasedPartitioner(
76+
exported_program.graph_module,
77+
AotiOperatorSupport(),
78+
allows_single_node_partition=True,
79+
)
80+
81+
partition_tags: Dict[str, DelegationSpec] = {}
82+
for partition in partitioner.propose_partitions():
83+
tag = f"aoti_{partition.id}"
84+
partition_tags[tag] = self.delegation_spec
85+
for node in partition.nodes:
7886
node.meta["delegation_tag"] = tag
79-
# Tag get_attr nodes that are used by control flow operations
80-
elif node.op == "get_attr":
81-
# Check if any user is a control flow operation
82-
for user in node.users:
83-
if is_control_flow(user):
84-
node.meta["delegation_tag"] = tag
85-
break
8687

87-
partition_tags[tag] = self.delegation_spec
88+
# A control-flow op carries its branch GraphModules as get_attr operands;
89+
# they must share the op's tag so they land inside the same submodule. A
90+
# branch module feeds a single control-flow op, so first match wins.
91+
for node in exported_program.graph.nodes:
92+
if node.op != "get_attr":
93+
continue
94+
for user in node.users:
95+
if (
96+
user.op == "call_function"
97+
and user.target in control_flow_targets
98+
and "delegation_tag" in user.meta
99+
):
100+
node.meta["delegation_tag"] = user.meta["delegation_tag"]
101+
break
88102

89103
tag_constant_data(exported_program)
90104
tag_mutated_buffer(exported_program)
91105

92-
# A constant that still has users feeds only a prior delegate; tagging it
93-
# would fail backend lowering's same-tag check (its user keeps the prior
94-
# tag). tag_constant_data already claimed the ones this partition uses, so
95-
# tag only the genuinely unused constants here.
96-
for node in exported_program.graph.nodes:
97-
if (
98-
node.op == "placeholder"
99-
and not node.users
100-
and "delegation_tag" not in node.meta
101-
and (
102-
is_param(exported_program, node)
103-
or is_buffer(exported_program, node)
104-
or is_lifted_tensor_constant(exported_program, node)
105-
)
106-
):
107-
node.meta["delegation_tag"] = tag
106+
# tag_constant_data only tags constants that have users; tag the
107+
# genuinely unused ones too so none are left dangling.
108+
if partition_tags:
109+
fallback_tag = next(iter(partition_tags))
110+
for node in exported_program.graph.nodes:
111+
if (
112+
node.op == "placeholder"
113+
and not node.users
114+
and "delegation_tag" not in node.meta
115+
and (
116+
is_param(exported_program, node)
117+
or is_buffer(exported_program, node)
118+
or is_lifted_tensor_constant(exported_program, node)
119+
)
120+
):
121+
node.meta["delegation_tag"] = fallback_tag
108122

109123
return PartitionResult(
110124
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/cuda/tests/test_cuda_partitioner.py

Lines changed: 148 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1313
from executorch.exir.backend.partitioner import PartitionResult
1414
from executorch.exir.delegate import executorch_call_delegate
15-
from torch._export.utils import is_buffer
15+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1616
from torch.export import export
17+
from torch.fx.passes.utils.fuser_utils import validate_partition
1718

1819

1920
class TestCudaPartitioner(unittest.TestCase):
2021
"""
2122
Test CUDA partitioner functionality.
2223
23-
After CUDA partitioning, there should be exactly one partitioned graph that contains
24-
all operators from the input graph. This means all operators should be tagged with
25-
the same delegation tag, indicating they will all be executed by the CUDA backend.
24+
A fully delegatable graph collapses to a single partition. When a
25+
non-delegated node splits the delegatable ops, the partitioner emits one
26+
convex partition per island.
2627
"""
2728

2829
def _get_partition_result(
@@ -178,12 +179,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
178179
for node in partition_result.tagged_exported_program.graph.nodes:
179180
if node.op == "placeholder":
180181
# Check if this is a constant (param, buffer, or lifted tensor constant)
181-
from torch._export.utils import (
182-
is_buffer,
183-
is_lifted_tensor_constant,
184-
is_param,
185-
)
186-
187182
is_constant = (
188183
is_param(partition_result.tagged_exported_program, node)
189184
or is_buffer(partition_result.tagged_exported_program, node)
@@ -216,8 +211,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
216211
f"All constant placeholders should be tagged. Found untagged constants: {untagged_constants}",
217212
)
218213

219-
# Verify all tagged constants have the expected tag
220-
expected_tag = "tag0"
214+
# Verify all tagged constants share the (single) partition's tag.
215+
self.assertEqual(len(partition_result.partition_tags), 1)
216+
expected_tag = next(iter(partition_result.partition_tags))
221217
for node in constant_placeholders:
222218
actual_tag = node.meta.get("delegation_tag")
223219
self.assertEqual(
@@ -320,3 +316,143 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
320316
self.assertNotIn("delegation_tag", buffer_placeholder.meta)
321317
self.assertNotIn("delegation_tag", delegate.meta)
322318
self.assertIn("delegation_tag", aten_node.meta)
319+
320+
def test_multiple_partitions_for_split_graph(self) -> None:
321+
"""Ops split by a non-delegated node must land in separate partitions.
322+
323+
One tag over the disconnected islands would be non-convex and fail fusion.
324+
"""
325+
326+
class TwoAddModule(torch.nn.Module):
327+
def forward(self, x: torch.Tensor) -> torch.Tensor:
328+
a = x + 1.0
329+
return a + 2.0
330+
331+
exported_program = export(TwoAddModule(), (torch.randn(3, 4),), strict=True)
332+
graph_module = exported_program.graph_module
333+
graph = graph_module.graph
334+
335+
add_nodes = [
336+
n
337+
for n in graph.nodes
338+
if n.op == "call_function" and n.target != operator.getitem
339+
]
340+
first_add, second_add = add_nodes[0], add_nodes[1]
341+
342+
# Splice an already-lowered region between the two adds so the second add
343+
# depends on the first only through that non-delegated node.
344+
graph_module.lowered_module_0 = torch.nn.Module()
345+
with graph.inserting_before(second_add):
346+
lowered = graph.get_attr("lowered_module_0")
347+
delegate = graph.call_function(
348+
executorch_call_delegate, (lowered, first_add)
349+
)
350+
delegate_output = graph.call_function(operator.getitem, (delegate, 0))
351+
second_add.replace_input_with(first_add, delegate_output)
352+
graph.lint()
353+
354+
result = CudaPartitioner([]).partition(exported_program)
355+
356+
# Separated by the delegate, the adds must land in different partitions.
357+
self.assertEqual(len(result.partition_tags), 2)
358+
self.assertIn("delegation_tag", first_add.meta)
359+
self.assertIn("delegation_tag", second_add.meta)
360+
self.assertNotEqual(
361+
first_add.meta["delegation_tag"], second_add.meta["delegation_tag"]
362+
)
363+
self.assertNotIn("delegation_tag", delegate.meta)
364+
self.assertNotIn("delegation_tag", delegate_output.meta)
365+
366+
# Each partition must be convex on its own so fusion does not cycle.
367+
for tag in result.partition_tags:
368+
tagged = [
369+
n
370+
for n in exported_program.graph.nodes
371+
if n.meta.get("delegation_tag") == tag
372+
]
373+
self.assertTrue(validate_partition(tagged))
374+
375+
def test_control_flow_get_attr_shares_op_tag(self) -> None:
376+
"""A control-flow op's branch get_attrs must share the op's partition tag.
377+
378+
They are not call_function nodes, so the capability partitioner does not
379+
claim them; they must be lowered into the same submodule as the op.
380+
"""
381+
382+
class CondModule(torch.nn.Module):
383+
def forward(self, x: torch.Tensor) -> torch.Tensor:
384+
return torch.cond(x.sum() > 0, torch.sin, torch.cos, (x,))
385+
386+
exported_program = export(CondModule(), (torch.randn(3, 4),), strict=True)
387+
result = CudaPartitioner([]).partition(exported_program)
388+
389+
cond_node = next(
390+
n
391+
for n in exported_program.graph.nodes
392+
if n.op == "call_function" and n.target is torch.ops.higher_order.cond
393+
)
394+
branch_get_attrs = [
395+
arg
396+
for arg in cond_node.args
397+
if isinstance(arg, torch.fx.Node) and arg.op == "get_attr"
398+
]
399+
400+
self.assertEqual(len(branch_get_attrs), 2)
401+
self.assertIn(cond_node.meta["delegation_tag"], result.partition_tags)
402+
for get_attr in branch_get_attrs:
403+
self.assertEqual(
404+
get_attr.meta.get("delegation_tag"),
405+
cond_node.meta["delegation_tag"],
406+
)
407+
408+
def test_shared_constant_across_partitions(self) -> None:
409+
"""A constant read by two partitions is claimed, not dropped.
410+
411+
tag_constant_data assigns it one partition's tag; backend lowering later
412+
duplicates it per consumer, so partitioning must not crash or drop it.
413+
"""
414+
415+
class SharedWeightModule(torch.nn.Module):
416+
def __init__(self) -> None:
417+
super().__init__()
418+
self.register_buffer("w", torch.randn(3, 4))
419+
420+
def forward(self, x: torch.Tensor) -> torch.Tensor:
421+
return (x + self.w) + self.w
422+
423+
exported_program = export(
424+
SharedWeightModule(), (torch.randn(3, 4),), strict=True
425+
)
426+
graph_module = exported_program.graph_module
427+
graph = graph_module.graph
428+
429+
add_nodes = [
430+
n
431+
for n in graph.nodes
432+
if n.op == "call_function" and n.target != operator.getitem
433+
]
434+
first_add, second_add = add_nodes[0], add_nodes[1]
435+
436+
# Split the two adds (both reading w) with an already-lowered region.
437+
graph_module.lowered_module_0 = torch.nn.Module()
438+
with graph.inserting_before(second_add):
439+
lowered = graph.get_attr("lowered_module_0")
440+
delegate = graph.call_function(
441+
executorch_call_delegate, (lowered, first_add)
442+
)
443+
delegate_output = graph.call_function(operator.getitem, (delegate, 0))
444+
second_add.replace_input_with(first_add, delegate_output)
445+
graph.lint()
446+
447+
result = CudaPartitioner([]).partition(exported_program)
448+
449+
# Two islands, and the shared buffer is claimed by one of them, not dropped.
450+
self.assertEqual(len(result.partition_tags), 2)
451+
buffer_placeholder = next(
452+
n
453+
for n in graph.nodes
454+
if n.op == "placeholder" and is_buffer(exported_program, n)
455+
)
456+
self.assertIn(
457+
buffer_placeholder.meta.get("delegation_tag"), result.partition_tags
458+
)

exir/backend/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,10 @@ def tag_constant_data(edge_program: ExportedProgram) -> None:
390390
"If the data is too large and it's not preferred to copy, please tag the "
391391
"constant node like node.['no_copy'] = True and they won't be copied."
392392
)
393-
# tag the data node with the same tag as the last user
393+
# Pick a deterministic consumer tag so a constant shared across
394+
# partitions is assigned reproducibly across runs.
394395
if len(user_tags) > 0:
395-
node.meta["delegation_tag"] = user_tags.pop()
396+
node.meta["delegation_tag"] = min(user_tags)
396397

397398

398399
def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
@@ -429,9 +430,10 @@ def tag_mutated_buffer(edge_program: ExportedProgram) -> None:
429430
"If the data is too large and it's not preferred to copy, please tag the "
430431
"constant node like node.['no_copy'] = True and they won't be copied."
431432
)
432-
# tag the data node with the same tag as the last user
433+
# Pick a deterministic consumer tag so a buffer shared across
434+
# partitions is assigned reproducibly across runs.
433435
if len(user_tags) > 0:
434-
node.meta["delegation_tag"] = user_tags.pop()
436+
node.meta["delegation_tag"] = min(user_tags)
435437

436438

437439
def is_shape_dynamic(node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)