Skip to content

Commit b602537

Browse files
author
ssjia
committed
Update
[ghstack-poisoned]
2 parents 8701dac + 93a2feb commit b602537

248 files changed

Lines changed: 11926 additions & 5908 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/mlx.yml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- backends/mlx/**
1414
- extension/llm/export/**
1515
- extension/audio/**
16+
- examples/models/gemma4_31b/**
1617
- examples/models/parakeet/**
1718
- examples/models/voxtral_realtime/**
1819
- examples/models/qwen3_5_moe/**
@@ -77,6 +78,8 @@ jobs:
7778
backends/mlx/test/test_passes.py \
7879
backends/mlx/test/test_pattern_utils.py \
7980
backends/mlx/test/test_partitioner.py \
81+
backends/mlx/test/test_serialization_dedup.py \
82+
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
8083
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
8184
-v
8285
echo "::endgroup::"
@@ -89,20 +92,16 @@ jobs:
8992
./cmake-out/backends/mlx/test/multi_thread_test_runner
9093
echo "::endgroup::"
9194
92-
echo "::group::Run gated_delta_rule op tests"
93-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
94-
echo "::endgroup::"
95-
96-
echo "::group::Run tq_norm op tests"
97-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v
98-
echo "::endgroup::"
99-
100-
echo "::group::Run tq4_compress op tests"
101-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v
102-
echo "::endgroup::"
103-
104-
echo "::group::Run tq_dequant op tests"
105-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v
95+
echo "::group::Run custom_kernel_ops op tests"
96+
# Run every custom_kernel_ops/**/test/test_*.py via its OpTestCase `run`
97+
# CLI. Recurses into per-format subpackages (e.g. gguf/test), so adding a
98+
# new op test file requires no change here.
99+
set -e
100+
for t in $(find backends/mlx/custom_kernel_ops -path '*/test/test_*.py' | sort); do
101+
mod="executorch.$(echo "${t%.py}" | tr '/' '.')"
102+
echo "--- ${mod} ---"
103+
${CONDA_RUN} python -m "${mod}" run -v
104+
done
106105
echo "::endgroup::"
107106
108107
test-mlx-qwen35-moe:

.github/workflows/trunk.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ jobs:
258258
- test_arm_backend: test_pytest_models_ethos_u85
259259
- test_arm_backend: test_run_ethos_u85
260260
- test_arm_backend: test_smaller_stories_llama_tosa
261+
- test_arm_backend: test_model_smollm2_135M_ethos_u85
261262
- test_arm_backend: test_memory_allocation
262263
- test_arm_backend: test_ootb_tests_ethos_u
263264
- test_arm_backend: test_ootb_tests_tosa

Makefile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ parakeet-vulkan:
261261

262262
dinov2-cuda:
263263
@echo "==> Building and installing ExecuTorch with CUDA..."
264-
cmake --workflow --preset llm-release-cuda
264+
cmake --preset llm-release-cuda -DEXECUTORCH_BUILD_EXTENSION_IMAGE=ON
265+
cmake --build --preset llm-release-cuda-install
265266
@echo "==> Building DINOv2 runner with CUDA..."
266267
cd examples/models/dinov2 && cmake --workflow --preset dinov2-cuda
267268
@echo ""
@@ -270,7 +271,8 @@ dinov2-cuda:
270271

271272
dinov2-cuda-debug:
272273
@echo "==> Building and installing ExecuTorch with CUDA (debug mode)..."
273-
cmake --workflow --preset llm-debug-cuda
274+
cmake --preset llm-debug-cuda -DEXECUTORCH_BUILD_EXTENSION_IMAGE=ON
275+
cmake --build --preset llm-debug-cuda-install
274276
@echo "==> Building DINOv2 runner with CUDA (debug mode)..."
275277
cd examples/models/dinov2 && cmake --workflow --preset dinov2-cuda-debug
276278
@echo ""

backends/aoti/aoti_partitioner.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
Partitioner,
1515
PartitionResult,
1616
)
17-
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
17+
from executorch.exir.backend.utils import (
18+
get_non_lowered_nodes,
19+
tag_constant_data,
20+
tag_mutated_buffer,
21+
)
1822
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1923
from torch.export.exported_program import ExportedProgram
2024

