From 57b311aaecc13ba9939f02f5d04ddc91d9624420 Mon Sep 17 00:00:00 2001 From: morelos Date: Mon, 9 Jun 2025 08:04:53 -0700 Subject: [PATCH] [ET-VK][Ops] dequantize ops skeleton test framework Skeleton framework that is needed to build out the dequantize_per_tensor and dequantize_per_token operators based on cpu implementation Differential Revision: [D76267021](https://our.internmc.facebook.com/intern/diff/D76267021/) [ghstack-poisoned] --- .../vulkan/test/op_tests/dequantize_test.cpp | 296 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 8 + 2 files changed, 304 insertions(+) create mode 100644 backends/vulkan/test/op_tests/dequantize_test.cpp diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp new file mode 100644 index 00000000000..c5c1ba5c2e9 --- /dev/null +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -0,0 +1,296 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +// Wrapper function for dequantize_per_tensor_out without context +Tensor& dequantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +// Wrapper function for dequantize_per_token_out without context +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +// ATen wrapper for dequantize_per_tensor +at::Tensor dequantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype; + ScalarType et_out_dtype; + + switch (dtype) { + case at::kByte: + et_dtype = ScalarType::Byte; + break; + case at::kChar: + et_dtype = ScalarType::Char; + break; + case at::kShort: + et_dtype = ScalarType::Short; + break; + case at::kInt: + et_dtype = ScalarType::Int; + break; + case at::kLong: + et_dtype = ScalarType::Long; + break; + default: + throw std::runtime_error("Unsupported dtype"); + } + + switch (out_dtype) { + case at::kFloat: + et_out_dtype = ScalarType::Float; + break; + case at::kDouble: + et_out_dtype = ScalarType::Double; + break; + default: + throw std::runtime_error("Unsupported out_dtype"); + } + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) + (input, scale, zero_point, quant_min, quant_max, et_dtype, opt_et_out_dtype, out); + return out; +} + +// ATen wrapper for dequantize_per_token +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype; + ScalarType et_out_dtype; + + switch (dtype) { + case at::kByte: + et_dtype = ScalarType::Byte; + break; + case at::kChar: + et_dtype = ScalarType::Char; + break; + case at::kShort: + et_dtype = ScalarType::Short; + break; + case at::kInt: + et_dtype = ScalarType::Int; + break; + case at::kLong: + et_dtype = ScalarType::Long; + break; + default: + throw std::runtime_error("Unsupported dtype"); + } + + switch (out_dtype) { + case at::kFloat: + et_out_dtype = ScalarType::Float; + break; + case at::kDouble: + et_out_dtype = ScalarType::Double; + break; + default: + throw std::runtime_error("Unsupported out_dtype"); + } + + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, scale, zero_points, quant_min, quant_max, et_dtype, et_out_dtype, out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + + +// +// Test functions +// + +// Helper function to get the name of a ScalarType for better error messages +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + default: + return "Unknown(" + std::to_string(static_cast(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kFloat: + return vkapi::kFloat; + case c10::kHalf: + return vkapi::kHalf; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + // We don't have inherent vkapi::kLong, use kInt instead + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kDouble: + return vkapi::kDouble; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast(at_scalartype), + ")"); + } +} + +void check_dequantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType in_dtype, + c10::ScalarType out_dtype) { + using namespace vkcompute; + + // Check that quant_min <= quant_max + VK_CHECK_COND( + quant_min <= quant_max, + "quant_min must be <= quant_max, got quant_min: ", + quant_min, + " quant_max: ", + quant_max); + + // Check that input dtype is a quantized type + switch (in_dtype) { + case c10::kByte: + case c10::kChar: + case c10::kShort: + case c10::kInt: + case c10::kLong: + break; + default: + VK_THROW( + "Unsupported input dtype: ", + scalar_type_name(in_dtype), + " (", + static_cast(in_dtype), + ")"); + } + + // Check that output dtype is a floating point type + switch (out_dtype) { + case c10::kFloat: + case c10::kDouble: + break; + default: + VK_THROW( + "Unsupported output dtype: ", + scalar_type_name(out_dtype), + " (", + static_cast(out_dtype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index aa7a3b68ec2..f8da9b7e3e9 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -156,5 +156,13 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/aten_util:aten_bridge", ] ) + define_test_targets( + "dequantize_test", + extra_deps = [ + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) define_test_targets("linear_weight_int4_test") define_test_targets("rotary_embedding_test")