From 9cd0d35df50e4a3237fa3bfa84b97431bf1d8856 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 09:14:36 -0700 Subject: [PATCH] [ET-VK][qlinear] Add bmm support to quantized linear pattern detector Some quantized linear projections (e.g. in EdgeTAM's SpatialPerceiver / mask decoder) decompose as aten.bmm instead of aten.mm. Add aten.bmm.default as an anchor node in the quantized linear pattern detector so these nodes can be fused into custom quantized linear ops. Reject bmm nodes with batch dim > 1 since the custom ops assume a single batch. Differential Revision: [D95807072](https://our.internmc.facebook.com/intern/diff/D95807072/) [ghstack-poisoned] --- backends/vulkan/patterns/quantized_linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 6326369d051..6a077693c14 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node + # bmm with batch dim > 1 is not supported + is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default + if is_bmm and self.output_node.meta["val"].shape[0] != 1: + return + # Identify primary input node of the anchor. Due to decomposition of aten.linear # there may be a view_copy node between the original input tensor to the linear # op and the actual linear op node. @@ -267,6 +272,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.bmm.default, }