From 0c9c7a6b3f8882412e25227d051421f793cd97ca Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 4 Jun 2025 11:03:04 -0700 Subject: [PATCH] [ET-VK][Ops] quantization op shaders and impl Creating the quantize_per_tensor and quantize_per_token logic shaders and impl which are linked with the testing framework. NOTE: Currently the only input types supported are **half** (fp16) and **float** (fp32). The only output types supported are **byte** (uint8), **char** (int8), **short** (int16), **int** (int32). Differential Revision: [D75959064](https://our.internmc.facebook.com/intern/diff/D75959064/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/quantize.glsl | 236 ++++++++++++++++++ .../runtime/graph/ops/glsl/quantize.yaml | 24 ++ .../runtime/graph/ops/impl/Quantize.cpp | 229 +++++++++++++++++ .../vulkan/test/op_tests/quantize_test.cpp | 118 +++++++++ 4 files changed, 607 insertions(+) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Quantize.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize.glsl new file mode 100644 index 00000000000..65faa62bdaa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize.glsl @@ -0,0 +1,236 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define IVEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, STORAGE)} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$else: + ${layout_declare_tensor(B, "r", "t_scale", "float", STORAGE)} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", STORAGE)} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_in_strides")} + ${layout_declare_ubo(B, "ivec4", "t_out_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_out_strides")} +$else: + ${layout_declare_ubo(B, "ivec3", "t_in_limits")} + ${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) { + // Use int for all intermediate calculations to match CPU implementation + // which uses int64_t/int32_t for all calculations before final casting + int qvalue; + + if (scale_val == 0.0) { + // When scale is 0, CPU implementation would produce a very large value + // that gets clamped to quant_min or quant_max + if (value < 0.0) { + qvalue = quant_min; + } else if (value > 0.0) { + qvalue = quant_max; + } else { + qvalue = zero_point_val; // value is exactly 0 + } + } else { + float inv_scale = 1.0 / scale_val; + + float rounded_float = round(inv_scale * float(value)); + + // Convert to int and add zero point (all in signed integer space) + qvalue = zero_point_val + int(rounded_float); + } + + // Apply clamping in int space before final cast to output type + qvalue = max(qvalue, quant_min); + qvalue = min(qvalue, quant_max); + + // Only cast to output type at the very end + return OUT_T(qvalue); +} + +#ifdef USING_BUFFER + +void main() { +$if MODE == "per_tensor": + const ivec4 pos = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z, + 0); + + const int t_in_idx = tidx_to_bufi(pos, t_in_strides); + const int t_out_idx = tidx_to_bufi(pos, t_out_strides); + + IN_T value = t_in[t_in_idx]; + OUT_T qvalue; + + qvalue = quantize_val(value, scale, zero_point); + + t_out[t_out_idx] = qvalue; + +$if MODE == "per_token": + const ivec4 pos = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z, + 0); + + const int t_in_idx = tidx_to_bufi(pos, t_in_strides); + const int t_out_idx = tidx_to_bufi(pos, t_out_strides); + + // Skip if out of bounds + if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) { + return; + } + + IN_T value = t_in[t_in_idx]; + OUT_T qvalue; + + // Calculate logical position from linear index and strides + ivec4 logical_pos; + int remaining = t_in_idx; + + logical_pos.x = remaining % t_in_sizes.x; + remaining /= t_in_sizes.x; + + logical_pos.y = remaining % t_in_sizes.y; + remaining /= t_in_sizes.y; + + logical_pos.z = remaining % t_in_sizes.z; + remaining /= t_in_sizes.z; + + logical_pos.w = remaining; + + // Calculate token index based on logical position + int token_idx = 0; + + // Check dimensions to determine how to calculate token_idx + if (t_in_sizes.w > 1) { + // 4D tensor + token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y; + } else if (t_in_sizes.z > 1) { + // 3D tensor + token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y; + } else if (t_in_sizes.y > 1) { + // 2D tensor + token_idx = logical_pos.y; + } + // For 1D tensor, token_idx remains 0 + + // Make sure token_idx is within bounds + token_idx = min(token_idx, num_tokens - 1); + + qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[t_out_idx] = qvalue; +} + +#else + +void main() { +$if MODE == "per_tensor": + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale, zero_point); + outtex[i] = qvalue; + } + write_texel(t_out, pos, outtex); + +$if MODE == "per_token": + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + // Make sure token_idx is within bounds + token_idx = min(token_idx, num_tokens - 1); + + // For texture storage, we need to calculate the texel position and component index + int texel_idx = token_idx / 4; + int comp_idx = token_idx % 4; + + vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0)); + ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0)); + + float scale_val = scale_vals[comp_idx]; + int zero_point_val = zp_vals[comp_idx]; + + IVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); + +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize.yaml new file mode 100644 index 00000000000..7f5ee475baf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize.yaml @@ -0,0 +1,24 @@ +quantize: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int + STORAGE: buffer + MODE: per_tensor + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + IN_DTYPE: + - VALUE: half + - VALUE: float + - VALUE: double + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: short + - VALUE: int + shader_variants: + - NAME: quantize_per_tensor + MODE: per_tensor + - NAME: quantize_per_token + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp new file mode 100644 index 00000000000..8b3cb50d4e2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_quantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr input = graph->get_tensor(args[1].refs[0]); + out->virtual_resize(input->sizes()); +} + +} // namespace + +void add_quantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + utils::uvec3 global_size; + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + global_size = graph.create_global_wg_size(input); + + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + global_size = graph.logical_limits_of(input); + + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + const utils::uvec3 local_size = graph.create_local_wg_size(global_size); + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{input, vkapi::kRead}, {output, vkapi::kReadWrite}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void add_quantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + utils::uvec3 global_size; + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + global_size = graph.create_global_wg_size(input); + + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + } else { + global_size = graph.logical_limits_of(input); + + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + } + + const utils::uvec3 local_size = graph.create_local_wg_size(global_size); + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{input, vkapi::kRead}, + {output, vkapi::kWrite}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Resize output tensor to match input tensor shape + graph.get_tensor(output)->virtual_resize(graph.sizes_of(input)); + + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void quantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + // Resize output tensor to match input tensor shape + graph.get_tensor(output)->virtual_resize(graph.sizes_of(input)); + + add_quantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); + VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index fe591ac2829..dd97c287099 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -630,6 +630,47 @@ TEST(VulkanQuantizePerTensorTest, test_reference_quantize_per_tensor_int8) { 127, // quant_max at::kChar); } + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_tensor_uint8) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_tensor_int8) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_tensor_int16) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -32768, // quant_min + 32767, // quant_max + at::kShort); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_tensor_int32) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt); +} + void test_reference_quantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -854,3 +895,80 @@ TEST(VulkanQuantizePerTensorTest, test_reference_quantize_per_token_int8) { 127, // quant_max at::kChar); } + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_uint8) { + std::vector scales = {-0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, + 0.4}; std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, + -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_int8) { + std::vector scales = {-0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, + 0.4}; std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, + -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_int16) { +std::vector scales = {-0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, +0.4}; std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, +-12}; + +test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + -32768, // quant_min + 32767, // quant_max + at::kShort); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_int32) { + std::vector scales = {-0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, + 0.4}; std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, + -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_many_tokens) +{ + std::vector scales(18, 0.1); + std::vector zero_points(18, 5); + + // Alternate scale values + for (size_t i = 0; i < scales.size(); i++) { + scales[i] = (i % 2 == 0) ? 0.3 : -0.5; + } + + test_vulkan_quantize_per_token( + {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) + scales, + zero_points, + 0, // quant_min + 125, // quant_max + at::kByte); +}