From 02f2162e656308fbb23e1d0a724d234c7dc51fac Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:57 -0700 Subject: [PATCH] [ET-VK][qlinear] Add bmm support to quantized linear pattern detector Pull Request resolved: https://github.com/pytorch/executorch/pull/18017 Some quantized linear projections (e.g. in EdgeTAM's SpatialPerceiver / mask decoder) decompose as aten.bmm instead of aten.mm. Add aten.bmm.default as an anchor node in the quantized linear pattern detector so these nodes can be fused into custom quantized linear ops. Reject bmm nodes with batch dim > 1 since the custom ops assume a single batch. ghstack-source-id: 349646654 @exported-using-ghexport Differential Revision: [D95807072](https://our.internmc.facebook.com/intern/diff/D95807072/) --- backends/vulkan/patterns/quantized_linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 85e3476cad3..b9b307e14f1 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node + # bmm with batch dim > 1 is not supported + is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default + if is_bmm and self.output_node.meta["val"].shape[0] != 1: + return + # Identify primary input node of the anchor. Due to decomposition of aten.linear # there may be a view_copy node between the original input tensor to the linear # op and the actual linear op node. @@ -268,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.bmm.default, }