From 2658037a638676535178304438d769f211558fa0 Mon Sep 17 00:00:00 2001 From: Andrew Pullin Date: Thu, 18 Jun 2026 14:11:12 -0700 Subject: [PATCH] Quantize moveaxis/movedim so they delegate to Ethos-U (#20314) Summary: The ARM PT2 quantizer's pass-through shared-qspec set in quantization_annotator.py (_one_to_one_shared_input_qspec) covers permute/permute_copy/transpose/view/squeeze etc., but omits aten.moveaxis/aten.movedim. A model that uses torch.moveaxis therefore leaves those ops unquantized: the quantizer brackets each one with dequantize -> moveaxis(float) -> quantize. On lowering, moveaxis decomposes to a float permute_copy. The Ethos-U55 operator-support check (operator_support/ethos_u55_support.py) only delegates permute_copy for int8/int16/int32, so it rejects the float one. Each rejected permute is stranded on the host, splitting the model into many delegated partitions (one NPU island per permute), which bloats the .pte with per-partition delegate overhead and host round-trips. Add aten.moveaxis.int / aten.movedim.int to _one_to_one_shared_input_qspec (guarded with getattr for torch-build variance, mirroring the existing transpose.Dimname handling) so they share the input quantization spec exactly like transpose/permute. They then stay int8, decompose to int8 permute_copy, and delegate to the NPU -- eliminating the host float islands. Impact: a quantized example ensemble (ConvNeXt-style blocks that use torch.moveaxis) that previously lowered into 9 Ethos-U55 partitions now lowers into a single delegate, with zero host permutes and ~24% smaller .pte, with no model changes. Generalizes to any moveaxis/movedim-using model on the Ethos-U backend. Differential Revision: D108478011 --- .../arm/quantizer/quantization_annotator.py | 5 +++++ backends/arm/test/ops/test_permute.py | 17 +++++++++++++++++ .../test/quantizer/test_generic_annotater.py | 13 ++++++++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 7810077a679..c61a4741af5 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -619,6 +619,10 @@ def _get_fixed_qparams_qspec( # dequant -> neg -> requant chain. torch.ops.aten.neg.default, torch.ops.aten.detach_copy.default, + torch.ops.aten.moveaxis.int, + torch.ops.aten.moveaxis.intlist, + torch.ops.aten.movedim.int, + torch.ops.aten.movedim.intlist, } # Dimname has been removed from upstream PyTorch, but there may be a window @@ -630,6 +634,7 @@ def _get_fixed_qparams_qspec( if _transpose_dimname is not None: _one_to_one_shared_input_qspec.add(_transpose_dimname) + _one_to_one_shared_input_or_input_act_qspec: set[OpOverload] = { torch.ops.aten.alias.default, torch.ops.aten.clone.default, diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 8864324dbd5..6819929104e 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -78,6 +78,12 @@ def forward(self, x): return torch.permute(x, self.dims) +class SimpleMoveAxis(torch.nn.Module): + + def forward(self, x): + return torch.moveaxis(x, 1, -1) + + @common.parametrize( "test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16 ) @@ -118,6 +124,17 @@ def test_permute_u55_INT(test_data): pipeline.run() +def test_moveaxis_u55_INT(): + pipeline = EthosU55PipelineINT[input_t1]( + SimpleMoveAxis(), + (torch.rand(1, 4, 5, 6),), + "torch.ops.aten.moveaxis.int", + exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default", + run_on_fvp=False, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite_u55_reject) def test_permute_u55_INT_not_delegated(test_data: torch.Tensor): test_data, dims = test_data() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index dd883e72b1f..4fb57d37054 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Tuple import torch -from executorch.backends.arm.quantizer import is_annotated +from executorch.backends.arm.quantizer import is_annotated, quantization_annotator from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT from executorch.backends.test.harness.stages import StageType @@ -89,6 +89,17 @@ def test_transpose_tosa_INT(): ) +def test_moveaxis_movedim_shared_qspec_annotations(): + expected_ops = { + torch.ops.aten.moveaxis.int, + torch.ops.aten.moveaxis.intlist, + torch.ops.aten.movedim.int, + torch.ops.aten.movedim.intlist, + } + + assert expected_ops <= quantization_annotator._one_to_one_shared_input_qspec + + def test_tile_tosa_INT(): check_annotation( SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),