Skip to content

Commit c5364fd

Browse files
authored
XNNPACK: Remove no-op expand_copy before partitioning (#19978)
Remove aten.expand_copy nodes when input and output metadata have the same dtype and shape. Static export can leave these shape-preserving expands as portable copy kernels even though they are identities for the lowered graph. Run the cleanup in the normal XNNPACK transform pass path so it can remove inter-delegate expand_copy nodes before partitioning. For EdgeTAM mask decoder, expand_copy ops drop from 32 to 0, non-delegate kernel calls drop from 114 to 82, and delegate calls drop by 1, resulting in a ~9% speedup on a measured SVE2 and SME2 Android devices. cc @GregoryComer @digantdesai @cbilgin @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell @rascani Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
1 parent b991342 commit c5364fd

3 files changed

Lines changed: 142 additions & 1 deletion

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class RemoveNoopExpandCopyPass(ExportPass):
14+
"""
15+
Remove ``expand_copy`` nodes that do not change tensor shape or dtype.
16+
17+
In static XNNPACK export flows, shape-specialization can turn an expand into
18+
a materialized copy whose input and output metadata are identical. Such a
19+
node is an identity for the lowered graph and can be bypassed. The pass
20+
leaves nodes in place whenever the output shape differs from the input
21+
shape.
22+
"""
23+
24+
def _is_noop_expand_copy(self, node: torch.fx.Node) -> bool:
25+
# TODO: Investigate moving this to a shared backend transform. Other
26+
# backends already carry equivalent no-op expand handling.
27+
if node.target != exir_ops.edge.aten.expand_copy.default:
28+
return False
29+
30+
input_node = node.args[0]
31+
if not isinstance(input_node, torch.fx.Node):
32+
return False
33+
34+
input_value = input_node.meta.get("val")
35+
output_value = node.meta.get("val")
36+
if input_value is None or output_value is None:
37+
return False
38+
39+
return (
40+
input_value.dtype == output_value.dtype
41+
and input_value.shape == output_value.shape
42+
)
43+
44+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
45+
graph = graph_module.graph
46+
47+
for node in list(graph.nodes):
48+
if not self._is_noop_expand_copy(node):
49+
continue
50+
51+
node.replace_all_uses_with(node.args[0])
52+
53+
graph.eliminate_dead_code()
54+
graph.lint()
55+
graph_module.recompile()
56+
57+
graph_module = super().call(graph_module).graph_module
58+
59+
return PassResult(graph_module, True)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import torch
11+
from executorch.backends.xnnpack._passes.remove_noop_expand_copy_pass import (
12+
RemoveNoopExpandCopyPass,
13+
)
14+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
15+
from executorch.backends.xnnpack.utils.configs import (
16+
get_transform_passes,
17+
get_xnnpack_edge_compile_config,
18+
)
19+
from executorch.exir import to_edge_transform_and_lower
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
22+
23+
class TestRemoveNoopExpandCopyPass(unittest.TestCase):
24+
PassStage = RunPasses([RemoveNoopExpandCopyPass])
25+
expand_copy_name = "executorch_exir_dialects_edge__ops_aten_expand_copy_default"
26+
27+
def setUp(self):
28+
torch._dynamo.reset()
29+
30+
class NoopExpand(torch.nn.Module):
31+
def forward(self, x):
32+
y = x.expand(x.shape)
33+
return y + 1
34+
35+
class BroadcastExpand(torch.nn.Module):
36+
def forward(self, x):
37+
y = x.expand(2, 3)
38+
return y + 1
39+
40+
def test_removes_same_shape_expand_copy(self):
41+
(
42+
Tester(self.NoopExpand(), (torch.randn(2, 3),))
43+
.export()
44+
.to_edge()
45+
.check_count({self.expand_copy_name: 1})
46+
.run_passes(self.PassStage)
47+
.check_count({self.expand_copy_name: 0})
48+
.run_method_and_compare_outputs()
49+
)
50+
51+
def test_keeps_broadcasting_expand_copy(self):
52+
(
53+
Tester(self.BroadcastExpand(), (torch.randn(1, 3),))
54+
.export()
55+
.to_edge()
56+
.check_count({self.expand_copy_name: 1})
57+
.run_passes(self.PassStage)
58+
.check_count({self.expand_copy_name: 1})
59+
.run_method_and_compare_outputs()
60+
)
61+
62+
def test_transform_passes_remove_same_shape_expand_copy(self):
63+
edge_program = to_edge_transform_and_lower(
64+
torch.export.export(self.NoopExpand(), (torch.randn(2, 3),), strict=True),
65+
transform_passes=get_transform_passes(),
66+
compile_config=get_xnnpack_edge_compile_config(),
67+
)
68+
graph = edge_program.exported_program().graph_module.graph
69+
70+
self.assertFalse(
71+
any(
72+
node.target == exir_ops.edge.aten.expand_copy.default
73+
for node in graph.nodes
74+
)
75+
)

backends/xnnpack/utils/configs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

78
from typing import List
89

910
import executorch.exir as exir
11+
12+
from executorch.backends.xnnpack._passes.remove_noop_expand_copy_pass import (
13+
RemoveNoopExpandCopyPass,
14+
)
1015
from executorch.exir.pass_manager import PassType
1116

1217

@@ -20,7 +25,9 @@ def get_xnnpack_edge_compile_config(
2025

2126

2227
def get_transform_passes(additional_passes=None) -> List[PassType]:
23-
passes = additional_passes if additional_passes else []
28+
passes = [RemoveNoopExpandCopyPass()]
29+
if additional_passes:
30+
passes.extend(additional_passes)
2431
return passes
2532

2633

0 commit comments

Comments
 (0)