Skip to content

Commit 194bb19

Browse files
committed
XNNPACK: Fix cat qparams for quantized ViT token paths
Bug: Cat annotation always used the first input as the shared qparam source. In DeiT Tiny, the first input to token concatenation is the class-token path: cls_token -> expand -> cat The patch-token path is a later cat input: conv -> flatten -> transpose -> cat The conv output has annotated activation qparams, and flatten/transpose are qparam-preserving view/layout ops. They should carry the same qparams with SharedQuantizationSpec. Anchoring cat qparams on the class-token path can leave the patch-token static transpose with different input/output qparams. XNNPACK rejects this during runtime initialization because static transpose only reorders bytes. Fix: Choose the first cat input that traces through qparam-preserving ops to annotated output qparams, falling back to the first input otherwise. Also propagate shared qparams through reshape and transpose so static transpose nodes keep identical input and output qparams. Add a quantized XNNPACK regression covering conv, flatten, transpose, cat, and add. Change-Id: I86fafd584c1cb561bd2d4444ea70c1a1b0650066 Signed-off-by: Måns Nilsson <mans.nilsson@arm.com>
1 parent e88fd04 commit 194bb19

2 files changed

Lines changed: 101 additions & 7 deletions

File tree

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
# mypy: allow-untyped-defs
27
import itertools
38
from typing import Callable, Optional
@@ -995,6 +1000,48 @@ def _annotate_mul(
9951000
return annotated_partitions
9961001

9971002

1003+
def _has_annotated_qparam_source(node: Node) -> bool:
1004+
"""Check whether a node traces to annotated output qparams.
1005+
1006+
Walk backward through qparam-preserving view/layout ops to find an already
1007+
annotated activation producer.
1008+
"""
1009+
visited: set[Node] = set()
1010+
while node not in visited:
1011+
visited.add(node)
1012+
1013+
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None)
1014+
if (
1015+
quantization_annotation is not None
1016+
and quantization_annotation._annotated
1017+
and quantization_annotation.output_qspec is not None
1018+
):
1019+
return True
1020+
1021+
if node.op != "call_function" or not _is_share_obs_or_fq_op(node.target):
1022+
return False
1023+
1024+
prev_node = node.args[0]
1025+
if not isinstance(prev_node, Node):
1026+
return False
1027+
node = prev_node
1028+
1029+
return False
1030+
1031+
1032+
def _get_cat_qparam_source(inputs) -> object:
1033+
"""Choose the input that should own a cat node's shared qparams.
1034+
1035+
Prefer the first input that traces to annotated output qparams. Fall back to
1036+
the first input otherwise.
1037+
"""
1038+
for input_act in inputs:
1039+
if isinstance(input_act, Node) and _has_annotated_qparam_source(input_act):
1040+
return input_act
1041+
1042+
return inputs[0]
1043+
1044+
9981045
# TODO: remove Optional in return type, fix annotated_partitions logic
9991046
@register_annotator("cat")
10001047
def _annotate_cat(
@@ -1014,18 +1061,20 @@ def _annotate_cat(
10141061

10151062
input_act_qspec = get_input_act_qspec(quantization_config)
10161063
inputs = cat_node.args[0]
1064+
input_act_qparam_source = _get_cat_qparam_source(inputs)
10171065

10181066
input_qspec_map = {}
1019-
input_act0 = inputs[0] # type: ignore[index]
1020-
if isinstance(input_act0, Node):
1021-
input_qspec_map[input_act0] = input_act_qspec
1067+
if isinstance(input_act_qparam_source, Node):
1068+
input_qspec_map[input_act_qparam_source] = input_act_qspec
10221069

1023-
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) # type: ignore[arg-type]
1024-
for input_act in inputs[1:]: # type: ignore[index, union-attr]
1070+
shared_with_source_qspec = SharedQuantizationSpec(
1071+
(input_act_qparam_source, cat_node) # type: ignore[arg-type]
1072+
)
1073+
for input_act in inputs: # type: ignore[union-attr]
10251074
if input_act not in input_qspec_map:
1026-
input_qspec_map[input_act] = shared_with_input0_qspec # type: ignore[index]
1075+
input_qspec_map[input_act] = shared_with_source_qspec # type: ignore[index]
10271076

1028-
output_act_qspec = shared_with_input0_qspec
1077+
output_act_qspec = shared_with_source_qspec
10291078

10301079
cat_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
10311080
input_qspec_map=input_qspec_map,
@@ -1045,12 +1094,14 @@ def _is_share_obs_or_fq_op(op: Callable) -> bool:
10451094
torch.ops.aten.mean.dim,
10461095
torch.ops.aten.permute.default,
10471096
torch.ops.aten.permute_copy.default,
1097+
torch.ops.aten.transpose.int,
10481098
torch.ops.aten.squeeze.dim,
10491099
torch.ops.aten.squeeze_copy.dim,
10501100
# TODO: remove?
10511101
torch.ops.aten.adaptive_avg_pool2d.default,
10521102
torch.ops.aten.view_copy.default,
10531103
torch.ops.aten.view.default,
1104+
torch.ops.aten.reshape.default,
10541105
torch.ops.aten.slice_copy.Tensor,
10551106
torch.ops.aten.flatten.using_ints,
10561107
]

backends/xnnpack/test/ops/test_cat.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -199,6 +200,48 @@ def test_qs8_cat_with_empty_tensor(self):
199200
)
200201
self._test_cat(self.Cat(), inputs, cat_num=4, quant=True)
201202

203+
class CatAfterConvAndTranspose(torch.nn.Module):
204+
def __init__(self):
205+
super().__init__()
206+
self.proj = torch.nn.Conv2d(
207+
in_channels=3,
208+
out_channels=8,
209+
kernel_size=(4, 4),
210+
stride=(4, 4),
211+
bias=False,
212+
)
213+
self.cls_token = torch.nn.Parameter(torch.full((1, 1, 8), 4.0))
214+
self.pos_embed = torch.nn.Parameter(torch.full((1, 5, 8), 0.125))
215+
216+
with torch.no_grad():
217+
self.proj.weight.fill_(0.025)
218+
219+
def forward(self, x):
220+
patch_tokens = self.proj(x).flatten(2).transpose(1, 2)
221+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
222+
tokens = torch.cat((cls_token, patch_tokens), dim=1)
223+
return tokens + self.pos_embed
224+
225+
def test_qs8_cat_uses_annotated_transpose_path_qparams(self):
226+
inputs = (torch.randn(1, 3, 8, 8),)
227+
(
228+
Tester(self.CatAfterConvAndTranspose(), inputs)
229+
.quantize()
230+
.export()
231+
.check_count({"torch.ops.aten.cat": 1})
232+
.to_edge_transform_and_lower()
233+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
234+
.check_not(
235+
[
236+
"executorch_exir_dialects_edge__ops_aten_cat",
237+
"torch.ops.quantized_decomposed",
238+
]
239+
)
240+
.to_executorch()
241+
.serialize()
242+
.run_method_and_compare_outputs(inputs=inputs)
243+
)
244+
202245
class CatNegativeDim(torch.nn.Module):
203246
def __init__(self):
204247
super().__init__()

0 commit comments

Comments
 (0)