diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 291b6ca7760..5802695f801 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -74,6 +74,18 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass): exir_ops.edge.aten.linear.default, } + # Broadcasting binary ops whose operands must share one memory format. A + # dynamic-quant input_to_nhwc may blanket-replace a shared activation + # (replace_all_uses_with), switching one operand of an already-processed + # binary op to NHWC while its other (e.g. per-channel constant) operand + # stays NCHW. These ops are re-converged after the main traversal. + broadcast_binary_ops = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.div.Tensor, + } + # Tag which is added to a node's meta to indicate that it uses NHWC format. # A constant data tensor with this tag assigned for use in a particular # format in one place cannot be used in other places in the other format @@ -567,4 +579,47 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 # to retrace the graph and regenerate metadata graph_module = super().call(graph_module).graph_module + # Re-converge broadcasting binary ops whose operand layouts diverged. + # The main traversal above processes nodes in topological order, but the + # dynamic-quant path of input_to_nhwc calls replace_all_uses_with, which + # can retroactively switch a binary op's activation operand to NHWC after + # the op was already processed - without converging its other operand + # (e.g. a per-channel constant that stays NCHW). XNNPACK then fails at + # runtime in xnn_reshape_binary_elementwise_nd because the operand shapes + # are not broadcast-compatible. Run this after the retrace above so it + # sees the settled graph, then retrace once more if anything changed. + reconverged = False + for node in list(graph_module.graph.nodes): + if ( + node.op != "call_function" + or node.target not in ChannelsLastTaggedReshapePass.broadcast_binary_ops + ): + continue + input_nodes = node.all_input_nodes + if len(input_nodes) != 2: + continue + layouts = [ + ChannelsLastTaggedReshapePass.is_nhwc_node(input_node) + for input_node in input_nodes + ] + if layouts[0] == layouts[1]: + continue + if all( + self.can_be_converted_to_nhwc(input_node) for input_node in input_nodes + ): + for input_node in input_nodes: + self.input_to_nhwc(graph_module, input_node, node) + self.mark_as_nhwc_node(node) + else: + for input_node in input_nodes: + self.input_to_nchw(graph_module, input_node, node) + reconverged = True + + if reconverged: + graph_module.recompile() + for node in graph_module.graph.nodes: + if ChannelsLastTaggedReshapePass.PARTNER_NODE in node.meta: + node.meta.pop(ChannelsLastTaggedReshapePass.PARTNER_NODE) + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index f4a52f25830..adf1c694b22 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -665,3 +665,90 @@ def test_dq_conv_cat_immutable_list(self): .run_passes(self.PassStage) .run_method_and_compare_outputs() ) + + class DynamicQuantPerChannelBinaryChain(torch.nn.Module): + """A per-channel broadcasting binary op that is the first consumer of an + input activation, followed by a dynamically-quantized convolution. + + This reproduces the graph shape that the dynamic-quant path of + ``input_to_nhwc`` mishandles. The input activation feeds a per-channel + ``mul``/``add`` (an NCHW ``[1, C, 1, 1]`` constant operand), and the + convolution chain runs *through* that binary op. When the convolution + requests NHWC, ``input_to_nhwc`` traces back through the binary op to the + input activation and calls ``replace_all_uses_with``, switching the binary + op's activation operand to NHWC while its constant operand stays NCHW. At + runtime XNNPACK then fails in ``xnn_reshape_binary_elementwise_nd`` with + ``xnn_status_invalid_parameter`` because ``[1, H, W, C]`` and + ``[1, C, 1, 1]`` are not broadcast-compatible -- unless the pass + re-converges the binary op's operands. + + The quantize/dequantize ops are emitted directly (rather than via the + quantizer) so the graph reliably reproduces the shared back-traced source; + the whole graph delegates to XNNPACK. + """ + + def __init__(self): + super().__init__() + out_channels, in_channels, kernel = 8, 8, 3 + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (out_channels, in_channels, kernel, kernel), + dtype=torch.int8, + ), + ) + self.register_buffer( + "weight_scale", torch.rand(out_channels) * 0.02 + 0.001 + ) + self.register_buffer( + "weight_zero_point", torch.zeros(out_channels, dtype=torch.int64) + ) + self.register_buffer("scale", torch.rand(1, out_channels, 1, 1) + 0.5) + self.register_buffer("bias", torch.rand(1, out_channels, 1, 1)) + + def forward(self, activation): + qd = torch.ops.quantized_decomposed + # Per-channel binary op as the first consumer of the input activation. + scaled = activation * self.scale + self.bias + relued = torch.relu(scaled) + # Dynamic (runtime-chosen) quantization feeding the convolution. + q_scale, q_zero_point = qd.choose_qparams.tensor( + relued, -128, 127, 1e-5, torch.int8 + ) + dequantized = qd.dequantize_per_tensor.tensor( + qd.quantize_per_tensor.tensor( + relued, q_scale, q_zero_point, -128, 127, torch.int8 + ), + q_scale, + q_zero_point, + -128, + 127, + torch.int8, + ) + weight = qd.dequantize_per_channel( + self.weight, + self.weight_scale, + self.weight_zero_point, + 0, + -127, + 127, + torch.int8, + ) + return torch.nn.functional.conv2d(dequantized, weight, padding=1) + + def test_dynamic_quant_per_channel_binary_chain_lowers_and_runs(self): + # Regression test: the full XNNPACK lowering of this graph must run + # without an xnn_status_invalid_parameter from a binary op whose operands + # ended up in mismatched (NHWC vs NCHW) memory formats. + model = self.DynamicQuantPerChannelBinaryChain().eval() + activation = torch.randn(1, 8, 16, 16) + ( + Tester(model, (activation,)) + .export() + .to_edge_transform_and_lower() + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + )