From 25b37e84755cb269794606870364e14b59fc80ed Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 26 Jun 2025 09:56:22 -0700 Subject: [PATCH 1/2] [ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders # Operator Description The linear_qta8a_qga4w_qta8o operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/) [ghstack-poisoned] --- .../glsl/linear_qta8a_qga4w_qta8o_coop.glsl | 241 +++++++++++++++ .../glsl/linear_qta8a_qga4w_qta8o_coop.yaml | 26 ++ .../glsl/linear_qta8a_qga4w_qta8o_tiled.glsl | 208 +++++++++++++ .../glsl/linear_qta8a_qga4w_qta8o_tiled.yaml | 26 ++ .../impl/QuantizedLinearQTA8AQGA4WQTA8O.cpp | 288 ++++++++++++++++++ .../linear_qta8a_qga4w_qta8o_test.cpp | 286 +++++++++++++++++ 6 files changed, 1075 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQTA8AQGA4WQTA8O.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl new file mode 100644 index 00000000000..4b1e2b6b7be --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl @@ -0,0 +1,241 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +#define NGROUPS 8 +#define NWORKERS 8 + +${define_required_extensions(DTYPE)} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qparams", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_output_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_output_zero_point", "int", "buffer", is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2]; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a quantized int8 output. + * + * This shader implements a co-operative algorithm to compute the output. The + * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads + * cooperative to compute TILE_ROWS * 2 output texels. Therefore, + * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. + * + * The threads co-operate by each thread computing a partial reduction along the + * K dimension. To illustrate the computation, consider a scalar variant of the + * algorithm that computes the dot product of 2 vectors. Also assume that + * NWORKERS is 8. + * + * Thread 1 in each group will compute: + * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... + * + * Thread 2 in each group will compute: + * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... + * + * Thread 3 in each group will compute: + * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... + * + * The partial accumulations is structured such that memory accesses in each + * loop iteration can be coalesced. + * + * Then, at the end first thread in each group will accumulate the partial + * accumulations computed by each thread to obtain the final result. + * + * Note that this shader assumes that all tensors are width packed. + */ + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + const uint gid = gl_LocalInvocationID.x; // group id + const uint wid = gl_LocalInvocationID.z; // worker id + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + VEC4_T mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_y_stride = out_sizes.x >> 2; + const int qparams_z_stride = qparams_y_stride * 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; + zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; + + scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; + zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; + $else: + scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); + zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); + + scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); + zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators for this block + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const uvec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - ivec4(8); + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - ivec4(8); + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = VEC4_T(t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]); + $else: + mat1_quantized[r] = VEC4_T(texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]); + + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + final_result[r][0] += input_scale * (vec4(int32_sums[r][0]) * scales[0] + input_sum_scalar * zeros[0]); + final_result[r][1] += input_scale * (vec4(int32_sums[r][1]) * scales[1] + input_sum_scalar * zeros[1]); + } + } + + // Store worker results in shared memory + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + partial_results[gid][wid][r][0] = final_result[r][0]; + partial_results[gid][wid][r][1] = final_result[r][1]; + } + + memoryBarrierShared(); + barrier(); + + // Only the first worker in each group accumulates and writes output + if (wid != 0) { + return; + } + + vec4 cooperative_result[TILE_ROWS][2]; + + for (int r = 0; r < TILE_ROWS; ++r) { + cooperative_result[r][0] = vec4(0.0); + cooperative_result[r][1] = vec4(0.0); + [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { + cooperative_result[r][0] += partial_results[gid][worker][r][0]; + cooperative_result[r][1] += partial_results[gid][worker][r][1]; + } + } + + // Apply final output quantization + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int token_idx = int(out_row) + r; + + float output_scale = t_output_scale[token_idx]; + int output_zero_point = t_output_zero_point[token_idx]; + + VEC4_T quantized_out_0 = VEC4_T(clamp( + ivec4(round(cooperative_result[r][0] / output_scale)) + float(output_zero_point), + -128, 127)); + VEC4_T quantized_out_1 = VEC4_T(clamp( + ivec4(round(cooperative_result[r][1] / output_scale)) + float(output_zero_point), + -128, 127)); + + $if OUT_STORAGE == "buffer": + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = quantized_out_0; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = quantized_out_1; + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), quantized_out_0); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), quantized_out_1); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml new file mode 100644 index 00000000000..2d8a979494c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_qta8o_coop: + parameter_names_with_default_values: + DTYPE: int8 + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 1 + shader_variants: + - NAME: linear_qta8a_qga4w_qta8o_coop_texture3d_texture3d_texture2d_int8 + - NAME: linear_qta8a_qga4w_qta8o_coop_buffer_buffer_texture2d_int8 + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_coop_buffer_buffer_buffer_int8 + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_coop_buffer_texture2d_buffer_int8 + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl new file mode 100644 index 00000000000..7b4f2733066 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl @@ -0,0 +1,208 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("uint8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qparams", "float", PARAMS_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_output_scale", "float", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_output_zero_point", "int", "buffer", is_scalar_array=True)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 mat1_sizes; + ivec4 qmat2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int group_size = 64; + +/* + * This shader computes a linear operator between a quantized int8 input matrix + * x and a weights matrix that is quantized to 4 bits, producing a quantized int8 output. + * + * The (W, H, C) shape of each tensor is: + * - x: (K, M) - quantized int8 input + * - weights: (N / 2, K) + * - The weights tensor has a data type of `uint8`. Each element in the tensor + * contains 2 4-bit values packed into a uint8. + * - See the pack_int4_linear_weight_transposed_interleave shader to see more + * details on how the weight tensor is stored. + * - qparams: (2, N, number_of_groups) + * - This tensor contains the scales and zeros quantization parameters for the + * weights tensor. The weight tensor is quantized group-wise, which means + * that every `group_size` elements along the K dimension of the weights + * tensor has independent quantization parameters. Along the width dim, the + * first value contains the scale for the group and the second value + * contains the zero point for the group. + * - output: (N, M) - quantized int8 output + * + * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. + * + * Note that this shader assumes that all tensors are width packed. + */ + +bool is_main_thread() { + return gl_GlobalInvocationID.x == 0 && gl_GlobalInvocationID.y == 0 && gl_GlobalInvocationID.z == 0; +} + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 3; + const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + const int num_blocks = mat1_sizes.x / group_size; + + VEC4_T mat1_quantized[TILE_ROWS]; + ivec4 qmat2_quantized[4][2]; + vec4 final_result[TILE_ROWS][2]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + final_result[r][0] = vec4(0.0); + final_result[r][1] = vec4(0.0); + } + + vec4 scales[2]; + vec4 zeros[2]; + + $if WEIGHT_STORAGE == "buffer": + const int qmat2_stride = qmat2_sizes.x >> 2; + $if PARAMS_STORAGE == "buffer": + const int qparams_y_stride = out_sizes.x >> 2; + const int qparams_z_stride = qparams_y_stride * 2; + + for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { + $if PARAMS_STORAGE == "buffer": + scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; + zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; + + scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; + zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; + $else: + scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); + zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); + + scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); + zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); + + ivec4 int32_sums[TILE_ROWS][2]; + int input_sums[TILE_ROWS]; + + // Initialize accumulators + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] = ivec4(0); + int32_sums[r][1] = ivec4(0); + input_sums[r] = 0; + } + + for (int g_idx = 0; g_idx < group_size; g_idx += 4) { + const int k = block_idx * group_size + g_idx; + + // Preload B (weights) - keep as quantized integers + [[unroll]] for (int r = 0; r < 4; ++r) { + $if WEIGHT_STORAGE == "buffer": + const uvec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; + $else: + const uvec4 packed_weight_tex = texelFetch( + t_qmat2, + ivec2(gl_GlobalInvocationID.x, k + r), + 0); + + // Unpack 4-bit weights to integers (subtract 8 as the 4-bit zero point) + qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; + qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; + } + + // Preload A (quantized input) - keep as quantized integers + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $if IN_STORAGE == "buffer": + mat1_quantized[r] = VEC4_T(t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]); + $else: + mat1_quantized[r] = VEC4_T(texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]); + + input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; + } + + // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] + + mat1_quantized[r].y * qmat2_quantized[1][0] + + mat1_quantized[r].z * qmat2_quantized[2][0] + + mat1_quantized[r].w * qmat2_quantized[3][0]; + + int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] + + mat1_quantized[r].y * qmat2_quantized[1][1] + + mat1_quantized[r].z * qmat2_quantized[2][1] + + mat1_quantized[r].w * qmat2_quantized[3][1]; + } + } + + // Incorporates this block's results into the final accumulation + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + float input_scale = t_input_scale[int(out_row) + r]; + float input_sum_scalar = float(input_sums[r]); + + final_result[r][0] += input_scale * (vec4(int32_sums[r][0]) * scales[0] + input_sum_scalar * zeros[0]); + final_result[r][1] += input_scale * (vec4(int32_sums[r][1]) * scales[1] + input_sum_scalar * zeros[1]); + } + } + + // Apply ALL scaling at the very end + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + if (out_row + r >= out_sizes.y) { + continue; + } + + int token_idx = int(out_row) + r; + float output_scale = t_output_scale[token_idx]; + int output_zero_point = t_output_zero_point[token_idx]; + + VEC4_T quantized_out_0 = VEC4_T(clamp( + ivec4(round(final_result[r][0] / output_scale)) + float(output_zero_point), + -128, 127)); + VEC4_T quantized_out_1 = VEC4_T(clamp( + ivec4(round(final_result[r][1] / output_scale)) + float(output_zero_point), + -128, 127)); + + $if OUT_STORAGE == "buffer": + t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = quantized_out_0; + t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = quantized_out_1; + $else: + imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), quantized_out_0); + imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), quantized_out_1); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml new file mode 100644 index 00000000000..9de102cf5f0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_qta8a_qga4w_qta8o_tiled: + parameter_names_with_default_values: + DTYPE: int8 + OUT_STORAGE: texture3d + IN_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PARAMS_STORAGE: buffer + TILE_ROWS: 3 + shader_variants: + - NAME: linear_qta8a_qga4w_qta8o_tiled_texture3d_texture3d_texture2d_int8 + - NAME: linear_qta8a_qga4w_qta8o_tiled_buffer_buffer_texture2d_int8 + OUT_STORAGE: buffer + IN_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_tiled_buffer_buffer_buffer_int8 + OUT_STORAGE: buffer + IN_STORAGE: buffer + WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_tiled_buffer_texture2d_buffer_int8 + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQTA8AQGA4WQTA8O.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQTA8AQGA4WQTA8O.cpp new file mode 100644 index 00000000000..bd9c1742881 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQTA8AQGA4WQTA8O.cpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace vkcompute { + +void check_linear_qta8a_qga4w_qta8o_args( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros, + const ValueRef output_scale, + const ValueRef output_zero_point, + const ValueRef out) { + VK_CHECK_COND(graph.val_is_tensor(mat1)); + VK_CHECK_COND(graph.val_is_tensor(mat1_scale)); + VK_CHECK_COND(graph.val_is_tensor(mat1_zero_point)); + VK_CHECK_COND(graph.val_is_tref(mat2_data)); + VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); + VK_CHECK_COND(graph.val_is_tensor(output_scale)); + VK_CHECK_COND(graph.val_is_tensor(output_zero_point)); + + VK_CHECK_COND(graph.dim_of(mat1) <= 3); + VK_CHECK_COND(graph.dim_of(mat2_data) == 2); + VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); + + VK_CHECK_COND(graph.size_at(-3, mat1) == 1); + const int K = graph.size_at(-1, mat1); + VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); + + const int group_size_val = graph.extract_scalar(group_size); + VK_CHECK_COND(K % group_size_val == 0); + // Due to the way weight packing works, group size needs to be a multiple of 8 + VK_CHECK_COND(group_size_val % 8 == 0); + + VK_CHECK_COND(graph.has_standard_axis_map(mat1)); + VK_CHECK_COND(graph.has_standard_axis_map(out)); + + // Check that scale and zero_point tensors are buffer storage with width + // packing + VK_CHECK_COND(graph.is_buffer_storage(mat1_scale)); + VK_CHECK_COND(graph.packed_dim_of(mat1_scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(mat1_zero_point)); + VK_CHECK_COND(graph.packed_dim_of(mat1_zero_point) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(output_scale)); + VK_CHECK_COND(graph.packed_dim_of(output_scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(output_zero_point)); + VK_CHECK_COND(graph.packed_dim_of(output_zero_point) == WHCN::kWidthDim); + + // Calculate number of tokens for input and output + int64_t input_num_tokens = 1; + const auto mat1_sizes = graph.sizes_of(mat1); + for (size_t i = 0; i < mat1_sizes.size() - 1; i++) { + input_num_tokens *= mat1_sizes[i]; + } + + int64_t output_num_tokens = 1; + const auto out_sizes = graph.sizes_of(out); + for (size_t i = 0; i < out_sizes.size() - 1; i++) { + output_num_tokens *= out_sizes[i]; + } + + // Verify scale and zero_point tensor sizes match number of tokens + const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); + const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); + const auto output_scale_sizes = graph.sizes_of(output_scale); + const auto output_zero_point_sizes = graph.sizes_of(output_zero_point); + + VK_CHECK_COND(mat1_scale_sizes.size() == 1); + VK_CHECK_COND(mat1_zero_point_sizes.size() == 1); + VK_CHECK_COND(output_scale_sizes.size() == 1); + VK_CHECK_COND(output_zero_point_sizes.size() == 1); + + VK_CHECK_COND(mat1_scale_sizes[0] == input_num_tokens); + VK_CHECK_COND(mat1_zero_point_sizes[0] == input_num_tokens); + VK_CHECK_COND(output_scale_sizes[0] == output_num_tokens); + VK_CHECK_COND(output_zero_point_sizes[0] == output_num_tokens); +} + +void resize_linear_qta8a_qga4w_qta8o_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + + const int out_cols = utils::val_at(-2, mat1->sizes()); + const int out_rows = utils::val_at(-1, mat2->sizes()) * 2; + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +/** + * Determines if the cooperative algorithm should be used based on input tensor + * dimensions. Apply the coop algorithm for vectors (GEMV cases), tiled for + * matrices (GEMM cases). + */ +bool should_use_coop_algorithm_qta8a_qga4w_qta8o( + ComputeGraph* graph, + const ValueRef& mat1) { + const uint32_t M = graph->size_at(-2, mat1); + // Use coop algorithm for vectors (GEMV), tiled for larger matrices (GEMM) + return M == 1; +} + +vkapi::ShaderInfo pick_linear_qta8a_qga4w_qta8o_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const bool use_coop_algorithm = + should_use_coop_algorithm_qta8a_qga4w_qta8o(graph, mat1); + + std::string kernel_name = "linear_qta8a_qga4w_qta8o"; + if (use_coop_algorithm) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 linear_qta8a_qga4w_qta8o_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + // C = 1, H = 2, W = 3 + // global_wg_size = {round_up(C / 2f), round_up(H / 3f), W} --> (2W, 1H, 0C) + // --> {1, 1, 3} global + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + if (!use_coop_algorithm) { // GEMM - TILED + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + } + + return global_wg_size; +} + +utils::uvec3 linear_qta8a_qga4w_qta8o_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + utils::uvec3 local_wg_size; + if (use_coop_algorithm) { // GEMV - COOP + local_wg_size = {8, 1, 8}; + } else { // GEMM - TILED + local_wg_size = graph->create_local_wg_size(global_workgroup_size); + } + + return local_wg_size; +} + +void add_linear_qta8a_qga4w_qta8o_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat1_scale, + const ValueRef mat1_zero_point, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef output_scale, + const ValueRef output_zero_point, + const ValueRef out) { + check_linear_qta8a_qga4w_qta8o_args( + graph, + mat1, + mat1_scale, + mat1_zero_point, + mat2_data, + group_size, + scales_and_zeros_data, + output_scale, + output_zero_point, + out); + const uint32_t group_size_val = graph.extract_scalar(group_size); + + ValueRef mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + ValueRef scales_and_zeros = prepack_standard_hw_transposed( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_linear_qta8a_qga4w_qta8o_shader, + linear_qta8a_qga4w_qta8o_global_wg_size, + linear_qta8a_qga4w_qta8o_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, + {{mat1, + mat2, + scales_and_zeros, + mat1_scale, + mat1_zero_point, + output_scale, + output_zero_point}, + vkapi::kRead}}, + // Shader params buffers + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)}, + // Specialization Constants + {SV(group_size_val)}, + // Resize Args + {}, + // Resizing Logic + resize_linear_qta8a_qga4w_qta8o_node)); +} + +void linear_qta8a_qga4w_qta8o( + ComputeGraph& graph, + const std::vector& args) { + return add_linear_qta8a_qga4w_qta8o_node( + graph, + args[0], // quantized input (char tensor) + args[1], // input_scale (float buffer tensor) + args[2], // input_zero_point (int buffer tensor) + args[3], // quantized weights (4-bit packed, byte) + args[4], // group_size (int) + args[5], // weight_scales_and_zeros (float tensor) + args[6], // output_scale (float buffer tensor) + args[7], // output_zero_point (int buffer tensor) + args[8] // quantized output (char tensor) + ); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + et_vk.linear_qta8a_qga4w_qta8o.default, linear_qta8a_qga4w_qta8o); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp index 2496c256fb1..518cea2ea11 100644 --- a/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp +++ b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp @@ -140,3 +140,289 @@ at::Tensor linear_qta8a_qga4w_qta8o_4bit_dequant_impl( return quantized_result; } + +void test_vulkan_linear_qta8a_qga4w_qta8o_impl( + const int B, + const int M, + const int K, + const int N, + const int group_size = 8, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + assert(K % group_size == 0); + + // Create per-token quantization parameters for input + const int64_t input_num_tokens = B * M; + at::Tensor input_scale = + at::rand({input_num_tokens}, at::device(at::kCPU).dtype(at::kFloat)) * + 0.1f + + 0.05f; // Range [0.05, 0.15] + at::Tensor input_zero_point = at::randint( + -10, 10, {input_num_tokens}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor float_x = + at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + + // Create a reference quantized tensor using per-token quantization + // Mimic per-token quantization using at::quantize_per_channel by reshaping to + // [num_tokens, features] + at::Tensor float_x_reshaped = float_x.view({input_num_tokens, K}); + at::Tensor qx_ref_reshaped = at::quantize_per_channel( + float_x_reshaped, + input_scale.to(at::kDouble), + input_zero_point.to(at::kLong), + 0, // axis 0 for per-token (first dimension after reshape) + c10::ScalarType::QInt8); + + // Convert back to int8 tensor and reshape to original shape + at::Tensor x = + at::int_repr(qx_ref_reshaped).view(float_x.sizes()).to(at::kChar); + + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + const int k_groups = K / group_size; + at::Tensor scales_and_zeros = + at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + // Create per-token quantization parameters for output + const int64_t output_num_tokens = B * M; + at::Tensor output_scale = + at::rand({output_num_tokens}, at::device(at::kCPU).dtype(at::kFloat)) * + 0.1f + + 0.1f; // Range [0.1, 0.2] + at::Tensor output_zero_point = at::randint( + -10, 10, {output_num_tokens}, at::device(at::kCPU).dtype(at::kInt)); + + at::Tensor out_ref = linear_qta8a_qga4w_qta8o_4bit_dequant_impl( + x, + input_scale, + input_zero_point, + weights_4x2, + group_size, + scales_and_zeros, + output_scale, + output_zero_point); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(scales_and_zeros); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); + + IOValueRef r_input_scale = graph.add_input_tensor( + input_scale.sizes().vec(), + from_at_scalartype(input_scale.scalar_type()), + utils::kBuffer); + + IOValueRef r_input_zero_point = graph.add_input_tensor( + input_zero_point.sizes().vec(), + from_at_scalartype(input_zero_point.scalar_type()), + utils::kBuffer); + + IOValueRef r_output_scale = graph.add_input_tensor( + output_scale.sizes().vec(), + from_at_scalartype(output_scale.scalar_type()), + utils::kBuffer); + + IOValueRef r_output_zero_point = graph.add_input_tensor( + output_zero_point.sizes().vec(), + from_at_scalartype(output_zero_point.scalar_type()), + utils::kBuffer); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); + + VK_GET_OP_FN("et_vk.linear_qta8a_qga4w_qta8o.default") + (graph, + {r_x.value, + r_input_scale.value, + r_input_zero_point.value, + r_weights_4x2, + graph.add_scalar(group_size), + r_scales_and_zeros, + r_output_scale.value, + r_output_zero_point.value, + r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + graph.copy_into_staging( + r_input_scale.staging, input_scale.const_data_ptr(), input_scale.numel()); + graph.copy_into_staging( + r_input_zero_point.staging, + input_zero_point.const_data_ptr(), + input_zero_point.numel()); + graph.copy_into_staging( + r_output_scale.staging, + output_scale.const_data_ptr(), + output_scale.numel()); + graph.copy_into_staging( + r_output_zero_point.staging, + output_zero_point.const_data_ptr(), + output_zero_point.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // For quantized int8 operations, allow for 1-unit differences due to rounding + bool is_close = at::allclose(vk_out, out_ref, 1.0, 1.0); + + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); + + if (!is_close) { + std::cout << "out_ref: \n" << out_ref << std::endl; + std::cout << "vk_out: \n" << vk_out << std::endl; + } + + ASSERT_TRUE(is_close); +} + +void test_vulkan_linear_qta8a_qga4w_qta8o( + const int B, + const int M, + const int K, + const int N, + const int group_size = 32) { + test_vulkan_linear_qta8a_qga4w_qta8o_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + test_vulkan_linear_qta8a_qga4w_qta8o_impl( + B, + M, + K, + N, + group_size, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST( + VulkanLinearQTA8AQGA4WQTA8OTest, + test_vulkan_linear_quant_gemm_custom_groupsize_one) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 2, + /*K = */ 8, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST( + VulkanLinearQTA8AQGA4WQTA8OTest, + test_vulkan_linear_quant_gemm_custom_groupsize_two) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 2, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_one) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 4, + /*K = */ 64, + /*N = */ 32); +} + +TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_two) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} + +TEST( + VulkanLinearQTA8AQGA4WQTA8OTest, + test_vulkan_linear_quant_gemm_case_three) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 8, + /*K = */ 64, + /*N = */ 16); +} + +TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_four) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 256, + /*K = */ 256, + /*N = */ 256); +} + +TEST( + VulkanLinearQTA8AQGA4WQTA8OTest, + test_vulkan_linear_quant_gemv_custom_groupsize_one) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 2, + /*K = */ 16, + /*N = */ 8, + /*group_size = */ 8); +} + +TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_one) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 1, + /*K = */ 256, + /*N = */ 256); +} + +TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_two) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 1, + /*K = */ 32, + /*N = */ 32); +} + +TEST( + VulkanLinearQTA8AQGA4WQTA8OTest, + test_vulkan_linear_quant_gemv_case_three) { + test_vulkan_linear_qta8a_qga4w_qta8o( + /*B = */ 1, + /*M = */ 1, + /*K = */ 64, + /*N = */ 16); +} From 05c0fcb52ebbcb9b01f8867521b8f34a71cf6452 Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 26 Jun 2025 10:26:05 -0700 Subject: [PATCH 2/2] Update on "[ET-VK][Ops] linear_qta8a_qga4w_qta8o impl and shaders" # Operator Description The linear_qta8a_qga4w_qta8o operator implements a quantized linear transformation that enables efficient neural network inference through dynamic quantization. This operator performs matrix multiplication between quantized 8-bit activations and 4-bit grouped quantized weights, producing quantized 8-bit outputs. The quantization scheme follows the standard affine mapping where `real_value = scale * (quantized_value - zero_point)`. Input activations use 8-bit signed integers with per-token scale and zero-point parameters, while weights employ 4-bit quantization with group-wise parameters. # Implementation Architecture The operator provides two distinct computational approaches optimized for different matrix multiplication scenarios: the TILED algorithm for general matrix-matrix multiplication (GEMM) and the COOPERATIVE algorithm for matrix-vector multiplication (GEMV). ## TILED Algorithm (GEMM Cases) The tiled implementation processes the output matrix in rectangular blocks. Each thread is responsible for calculating a tile of output values, typically processing 3 rows and 2 columns worth of results in each iteration. The algorithm operates by having each thread load blocks of quantized weights and activations, perform integer arithmetic accumulation, and then apply the necessary scaling operations. Weight data is pre-packed in a specialized format where two 4-bit values are stored in each byte. Each thread loads multiple weight elements simultaneously and unpacks them during computation. The quantization parameters for weights are organized by groups, where each group of consecutive weight elements shares the same scale and zero-point values. ## COOPERATIVE Algorithm (GEMV Cases) The cooperative implementation uses shared memory and thread cooperation where this approach uses workgroups of 64 threads arranged as 8 groups of 8 workers each. The key insight is that GEMV operations have limited parallelism in the output dimension but substantial parallelism in the reduction dimension, making cooperative reduction strategies more effective than independent thread computation. Each group of 8 worker threads collaboratively computes a portion of the output vector. The workers divide the reduction work along the input feature dimension, with each worker processing every 8th element in a strided pattern. # Future Performance Improvements - Making use of dotPacked4x8EXT (this requires upgrading glslc and vulkan) - Fixed point math for pure integer operations - Might be more performant to avoid preloading tensors - Might also be more performant to avoid excessive register overhead by defining the ivec4 within each block operation (allowing more threads to be more register intensive) Differential Revision: [D77173441](https://our.internmc.facebook.com/intern/diff/D77173441/) [ghstack-poisoned] --- .../linear_qta8a_qga4w_qta8o_test.cpp | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp index 518cea2ea11..a9c5ff7c9f2 100644 --- a/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp +++ b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp @@ -337,6 +337,11 @@ void test_vulkan_linear_qta8a_qga4w_qta8o( TEST( VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_custom_groupsize_one) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 2, @@ -348,6 +353,11 @@ TEST( TEST( VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_custom_groupsize_two) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 2, @@ -357,6 +367,11 @@ TEST( } TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_one) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 4, @@ -365,6 +380,11 @@ TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_one) { } TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_two) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 4, @@ -375,6 +395,11 @@ TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_two) { TEST( VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_three) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 8, @@ -383,6 +408,11 @@ TEST( } TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_four) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 256, @@ -393,6 +423,11 @@ TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemm_case_four) { TEST( VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_custom_groupsize_one) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 2, @@ -402,6 +437,11 @@ TEST( } TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_one) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 1, @@ -410,6 +450,11 @@ TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_one) { } TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_two) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 1, @@ -420,6 +465,11 @@ TEST(VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_two) { TEST( VulkanLinearQTA8AQGA4WQTA8OTest, test_vulkan_linear_quant_gemv_case_three) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } test_vulkan_linear_qta8a_qga4w_qta8o( /*B = */ 1, /*M = */ 1,