From 0e34e301c4770f9d56496b7383740261f7252a4d Mon Sep 17 00:00:00 2001 From: morelos Date: Tue, 17 Jun 2025 10:06:34 -0700 Subject: [PATCH 1/3] [ET-VK][Ops] quantization op shaders and impl Pull Request resolved: https://github.com/pytorch/executorch/pull/11369 # Operator Description The quantization operator converts floating-point tensors (fp16/fp32) to lower-precision integer formats (uint8/int8/int32) using affine quantization. This operator supports two quantization modes: - **Per-tensor quantization**: Uses a single scale and zero_point for the entire tensor - **Per-token quantization**: Uses different scale and zero_point values for each "token" (typically rows or channels) The quantization formula is: `quantized_value = clamp(round(input_value / scale) + zero_point, quant_min, quant_max)` **Example**: For a float value `2.5` with `scale=0.1`, `zero_point=128`, `quant_min=0`, `quant_max=255`: - `round(2.5 / 0.1) + 128 = round(25) + 128 = 153` - `clamp(153, 0, 255) = 153` (uint8 output) The quantization parameters serve these purposes: - **scale**: Controls the granularity of quantization (smaller scale = finer precision) - **zero_point**: Maps the floating-point zero to an integer value - **quant_min/quant_max**: Define the valid range for the quantized output type # Shader Algorithm Overview ## Texture Storage Implementation (`quantize_texture.glsl`) The texture-based implementation operates on 3D textures where data is stored in RGBA texel format (4 components per texel): **Per-tensor Mode**: Each compute thread processes one texel position. It loads a 4-component texel from the input texture, and applies quantization to each of the 4 components using shared scale/zero_point. It then writes the quantized 4-component result to the output texture. This method is fairly linear. **Per-token Mode**: We need to calculate the token index based on the spatial position, it'll differ between various cases like 3D and 2D. For instand we might define the token_idx as `z * dims.y + y` for 3D, or just `y` for 2D cases. We then retrieve the per-token scale/zero_point from the texture storage according to the token_idx. We need to do component indexing based on the texel_idx and token_idx: `texel_idx = token_idx / 4`, along with the component id `comp_idx = token_idx % 4` to get the necessary scale/zero_point. We then apply quantization with the corresponding token-specific parameters to the 4 components of the current texel. ## Buffer Storage Implementation (`quantize_buffer.glsl`) The buffer-based implementation operates on linear memory buffers with stride-based indexing: **Per-tensor Mode**: In this case, each compute thread will process one element at its global position. It converts the 3D position to linear buffer indices using stride calculations `tidx_to_bufi(pos, strides)`. It then loads single scalar values from the input buffer and applies quantization using shared scale/zero_point parameters. We then store the quantized result to the output buffer at the corresponding index. **Per-token Mode**: We first calculate the logical tensor position from the linear buffer index through dimension unwrapping. We then determine the token index based on the tensor dimensionality: - 4D: `token_idx = w * (z * y) + z * y + y` - 3D: `token_idx = z * y + y` - 2D: `token_idx = y` We then directly index into scale/zero_point buffers using token_idx and also apply quantization with the token-specific parameters. # Performance Considerations / Future Improvements Current implementation uses default workgroup sizing. Profiling different local workgroup sizes could improve occupancy and cache utilization. Buffer implementation processes one element per thread. Could be optimized to process multiple elements per thread. NOTE: Currently the only input types supported are **half** (fp16) and **float** (fp32). The only output types supported are **byte** (uint8), **char** (int8), **int** (int32). A future diff plans to implement **double** (fp64) input dtype support. ghstack-source-id: 291010148 @exported-using-ghexport Differential Revision: [D75959064](https://our.internmc.facebook.com/intern/diff/D75959064/) --- .../runtime/graph/ops/glsl/quantize.glslh | 25 ++ .../graph/ops/glsl/quantize_buffer.glsl | 179 ++++++++++++ .../graph/ops/glsl/quantize_buffer.yaml | 18 ++ .../graph/ops/glsl/quantize_texture.glsl | 184 +++++++++++++ .../graph/ops/glsl/quantize_texture.yaml | 18 ++ .../runtime/graph/ops/impl/Quantize.cpp | 258 ++++++++++++++++++ .../vulkan/test/op_tests/quantize_test.cpp | 250 ++++++++++++++++- backends/vulkan/test/op_tests/test_utils.cpp | 3 +- 8 files changed, 926 insertions(+), 9 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Quantize.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh new file mode 100644 index 00000000000..cde72e41ac7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QUANTIZE_GLSLH +#define QUANTIZE_GLSLH + +OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) { + float inv_scale = 1.0 / scale_val; + + float rounded_float = round(inv_scale * float(value)); + + int qvalue = zero_point_val + int(rounded_float); + + qvalue = max(qvalue, quant_min); + qvalue = min(qvalue, quant_max); + + return OUT_T(qvalue); +} + +#endif // QUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl new file mode 100644 index 00000000000..ea0c2f7dce7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * QUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point input value from buffer + * 2. Apply quantization formula: qvalue = round(value / scale) + zero_point + * 3. Clamp result to [quant_min, quant_max] range + * 4. Store quantized integer value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access + * - and supports any tensor layout through stride calculations and dimension ordering + * - Per-Token Config: Assumes width-packed layout (packed_dim = 0) + * - since that is how token index is calculated + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Quantization Process: + * Input: 2.5 (float) + * Step 1: value / scale = 2.5 / 0.1 = 25.0 + * Step 2: round(25.0) + zero_point = 25 + (-128) = -103 + * Step 3: clamp(-103, -128, 127) = -103 + * Output: -103 (int8) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + OUT_T qvalue = quantize_val(value, scale, zero_point); + + t_out[out_bufi] = qvalue; +} + +#else + +void quantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = qvalue; +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml new file mode 100644 index 00000000000..90af2590936 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -0,0 +1,18 @@ +quantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_buffer + MODE: per_tensor + - NAME: quantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl new file mode 100644 index 00000000000..9ba7074f75b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * QUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point texel (4 values) from 3D texture + * 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point + * 3. Clamp each result to [quant_min, quant_max] range + * 4. Store quantized integer texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Texel Quantization Process: + * Input Texel: [2.5, -1.0, 0.5, 3.2] (float4) + * Per-component quantization with scale=0.1, zero_point=-128: + * Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103 + * Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128 + * Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123 + * Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96 + * Output Texel: [-103, -128, -123, -96] (int4) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale, zero_point); + outtex[i] = qvalue; + } + write_texel(t_out, pos, outtex); +} + +#else + +void quantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + IVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml new file mode 100644 index 00000000000..042eb0f8196 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -0,0 +1,18 @@ +quantize_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_texture3d + MODE: per_tensor + - NAME: quantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp new file mode 100644 index 00000000000..35712d59fb9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_quantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_quantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void add_quantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void quantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_quantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); + VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 8b79dc1ce6b..7ea98b14fb2 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -21,6 +21,9 @@ #include #include +#include + +float eps = 1e-7; namespace torch { namespace executor { @@ -383,6 +386,8 @@ void test_reference_quantize_per_tensor( // Reshape back to original dimensions input = flat_input.reshape(input_sizes_int64); + scale = scale < eps ? eps : scale; + // Get reference output at::Tensor reference_out = quantize_per_tensor_reference_impl( input, scale, zero_point, quant_min, quant_max, dtype); @@ -435,6 +440,8 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor input = at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + scale = scale < eps ? eps : scale; + // Get reference output at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten( input, scale, zero_point, quant_min, quant_max, dtype); @@ -490,7 +497,7 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::equal(reference_int, vk_int); + const bool output_correct = at::allclose(reference_int, vk_int); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -500,6 +507,10 @@ void test_vulkan_quantize_per_tensor_impl( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -564,9 +575,89 @@ TEST( at::kInt); } +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32_small_scale) { + test_vulkan_quantize_per_tensor( + {2, 8, 1, 3}, // input sizes + 0.0, // scale + 20, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} + void test_reference_quantize_per_token( const std::vector& input_sizes, - const std::vector& scales, + const std::vector& pre_scales, const std::vector& zero_points, int64_t quant_min, int64_t quant_max, @@ -595,9 +686,14 @@ void test_reference_quantize_per_token( } // Verify that the number of tokens matches the size of scales and zero_points - ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, pre_scales.size()); ASSERT_EQ(num_tokens, zero_points.size()); + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + // Create scale and zero_point tensors at::Tensor scale_tensor = at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); @@ -646,7 +742,7 @@ void test_reference_quantize_per_token( void test_vulkan_quantize_per_token_impl( const std::vector& input_sizes, - const std::vector& scales, + const std::vector& pre_scales, const std::vector& zero_points, int64_t quant_min, int64_t quant_max, @@ -662,9 +758,14 @@ void test_vulkan_quantize_per_token_impl( num_tokens *= input_sizes[i]; } - ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, pre_scales.size()); ASSERT_EQ(num_tokens, zero_points.size()); + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + // Create input tensor with random values std::vector input_sizes_int64( input_sizes.begin(), input_sizes.end()); @@ -688,9 +789,15 @@ void test_vulkan_quantize_per_token_impl( IOValueRef r_input = graph.add_input_tensor( input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); const ValueRef r_quant_min = graph.add_scalar(quant_min); const ValueRef r_quant_max = graph.add_scalar(quant_max); @@ -744,7 +851,7 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::equal(reference_int, vk_int); + const bool output_correct = at::allclose(reference_int, vk_int); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -841,3 +948,130 @@ TEST( at::kHalf, at::kByte); } + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32) { + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32_small_scales) { + std::vector scales = { + 0, + 2.9387358770557188e-39f, + 1.40129846e-45f, + 1.17549435e-38f, + 0.0000000000001}; + std::vector zero_points = {20, -10, 15, 200, 50}; + + test_vulkan_quantize_per_token( + {5, 2}, // input sizes (3 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(18, 0.1); + std::vector zero_points(18, 5); + + // Alternate scale values + for (size_t i = 0; i < scales.size(); i++) { + scales[i] = (i % 2 == 0) ? 0.3 : -0.5; + } + + test_vulkan_quantize_per_token( + {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) + scales, + zero_points, + 0, // quant_min + 125, // quant_max + at::kFloat, + at::kByte); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp index 196f079be2c..c5702abd079 100644 --- a/backends/vulkan/test/op_tests/test_utils.cpp +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -94,7 +94,8 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { case c10::kInt: return vkapi::kInt; case c10::kLong: - return vkapi::kLong; + // No support for 64-bit integers + return vkapi::kInt; case c10::kChar: return vkapi::kChar; case c10::kByte: From 8fe89d6c0581708838af6def999dc4d4b46abcd3 Mon Sep 17 00:00:00 2001 From: morelos Date: Tue, 17 Jun 2025 10:06:35 -0700 Subject: [PATCH 2/3] [ET-VK][Ops] dequantization op shaders and impl Pull Request resolved: https://github.com/pytorch/executorch/pull/11483 # Operator Description The dequantization operator converts lower-precision integer tensors (uint8/int8/int32) back to floating-point formats (fp16/fp32) using affine dequantization. This operator supports two dequantization modes: - **Per-tensor dequantization**: Uses a single scale and zero_point for the entire tensor - **Per-token dequantization**: Uses different scale and zero_point values for each "token" (typically rows or channels) The dequantization formula is: `dequantized_value = (quantized_value - zero_point) * scale` **Example**: For a quantized uint8 value `153` with `scale=0.1`, `zero_point=128`: - `(153 - 128) * 0.1 = 25 * 0.1 = 2.5` (float output) The dequantization parameters serve these purposes: - **scale**: Controls the granularity of reconstruction (same scale used during quantization) - **zero_point**: Maps the integer zero representation back to floating-point zero - **quant_min/quant_max**: Define the valid range that was used during original quantization (for validation) # Shader Algorithm Overview ## Texture Storage Implementation (`dequantize_texture.glsl`) The texture-based implementation operates on 3D textures where data is stored in RGBA texel format (4 components per texel): **Per-tensor Mode**: Each compute thread processes one texel position. It loads a 4-component integer texel from the input texture, and applies dequantization to each of the 4 components using shared scale/zero_point parameters. It then writes the dequantized 4-component floating-point result to the output texture. This method processes all components uniformly with the same dequantization parameters. **Per-token Mode**: We need to calculate the token index based on the spatial position, it'll differ between various cases like 3D and 2D. For instance we might define the token_idx as `z * dims.y + y` for 3D, or just `y` for 2D cases. We then retrieve the per-token scale/zero_point from the texture storage according to the token_idx. We need to do component indexing based on the texel_idx and token_idx: `texel_idx = token_idx / 4`, along with the component id `comp_idx = token_idx % 4` to get the necessary scale/zero_point values. We then apply dequantization with the corresponding token-specific parameters to the 4 components of the current texel, converting each integer component to its floating-point representation. ## Buffer Storage Implementation (`dequantize_buffer.glsl`) The buffer-based implementation operates on linear memory buffers with stride-based indexing: **Per-tensor Mode**: In this case, each compute thread will process one element at its global position. It converts the 3D position to linear buffer indices using stride calculations `tidx_to_bufi(pos, strides)`. It then loads single quantized integer values from the input buffer and applies dequantization using shared scale/zero_point parameters. We then store the dequantized floating-point result to the output buffer at the corresponding index. **Per-token Mode**: We first calculate the logical tensor position from the linear buffer index through dimension unwrapping. We then determine the token index based on the tensor dimensionality: - 4D: `token_idx = w * (z * y) + z * y + y` - 3D: `token_idx = z * y + y` - 2D: `token_idx = y` We then directly index into scale/zero_point buffers using token_idx and apply dequantization with the token-specific parameters, converting the quantized integer value back to its original floating-point representation. # Performance Considerations / Future Improvements Current implementation uses default workgroup sizing. Buffer implementation processes one element per thread. Could be optimized to process multiple elements per thread for better throughput. NOTE: Currently the only input types supported are **byte** (uint8), **char** (int8), **int** (int32). The only output types supported are **half** (fp16) and **float** (fp32). A future diff plans to implement **double** (fp64) output dtype support. ghstack-source-id: 291010146 @exported-using-ghexport Differential Revision: [D76267107](https://our.internmc.facebook.com/intern/diff/D76267107/) --- .../runtime/graph/ops/glsl/dequantize.glslh | 16 + .../graph/ops/glsl/dequantize_buffer.glsl | 183 ++++++++++++ .../graph/ops/glsl/dequantize_buffer.yaml | 18 ++ .../graph/ops/glsl/dequantize_texture.glsl | 190 ++++++++++++ .../graph/ops/glsl/dequantize_texture.yaml | 18 ++ .../runtime/graph/ops/impl/Dequantize.cpp | 274 ++++++++++++++++++ .../vulkan/test/op_tests/dequantize_test.cpp | 245 +++++++++++++++- 7 files changed, 936 insertions(+), 8 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh new file mode 100644 index 00000000000..7194bebda35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef DEQUANTIZE_GLSLH +#define DEQUANTIZE_GLSLH + +OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { + return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); +} + +#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl new file mode 100644 index 00000000000..2a1f62719a0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * DEQUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer value from buffer + * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access + * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering + * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping + * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Dequantization Process: + * Input: -103 (int8) + * Step 1: qvalue - zero_point = -103 - (-128) = 25 + * Step 2: result * scale = 25 * 0.1 = 2.5 + * Output: 2.5 (float) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: value = (qvalue - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for element at tensor index (w, z, y, x): + * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y + * - 3D tensor: token_id = z * sizes.y + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + OUT_T value = dequantize_val(qvalue, scale, zero_point); + + t_out[out_bufi] = value; +} + +#else + +void dequantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = value; +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml new file mode 100644 index 00000000000..4e434935356 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -0,0 +1,18 @@ +dequantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_buffer + MODE: per_tensor + - NAME: dequantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl new file mode 100644 index 00000000000..cfc61dd1816 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * DEQUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer texel (4 values) from 3D texture + * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) for input/output textures + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * - Input/output textures: Must use standard axis mapping for per-token mode + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Texel Dequantization Process: + * Input Texel: [-103, -128, -123, -96] (int4) + * Per-component dequantization with scale=0.1, zero_point=-128: + * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 + * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 + * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 + * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 + * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: value[i] = (qvalue[i] - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for texel at position (x, y, z): + * - 3D tensor: token_id = z * texture_height + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale, zero_point); + outtex[i] = value; + } + write_texel(t_out, pos, outtex); +} + +#else + +void dequantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + FVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml new file mode 100644 index 00000000000..fc8c18468ed --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -0,0 +1,18 @@ +dequantize_texture: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor_texture3d + MODE: per_tensor + - NAME: dequantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp new file mode 100644 index 00000000000..77a51ce24f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_dequantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_dequantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void add_dequantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void dequantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_dequantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 7b155c8f98b..1ec0602a4f2 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -20,6 +20,7 @@ #include "test_utils.h" #include +#include #include #include @@ -481,6 +482,8 @@ void test_reference_dequantize_per_tensor( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -598,8 +601,15 @@ void test_vulkan_dequantize_per_tensor_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -611,6 +621,8 @@ void test_vulkan_dequantize_per_tensor_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -623,7 +635,6 @@ void test_vulkan_dequantize_per_tensor_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_tensor TEST( VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) { @@ -689,6 +700,99 @@ TEST( at::kHalf); // output dtype } +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {3, 4}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 4, 3, 12}, // input sizes + 0.0001, // scale + 100, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scale to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + test_vulkan_dequantize_per_tensor( + {7}, // input sizes + 1e-5, // scale (much smaller to avoid overflow) + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -793,6 +897,8 @@ void test_reference_dequantize_per_token( std::cout << "" << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -894,9 +1000,15 @@ void test_vulkan_dequantize_per_token_impl( IOValueRef r_input = graph.add_input_tensor( input.sizes().vec(), from_at_scalartype(dtype), in_storage); IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); const ValueRef r_quant_min = graph.add_scalar(quant_min); const ValueRef r_quant_max = graph.add_scalar(quant_max); @@ -946,8 +1058,15 @@ void test_vulkan_dequantize_per_token_impl( graph.copy_from_staging( staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - // Compare outputs - const bool output_correct = at::allclose(reference_out, vk_out); + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } if (!output_correct) { std::cout << "\n" << "Failed with parameters: " << std::endl; @@ -967,6 +1086,8 @@ void test_vulkan_dequantize_per_token_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -979,7 +1100,6 @@ void test_vulkan_dequantize_per_token_impl( ASSERT_TRUE(output_correct); } -// Test cases for dequantize_per_token TEST( VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_float) { @@ -1059,3 +1179,112 @@ TEST( at::kInt, // input dtype at::kHalf); // output dtype } + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_vulkan_dequantize_per_token( + {2, 3, 6}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.0}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_float) { + std::vector scales = { + 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; + std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; + + test_vulkan_dequantize_per_token( + {2, 2, 2, 12}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.2}; + std::vector zero_points = {2, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scales to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + std::vector scales = {1e-5, 2e-5, 1.5e-5}; + std::vector zero_points = {20, -15, 1}; + + test_vulkan_dequantize_per_token( + {3, 6}, // input sizes (3 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} From e44a3d19753071de79682c62ea844a08cabbe276 Mon Sep 17 00:00:00 2001 From: morelos Date: Tue, 17 Jun 2025 10:06:36 -0700 Subject: [PATCH 3/3] [ET-VK][Ops] choose_qparams op shaders and impl Pull Request resolved: https://github.com/pytorch/executorch/pull/11557 # Operator Description The choose_qparams operator computes optimal quantization parameters (scale and zero_point) from floating-point input tensors. This operator analyzes the statistical distribution of input data to determine the best quantization mapping for subsequent quantization operations. It supports two computation modes: - **Per-tensor quantization**: Computes a single scale and zero_point for the entire tensor based on global min/max values - **Per-token quantization**: Computes separate scale and zero_point values for each "token" (typically rows or channels) based on per-token min/max values The parameter calculation formulas are: - `scale = (data_max - data_min) / (quant_max - quant_min)` - `zero_point = clamp(round(quant_min - data_min/scale), quant_min, quant_max)` **Example**: For input data with `min=-2.5`, `max=7.3`, `quant_min=0`, `quant_max=255`: - `scale = (7.3 - (-2.5)) / (255 - 0) = 9.8 / 255 = 0.0384` - `zero_point = clamp(round(0 - (-2.5)/0.0384), 0, 255) = clamp(65, 0, 255) = 65` The quantization parameters serve these purposes: - **scale**: Determines the precision/granularity of the quantization mapping - **zero_point**: Ensures that floating-point zero maps to an exact integer value - **quant_min/quant_max**: Define the target quantization range (e.g., 0-255 for uint8) # Shader Algorithm Overview ## Texture Storage Implementation (`choose_qparams_texture.glsl`) The texture-based implementation uses a parallel reduction algorithms to efficiently compute min/max values across 3D textures with RGBA texel format: **Per-tensor Mode**: Each compute thread processes multiple texels using strided access patterns across the entire tensor. For each texel, it converts linear indices to 3D coordinates using `z = idx/(x*y), y = (idx%(x*y))/x, x = idx%x`, then loads 4-component texel data. The implementation validates each component against padding boundaries by calculating `valid_elements = min(4, remaining_elements)` to avoid processing padded data. Thread-local min/max reduction processes valid components while filtering NaN and infinity values. The algorithm then performs intra-workgroup reduction using shared memory arrays `shared_min[NWORKERS]` and `shared_max[NWORKERS]`. A tree reduction pattern halves the stride iteratively: `stride = workgroup_size/2; stride > 0; stride >>= 1`, combining results from `shared_min[local_id + stride]` with proper infinity handling. Finally, the master thread (local_id == 0, group_id == 0) computes the final scale and zero_point using the `calculate_scale_and_zero_point()` function and writes results to output textures. **Per-token Mode**: This mode implements a more complex multi-workgroup coordination strategy where each workgroup processes multiple tokens. The algorithm calculates `tokens_per_workgroup = (num_tokens + total_workgroups - 1) / total_workgroups` to distribute work evenly. For each assigned token, it determines the texel range using `token_start_texel = token_id * texels_per_token` and processes texels within that range using strided access `texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += workgroup_size`. The same padding validation and component processing logic applies, but scoped to the current token's data. After thread-local reduction, it performs the same tree reduction pattern within the workgroup. The master thread computes token-specific scale/zero_point and converts the linear token_id back to 3D output coordinates using `out_z = token_id/(x*y), out_y = (token_id%(x*y))/x, out_x = token_id%x` for writing results. Workgroup synchronization via `barrier()` ensures proper coordination between token processing iterations. ## Buffer Storage Implementation (`choose_qparams_buffer.glsl`) The buffer-based implementation operates on linear memory with simpler indexing but maintains the same parallel reduction strategy: **Per-tensor Mode**: Each compute thread processes multiple elements using strided access across the entire linear buffer: `for (i = global_id; i < total_elements; i += total_threads)`. Direct buffer access `t_in[i]` loads scalar values with NaN/infinity filtering. Thread-local min/max reduction accumulates valid values. The same shared memory tree reduction pattern applies: threads store results in `shared_min[local_id]` and `shared_max[local_id]`, then perform logarithmic reduction with stride halving. The master thread (local_id == 0) computes final parameters and directly writes to output buffers: `t_scale[0] = scale_val; t_zero_point[0] = zero_point_val`. **Per-token Mode**: This mode distributes tokens across workgroups using `tokens_per_workgroup = (num_tokens + total_workgroups - 1) / total_workgroups` for load balancing. Each workgroup processes its assigned token range `[start_token, end_token)`. For each token, it calculates the linear element range: `token_start = token_id * token_size; token_end = token_start + token_size`. Threads process elements within the token using strided access: `for (i = token_start + local_id; i < token_end; i += workgroup_size)`. The same thread-local reduction and shared memory tree reduction patterns apply, but scoped to the current token's data. The master thread computes token-specific parameters and writes directly to the output arrays: `t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val`. Workgroup synchronization ensures proper coordination between token processing iterations. # Performance Considerations / Future Improvements Current implementation uses a parallel reduction algorithms with shared memory optimization, but several areas offer improvement opportunities: - The tree reduction pattern achieves O(log N) complexity within workgroups, but the current implementation uses fixed 64-thread workgroups. Dynamic workgroup sizing based on tensor dimensions could improve occupancy. - Fixed 64-thread workgroups match the NWORKERS constant, but profiling different sizes (32, 128, 256) could reveal better performance characteristics for different tensor sizes and GPU architectures. NOTE: Currently the only input type supported is **float** (fp32). The output types are **float** for scale and **int** for zero_point. ghstack-source-id: 290041468 @exported-using-ghexport ghstack-source-id: 291010147 Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) --- .../graph/ops/glsl/choose_qparams.glslh | 70 +++ .../graph/ops/glsl/choose_qparams_buffer.glsl | 278 ++++++++++++ .../graph/ops/glsl/choose_qparams_buffer.yaml | 12 + .../ops/glsl/choose_qparams_texture.glsl | 398 ++++++++++++++++++ .../ops/glsl/choose_qparams_texture.yaml | 12 + .../runtime/graph/ops/impl/ChooseQParams.cpp | 347 +++++++++++++++ .../test/op_tests/choose_qparams_test.cpp | 96 +++++ 7 files changed, 1213 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh new file mode 100644 index 00000000000..66620e9b174 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CHOOSE_QPARAMS_GLSLH +#define CHOOSE_QPARAMS_GLSLH + +// equivalent of the eps defined in the cpu implementation +#define SMALL_SCALE_THRESHOLD 6.1e-5 + +// Calculate scale and zero point from min and max values +void calculate_scale_and_zero_point( + float min_val, + float max_val, + int qmin, + int qmax, + out float scale_val, + out int zero_point_val) { + // ensure we have zero included in our range + min_val = min(min_val, 0.0); + max_val = max(max_val, 0.0); + + scale_val = (max_val - min_val) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale_val == 0.0 || isinf(1.0 / scale_val)) { + scale_val = 0.1; + } + + // Cut off small scale + if (scale_val < SMALL_SCALE_THRESHOLD) { + float org_scale = scale_val; + scale_val = SMALL_SCALE_THRESHOLD; + + // Adjust min and max based on new scale + if (min_val == 0.0) { + max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else if (max_val == 0.0) { + min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + float zero_point_from_min = float(qmin) - min_val / scale_val; + float zero_point_from_max = float(qmax) - max_val / scale_val; + float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); + float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); + float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to integer + if (initial_zero_point < float(qmin)) { + zero_point_val = qmin; + } else if (initial_zero_point > float(qmax)) { + zero_point_val = qmax; + } else { + zero_point_val = int(round(initial_zero_point)); + } +} + +#endif // CHOOSE_QPARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl new file mode 100644 index 00000000000..dcbfe493f34 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -0,0 +1,278 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} +${layout_declare_ubo(B, "ivec4", "t_scale_strides")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * - Per-Token Mode: + * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses simple linear indexing through buffer elements + * - No axis mapping or packing considerations - processes elements sequentially + * - Works with any tensor layout since it accesses buffer data linearly + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor with strided access + * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing one token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup finds min/max within its assigned token + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + + // Each thread processes multiple elements with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + for (uint i = global_id; i < total_elements; i += total_threads) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only) + if (local_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[0] = scale_val; + t_zero_point[0] = zero_point_val; + } +} + +#else + +void choose_qparams_per_token() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + uint token_size = total_elements / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + // This handles the case where we have more tokens than workgroups + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Early exit if this workgroup has no tokens to process + if (start_token >= uint(num_tokens)) { + return; + } + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the start and end indices for this token + uint token_start = token_id * token_size; + uint token_end = token_start + token_size; + + // Each thread processes multiple elements within the token with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process elements within this token only + for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml new file mode 100644 index 00000000000..c37039f68e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -0,0 +1,12 @@ +choose_qparams_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_buffer + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl new file mode 100644 index 00000000000..282f1de170a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_scale_limits")} +${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: Default (typically {num_elements, 1, 1}) + * - Local WG Size: Default (typically {64, 1, 1}) + * - Per-Token Mode: + * - Global WG Size: Default (typically based on tensor dimensions) + * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with linear texel iteration + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - Note: Axis mapping support depends on indexing utilities + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor + * - Each thread processes multiple texels with stride + * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] + * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing subset of tokens + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup processes multiple tokens if num_tokens > num_workgroups + * - Within each token, threads process texels containing token elements + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Each thread processes multiple texels with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels with stride across all threads + for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only for reliability) + if (local_id == 0 && group_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); + } +} + +#else + +void choose_qparams_per_token() { + // Each token is processed by multiple workgroups for parallel reduction + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Calculate texels per token (assuming last dimension contains the token data) + // For per-token quantization, we assume tokens are along the last dimension + uint texels_per_token = total_texels / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the texel range for this token + uint token_start_texel = token_id * texels_per_token; + uint token_end_texel = token_start_texel + texels_per_token; + + // Each thread processes multiple texels within the token + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels within this token only + for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + // Convert token_id to 3D coordinates for output texture + // Assuming output tensors have the same layout as input but with different dimensions + uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); + uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); + uint out_y = out_remainder / uint(t_scale_limits.x); + uint out_x = out_remainder % uint(t_scale_limits.x); + ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); + + write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml new file mode 100644 index 00000000000..f3961b87a0f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -0,0 +1,12 @@ +choose_qparams_texture: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_texture3d + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp new file mode 100644 index 00000000000..1dc2d34afbf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace vkcompute { + +namespace { + +void resize_choose_qparams_tensor_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + + // Both scale and zero_point are scalar tensors for per-tensor quantization + // Since we use single workgroup approach, no extra buffer space needed + graph->virtual_resize(scale_out, {}); + graph->virtual_resize(zero_point_out, {}); +} + +void resize_choose_qparams_per_token_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + const ValueRef input = args.at(1).refs.at(0); + + // Calculate output sizes for scale and zero_point tensors + const auto input_sizes = graph->sizes_of(input); + std::vector output_sizes; + output_sizes.reserve(input_sizes.size() - 1); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + output_sizes.push_back(input_sizes[i]); + } + output_sizes.push_back(1); + + graph->virtual_resize(scale_out, output_sizes); + graph->virtual_resize(zero_point_out, output_sizes); +} + +// Custom workgroup size pickers for ChooseQParams operations +utils::uvec3 choose_qparams_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + // For per-tensor quantization, we want a single workgroup that can handle + // all elements with proper reduction. The shader uses NWORKERS=64 threads. + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use a single workgroup in X dimension + // The shader will handle strided access across all elements + return {1u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + // This ensures the shared memory arrays are properly sized + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +utils::uvec3 choose_qparams_per_token_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For per-token quantization, we need one workgroup per token + // Calculate number of tokens (product of all dimensions except the last + // one) + const auto input_sizes = graph->sizes_of(input); + int64_t num_tokens = 1; + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + return {static_cast(num_tokens), 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_per_token_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +} // namespace + +void add_choose_qparams_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_pick_global_wg_size, + choose_qparams_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_tensor_output)); +} + +void add_choose_qparams_per_token_asymmetric_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_per_token_asymmetric"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + int num_tokens_val = static_cast(num_tokens); + int quant_min_val = -128; // Fixed for asymmetric quantization + int quant_max_val = 127; // Fixed for asymmetric quantization + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_per_token_pick_global_wg_size, + choose_qparams_per_token_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_per_token_output)); +} + +void choose_qparams_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } + + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, scale_out, zero_point_out); +} + +void choose_qparams_per_token_asymmetric_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + add_choose_qparams_per_token_asymmetric_node( + graph, input, scale_out, zero_point_out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + choose_qparams_per_token_asymmetric.default, + choose_qparams_per_token_asymmetric_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 24c856e9d46..55e96151387 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -516,6 +516,58 @@ TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { at::kChar); } +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 3, 2, 4}, // input sizes + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 5}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {12, 8, 2}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {10, 10, 6, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + void test_reference_choose_qparams_per_token_asymmetric( const std::vector& input_sizes, at::ScalarType dtype) { @@ -673,3 +725,47 @@ TEST( {2, 3, 4}, // input sizes (2*3=6 tokens) at::kChar); } + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); +}