From e298f76750a2cc30680e8e816e017cbdb025ac34 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:50 -0800 Subject: [PATCH] [ET-VK][q8ta] Fix addmm arg indexing in QuantizedLinearMatch QuantizedLinearMatch always used args[1] for the weight and args[0] for the input, which is correct for mm(input, weight) and linear(input, weight, bias?) but wrong for addmm(bias, input, weight) where the weight is at args[2] and the input is at args[1]. This was exposed by a torchao change (D69887498) that added Linear+BatchNorm fusion to prepare_pt2e(). The fusion adds a bias to Linear nodes that previously had none, causing them to decompose to addmm instead of mm in the edge dialect. The pattern matcher then read the input's per-tensor dequantize scale (a float literal) as if it were the weight's per-channel scale (a Node), causing an assertion failure. The fix determines the correct arg indices based on whether the anchor node is addmm. The bias handling at args[0] for addmm was already correct. Authored-by: Claude Differential Revision: [D93768640](https://our.internmc.facebook.com/intern/diff/D93768640/) [ghstack-poisoned] --- backends/vulkan/custom_ops_lib.py | 4 +- backends/vulkan/patterns/quantized_linear.py | 61 +++++++++++++++++--- backends/vulkan/test/test_vulkan_passes.py | 46 +++++++++++++++ 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 78bc87bc159..db9cd731daf 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -389,7 +389,7 @@ def q8ta_linear( out = torch.nn.functional.linear(x, weights) if bias is not None: - out = out + bias + out = out + bias[: out.shape[-1]] out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 @@ -449,7 +449,7 @@ def q8ta_linear_gemv( out = torch.nn.functional.linear(x, weights) if bias is not None: - out = out + bias + out = out + bias[: out.shape[-1]] out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 1b6c64af8e3..74e0c1f74eb 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -36,8 +36,14 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.match_found = False self.all_nodes = [self.anchor_node] + # addmm(bias, mat1, mat2) has a different arg layout than + # mm(mat1, mat2) and linear(input, weight, bias?) + is_addmm = self.anchor_node.target == exir_ops.edge.aten.addmm.default + weight_arg_idx = 2 if is_addmm else 1 + input_arg_idx = 1 if is_addmm else 0 + const_node, arg_chain = utils.trace_args_until_placeholder( - self.anchor_node.args[1] + self.anchor_node.args[weight_arg_idx] ) # mat2 is not a constant tensor - no match @@ -84,19 +90,12 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node - # The implementation has a limitation that output channels must be a - # multiple of 4. This is to ensure that data loads are aligned well with - # texel boundaries. If this is not true, then don't match the pattern. - out_channels = self.output_node.meta["val"].shape[-1] - if out_channels % 4 != 0: - return - # Identify input node ( self.fp_input_node, self.quantize_input_node, dq_node, - ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) + ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[input_arg_idx]) assert self.fp_input_node is not None self.all_nodes.append(self.fp_input_node) @@ -442,12 +441,34 @@ def make_linear_q8ta_q8csw_custom_op( match: QuantizedLinearMatch, weight_tensor: torch.Tensor, ): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + first_graph_node = list(graph_module.graph.nodes)[0] with graph_module.graph.inserting_before(first_graph_node): weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point # when using integer accumulation. sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = weight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") @@ -484,10 +505,32 @@ def make_q8ta_linear_custom_op( match: QuantizedLinearMatch, weight_tensor: torch.Tensor, ): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + first_graph_node = list(graph_module.graph.nodes)[0] with graph_module.graph.inserting_before(first_graph_node): weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = weight_tensor_name + "_sums" sums_name = sums_name.replace(".", "_") diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 3488357d155..4efdf73182a 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -283,3 +283,49 @@ def forward(self, x): 1, "Expected at least one q8ta_linear_gemv op for batch-1 linear fusion", ) + + def test_fuse_q8ta_linear_gemv_non_aligned_oc(self): + """Test that quantized linear with non-aligned output channels (not multiple of 4) fuses correctly.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Use non-aligned output channels (9 is not a multiple of 4) + self.linear1 = torch.nn.Linear(128, 9, bias=False) + self.linear2 = torch.nn.Linear(9, 4, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # The first linear (OC=9, not multiple of 4) should still fuse + q8ta_linear_gemv_count = op_node_count(gm, "et_vk__q8ta_linear_gemv__default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected non-aligned OC linear to fuse into q8ta_linear_gemv", + )