diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 7f891409e41..87506f0b773 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -390,7 +390,7 @@ def q8ta_linear( out = torch.nn.functional.linear(x, weights) if bias is not None: - out = out + bias + out = out + bias[: out.shape[-1]] if activation == "relu": out = torch.nn.functional.relu(out) @@ -455,7 +455,7 @@ def q8ta_linear_gemv( out = torch.nn.functional.linear(x, weights) if bias is not None: - out = out + bias + out = out + bias[: out.shape[-1]] if activation == "relu": out = torch.nn.functional.relu(out) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index f1bcfc775bc..df80749e72f 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -36,8 +36,14 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.match_found = False self.all_nodes = [self.anchor_node] + # addmm(bias, mat1, mat2) has a different arg layout than + # mm(mat1, mat2) and linear(input, weight, bias?) + is_addmm = self.anchor_node.target == exir_ops.edge.aten.addmm.default + weight_arg_idx = 2 if is_addmm else 1 + input_arg_idx = 1 if is_addmm else 0 + const_node, arg_chain = utils.trace_args_until_placeholder( - self.anchor_node.args[1] + self.anchor_node.args[weight_arg_idx] ) # mat2 is not a constant tensor - no match @@ -84,26 +90,64 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node - # The implementation has a limitation that output channels must be a - # multiple of 4. This is to ensure that data loads are aligned well with - # texel boundaries. If this is not true, then don't match the pattern. - out_channels = self.output_node.meta["val"].shape[-1] - if out_channels % 4 != 0: - return + # Identify primary input node of the anchor. Due to decomposition of aten.linear + # there may be a view_copy node between the original input tensor to the linear + # op and the actual linear op node. + anchor_primary_input_node = self.anchor_node.args[input_arg_idx] + assert isinstance(anchor_primary_input_node, torch.fx.Node) + + # Skip potential view_copy between dq and linear + if utils.is_view_copy_node(anchor_primary_input_node): + self.all_nodes.append(anchor_primary_input_node) + anchor_primary_input_node = anchor_primary_input_node.args[ + 0 + ] # pyre-ignore[16] + assert isinstance(anchor_primary_input_node, torch.fx.Node) + + # By default, assume that the input tensor is not quantized in any way + self.quantize_input_node = None + self.dequantize_input_node = None + self.pattern_input_node = anchor_primary_input_node + + self.input_scales_node = None + self.input_zeros_node = None + + scales_arg_idx = 1 + zeros_arg_idx = 2 - # Identify input node - ( - self.fp_input_node, - self.quantize_input_node, - dq_node, - ) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0]) - assert self.fp_input_node is not None - self.all_nodes.append(self.fp_input_node) + # If the primary input node comes from a dequantize node, that implies the input + # input tensor is quantized (either statically or dynamically). + if utils.is_dequant_node(anchor_primary_input_node): + # Assume that this is a static quantization pattern; the input to the + # pattern is a statically quantized int8 tensor. + self.dequantize_input_node = anchor_primary_input_node + self.all_nodes.append(self.dequantize_input_node) + input_to_dq_node = self.dequantize_input_node.args[0] + self.pattern_input_node = input_to_dq_node + + # torchao dequantize has a slightly different function schema + if ( + self.dequantize_input_node.target + == exir_ops.edge.torchao.dequantize_affine.default + ): + scales_arg_idx = 2 + zeros_arg_idx = 3 + + self.input_scales_node = self.dequantize_input_node.args[scales_arg_idx] + self.input_zeros_node = self.dequantize_input_node.args[zeros_arg_idx] + + # Check for dynamic quantization: input scales are dynamically + # computed via a choose_qparams op + if utils.is_quant_node(input_to_dq_node) and utils.is_dynamic_qscale( + self.input_scales_node + ): + self.quantize_input_node = input_to_dq_node + self.pattern_input_node = self.quantize_input_node.args[0] # The implementation has a limitation that input channels must be a # multiple of 4. This is to ensure that data loads are aligned well with # texel boundaries. If this is not true, then don't match the pattern. - in_channels = self.fp_input_node.meta["val"].shape[-1] + in_channels = self.pattern_input_node.meta["val"].shape[-1] if in_channels % 4 != 0: return @@ -124,32 +168,10 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.all_nodes.extend(arg_chain) # If input is not quantized, then we are done - if self.quantize_input_node is None: + if self.dequantize_input_node is None: self.match_found = True return - scales_arg_idx = 1 - zeros_arg_idx = 2 - - # torchao op has a slightly different function schema - if ( - self.quantize_input_node.target - == exir_ops.edge.torchao.quantize_affine.default - ): - scales_arg_idx = 2 - zeros_arg_idx = 3 - - self.input_scales_node = self.quantize_input_node.args[scales_arg_idx] - self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx] - - assert dq_node is not None - self.all_nodes.extend( - [ - self.quantize_input_node, - dq_node, - ] - ) - # Check if the output is also quantized (q → dq → linear → q pattern) # Also handle fused linear+relu (q → dq → linear → relu → q pattern) self.quantize_output_node = None @@ -172,7 +194,7 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.match_found = True def is_weight_only_quantized(self) -> bool: - return self.quantize_input_node is None + return self.dequantize_input_node is None def has_output_quantization(self) -> bool: return ( @@ -204,7 +226,7 @@ def is_weight_perchannel_quantized(self) -> bool: return scales_shape[0] == weight_shape[-2] def is_input_static_per_tensor_quantized(self) -> bool: - if self.quantize_input_node is None: + if self.dequantize_input_node is None: return False # For static quantization per tensor quantization, the scales and zeros @@ -212,7 +234,7 @@ def is_input_static_per_tensor_quantized(self) -> bool: return isinstance(self.input_scales_node, float) def is_input_dynamic_perchannel_quantized(self) -> bool: - if self.quantize_input_node is None: + if self.dequantize_input_node is None: return False if not isinstance(self.input_scales_node, torch.fx.Node): @@ -228,7 +250,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: return False scales_shape = self.input_scales_node.meta["val"].shape - input_shape = self.fp_input_node.meta["val"].shape + input_shape = self.pattern_input_node.meta["val"].shape return input_shape[-2] == scales_shape[-1] @@ -366,7 +388,7 @@ def make_linear_q4gsw_op( "call_function", exir_ops.edge.et_vk.linear_q4gsw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.weight_node, match.weight_scales_node, group_size, @@ -430,7 +452,7 @@ def make_linear_dq8ca_q4gsw_op( "call_function", exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, @@ -450,12 +472,34 @@ def make_linear_q8ta_q8csw_custom_op( match: QuantizedLinearMatch, weight_tensor: torch.Tensor, ): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + first_graph_node = list(graph_module.graph.nodes)[0] with graph_module.graph.inserting_before(first_graph_node): weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) # Pre-compute the weight sums which are needed to apply activation zero point # when using integer accumulation. sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = weight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") @@ -473,7 +517,7 @@ def make_linear_q8ta_q8csw_custom_op( "call_function", exir_ops.edge.et_vk.linear_q8ta_q8csw.default, args=( - match.fp_input_node, + match.pattern_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, @@ -492,10 +536,32 @@ def make_q8ta_linear_custom_op( match: QuantizedLinearMatch, weight_tensor: torch.Tensor, ): + # Pad weight_scales to multiple of 4 so GPU shader reads don't go OOB + weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node) + assert weight_scales_tensor is not None + utils.align_width_and_update_state_dict( + ep, match.weight_scales_node, weight_scales_tensor + ) + + # Pad bias to multiple of 4 if present + if match.bias_node is not None: + bias_tensor = get_param_tensor(ep, match.bias_node) + if bias_tensor is not None: + utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor) + first_graph_node = list(graph_module.graph.nodes)[0] with graph_module.graph.inserting_before(first_graph_node): weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + + # Pad weight sums to align OC to multiple of 4 + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = F.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = weight_tensor_name + "_sums" sums_name = sums_name.replace(".", "_") @@ -508,7 +574,7 @@ def make_q8ta_linear_custom_op( ) # Use gemv variant when batch size is 1 - input_shape = match.fp_input_node.meta["val"].shape + input_shape = match.pattern_input_node.meta["val"].shape batch_size = input_shape[-2] if len(input_shape) >= 2 else 1 if batch_size == 1: op_target = exir_ops.edge.et_vk.q8ta_linear_gemv.default @@ -520,7 +586,7 @@ def make_q8ta_linear_custom_op( "call_function", op_target, args=( - match.quantize_input_node, + match.pattern_input_node, match.input_scales_node, match.input_zeros_node, match.weight_node, diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index c5664de1e73..bcd240d8d12 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -283,3 +283,123 @@ def forward(self, x): 1, "Expected at least one q8ta_linear_gemv op for batch-1 linear fusion", ) + + def test_fuse_three_chained_q8ta_linears(self): + """Test that 3 consecutive quantized linears fuse into q8ta_linear ops with + correct quant params at each layer boundary. + + Each linear's input scale/zp (args[1], args[2]) must equal its predecessor's + output scale/zp (args[6], args[7]). This is a regression test for a bug where + topological pattern replacement caused later linears to read scale/zp from the + wrong arg position of the already-replaced q8ta_linear node, producing wildly + incorrect quantization parameters (outputs saturating to -128/127). + """ + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class ThreeLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(256, 128, bias=False) + self.linear2 = torch.nn.Linear(128, 64, bias=False) + self.linear3 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear3(self.linear2(self.linear1(x))) + + model = ThreeLinearModule() + # Batch size 4 to select q8ta_linear (not the gemv variant) + sample_inputs = (torch.randn(4, 256),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + q8ta_nodes = [ + node + for node in gm.graph.nodes + if get_target_canonical_name(node) == "q8ta_linear.default" + ] + self.assertGreaterEqual( + len(q8ta_nodes), + 2, + "Expected at least 2 q8ta_linear ops from 3 chained quantized linears", + ) + + # For each consecutive q8ta_linear pair, the boundary scale/zp must be + # consistent: linear_i.output_scale == linear_{i+1}.input_scale. + # Before the fix, linear_{i+1}.input_scale was incorrectly read from the + # replaced q8ta_linear node's input args instead of the dq node's args. + for i in range(len(q8ta_nodes) - 1): + self.assertEqual( + q8ta_nodes[i].args[6], + q8ta_nodes[i + 1].args[1], + f"q8ta_linear[{i}].output_scale should equal q8ta_linear[{i + 1}].input_scale", + ) + self.assertEqual( + q8ta_nodes[i].args[7], + q8ta_nodes[i + 1].args[2], + f"q8ta_linear[{i}].output_zero_point should equal q8ta_linear[{i + 1}].input_zero_point", + ) + + def test_fuse_q8ta_linear_gemv_non_aligned_oc(self): + """Test that quantized linear with non-aligned output channels (not multiple of 4) fuses correctly.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Use non-aligned output channels (9 is not a multiple of 4) + self.linear1 = torch.nn.Linear(128, 9, bias=False) + self.linear2 = torch.nn.Linear(9, 4, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + sample_inputs = (torch.randn(1, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # The first linear (OC=9, not multiple of 4) should still fuse + q8ta_linear_gemv_count = op_node_count(gm, "q8ta_linear_gemv.default") + self.assertGreaterEqual( + q8ta_linear_gemv_count, + 1, + "Expected non-aligned OC linear to fuse into q8ta_linear_gemv", + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index caa5439bc98..dde9aaac973 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -142,6 +142,15 @@ def is_choose_qparams_node(node: torch.fx.Node) -> bool: return "choose_qparams" in node_name +def is_dynamic_qscale(node: Any) -> bool: + """Check if a scale node is dynamically computed via a choose_qparams op.""" + return ( + isinstance(node, torch.fx.Node) + and node.target == operator.getitem + and is_choose_qparams_node(node.args[0]) + ) + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False