Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
exir_ops.edge.aten.linear.default,
}

# Broadcasting binary ops whose operands must share one memory format. A
# dynamic-quant input_to_nhwc may blanket-replace a shared activation
# (replace_all_uses_with), switching one operand of an already-processed
# binary op to NHWC while its other (e.g. per-channel constant) operand
# stays NCHW. These ops are re-converged after the main traversal.
broadcast_binary_ops = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.div.Tensor,
}

# Tag which is added to a node's meta to indicate that it uses NHWC format.
# A constant data tensor with this tag assigned for use in a particular
# format in one place cannot be used in other places in the other format
Expand Down Expand Up @@ -567,4 +579,47 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
# to retrace the graph and regenerate metadata
graph_module = super().call(graph_module).graph_module

# Re-converge broadcasting binary ops whose operand layouts diverged.
# The main traversal above processes nodes in topological order, but the
# dynamic-quant path of input_to_nhwc calls replace_all_uses_with, which
# can retroactively switch a binary op's activation operand to NHWC after
# the op was already processed - without converging its other operand
# (e.g. a per-channel constant that stays NCHW). XNNPACK then fails at
# runtime in xnn_reshape_binary_elementwise_nd because the operand shapes
# are not broadcast-compatible. Run this after the retrace above so it
# sees the settled graph, then retrace once more if anything changed.
reconverged = False
for node in list(graph_module.graph.nodes):
if (
node.op != "call_function"
or node.target not in ChannelsLastTaggedReshapePass.broadcast_binary_ops
):
continue
input_nodes = node.all_input_nodes
if len(input_nodes) != 2:
continue
layouts = [
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
for input_node in input_nodes
]
if layouts[0] == layouts[1]:
continue
if all(
self.can_be_converted_to_nhwc(input_node) for input_node in input_nodes
):
for input_node in input_nodes:
self.input_to_nhwc(graph_module, input_node, node)
self.mark_as_nhwc_node(node)
else:
for input_node in input_nodes:
self.input_to_nchw(graph_module, input_node, node)
reconverged = True

if reconverged:
graph_module.recompile()
for node in graph_module.graph.nodes:
if ChannelsLastTaggedReshapePass.PARTNER_NODE in node.meta:
node.meta.pop(ChannelsLastTaggedReshapePass.PARTNER_NODE)
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,90 @@ def test_dq_conv_cat_immutable_list(self):
.run_passes(self.PassStage)
.run_method_and_compare_outputs()
)

class DynamicQuantPerChannelBinaryChain(torch.nn.Module):
"""A per-channel broadcasting binary op that is the first consumer of an
input activation, followed by a dynamically-quantized convolution.

This reproduces the graph shape that the dynamic-quant path of
``input_to_nhwc`` mishandles. The input activation feeds a per-channel
``mul``/``add`` (an NCHW ``[1, C, 1, 1]`` constant operand), and the
convolution chain runs *through* that binary op. When the convolution
requests NHWC, ``input_to_nhwc`` traces back through the binary op to the
input activation and calls ``replace_all_uses_with``, switching the binary
op's activation operand to NHWC while its constant operand stays NCHW. At
runtime XNNPACK then fails in ``xnn_reshape_binary_elementwise_nd`` with
``xnn_status_invalid_parameter`` because ``[1, H, W, C]`` and
``[1, C, 1, 1]`` are not broadcast-compatible -- unless the pass
re-converges the binary op's operands.

The quantize/dequantize ops are emitted directly (rather than via the
quantizer) so the graph reliably reproduces the shared back-traced source;
the whole graph delegates to XNNPACK.
"""

def __init__(self):
super().__init__()
out_channels, in_channels, kernel = 8, 8, 3
self.register_buffer(
"weight",
torch.randint(
-127,
127,
(out_channels, in_channels, kernel, kernel),
dtype=torch.int8,
),
)
self.register_buffer(
"weight_scale", torch.rand(out_channels) * 0.02 + 0.001
)
self.register_buffer(
"weight_zero_point", torch.zeros(out_channels, dtype=torch.int64)
)
self.register_buffer("scale", torch.rand(1, out_channels, 1, 1) + 0.5)
self.register_buffer("bias", torch.rand(1, out_channels, 1, 1))

def forward(self, activation):
qd = torch.ops.quantized_decomposed
# Per-channel binary op as the first consumer of the input activation.
scaled = activation * self.scale + self.bias
relued = torch.relu(scaled)
# Dynamic (runtime-chosen) quantization feeding the convolution.
q_scale, q_zero_point = qd.choose_qparams.tensor(
relued, -128, 127, 1e-5, torch.int8
)
dequantized = qd.dequantize_per_tensor.tensor(
qd.quantize_per_tensor.tensor(
relued, q_scale, q_zero_point, -128, 127, torch.int8
),
q_scale,
q_zero_point,
-128,
127,
torch.int8,
)
weight = qd.dequantize_per_channel(
self.weight,
self.weight_scale,
self.weight_zero_point,
0,
-127,
127,
torch.int8,
)
return torch.nn.functional.conv2d(dequantized, weight, padding=1)

def test_dynamic_quant_per_channel_binary_chain_lowers_and_runs(self):
# Regression test: the full XNNPACK lowering of this graph must run
# without an xnn_status_invalid_parameter from a binary op whose operands
# ended up in mismatched (NHWC vs NCHW) memory formats.
model = self.DynamicQuantPerChannelBinaryChain().eval()
activation = torch.randn(1, 8, 16, 16)
(
Tester(model, (activation,))
.export()
.to_edge_transform_and_lower()
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
Loading