From 95ee132051ed3555abd81d0626cff1585bf258b6 Mon Sep 17 00:00:00 2001 From: morelos Date: Fri, 13 Jun 2025 15:49:29 -0700 Subject: [PATCH] [ET-VK][Ops] dequantize ops skeleton test framework Pull Request resolved: https://github.com/pytorch/executorch/pull/11480 # Context In this diff we plan on creating the skeleton test framework for dequantization. This is necessary as we need a reference to test our vulkan implementation of the dequantization operator against an existing cpu implementation. This test framework is heavily inspired by [sdpa_test.cpp](https://github.com/pytorch/executorch/blob/main/backends/vulkan/test/op_tests/sdpa_test.cpp). We make use of the [op_dequantize.cpp](https://github.com/pytorch/executorch/blob/main/kernels/quantized/cpu/op_dequantize.cpp) cpu implementation of the `dequantize_per_tensor`, and the `dequantize_per_token` operators. An explanation for the operator is included where the actual vulkan implementation is created in a future diff along this stack. # Changes The main thing in this difference is the creation of a new test framework `dequantize_test.cpp`, and also including it in targets.bzl such that we can properly call the test. As this is inspired by sdpa_test.cpp, we also follow a similar format. First we have forward declarations of the functions that we wish to test against (dequantize_per_tensor, and dequantize_per_token). Then we also have wrappers for the functions without context, and finally wrappers for the ATen implementations of the same operators using the `WRAP_TO_ATEN` macro. We don't need context as this is merely for testing. We also have a utility function to test the quantize arguments that will be used when actually using the vulkan implementation. This utility function is just for a sanity check. ghstack-source-id: 290376489 @exported-using-ghexport Differential Revision: [D76267021](https://our.internmc.facebook.com/intern/diff/D76267021/) --- .../vulkan/test/op_tests/dequantize_test.cpp | 182 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 9 + 2 files changed, 191 insertions(+) create mode 100644 backends/vulkan/test/op_tests/dequantize_test.cpp diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp new file mode 100644 index 00000000000..fe9b82f91d9 --- /dev/null +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +// Wrapper function for dequantize_per_tensor_out without context +Tensor& dequantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +// Wrapper function for dequantize_per_token_out without context +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +// ATen wrapper for dequantize_per_tensor +at::Tensor dequantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + +// ATen wrapper for dequantize_per_token +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, + scale, + zero_points, + quant_min, + quant_max, + et_dtype, + et_out_dtype, + out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_dequantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType in_dtype, + c10::ScalarType out_dtype) { + using namespace vkcompute; + + // Check that quant_min <= quant_max + VK_CHECK_COND( + quant_min <= quant_max, + "quant_min must be <= quant_max, got quant_min: ", + quant_min, + " quant_max: ", + quant_max); + + // Check that input dtype is a quantized type + switch (in_dtype) { + case c10::kByte: + case c10::kChar: + case c10::kShort: + case c10::kInt: + case c10::kLong: + break; + default: + VK_THROW( + "Unsupported input dtype: ", + scalar_type_name(in_dtype), + " (", + static_cast(in_dtype), + ")"); + } + + // Check that output dtype is a floating point type + switch (out_dtype) { + case c10::kHalf: + case c10::kFloat: + case c10::kDouble: + break; + default: + VK_THROW( + "Unsupported output dtype: ", + scalar_type_name(out_dtype), + " (", + static_cast(out_dtype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index cb5b49c3900..a22f2323896 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -186,6 +186,15 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/aten_util:aten_bridge", ] ) + define_test_targets( + "dequantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) define_test_targets( "linear_weight_int4_test", extra_deps = [