From 3adaa8913be587780f6fd5c0f75e89e38a827be9 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 09:14:16 -0700 Subject: [PATCH] [ET-VK][qlinear] Look through output view_copy when detecting output quantization When `aten.linear` has 3D+ inputs, it decomposes into `view_copy -> mm -> view_copy`. The output view_copy between mm and the subsequent quantize_per_tensor node was preventing the pattern matcher from detecting output quantization, causing the match to fall through to `linear_q8ta_q8csw` instead of `q8ta_linear_gemv`. This caused a dtype mismatch during FakeTensor re-tracing in FusePatternsPass because `linear_q8ta_q8csw`'s composite implementation does not dequantize its input, producing int8 output where float32 was expected. Mirror the existing input-side view_copy handling (lines 99-104) on the output side so the quantize node is found through the view_copy. Differential Revision: [D95807075](https://our.internmc.facebook.com/intern/diff/D95807075/) [ghstack-poisoned] --- backends/vulkan/patterns/quantized_linear.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index df80749e72f..6326369d051 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -174,12 +174,20 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Check if the output is also quantized (q → dq → linear → q pattern) # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + # Due to decomposition of aten.linear for 3D+ inputs, there may be a + # view_copy between the mm output and the quantize node. self.quantize_output_node = None self.output_scales_node = None self.output_zeros_node = None self.relu_node = None + self.output_view_copy_node = None if len(self.output_node.users) == 1: cur_node = list(self.output_node.users)[0] + # Skip potential view_copy between linear and output quantize + if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1: + self.output_view_copy_node = cur_node + self.all_nodes.append(self.output_view_copy_node) + cur_node = list(cur_node.users)[0] if cur_node.target == exir_ops.edge.aten.relu.default: self.relu_node = cur_node if len(cur_node.users) == 1: