From 04e6b5488f47b421edc8f29bacea27aa99c74901 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 19 Feb 2026 11:48:32 -0800 Subject: [PATCH] [ET-VK] Support different input layouts in q8ta_binary operator Previously, the q8ta_binary operator required both inputs to use the same memory layout. This was enforced by using a single `in_layout` specialization constant for both input buffers. However, some models may have inputs with different layouts (e.g., 4W4C and 4C1W) that share the same packed dimension and block size, which should be compatible for binary operations. This change introduces a separate `other_layout` specialization constant for the second input, allowing the shader to correctly load from input_b using its actual layout while input_a continues to use `in_layout`. The C++ side now passes both layout hashes as separate specialization constants to the shader. Differential Revision: [D93768638](https://our.internmc.facebook.com/intern/diff/D93768638/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl | 3 ++- backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl index 60f437fbdce..be93e800436 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "block_config", "0")} // Generate loading functions for input buffers @@ -71,7 +72,7 @@ void main() { ivec4 in_block_a = load_int8x4_block_from_t_in_a( in_a_meta, tidx, in_layout, block_outer_dim); ivec4 in_block_b = load_int8x4_block_from_t_in_b( - in_b_meta, tidx, in_layout, block_outer_dim); + in_b_meta, tidx, other_layout, block_outer_dim); ivec4 out_block; diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp index af934b9b521..05bdd9431c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp @@ -42,6 +42,7 @@ void add_q8ta_binary_node( VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim); VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( input_a_info.packed_dim_block_size == output_info.packed_dim_block_size); VK_CHECK_COND( @@ -105,6 +106,7 @@ void add_q8ta_binary_node( // Specialization Constants {graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input_a), + graph.hashed_layout_of(packed_int8_input_b), block_config.as_packed_int()}, // Resize args {block_config_ref},