From 2e1cbf0a959c1dfaf7ff17489ffb773f681c5c3f Mon Sep 17 00:00:00 2001 From: morelos Date: Thu, 26 Jun 2025 09:56:19 -0700 Subject: [PATCH] [ET-VK][Ops] linear_qta8a_qga4w_qta8o test framework MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Context This test framework establishes the foundation for validating the `linear_qta8a_qga4w_qta8o` operator implementation as part of enabling dynamic quantization. The motivation stems from advancing beyond weight-only quantization to full activation and weight quantized linear operations, enabling true integer arithmetic throughout the matrix multiplication process for improved performance on GPU hardware. The current weight-only quantized linear implementations in ET-VK dequantize weights to floating point before computation, missing the performance benefits of integer arithmetic. This operator nomenclature breakdown: - **qta8a**: Quantized per-token affine 8-bit activation inputs - **qga4w**: Quantized per-group affine 4-bit weights - **qta8o**: Quantized per-token affine 8-bit outputs # Changes The reference implementation (`linear_qta8a_qga4w_qta8o_4bit_dequant_impl`) provides a baseline for validating the GPU shader implementation through a deliberately simplified computation path. The quantized int8 input tensor is dequantized using the standard affine transformation `(quantized_input.to(at::kFloat) - input_zero_point) * input_scale`. After dequantization, the implementation performs standard floating point linear operation `at::linear(x_float, weights_dequantized)`, then manually quantizes the result using `at::round(linear_result / output_scale) + output_zero_point` with clamping to the int8 range [-128,127]. This two-stage approach of dequantize → compute → quantize provides a clear reference against which the GPU's integer arithmetic implementation can be validated. Differential Revision: [D77173442](https://our.internmc.facebook.com/intern/diff/D77173442/) [ghstack-poisoned] --- .../linear_qta8a_qga4w_qta8o_test.cpp | 142 ++++++++++++++++++ backends/vulkan/test/op_tests/targets.bzl | 6 + 2 files changed, 148 insertions(+) create mode 100644 backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp diff --git a/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp new file mode 100644 index 00000000000..2496c256fb1 --- /dev/null +++ b/backends/vulkan/test/op_tests/linear_qta8a_qga4w_qta8o_test.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include "test_utils.h" + +#include + +at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_unpacked = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt)); + + const int64_t N = weights_unpacked.size(0); + const int64_t K = weights_unpacked.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + weights_unpacked[n][k] = int(first_val); + weights_unpacked[n][k + 1] = int(second_val); + } + } + + return weights_unpacked; +} + +at::Tensor linear_qta8a_qga4w_qta8o_4bit_dequant_impl( + const at::Tensor& quantized_input, + const at::Tensor& input_scale, + const at::Tensor& input_zero_point, + const at::Tensor& weights_4x2, + const int64_t group_size, + const at::Tensor& weight_scales_and_zeros, + const at::Tensor& output_scale, + const at::Tensor& output_zero_point) { + // Calculate number of input tokens + int64_t input_num_tokens = 1; + for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) { + input_num_tokens *= quantized_input.size(i); + } + + // Manually dequantize the char tensor using per-token quantization + at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat); + + // Apply per-token dequantization + auto input_accessor = quantized_input.accessor(); + auto output_accessor = x_float.accessor(); + + for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) { + float scale_val = input_scale[token_idx].item(); + int zero_point_val = input_zero_point[token_idx].item(); + + // Calculate batch and sequence indices for this token + int64_t b = token_idx / quantized_input.size(1); + int64_t m = token_idx % quantized_input.size(1); + + // Apply dequantization for all features in this token + for (int64_t k = 0; k < quantized_input.size(-1); k++) { + float dequant_val = + (input_accessor[b][m][k] - zero_point_val) * scale_val; + output_accessor[b][m][k] = dequant_val; + } + } + + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + const int group_idx = k / group_size; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = + weight_scales_and_zeros[group_idx][n][0].item().to(); + const float zero = + weight_scales_and_zeros[group_idx][n][1].item().to(); + + weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero; + weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero; + } + } + + at::Tensor linear_result = at::linear(x_float, weights_dequantized); + + // Calculate number of output tokens + int64_t output_num_tokens = 1; + for (size_t i = 0; i < linear_result.sizes().size() - 1; i++) { + output_num_tokens *= linear_result.size(i); + } + + // Quantize the result manually using per-token quantization + at::Tensor quantized_result = at::zeros_like(linear_result, at::kChar); + + // Apply per-token quantization + auto linear_accessor = linear_result.accessor(); + auto quant_accessor = quantized_result.accessor(); + + for (int64_t token_idx = 0; token_idx < output_num_tokens; token_idx++) { + float scale_val = output_scale[token_idx].item(); + int zero_point_val = output_zero_point[token_idx].item(); + + // Calculate batch and sequence indices for this token + int64_t b = token_idx / linear_result.size(1); + int64_t m = token_idx % linear_result.size(1); + + // Apply quantization for all features in this token + for (int64_t n = 0; n < linear_result.size(-1); n++) { + float quant_val = + std::round(linear_accessor[b][m][n] / scale_val) + zero_point_val; + quant_val = std::clamp(quant_val, -128.0f, 127.0f); + quant_accessor[b][m][n] = static_cast(quant_val); + } + } + + return quantized_result; +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 0d014c7ef29..76a8d5a95fd 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -210,6 +210,12 @@ def define_common_targets(is_fbcode = False): ":test_utils", ] ) + define_test_targets( + "linear_qta8a_qga4w_qta8o_test", + extra_deps = [ + ":test_utils", + ] + ) define_test_targets( "rotary_embedding_test", extra_deps = [