@@ -60,8 +64,17 @@ def is_control_flow(node: torch.fx.Node) -> bool:
6064
torch.ops.higher_order.while_loop,
6165
]
6266

67+
# Nodes already lowered by an earlier partitioner (e.g. a preceding
68+
# TensorRT partition) appear as executorch_call_delegate calls and their
69+
# output getitems; re-delegating them would nest a foreign delegate. Tag
70+
# only the remaining non-lowered ops so this partitioner composes after
71+
# others.
72+
non_lowered_nodes = set(get_non_lowered_nodes(exported_program.graph))
73+
6374
for node in exported_program.graph.nodes:
6475
if node.op == "call_function":
76+
if node not in non_lowered_nodes:
77+
continue
6578
node.meta["delegation_tag"] = tag
6679
# Tag get_attr nodes that are used by control flow operations
6780
elif node.op == "get_attr":
@@ -76,17 +89,22 @@ def is_control_flow(node: torch.fx.Node) -> bool:
7689
tag_constant_data(exported_program)
7790
tag_mutated_buffer(exported_program)
7891

79-
# Tag constant placeholders that have no users
80-
# tag_constant_data only tags constants that have users with delegation_tag
81-
# but we need to tag all constants for this partition
92+
# A constant that still has users feeds only a prior delegate; tagging it
93+
# would fail backend lowering's same-tag check (its user keeps the prior
94+
# tag). tag_constant_data already claimed the ones this partition uses, so
95+
# tag only the genuinely unused constants here.
8296
for node in exported_program.graph.nodes:
83-
if node.op == "placeholder" and (
84-
is_param(exported_program, node)
85-
or is_buffer(exported_program, node)
86-
or is_lifted_tensor_constant(exported_program, node)
97+
if (
98+
node.op == "placeholder"
99+
and not node.users
100+
and "delegation_tag" not in node.meta
101+
and (
102+
is_param(exported_program, node)
103+
or is_buffer(exported_program, node)
104+
or is_lifted_tensor_constant(exported_program, node)
105+
)
87106
):
88-
if "delegation_tag" not in node.meta:
89-
node.meta["delegation_tag"] = tag
107+
node.meta["delegation_tag"] = tag
90108

91109
return PartitionResult(
92110
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ Below is an overview of some of the testing options this script provides:
251251
| `test_arm_backend.sh test_pytest_ops_vkml` | Runs operator unit tests for VKML/VGF specific use-cases. |
252252
| `test_arm_backend.sh test_pytest_models_vkml` | Runs model unit tests for VKML/VGF specific use-cases. |
253253
| `test_arm_backend.sh test_run_vkml` | Runs end-to-end unit tests for VKML/VGF specific use-cases. |
254-
| `test_arm_backend.sh test_model_smollm2_135M` | Runs some models with Corstone FVP. |
254+
| `test_arm_backend.sh test_model_smollm2_135M_ethos_u85` | Runs smollm2_135M for Ethos-U85 specific use-cases. |
255255
| `test_arm_backend.sh test_ootb_tests_ethos_u` | Runs out-of-the-box tests for Ethos-U. |
256256
| `test_arm_backend.sh test_ootb_tests_tosa` | Runs out-of-the-box tests for TOSA. |
257257
| `test_arm_backend.sh test_ootb_tests_vgf` | Runs out-of-the-box tests for VKML/VGF. |

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
from .replace_scalar_with_tensor_pass import ( # noqa
150150
ReplaceScalarWithTensorByProfilePass,
151151
)
152+
from .rewrite_adaptive_avg_pool2d import RewriteAdaptiveAvgPool2dPass # noqa
152153
from .rewrite_avg_pool2d_pass import RewriteAvgPool2dPass # noqa
153154
from .rewrite_bool_bitwise_to_logical_pass import ( # noqa
154155
RewriteBoolBitwiseToLogicalPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
RemovePermutesAroundElementwiseTosaOps,
132132
ReplaceInfAndLimitValuesPass,
133133
ReplaceScalarWithTensorByProfilePass,
134+
RewriteAdaptiveAvgPool2dPass,
134135
RewriteAvgPool2dPass,
135136
RewriteBoolBitwiseToLogicalPass,
136137
RewriteBoolToFp32CastViaInt8Pass,
@@ -504,6 +505,7 @@ def _tosa_pipeline(
504505
DecomposeAsStridedCopyPass(),
505506
DecomposeMaxPool2dPass(),
506507
SizeAdjustInputPass(),
508+
RewriteAdaptiveAvgPool2dPass(),
507509
RewriteAvgPool2dPass(),
508510
ComputeConstantOpsAOTPass(exported_program),
509511
FuseConstantArgsPass(exported_program),

backends/arm/_passes/fuse_duplicate_users_pass.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
3434
graph = graph_module.graph
3535
modified = False
3636

37+
node_order = {node: index for index, node in enumerate(graph.nodes)}
3738
producers: Deque[Node] = deque(node for node in graph.nodes)
3839

3940
while producers:
@@ -48,7 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
4849
if len(user_nodes) < 2:
4950
continue
5051

51-
candidate_groups = self._get_candidate_groups(user_nodes)
52+
candidate_groups = self._get_candidate_groups(node_order, user_nodes)
5253

5354
signature_to_user: Dict[Tuple[Hashable, ...], Node] = {}
5455
for group in candidate_groups:
@@ -84,7 +85,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8485

8586
return PassResult(graph_module, modified)
8687

87-
def _get_candidate_groups(self, user_nodes):
88+
def _get_candidate_groups(self, node_order, user_nodes):
8889
users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {}
8990
for user in user_nodes:
9091
if user.graph is None:
@@ -98,9 +99,12 @@ def _get_candidate_groups(self, user_nodes):
9899
target_signature = (user.op, target_key)
99100
users_by_target.setdefault(target_signature, []).append(user)
100101

101-
candidate_groups = [
102-
group for group in users_by_target.values() if len(group) > 1
103-
]
102+
candidate_groups = []
103+
for group in users_by_target.values():
104+
if len(group) > 1:
105+
candidate_groups.append(
106+
sorted(group, key=lambda node: node_order[node])
107+
)
104108

105109
return candidate_groups
106110

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
11+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
12+
ComputeConstantOpsAOTPass,
13+
)
14+
from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER
15+
from executorch.backends.arm.tosa.specification import (
16+
get_context_shape_env,
17+
get_context_spec,
18+
)
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.pass_base import ExportPass
21+
22+
23+
class RewriteAdaptiveAvgPool2dPass(ArmPass):
24+
"""Rewrite dynamic adaptive average pooling to tosa.avg_pool2d_adaptive when
25+
possible.
26+
27+
The condition for rewriting is that symbolic input dimensions have a known
28+
remainder of 0 or 1 when divided by the static output dimensions. This
29+
preserves the adaptive pooling regions without materializing slice/cat
30+
decomposition.
31+
32+
"""
33+
34+
targeted_ops = {exir_ops.edge.aten._adaptive_avg_pool2d.default}
35+
_passes_required_after: Set[Type[ExportPass]] = {
36+
ComputeConstantOpsAOTPass,
37+
}
38+
39+
@staticmethod
40+
def _is_symbolic_dim(dim) -> bool:
41+
return isinstance(dim, torch.SymInt)
42+
43+
@staticmethod
44+
def _supports_dynamic_tosa_adaptive() -> bool:
45+
try:
46+
tosa_spec = get_context_spec()
47+
except Exception:
48+
return False
49+
return (
50+
tosa_spec.version.major == 1
51+
and tosa_spec.version.minor >= 1
52+
and tosa_spec.support_extension("shape")
53+
)
54+
55+
@classmethod
56+
def _get_pool_params(cls, input_size, output_size: int):
57+
if isinstance(output_size, torch.SymInt) or not isinstance(output_size, int):
58+
return None
59+
60+
remainder = input_size % output_size
61+
if cls._is_symbolic_dim(remainder):
62+
shape_env = get_context_shape_env()
63+
try:
64+
remainder_range = shape_env.bound_sympy(remainder.node.expr)
65+
except Exception:
66+
return None
67+
68+
if not remainder_range.is_singleton() or int(remainder_range.upper) not in (
69+
0,
70+
1,
71+
):
72+
return None
73+
74+
stride = input_size // output_size
75+
return stride + int(remainder_range.upper), stride
76+
77+
if remainder not in (0, 1):
78+
return None
79+
80+
stride = input_size // output_size
81+
return stride + remainder, stride
82+
83+
def call_operator(self, op, args, kwargs, meta, updated=False):
84+
if op not in self.targeted_ops:
85+
return super().call_operator(op, args, kwargs, meta, updated)
86+
87+
x = args[0]
88+
_, _, input_h, input_w = x.data.shape
89+
if not (self._is_symbolic_dim(input_h) or self._is_symbolic_dim(input_w)):
90+
return super().call_operator(op, args, kwargs, meta, updated)
91+
92+
# Dynamic adaptive lowering requires shape-aware TOSA support.
93+
if not self._supports_dynamic_tosa_adaptive():
94+
raise RuntimeError(
95+
"Dynamic adaptive_avg_pool2d rewrite requires TOSA-1.1 with the shape extension."
96+
)
97+
98+
output_h, output_w = args[1]
99+
h_params = self._get_pool_params(input_h, output_h)
100+
w_params = self._get_pool_params(input_w, output_w)
101+
# Fall back when either spatial dimension cannot be expressed as one TOSA adaptive pool.
102+
if h_params is None or w_params is None:
103+
return super().call_operator(op, args, kwargs, meta, updated)
104+
105+
kernel = [h_params[0], w_params[0]]
106+
stride = [h_params[1], w_params[1]]
107+
pad = [0, 0, 0, 0]
108+
pad = super().call_shape_operator(
109+
exir_ops.backend.tosa.CONST_SHAPE.default,
110+
(pad,),
111+
{},
112+
meta,
113+
)
114+
if all(isinstance(k, int) for k in kernel):
115+
kernel = super().call_shape_operator(
116+
exir_ops.backend.tosa.CONST_SHAPE.default,
117+
(kernel,),
118+
{},
119+
meta,
120+
)
121+
if all(isinstance(s, int) for s in stride):
122+
stride = super().call_shape_operator(
123+
exir_ops.backend.tosa.CONST_SHAPE.default,
124+
(stride,),
125+
{},
126+
meta,
127+
)
128+
129+
in_qparams = meta.data.get("input_qparams", {})
130+
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0
131+
input_zp = self.call_scalar(in_zp_val, meta)
132+
133+
out_qparams = meta.data.get("output_qparams", {})
134+
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0
135+
output_zp = self.call_scalar(out_zp_val, meta)
136+
137+
acc_type = (
138+
torch.int32 if x.data.dtype in (torch.int8, torch.int16) else torch.float32
139+
)
140+
pre_permute = super().call_operator(
141+
exir_ops.edge.aten.permute_copy.default,
142+
(x, list(NHWC_ORDER)),
143+
{},
144+
meta,
145+
True,
146+
)
147+
tosa_args = (
148+
pre_permute,
149+
input_zp,
150+
output_zp,
151+
kernel,
152+
stride,
153+
pad,
154+
acc_type,
155+
)
156+
157+
tosa_avg_pool = super().call_operator(
158+
exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default,
159+
tosa_args,
160+
{},
161+
meta,
162+
True,
163+
)
164+
return super().call_operator(
165+
exir_ops.edge.aten.permute_copy.default,
166+
(tosa_avg_pool, list(NHWC_INVERSE_ORDER)),
167+
{},
168+
meta,
169+
True,
170+
)

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
6565
# Materialize output zero-point as a scalar tensor
6666
output_zp = super().call_scalar(out_zp_val, meta)
6767

68-
# Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise
68+
# Determine accumulator dtype for AVG_POOL2D.
6969
if x.data.dtype in (torch.int8, torch.int16):
7070
acc_type = torch.int32
71+
elif x.data.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
72+
acc_type = torch.float16
7173
else:
7274
acc_type = torch.float32
7375

0 commit comments

Comments
 (0)