From 694988317b433fdcf0e1541d54f60297a7046eed Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 11 Feb 2026 12:15:32 -0800 Subject: [PATCH] [ET-VK][qconv][ez] Make q8ta_im2col shader support stride_w != 1 The im2col path previously had two shaders: the generic `q8ta_im2col` which loaded input elements one int8 at a time and supported all strides, and `q8ta_im2col_4w4c` which loaded packed int32s (4 channels at once) but was restricted to stride_w == 1 because it assumed consecutive width positions in the input (i.e. `input_x_base + i`). This diff replaces both shaders with a single `q8ta_im2col` that uses the efficient packed loading approach from `q8ta_im2col_4w4c` while generalizing the width offset to `input_x_base + i * stride_x`. This removes the stride_w == 1 restriction and deletes the separate `q8ta_im2col_4w4c` shader entirely. The C++ dispatch is updated to reference the unified shader name, and the test gate on `stride.w == 1` for the im2col path is removed. Authored with Claude. Differential Revision: [D93000161](https://our.internmc.facebook.com/intern/diff/D93000161/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 2 + .../runtime/graph/ops/glsl/q8ta_im2col.glsl | 180 +++++++----------- .../graph/ops/glsl/q8ta_im2col_4w4c.glsl | 130 ------------- .../graph/ops/glsl/q8ta_im2col_4w4c.yaml | 11 -- .../graph/ops/impl/Q8taConv2dIm2Col.cpp | 2 +- .../test/custom_ops/test_q8ta_conv2d.cpp | 6 +- 6 files changed, 78 insertions(+), 253 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a802b48d9c..2e927260158 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -908,6 +908,7 @@ def register_clone_dim_order(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) @@ -922,6 +923,7 @@ def register_gather(): inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, + supports_highdim=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl index 4c9ca6b5728..ed4e124ac45 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col.glsl @@ -12,14 +12,6 @@ #define PACKED_INT8_OUTPUT_BUFFER -#define TILE_M4 1 -#define TILE_N4 1 -#define TILE_K4 1 - -#define TILE_M 4 -#define TILE_N 4 -#define TILE_K 4 - layout(std430) buffer; #include "indexing.glslh" @@ -36,8 +28,8 @@ ${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} ${layout_declare_spec_const(C, "int", "apply_bias", "1")} // Layout specialization constants +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} -${layout_declare_spec_const(C, "int", "im2col_outp_layout", "CONTIG_LAYOUT_INT")} layout(push_constant) uniform restrict Block { int zp; @@ -45,120 +37,94 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#include "conv2d_int8_output_tile_store.glslh" - -// Compute input tensor index from im2col coordinates -TensorIndex4D get_input_tidx( - const int im2col_w, - const int im2col_h, - const int k_in_group, - const int group_idx) { - TensorIndex4D tidx; - tidx.data.w = 0; +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); - const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; - const int row = k_in_group / conv2d_params.in_channels_per_group; - const int kernel_x = row % conv2d_params.kernel_size.x; - const int kernel_y = row / conv2d_params.kernel_size.x; + // Extract sizes from BufferMetadata + const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); + const ivec4 input_sizes = ivec4(inp.sizes[0]); - tidx.data.z = group_idx * conv2d_params.in_channels_per_group + c_in_group; + // im2col block extents + const int im2col_W4 = div_up_4(im2col_sizes.x); + const int im2col_H = im2col_sizes.y; + const int im2col_Z4 = div_up_4(im2col_sizes.z); - tidx.data.x = (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + - (kernel_x * conv2d_params.dilation.x); - tidx.data.y = (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + - (kernel_y * conv2d_params.dilation.y); + // im2col block index from linear output buffer index + const int c4_idx = out_buf_idx % im2col_Z4; + const int row = out_buf_idx / im2col_Z4; + const int w4_idx = row % im2col_W4; + const int h_idx = row / im2col_W4; - return tidx; -} - -// Load a single int8 value from the input tensor using layout-agnostic indexing -int load_input_element(const TensorIndex4D tidx, const int input_zp) { - // Bounds checking - if (any(lessThan(tidx.data, ivec4(0))) || - any(greaterThanEqual(tidx.data, ivec4(inp.sizes[0])))) { - return input_zp; + // out of bounds check + if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) { + return; } - // Use layout-agnostic indexing to get buffer position - int texel_idx; - if (get_outer_packed_dim_block_size(inp_layout) == 1) { - // For 4C or 4C1W layouts: use tensor4d_idx_to_texel_idx - texel_idx = tensor4d_idx_to_texel_idx(inp, tidx, inp_layout); - } else { - // For 4W4C layout: compute index directly - const int w4 = div_4(tidx.data[0]); - const int c4 = div_4(tidx.data[2]); - const int h_stride = int(inp.strides[0][1]); - const int w_stride = int(inp.strides[0][0]); - texel_idx = (tidx.data[1] * h_stride + w4 * w_stride + c4) * 4 + mod_4(tidx.data[0]); - } + const int im2col_w = mul_4(w4_idx); + const int im2col_h = h_idx; + const int im2col_k = mul_4(c4_idx); - // Load packed int32 containing 4 int8 values - const int packed_input = t_packed_int8_input[texel_idx]; + const int group_idx = im2col_k / conv2d_params.K_per_group; + const int k_in_group = im2col_k % conv2d_params.K_per_group; - // Extract the appropriate int8 value based on channel offset within texel - const int c_offset = mod_4(tidx.data[2]); - return extract_8bit_from_packed_int_le(packed_input, c_offset); -} + const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; + const int krow = k_in_group / conv2d_params.in_channels_per_group; + const int kernel_x = krow % conv2d_params.kernel_size.x; + const int kernel_y = krow / conv2d_params.kernel_size.x; -// Load a 4x4 im2col block (4 widths × 4 channels) -ivec4 load_im2col_block( - const int im2col_w_start, - const int im2col_h, - const int k_in_group_start, - const int group_idx) { - ivec4 im2col_block; + // Base input position + const int input_x_base = + (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + + (kernel_x * conv2d_params.dilation.x); + const int input_y = + (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + + (kernel_y * conv2d_params.dilation.y); + const int input_z = + group_idx * conv2d_params.in_channels_per_group + c_in_group; - for (int r = 0; r < 4; r++) { - const int im2col_w = im2col_w_start + r; - ivec4 row_values; - for (int c = 0; c < 4; c++) { - const int k_in_group = k_in_group_start + c; + // Input tensor extents + const int input_W = input_sizes.x; + const int input_H = input_sizes.y; + const int input_Z4 = div_up_4(input_sizes.z); - if (k_in_group >= conv2d_params.logical_K_per_group) { - row_values[c] = zp; - continue; - } + const int zp_packed = pack_into_int32(ivec4(zp)); + const int z4 = div_4(input_z); - TensorIndex4D input_tidx = - get_input_tidx(im2col_w, im2col_h, k_in_group, group_idx); + // Check if y and z are in bounds (constant for all 4 width elements) + const bool y_z_in_bounds = + (input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4); - row_values[c] = load_input_element(input_tidx, zp); + // Load 4 elements from input, one for each output width position. + // Each loaded int contains 4 packed int8 channel values. + ivec4 im2col_block; + for (int i = 0; i < 4; i++) { + const int x = input_x_base + i * conv2d_params.stride.x; + if (!y_z_in_bounds || x < 0 || x >= input_W) { + im2col_block[i] = zp_packed; + } else { + const int x4 = div_4(x); + const int x_mod = mod_4(x); + int scalar_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + scalar_idx = input_y * int(inp.strides[0][1]) + + x * int(inp.strides[0][0]) + + z4 * int(inp.strides[0][2]); + } else { + scalar_idx = mul_4( + input_y * int(inp.strides[0][1]) + + x4 * int(inp.strides[0][0]) + + z4) + x_mod; + } + im2col_block[i] = t_packed_int8_input[scalar_idx]; } - - im2col_block[r] = pack_into_int32(row_values); } - return im2col_block; -} - -void main() { - const int out_buf_idx = int(gl_GlobalInvocationID.x); - const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); - Conv2dBlockExtents im2col_block_extents = make_block_extents(im2col_sizes); - - Conv2dBlockIndex im2col_block_idx = linear_idx_to_block_idx( - out_buf_idx, im2col_block_extents); + // store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1) + const int buffer_idx = h_idx * int(im2col_outp.strides[0][1]) + + w4_idx * int(im2col_outp.strides[0][0]) + + c4_idx; - if (block_idx_out_of_bounds(im2col_block_idx, im2col_block_extents)) { - return; + if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) { + t_packed_int8_output[buffer_idx] = im2col_block; } - - // Convert block index to im2col coordinates - const int im2col_w = mul_4(im2col_block_idx.data.x); - const int im2col_h = im2col_block_idx.data.y; - const int im2col_k = mul_4(im2col_block_idx.data.z); - - // Compute group and k offset within group - const int group_idx = im2col_k / conv2d_params.K_per_group; - const int k_in_group = im2col_k % conv2d_params.K_per_group; - - // Load the im2col block using layout-agnostic input access - Int8OutTile int8_im2col_tile; - int8_im2col_tile.data[0][0] = load_im2col_block( - im2col_w, im2col_h, k_in_group, group_idx); - - // Store to output (4W4C format) - store_packed_int8_output_tile( - int8_im2col_tile, im2col_block_idx, im2col_block_extents); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl deleted file mode 100644 index 9c5e0ee9066..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.glsl +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define PACKED_INT8_OUTPUT_BUFFER - -layout(std430) buffer; - -#include "indexing.glslh" -#include "conv2d_common.glslh" - -${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} - -// Metadata for im2col output and input tensors (layout-agnostic) -${layout_declare_ubo(B, "BufferMetadata", "im2col_outp")} -${layout_declare_ubo(B, "BufferMetadata", "inp")} -${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} - -${layout_declare_spec_const(C, "int", "apply_bias", "1")} - -// Layout specialization constants -${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} -${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} - -layout(push_constant) uniform restrict Block { - int zp; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const int out_buf_idx = int(gl_GlobalInvocationID.x); - - // Extract sizes from BufferMetadata - const ivec4 im2col_sizes = ivec4(im2col_outp.sizes[0]); - const ivec4 input_sizes = ivec4(inp.sizes[0]); - - // im2col block extents - const int im2col_W4 = div_up_4(im2col_sizes.x); - const int im2col_H = im2col_sizes.y; - const int im2col_Z4 = div_up_4(im2col_sizes.z); - - // im2col block index from linear output buffer index - const int c4_idx = out_buf_idx % im2col_Z4; - const int row = out_buf_idx / im2col_Z4; - const int w4_idx = row % im2col_W4; - const int h_idx = row / im2col_W4; - - // out of bounds check - if (w4_idx >= im2col_W4 || h_idx >= im2col_H || c4_idx >= im2col_Z4) { - return; - } - - const int im2col_w = mul_4(w4_idx); - const int im2col_h = h_idx; - const int im2col_k = mul_4(c4_idx); - - const int group_idx = im2col_k / conv2d_params.K_per_group; - const int k_in_group = im2col_k % conv2d_params.K_per_group; - - const int c_in_group = k_in_group % conv2d_params.in_channels_per_group; - const int krow = k_in_group / conv2d_params.in_channels_per_group; - const int kernel_x = krow % conv2d_params.kernel_size.x; - const int kernel_y = krow / conv2d_params.kernel_size.x; - - // Base input position - const int input_x_base = - (im2col_w * conv2d_params.stride.x) - conv2d_params.padding.x + - (kernel_x * conv2d_params.dilation.x); - const int input_y = - (im2col_h * conv2d_params.stride.y) - conv2d_params.padding.y + - (kernel_y * conv2d_params.dilation.y); - const int input_z = - group_idx * conv2d_params.in_channels_per_group + c_in_group; - - // Input tensor extents - const int input_W = input_sizes.x; - const int input_H = input_sizes.y; - const int input_Z4 = div_up_4(input_sizes.z); - - const int zp_packed = pack_into_int32(ivec4(zp)); - const int z4 = div_4(input_z); - - // Check if y and z are in bounds (constant for all 4 width elements) - const bool y_z_in_bounds = - (input_y >= 0 && input_y < input_H && z4 >= 0 && z4 < input_Z4); - - // Load 4 elements from input, one for each output width position. - // Each loaded int contains 4 packed int8 channel values. - ivec4 im2col_block; - for (int i = 0; i < 4; i++) { - const int x = input_x_base + i; - if (!y_z_in_bounds || x < 0 || x >= input_W) { - im2col_block[i] = zp_packed; - } else { - const int x4 = div_4(x); - const int x_mod = mod_4(x); - int scalar_idx; - if (get_outer_packed_dim_block_size(inp_layout) == 1) { - scalar_idx = input_y * int(inp.strides[0][1]) - + x * int(inp.strides[0][0]) - + z4 * int(inp.strides[0][2]); - } else { - scalar_idx = mul_4( - input_y * int(inp.strides[0][1]) - + x4 * int(inp.strides[0][0]) - + z4) + x_mod; - } - im2col_block[i] = t_packed_int8_input[scalar_idx]; - } - } - - // store_packed_int8_output_tile (with TILE_M4=1, TILE_N4=1) - const int buffer_idx = h_idx * int(im2col_outp.strides[0][1]) - + w4_idx * int(im2col_outp.strides[0][0]) - + c4_idx; - - if (w4_idx < im2col_W4 && c4_idx < im2col_Z4) { - t_packed_int8_output[buffer_idx] = im2col_block; - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml deleted file mode 100644 index 0de3d97f324..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_im2col_4w4c.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -q8ta_im2col_4w4c: - parameter_names_with_default_values: - DTYPE: float - shader_variants: - - NAME: q8ta_im2col_4w4c diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index 4bbcc16e43d..f634f8b1773 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -132,7 +132,7 @@ void add_q8ta_im2col_node( // The implementation also requires that input channels is a multiple of 4 VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); - std::string kernel_name = "q8ta_im2col_4w4c"; + std::string kernel_name = "q8ta_im2col"; vkapi::ParamsBindList param_buffers = { graph.buffer_meta_ubo(packed_int8_im2col), diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 8f445ab7230..0de4f546b0d 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -237,8 +237,7 @@ std::vector generate_quantized_conv2d_easy_cases() { // Test im2col implementation for non-grouped convolutions with input // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0 && - config.stride.w == 1) { + if (config.groups == 1 && config.channels.in % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat, @@ -417,8 +416,7 @@ static std::vector generate_quantized_conv2d_test_cases() { // Test im2col implementation for non-grouped convolutions with input // channels that are a multiple of 4 and stride_w == 1 - if (config.groups == 1 && config.channels.in % 4 == 0 && - config.stride.w == 1) { + if (config.groups == 1 && config.channels.in % 4 == 0) { test_cases.push_back(create_test_case_from_config( config, vkapi::kFloat,