From 6cb1faccdf3e62fa677f9ea7598e9e7d357cf141 Mon Sep 17 00:00:00 2001 From: morelos Date: Fri, 13 Jun 2025 15:49:22 -0700 Subject: [PATCH] [ET] enabling half dtype output for dequantization and making logic consistent Pull Request resolved: https://github.com/pytorch/executorch/pull/11552 # Context Currently the cpu implementation for the dequantization operator (which includes `dequantize_per_token`, `dequantize_per_tensor`, and `dequantize_per_channel`), does not inherently support half (fp16) input scalar types. In order to align with the PyTorch implementation that accepts fp16 and bfp16 inputs, this diff aims to enable half input dtype support for the quantization operators. We will be comparing this implementation against the vulkan operators. Furthermore, there is a casting bug when applying the zero_point, as only in the `dequantize_per_tensor` implementation does it cast the zero_point to int32, while comparitively for `dequantize_per_channel` and `dequantize_per_token` they do not cast the zero_point. In an environment that only supports 32bit integers, understandbly there will be some inconsistencies in dequantization logic as per_tensor will contain different overflow logic compared to its respective per_token and per_channel partner since the latter eliminates the overflow by utilizing a 64bit value. # Changes As defined in ExecuTorch [scalar_type_util.h](https://github.com/pytorch/executorch/blob/053686242c1687f0d51b3bb8befd14b047d7b025/runtime/core/exec_aten/util/scalar_type_util.h), the changes in this diff include adding a new macro `ET_FORALL_FLOATH_TYPES_WITH` to `util/scalar_type_util.h`, updating the `CALCULATE_INT_TYPE` macro to handle the new dtype. This enables support for Half (fp16), Float (fp32), and Double (fp64). I have also included more comprehensive testing against the input dtypes, including adding double testing since it didn't already exist before. Instead of just confirming that all the output dtypes are supported, we also have a check that all input dtypes are supported now as well. In order to align both dequantization implementations, we cast the zero_point to 32bit for both to maintain the overflow logic carried over from `dequantize_per_tensor`. ghstack-source-id: 290376483 @exported-using-ghexport Differential Revision: [D76289181](https://our.internmc.facebook.com/intern/diff/D76289181/) --- kernels/quantized/cpu/op_dequantize.cpp | 46 +++++----- kernels/quantized/test/op_dequantize_test.cpp | 90 +++++++++++++++++++ .../core/exec_aten/util/scalar_type_util.h | 5 ++ 3 files changed, 119 insertions(+), 22 deletions(-) diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index c1f2770d3d6..876099598dc 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out( static_cast(scale)); \ } \ } break; -#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { @@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out( } \ out_data_ptr[current_ix] = \ static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ + input_data_ptr[current_ix] - \ + static_cast(zero_point)) * \ _scale; \ } \ }, \ @@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out( apply_over_dim_list( \ [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ + (input_data_ptr[in_ix] - static_cast(_zero_point)) * \ + _scale); \ }, \ input, \ optional_dim_list, \ channel_ix); \ } \ break; -#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index bbda1590a10..4a0c195e3ab 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype(); } +/// Test all supported output dtypes for dequantization +template +void test_output_dtype() { + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 100); + double scale = 0.5; + int64_t zero_point = 30; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (100 - 30) * 0.5 = 35 + Tensor expected = tfo.full({3, 5}, 35); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(OUT_DTYPE), + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, AllOutputDtypesSupported) { + et_pal_init(); + test_output_dtype(); + test_output_dtype(); + test_output_dtype(); +} + +TEST(OpDequantizeOutTest, HalfOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (10 - 100000) * 0.5 = -49995 + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Half), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, DoubleOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Double), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpDequantizeOutTest, NonWholeNumbers) { et_pal_init(); TensorFactory tf; diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 6f81146e925..d81b3ad4d0f 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT, float, Float) \ _(ANOTHER_INPUT, double, Double) +#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, ::executorch::aten::Half, Half) + #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)