From 3ed3dd8593482f2702d239241181ba3d94b0dbcb Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 3 Jul 2025 11:17:13 -0700 Subject: [PATCH] [ET] correcting cpu ref quantize_per_channel logic to align with ATen # Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/) [ghstack-poisoned] --- kernels/quantized/cpu/op_quantize.cpp | 75 +++--- kernels/quantized/test/op_quantize_test.cpp | 240 ++++++++++++++++++++ 2 files changed, 271 insertions(+), 44 deletions(-) diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index d0b7c882f8e..b78348986c4 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -282,55 +282,42 @@ Tensor& quantize_per_channel_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // a list contains all dimensions except axis - int64_t dims[kTensorDimensionLimit]; - for (int64_t i = 0; i < input.dim() - 1; i++) { - if (i < axis) { - dims[i] = i; - } else { - dims[i] = i - 1; - } - } const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - std::optional> optional_dim_list{ - executorch::aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] + // High-performance single loop with direct channel calculation #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + const int64_t axis_size = input.size(axis); \ + \ + /* Calculate the stride pattern for efficient channel index calculation */ \ + int64_t axis_block_size = 1; \ + for (int64_t i = axis + 1; i < input.dim(); i++) { \ + axis_block_size *= input.size(i); \ } \ - break; + \ + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[i], \ + quant_min, \ + quant_max); \ + } \ + } break; + #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 5cd17223d80..4ac835c24ce 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 2}, 4); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {100, 50, 25}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 2}); + // Channel 0: 4 / 0.5 + 100 = 108 + // Channel 1: 4 / 1.0 + 50 = 54 + // Channel 2: 4 / 2.0 + 25 = 27 + Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27}); + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel3D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 3D tensor with axis=1 (middle dimension) + Tensor input = tf_float.full({2, 3, 4}, 6); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3, 4}); + // Channel 0: 6 / 0.5 + 10 = 22 + // Channel 1: 6 / 1.0 + 20 = 26 + // Channel 2: 6 / 1.5 + 30 = 34 + Tensor expected = tfo.make( + {2, 3, 4}, + { + 22, 22, 22, 22, // First batch, channel 0 + 26, 26, 26, 26, // First batch, channel 1 + 34, 34, 34, 34, // First batch, channel 2 + 22, 22, 22, 22, // Second batch, channel 0 + 26, 26, 26, 26, // Second batch, channel 1 + 34, 34, 34, 34 // Second batch, channel 2 + }); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel4D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W) + Tensor input = tf_float.full({2, 2, 3, 2}, 8); + Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2, 3, 2}); + // Channel 0: 8 / 0.25 + 0 = 32 + // Channel 1: 8 / 0.5 + 10 = 26 + // Channel 2: 8 / 1.0 + 20 = 28 + std::vector expected_data; + for (int n = 0; n < 2; n++) { + for (int c = 0; c < 2; c++) { + for (int h = 0; h < 3; h++) { + for (int w = 0; w < 2; w++) { + int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28; + expected_data.push_back(val); + } + } + } + } + Tensor expected = tfo.make({2, 2, 3, 2}, expected_data); + quantize_per_channel_out( + input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 3}, 5); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Using axis=-1 should be equivalent to axis=1 for 2D tensor + // Channel 0: 5 / 0.5 + 0 = 10 + // Channel 1: 5 / 1.0 + 10 = 15 + // Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5) + Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22}); + quantize_per_channel_out( + input, + scale, + zero_point, + -1, + quant_min, + quant_max, + ScalarType::Byte, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 1, 4}, 7); + Tensor scale = tf_double.make({1}, {0.5}); + Tensor zero_point = tf_long.make({1}, {128}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 1, 4}); + // Single channel: 7 / 0.5 + 128 = 142 + Tensor expected = tfo.full({3, 1, 4}, 142); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) { + TensorFactory tf_double_input; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_double_input.full({2, 2}, 3.14159); + Tensor scale = tf_double.make({2}, {0.01, 0.02}); + Tensor zero_point = tf_long.make({2}, {0, 100}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127 + // Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127 + Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 2}, 10); + Tensor scale = tf_double.make({2}, {1.0, 2.0}); + Tensor zero_point = tf_long.make({2}, {1000, 2000}); + int64_t quant_min = -32768; + int64_t quant_max = 32767; + + // Test with 16-bit output + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 10 / 1.0 + 1000 = 1010 + // Channel 1: 10 / 2.0 + 2000 = 2005 + Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005}); + quantize_per_channel_out( + input, + scale, + zero_point, + 1, + quant_min, + quant_max, + ScalarType::Short, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test with different input values per position + Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32] + // Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34] + Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test values that will exceed quant_min/quant_max bounds + Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0}); + Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 0, 0}); + int64_t quant_min = -10; + int64_t quant_max = 10; + + TensorFactory tfo; + Tensor out = tfo.zeros({1, 3}); + // Values: [-100, 0, 100] should be clamped to [-10, 0, 10] + Tensor expected = tfo.make({1, 3}, {-10, 0, 10}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +}