From 5e0b30a22d63fb75f25f1e82dcbc265449350f3b Mon Sep 17 00:00:00 2001 From: morelos Date: Mon, 9 Jun 2025 08:05:03 -0700 Subject: [PATCH] [ET-VK][Ops] dequantization op shaders and impl Creating the dequantize_per_tensor and dequantize_per_token logic shaders and impl which are linked with the testing framework. Differential Revision: [D76267107](https://our.internmc.facebook.com/intern/diff/D76267107/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/dequantize.glsl | 222 ++++++++++++++++ .../runtime/graph/ops/glsl/dequantize.yaml | 22 ++ .../runtime/graph/ops/impl/Dequantize.cpp | 242 ++++++++++++++++++ .../vulkan/test/op_tests/dequantize_test.cpp | 199 +++++++++++--- 4 files changed, 644 insertions(+), 41 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl new file mode 100644 index 00000000000..6b90565ca05 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glsl @@ -0,0 +1,222 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define IVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define FVEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +// Need this in order to properly handle overflow for dequantize_val +// since there is an inconsistency between the cpu logic +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, STORAGE)} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$else: + ${layout_declare_tensor(B, "r", "t_scale", "float", STORAGE)} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", STORAGE)} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_in_strides")} + ${layout_declare_ubo(B, "ivec4", "t_out_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_out_strides")} +$else: + ${layout_declare_ubo(B, "ivec3", "t_in_limits")} + ${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { + $if MODE == "per_tensor": + // out_data_ptr[i] = static_cast((input_data_ptr[i] - static_cast(zero_point)) * scale); + // -2147483648 - 100 = 2147483548 * 0.0001 > cast to float32 + OUT_T value = OUT_T(float(int(qvalue) - zero_point_val) * scale_val); + + $if MODE == "per_token": + // out_data_ptr[i] = static_cast((input_data_ptr[i] - zero_point) * scale); + // -2147483648 - 100 = -2147483748 * 0.0001 > cast to float32 + OUT_T value = OUT_T(float(int(qvalue) - int64_t(zero_point_val)) * scale_val); + + return value; +} + +#ifdef USING_BUFFER + +void main() { +$if MODE == "per_tensor": + const ivec4 pos = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z, + 0); + + const int t_in_idx = tidx_to_bufi(pos, t_in_strides); + const int t_out_idx = tidx_to_bufi(pos, t_out_strides); + + IN_T qvalue = t_in[t_in_idx]; + OUT_T value; + + value = dequantize_val(qvalue, scale, zero_point); + + t_out[t_out_idx] = value; + +$if MODE == "per_token": + const ivec4 pos = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z, + 0); + + const int t_in_idx = tidx_to_bufi(pos, t_in_strides); + const int t_out_idx = tidx_to_bufi(pos, t_out_strides); + + // Skip if out of bounds + if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) { + return; + } + + IN_T qvalue = t_in[t_in_idx]; + OUT_T value; + + // Calculate logical position from linear index and strides + ivec4 logical_pos; + int remaining = t_in_idx; + + logical_pos.x = remaining % t_in_sizes.x; + remaining /= t_in_sizes.x; + + logical_pos.y = remaining % t_in_sizes.y; + remaining /= t_in_sizes.y; + + logical_pos.z = remaining % t_in_sizes.z; + remaining /= t_in_sizes.z; + + logical_pos.w = remaining; + + // Calculate token index based on logical position + int token_idx = 0; + + // Check dimensions to determine how to calculate token_idx + if (t_in_sizes.w > 1) { + // 4D tensor + token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y; + } else if (t_in_sizes.z > 1) { + // 3D tensor + token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y; + } else if (t_in_sizes.y > 1) { + // 2D tensor + token_idx = logical_pos.y; + } + // For 1D tensor, token_idx remains 0 + + // Make sure token_idx is within bounds + token_idx = min(token_idx, num_tokens - 1); + + value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[t_out_idx] = value; +} + +#else + +void main() { +$if MODE == "per_tensor": + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale, zero_point); + outtex[i] = value; + } + write_texel(t_out, pos, outtex); + +$if MODE == "per_token": + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + // Make sure token_idx is within bounds + token_idx = min(token_idx, num_tokens - 1); + + // For texture storage, we need to calculate the texel position and component index + int texel_idx = token_idx / 4; + int comp_idx = token_idx % 4; + + vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0)); + ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0)); + + float scale_val = scale_vals[comp_idx]; + int zero_point_val = zp_vals[comp_idx]; + + FVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); + +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml new file mode 100644 index 00000000000..0b6f3f10d1e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.yaml @@ -0,0 +1,22 @@ +dequantize: + parameter_names_with_default_values: + IN_DTYPE: int + OUT_DTYPE: float + STORAGE: buffer + MODE: per_tensor + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: dequantize_per_tensor + MODE: per_tensor + - NAME: dequantize_per_token + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp new file mode 100644 index 00000000000..41db473154e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_dequantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr input = graph->get_tensor(args[1].refs[0]); + out->virtual_resize(input->sizes()); +} + +} // namespace + +void add_dequantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + utils::uvec3 global_size; + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + global_size = graph.create_global_wg_size(input); + + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + global_size = graph.logical_limits_of(input); + + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + const utils::uvec3 local_size = graph.create_local_wg_size(global_size); + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{input, vkapi::kRead}, {output, vkapi::kReadWrite}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void add_dequantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + utils::uvec3 global_size; + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + global_size = graph.create_global_wg_size(input); + + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + } else { + global_size = graph.logical_limits_of(input); + + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + } + + const utils::uvec3 local_size = graph.create_local_wg_size(global_size); + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{input, vkapi::kRead}, + {output, vkapi::kWrite}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Resize output tensor to match input tensor shape + graph.get_tensor(output)->virtual_resize(graph.sizes_of(input)); + + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void dequantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + // Resize output tensor to match input tensor shape + graph.get_tensor(output)->virtual_resize(graph.sizes_of(input)); + + add_dequantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index eb0d430ccc3..0762b00414e 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -121,7 +121,14 @@ at::Tensor dequantize_per_tensor_aten( executorch::aten::optional opt_et_out_dtype(et_out_dtype); WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) - (input, scale, zero_point, quant_min, quant_max, et_dtype, opt_et_out_dtype, out); + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); return out; } @@ -171,7 +178,14 @@ at::Tensor dequantize_per_token_aten( } WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) - (input, scale, zero_points, quant_min, quant_max, et_dtype, et_out_dtype, out); + (input, + scale, + zero_points, + quant_min, + quant_max, + et_dtype, + et_out_dtype, + out); return out; } @@ -179,7 +193,6 @@ at::Tensor dequantize_per_token_aten( } // namespace executor } // namespace torch - // // Test functions // @@ -352,7 +365,8 @@ at::Tensor dequantize_per_token_reference_impl( num_tokens *= input.size(i); } - // Verify that the number of tokens matches the size of scale and zero_point tensors + // Verify that the number of tokens matches the size of scale and zero_point + // tensors assert(num_tokens == scale.numel()); assert(num_tokens == zero_point.numel()); @@ -507,7 +521,8 @@ void test_reference_dequantize_per_tensor( } else if (dtype == at::kChar) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); } else if (dtype == at::kShort) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); } else if (dtype == at::kInt) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); } else { @@ -556,6 +571,8 @@ void test_reference_dequantize_per_tensor( std::cout << " zero_point: " << zero_point << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -589,7 +606,8 @@ void test_vulkan_dequantize_per_tensor_impl( } else if (dtype == at::kChar) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); } else if (dtype == at::kShort) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); } else if (dtype == at::kInt) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); } else { @@ -622,8 +640,9 @@ void test_vulkan_dequantize_per_tensor_impl( input = flat_input.reshape(input_sizes_int64); // Get reference output - at::Tensor reference_out = torch::executor::native::dequantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); // Build Vulkan dequantize_per_tensor graph using namespace vkcompute; @@ -684,6 +703,8 @@ void test_vulkan_dequantize_per_tensor_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -697,7 +718,9 @@ void test_vulkan_dequantize_per_tensor_impl( } // Test cases for dequantize_per_tensor -TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_to_float) { +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_float) { test_reference_dequantize_per_tensor( {2, 3, 4}, // input sizes 0.1, // scale @@ -708,7 +731,9 @@ TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_uint8_t at::kFloat); // output dtype } -TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int8_to_float) { +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int8_to_float) { test_reference_dequantize_per_tensor( {3, 4, 5}, // input sizes 0.05, // scale @@ -719,14 +744,43 @@ TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int8_to at::kFloat); // output dtype } -TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int16_to_float) { - test_reference_dequantize_per_tensor( - {2, 2, 3}, // input sizes - 0.001, // scale - -10, // zero_point - -32768, // quant_min - 32767, // quant_max - at::kShort, // input dtype +// Vulkan test cases for dequantize_per_tensor +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_uint8_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_float) { + test_vulkan_dequantize_per_tensor( + {3, 4, 5}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 4, 3}, // input sizes + 0.0001, // scale + 100, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype at::kFloat); // output dtype } @@ -756,7 +810,8 @@ void test_reference_dequantize_per_token( } else if (dtype == at::kChar) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); } else if (dtype == at::kShort) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); } else if (dtype == at::kInt) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); } else { @@ -798,11 +853,23 @@ void test_reference_dequantize_per_token( // Get reference output at::Tensor reference_out = dequantize_per_token_reference_impl( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); // Get implementation output at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); // Compare outputs const bool output_correct = at::allclose(reference_out, impl_out, 1e-5, 1e-5); @@ -821,6 +888,8 @@ void test_reference_dequantize_per_token( std::cout << "" << std::endl; std::cout << " quant_min: " << quant_min << std::endl; std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -861,7 +930,8 @@ void test_vulkan_dequantize_per_token_impl( } else if (dtype == at::kChar) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); } else if (dtype == at::kShort) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); } else if (dtype == at::kInt) { input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); } else { @@ -903,7 +973,13 @@ void test_vulkan_dequantize_per_token_impl( // Get reference output at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); // Build Vulkan dequantize_per_token graph using namespace vkcompute; @@ -927,14 +1003,14 @@ void test_vulkan_dequantize_per_token_impl( VK_GET_OP_FN("dequantize_per_token.default") (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_out, - }); + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); ValueRef staging_out = graph.set_output_tensor(r_out); @@ -988,6 +1064,8 @@ void test_vulkan_dequantize_per_token_impl( << (in_storage == vkcompute::utils::kBuffer ? "buffer" : "texture") << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; std::cout << "input:" << std::endl; std::cout << input << std::endl; @@ -1001,7 +1079,9 @@ void test_vulkan_dequantize_per_token_impl( } // Test cases for dequantize_per_token -TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_float) { +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_uint8_to_float) { std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; std::vector zero_points = {5, 10, 15, 20, 25, 30}; @@ -1015,7 +1095,9 @@ TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_ at::kFloat); // output dtype } -TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_int8_to_float) { +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_float) { std::vector scales = {0.05, 0.1, 0.15, 0.2}; std::vector zero_points = {0, -5, 5, 10}; @@ -1029,16 +1111,51 @@ TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_int8_to_f at::kFloat); // output dtype } -TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_int16_to_float) { - std::vector scales = {0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008}; - std::vector zero_points = {-10, 0, 10, 20, -20, -15, 15, 25}; +// Vulkan test cases for dequantize_per_token +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_uint8_to_float) { + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; - test_reference_dequantize_per_token( - {2, 4, 6}, // input sizes (2*4=8 tokens) + test_vulkan_dequantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.0}; + std::vector zero_points = {0, -5, 5, 10}; + + test_vulkan_dequantize_per_token( + {2, 2, 5}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_float) { + std::vector scales = {0.0001, 0.0002, 0.0003, 0.0}; + std::vector zero_points = {100, -100, 50, -50}; + + test_vulkan_dequantize_per_token( + {2, 2, 8}, // input sizes (2*2=4 tokens) scales, zero_points, - -32768, // quant_min - 32767, // quant_max - at::kShort, // input dtype + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype at::kFloat); // output dtype }