From 289a5704a325c7cba5baf4fb67e8e647a8ef8a66 Mon Sep 17 00:00:00 2001 From: Hyungkeun-Park-Nota Date: Thu, 18 Jun 2026 07:22:54 +0000 Subject: [PATCH] Fix XNNPACK channels-last reshape for per-channel binary ops under dynamic quant ChannelsLastTaggedReshapePass.input_to_nhwc has a dynamic-quant branch that calls input_node.replace_all_uses_with(input_node_nhwc), globally redirecting a shared activation to its NHWC copy. The main traversal visits nodes in topological order, so this can retroactively switch a broadcasting binary op's activation operand to NHWC after that op was already processed, while its other operand (e.g. a per-channel constant) stays NCHW. The two operands then have incompatible logical shapes (e.g. [1, H, W, C] vs [1, C, 1, 1]), and XNNPACK fails at runtime in xnn_reshape_binary_elementwise_nd with xnn_status_invalid_parameter. Lowering succeeds; only execution fails (observed on DETR with w8a8 dynamic quantization). Add a re-convergence sweep after the main traversal that restores the pass's own invariant -- all operands of a node share one memory format -- for broadcasting binary ops (add/mul/sub/div), converging to NHWC when possible (else NCHW). It runs after the retrace so it observes the settled graph, and retraces once more only if anything changed. Graphs with no diverged binary operands are untouched. Add a regression test (DynamicQuantPerChannelBinaryChain) whose per-channel binary op is the first consumer of an input activation, with the convolution chain running through it; it fails to execute (xnn_status_invalid_parameter) without this fix and passes with it. --- .../channels_last_tagged_reshape_pass.py | 55 ++++++++++++ .../test_channels_last_tagged_reshape.py | 87 +++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 291b6ca7760..5802695f801 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -74,6 +74,18 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass): exir_ops.edge.aten.linear.default, } + # Broadcasting binary ops whose operands must share one memory format. A + # dynamic-quant input_to_nhwc may blanket-replace a shared activation + # (replace_all_uses_with), switching one operand of an already-processed + # binary op to NHWC while its other (e.g. per-channel constant) operand + # stays NCHW. These ops are re-converged after the main traversal. + broadcast_binary_ops = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.div.Tensor, + } + # Tag which is added to a node's meta to indicate that it uses NHWC format. # A constant data tensor with this tag assigned for use in a particular # format in one place cannot be used in other places in the other format @@ -567,4 +579,47 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 # to retrace the graph and regenerate metadata graph_module = super().call(graph_module).graph_module + # Re-converge broadcasting binary ops whose operand layouts diverged. + # The main traversal above processes nodes in topological order, but the + # dynamic-quant path of input_to_nhwc calls replace_all_uses_with, which + # can retroactively switch a binary op's activation operand to NHWC after + # the op was already processed - without converging its other operand + # (e.g. a per-channel constant that stays NCHW). XNNPACK then fails at + # runtime in xnn_reshape_binary_elementwise_nd because the operand shapes + # are not broadcast-compatible. Run this after the retrace above so it + # sees the settled graph, then retrace once more if anything changed. + reconverged = False + for node in list(graph_module.graph.nodes): + if ( + node.op != "call_function" + or node.target not in ChannelsLastTaggedReshapePass.broadcast_binary_ops + ): + continue + input_nodes = node.all_input_nodes + if len(input_nodes) != 2: + continue + layouts = [ + ChannelsLastTaggedReshapePass.is_nhwc_node(input_node) + for input_node in input_nodes + ] + if layouts[0] == layouts[1]: + continue + if all( + self.can_be_converted_to_nhwc(input_node) for input_node in input_nodes + ): + for input_node in input_nodes: + self.input_to_nhwc(graph_module, input_node, node) + self.mark_as_nhwc_node(node) + else: + for input_node in input_nodes: + self.input_to_nchw(graph_module, input_node, node) + reconverged = True + + if reconverged: + graph_module.recompile() + for node in graph_module.graph.nodes: + if ChannelsLastTaggedReshapePass.PARTNER_NODE in node.meta: + node.meta.pop(ChannelsLastTaggedReshapePass.PARTNER_NODE) + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index f4a52f25830..adf1c694b22 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -665,3 +665,90 @@ def test_dq_conv_cat_immutable_list(self): .run_passes(self.PassStage) .run_method_and_compare_outputs() ) + + class DynamicQuantPerChannelBinaryChain(torch.nn.Module): + """A per-channel broadcasting binary op that is the first consumer of an + input activation, followed by a dynamically-quantized convolution. + + This reproduces the graph shape that the dynamic-quant path of + ``input_to_nhwc`` mishandles. The input activation feeds a per-channel + ``mul``/``add`` (an NCHW ``[1, C, 1, 1]`` constant operand), and the + convolution chain runs *through* that binary op. When the convolution + requests NHWC, ``input_to_nhwc`` traces back through the binary op to the + input activation and calls ``replace_all_uses_with``, switching the binary + op's activation operand to NHWC while its constant operand stays NCHW. At + runtime XNNPACK then fails in ``xnn_reshape_binary_elementwise_nd`` with + ``xnn_status_invalid_parameter`` because ``[1, H, W, C]`` and + ``[1, C, 1, 1]`` are not broadcast-compatible -- unless the pass + re-converges the binary op's operands. + + The quantize/dequantize ops are emitted directly (rather than via the + quantizer) so the graph reliably reproduces the shared back-traced source; + the whole graph delegates to XNNPACK. + """ + + def __init__(self): + super().__init__() + out_channels, in_channels, kernel = 8, 8, 3 + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (out_channels, in_channels, kernel, kernel), + dtype=torch.int8, + ), + ) + self.register_buffer( + "weight_scale", torch.rand(out_channels) * 0.02 + 0.001 + ) + self.register_buffer( + "weight_zero_point", torch.zeros(out_channels, dtype=torch.int64) + ) + self.register_buffer("scale", torch.rand(1, out_channels, 1, 1) + 0.5) + self.register_buffer("bias", torch.rand(1, out_channels, 1, 1)) + + def forward(self, activation): + qd = torch.ops.quantized_decomposed + # Per-channel binary op as the first consumer of the input activation. + scaled = activation * self.scale + self.bias + relued = torch.relu(scaled) + # Dynamic (runtime-chosen) quantization feeding the convolution. + q_scale, q_zero_point = qd.choose_qparams.tensor( + relued, -128, 127, 1e-5, torch.int8 + ) + dequantized = qd.dequantize_per_tensor.tensor( + qd.quantize_per_tensor.tensor( + relued, q_scale, q_zero_point, -128, 127, torch.int8 + ), + q_scale, + q_zero_point, + -128, + 127, + torch.int8, + ) + weight = qd.dequantize_per_channel( + self.weight, + self.weight_scale, + self.weight_zero_point, + 0, + -127, + 127, + torch.int8, + ) + return torch.nn.functional.conv2d(dequantized, weight, padding=1) + + def test_dynamic_quant_per_channel_binary_chain_lowers_and_runs(self): + # Regression test: the full XNNPACK lowering of this graph must run + # without an xnn_status_invalid_parameter from a binary op whose operands + # ended up in mismatched (NHWC vs NCHW) memory formats. + model = self.DynamicQuantPerChannelBinaryChain().eval() + activation = torch.randn(1, 8, 16, 16) + ( + Tester(model, (activation,)) + .export() + .to_edge_transform_and_lower() + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )