From 8e404c79f8c0f310fccf780b2e0ada1d87e47964 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 8 Jun 2026 22:15:45 -0700 Subject: [PATCH 01/15] [cuda] int4: coalesced scale/zero layout baked at pack time for W4A8 decode Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the [N, n_groups] layout into the weight constant at pack time. Introduces CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the [n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm dispatch on it by type, and adds the coalesced dp4a matvec kernel that reads scale/zero row-for-row with qdata (single coalesced load vs 32 stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32. Rebased onto main; INT8 dp4a decode op and the floor_div pass from this branch landed separately and now live in quantize_op_dispatch/. --- backends/cuda/coalesced_int4_tensor.py | 119 +++++++++++++++++ .../cuda/quantize_op_dispatch/__init__.py | 4 +- .../quantize_op_dispatch/int4_dispatch.py | 42 ++++-- backends/cuda/runtime/shims/int4_plain_mm.cu | 37 ++++- backends/cuda/runtime/shims/int4_plain_mm.cuh | 59 ++++---- .../test_aoti_torch_cuda_int4_plain_mm.cpp | 104 +++++++++++---- backends/cuda/tests/test_int4_dispatch.py | 126 +++++++++++++++++- examples/models/gemma4_31b/quant/pack_cuda.py | 21 ++- 8 files changed, 436 insertions(+), 76 deletions(-) create mode 100644 backends/cuda/coalesced_int4_tensor.py diff --git a/backends/cuda/coalesced_int4_tensor.py b/backends/cuda/coalesced_int4_tensor.py new file mode 100644 index 00000000000..a623f7f41c4 --- /dev/null +++ b/backends/cuda/coalesced_int4_tensor.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""ExecuTorch-internal INT4 tensor for the CUDA W4A8 dp4a decode kernel. + +``CudaCoalescedInt4Tensor`` is an ExecuTorch-internal tensor subclass. It is +**NOT** torchao's ``Int4Tensor`` and is intentionally not a subclass of it, so +torchao's ``Int4Tensor`` F.linear handlers never match it via the method +resolution order. The CUDA decode/prefill dispatch (``int4_dispatch.py``) is +selected by *type* — it is registered on this class only — so stock +``Int4Tensor`` weights keep falling back to torchao's default (mslk/tinygemm) +path. + +Layout difference from torchao ``Int4Tensor``: + qdata : packed int4 weight (N, K/2), nibble-packed (same as Int4Tensor) + scale : (N, n_groups) — the *coalesced* layout, transposed from + torchao's documented (n_groups, N) + zero_point : (N, n_groups) — coalesced, transposed from (n_groups, N) + +The coalesced [N, n_groups] layout is exactly what the W4A8 dp4a matvec kernel +(``executorch_cuda::int4_plain_mm`` / ``int4_plain_mm.cuh``) reads row-for-row +with qdata, so the exported decode graph carries no per-step transpose. The +transpose is owned by :meth:`from_int4_tensor` so it is baked into the +serialized weight constant once at pack time. +""" + +from typing import List, Optional + +import torch +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "CudaCoalescedInt4Tensor", +] + + +class CudaCoalescedInt4Tensor(TorchAOBaseTensor): + """INT4 weight with scale/zero_point in the coalesced [N, n_groups] layout. + + ExecuTorch-internal; see the module docstring. Mirrors torchao + ``Int4Tensor``'s data/attribute layout (so the common tensor utilities and + serialization work) but owns the [n_groups, N] -> [N, n_groups] transpose + of scale/zero_point via :meth:`from_int4_tensor`. + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["act_pre_scale"] + optional_tensor_attribute_names = ["activation_dtype"] + + def __new__( + cls, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + activation_dtype: Optional[torch.dtype] = None, + ): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: List[int], + shape: torch.Size, + act_pre_scale: Optional[torch.Tensor] = None, + activation_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.activation_dtype = ( + activation_dtype if activation_dtype is not None else torch.bfloat16 + ) + self.act_pre_scale = act_pre_scale + + def _quantization_type(self): + s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}, activation_dtype={self.activation_dtype}" + if self.act_pre_scale is not None: + s += f", act_pre_scale.shape={self.act_pre_scale.shape}" + return s + + @classmethod + def from_int4_tensor(cls, t: Int4Tensor) -> "CudaCoalescedInt4Tensor": + """Build a coalesced tensor from a torchao ``Int4Tensor``. + + Owns the transpose: torchao stores scale/zero_point as (n_groups, N); + the CUDA decode kernel reads (N, n_groups). The ``.t().contiguous()`` + here is baked into the serialized weight constant so the exported + decode graph has no per-step transpose/clone. + """ + return cls( + t.qdata, + t.scale.t().contiguous(), + t.zero_point.t().contiguous(), + t.block_size, + t.shape, + t.act_pre_scale, + t.activation_dtype, + ) + + +# Allow a model with CudaCoalescedInt4Tensor weights to be loaded with +# `weights_only=True` (mirrors torchao Int4Tensor). +torch.serialization.add_safe_globals([CudaCoalescedInt4Tensor]) diff --git a/backends/cuda/quantize_op_dispatch/__init__.py b/backends/cuda/quantize_op_dispatch/__init__.py index 2248ef0b5c1..005c2b6e7c7 100644 --- a/backends/cuda/quantize_op_dispatch/__init__.py +++ b/backends/cuda/quantize_op_dispatch/__init__.py @@ -10,8 +10,8 @@ weight tensors so that torch.export traces through ExecuTorch's custom ops and dequant logic instead of torchao's defaults. It registers: - * INT4 (``Int4Tensor``) → ``executorch_cuda::int4_plain_mm`` - * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` + * INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm`` + * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details. diff --git a/backends/cuda/quantize_op_dispatch/int4_dispatch.py b/backends/cuda/quantize_op_dispatch/int4_dispatch.py index 27f491fef06..c3b8921e2fe 100644 --- a/backends/cuda/quantize_op_dispatch/int4_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int4_dispatch.py @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Int4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. +"""CudaCoalescedInt4Tensor F.linear dispatch for CUDA — runs at eager / export trace time. -This module overrides Int4Tensor's F.linear dispatch so that torch.export -traces through our custom op and dequant logic instead of torchao's default -(mslk/tinygemm). The code here executes during eager inference and during -AOTI export tracing — it does NOT run at .pte runtime. +This module registers an F.linear dispatch on ``CudaCoalescedInt4Tensor`` (an +ExecuTorch-internal subclass, see ``coalesced_int4_tensor.py``) so that +torch.export traces through our custom op and dequant logic. Routing is by +*type*: stock torchao ``Int4Tensor`` weights are left untouched and keep using +torchao's default (mslk/tinygemm) path. The code here executes during eager +inference and during AOTI export tracing — it does NOT run at .pte runtime. At .pte runtime, the captured graph is executed by the AOTI-generated .so: - The custom op ``executorch_cuda::int4_plain_mm`` maps to a C shim that @@ -22,17 +24,17 @@ Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) Importing the parent ``quantize_op_dispatch`` package registers this dispatch -override (along with the INT8 one) before using nn.Linear with Int4Tensor -weights:: +override (along with the INT8 one) before using nn.Linear with +CudaCoalescedInt4Tensor weights:: import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 """ import torch import torch.nn.functional as F +from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib from torch.library import impl -from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor # --------------------------------------------------------------------------- # Custom op for decode (M=1): dp4a matvec in C shim, dequant+F.linear in eager @@ -52,11 +54,18 @@ def _meta(self, qdata, scale, zero, group_size): @impl(_lib, "int4_plain_mm", "CUDA") def _cuda(self, qdata, scale, zero, group_size): + # scale/zero are stored in the coalesced [N, n_groups] layout (transposed + # at pack time, see pack_cuda.pack_linear_for_cuda), which is exactly what + # _dequant_matmul expects. return _dequant_matmul(self, qdata, scale, zero, group_size) def _dequant_matmul(x, qdata, scale, zero, group_size): - """Dequant INT4 weights to input dtype and call F.linear.""" + """Dequant INT4 weights to input dtype and call F.linear. + + scale/zero are in the coalesced [N, n_groups] layout (baked into the + weight constant at pack time), aligned row-for-row with qdata's [N, *]. + """ N, K_half = qdata.shape K = K_half * 2 n_groups = K // group_size @@ -68,20 +77,20 @@ def _dequant_matmul(x, qdata, scale, zero, group_size): high = ((p >> 4) & 0x0F).to(dtype) data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) - s = scale.to(dtype).t().unsqueeze(-1) - z = zero.to(dtype).t().unsqueeze(-1) + s = scale.to(dtype).unsqueeze(-1) + z = zero.to(dtype).unsqueeze(-1) w_deq = ((data - z) * s).reshape(N, K) return F.linear(x, w_deq) # --------------------------------------------------------------------------- -# Int4Tensor F.linear dispatch +# CudaCoalescedInt4Tensor F.linear dispatch # --------------------------------------------------------------------------- aten = torch.ops.aten -_implements = Int4Tensor.implements -_implements_torch_function = Int4Tensor.implements_torch_function +_implements = CudaCoalescedInt4Tensor.implements +_implements_torch_function = CudaCoalescedInt4Tensor.implements_torch_function @_implements([aten.linear.default]) @@ -101,6 +110,11 @@ def _(func, types, args, kwargs): M = x_2d.shape[0] if M <= 4: + # scale/zero are already in the coalesced [N, n_groups] layout the + # decode kernel reads directly (baked into the weight constant at pack + # time). Passing them straight through keeps the export graph free of + # any per-step transpose/clone, so the coalesced layout is realized + # without recomputing it every decode step. out = torch.ops.executorch_cuda.int4_plain_mm(x_2d, qdata, scale, zero, gs) else: out = _dequant_matmul(x_2d, qdata, scale, zero, gs) diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cu b/backends/cuda/runtime/shims/int4_plain_mm.cu index fd8fe3b0c3b..7cda801c348 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cu +++ b/backends/cuda/runtime/shims/int4_plain_mm.cu @@ -52,8 +52,43 @@ AOTITorchError aoti_torch_cuda_int4_plain_mm( InvalidArgument, "aoti_torch_cuda_int4_plain_mm: ret0 is null"); + // Validate the coalesced scale/zero layout [N, K/group_size] + + const int64_t N = qdata->size(0); + const int64_t K = qdata->size(1) * 2; + + ET_CHECK_OR_RETURN_ERROR( + group_size > 0 && (group_size & (group_size - 1)) == 0, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: group_size=%lld must be a positive power of 2", + static_cast(group_size)); + + const int64_t n_groups = K / group_size; + + ET_CHECK_OR_RETURN_ERROR( + scale->dim() == 2 && zero->dim() == 2, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be 2D (got scale.dim()=%lld, zero.dim()=%lld)", + static_cast(scale->dim()), + static_cast(zero->dim())); + + ET_CHECK_OR_RETURN_ERROR( + scale->size(0) == N && zero->size(0) == N, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(0)=N=%lld, got scale.size(0)=%lld, zero.size(0)=%lld", + static_cast(N), + static_cast(scale->size(0)), + static_cast(zero->size(0))); + + ET_CHECK_OR_RETURN_ERROR( + scale->size(1) == n_groups && zero->size(1) == n_groups, + InvalidArgument, + "aoti_torch_cuda_int4_plain_mm: scale/zero must be coalesced [N, K/group_size] (AOT layout); native [n_groups, N] is not supported - repack via pack_linear_for_cuda. Expected size(1)=K/group_size=%lld, got scale.size(1)=%lld, zero.size(1)=%lld", + static_cast(n_groups), + static_cast(scale->size(1)), + static_cast(zero->size(1))); + int32_t M = self->size(0); - int32_t N = qdata->size(0); Tensor* C = nullptr; std::array c_shape = {M, N}; std::array c_stride = {N, 1}; diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 42700969fa4..31214bc0bf6 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -9,7 +9,7 @@ // W4A8 dp4a matvec for INT4 decode (M <= 4). // // Reads plain nibble-packed [N, K//2] weights (Int4Tensor format). -// Scale/zero layout: [K//gs, N] (Int4Tensor's native layout). +// Scale/zero layout: [N, K//gs] (transposed AOT for coalesced loads). // // Dynamically quantizes bf16 activations to INT8 (per-32-element blocks), // then uses dp4a for fused int4×int8 dot products with 16-byte vectorized @@ -98,18 +98,28 @@ __global__ void quantize_activations_q8_kernel( } // --------------------------------------------------------------------------- -// W4A8 dp4a matvec kernel +// Coalesced-scale W4A8 dp4a matvec +// +// Reads scale/zero in the transposed [N, n_groups] layout (transposed AOT at +// export time). With group_size >= 32, one uint4 (32 weights) maps to exactly +// one activation block and one weight group, so within a warp the 32 lanes +// touch 32 consecutive groups. In [N, n_groups] layout those 32 group scales +// are contiguous => a single coalesced load, vs 32 stride-N cache lines in the +// native layout. For the gemma group_size=32 weights this is the dominant +// decode-matvec cost. // --------------------------------------------------------------------------- -__global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( - const uint8_t* __restrict__ qdata, - const __nv_bfloat16* __restrict__ w_scale, - const __nv_bfloat16* __restrict__ w_zero, - const Q8Block* __restrict__ q8, - __nv_bfloat16* __restrict__ out, - int32_t N, - int32_t K, - int32_t gs_shift) { +__global__ void __launch_bounds__(MV_THREADS) + int4_w4a8_matvec_coalesced_kernel( + const uint8_t* __restrict__ qdata, + const __nv_bfloat16* __restrict__ w_scale_t, // [N, n_groups] + const __nv_bfloat16* __restrict__ w_zero_t, // [N, n_groups] + const Q8Block* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift, + int32_t n_groups) { const int32_t n = blockIdx.x * MV_NWARPS + threadIdx.y; const int32_t m = blockIdx.y; if (n >= N) @@ -120,9 +130,10 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE; const uint8_t* qrow = qdata + static_cast(n) * K_half; - const __nv_bfloat16* scale_base = w_scale + n; - const __nv_bfloat16* zero_base = w_zero + n; - const int32_t scale_stride = N; + const __nv_bfloat16* scale_row = + w_scale_t + static_cast(n) * n_groups; + const __nv_bfloat16* zero_row = + w_zero_t + static_cast(n) * n_groups; const Q8Block* q8_row = q8 + static_cast(m) * n_q8_blocks; const uint4* qrow16 = reinterpret_cast(qrow); @@ -145,8 +156,8 @@ __global__ void __launch_bounds__(MV_THREADS) int4_w4a8_matvec_kernel( int32_t g = k_word >> gs_shift; if (g != prev_g) { - ws = __bfloat162float(__ldg(&scale_base[g * scale_stride])); - wz = __bfloat162float(__ldg(&zero_base[g * scale_stride])); + ws = __bfloat162float(__ldg(&scale_row[g])); + wz = __bfloat162float(__ldg(&zero_row[g])); prev_g = g; } @@ -227,8 +238,8 @@ static Q8Block* get_q8_buffer(size_t needed) { void _int4_plain_mm_cuda( const Tensor& A, // [M, K] bf16 const Tensor& qdata, // [N, K//2] uint8 - const Tensor& scale, // [K//gs, N] bf16 - const Tensor& zero, // [K//gs, N] bf16 + const Tensor& scale, // [N, K//gs] bf16 + const Tensor& zero, // [N, K//gs] bf16 int64_t group_size, Tensor* output) { // [M, N] bf16, pre-allocated int32_t M = A.size(0); @@ -245,9 +256,9 @@ void _int4_plain_mm_cuda( ET_CHECK(qdata.dim() == 2); ET_CHECK(qdata.size(1) == K / 2); ET_CHECK(scale.dim() == 2); - ET_CHECK(scale.size(1) == N); + ET_CHECK(scale.size(0) == N); ET_CHECK(zero.dim() == 2); - ET_CHECK(zero.size(1) == N); + ET_CHECK(zero.size(0) == N); int32_t gs = static_cast(group_size); ET_CHECK_MSG( @@ -279,15 +290,15 @@ void _int4_plain_mm_cuda( // dp4a matvec dim3 grid((N + MV_NWARPS - 1) / MV_NWARPS, M); dim3 block(MV_WARP_SIZE, MV_NWARPS); - int4_w4a8_matvec_kernel<<>>( + + int32_t n_groups = static_cast(scale.size(1)); + int4_w4a8_matvec_coalesced_kernel<<>>( reinterpret_cast(qdata.data_ptr()), reinterpret_cast(scale.data_ptr()), reinterpret_cast(zero.data_ptr()), q8_buf, reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), - N, - K, - gs_shift); + N, K, gs_shift, n_groups); } } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp index ab18e33c713..de5fd9774e0 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int4_plain_mm.cpp @@ -70,6 +70,18 @@ class AOTITorchInt4PlainMMTest : public ::testing::Test { cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); } + // Transpose a uint16 [rows, cols] row-major buffer into [cols, rows]. + // Used to convert native [n_groups, N] scale/zero literals into the + // [N, n_groups] layout the shim now expects (transposed AOT at export). + static std::vector + transpose_u16(const uint16_t* src, int rows, int cols) { + std::vector dst(static_cast(rows) * cols); + for (int r = 0; r < rows; r++) + for (int c = 0; c < cols; c++) + dst[static_cast(c) * rows + r] = src[r * cols + c]; + return dst; + } + // Run the shim and return the output tensor (asserts success). Tensor* run( Tensor* A, @@ -111,7 +123,7 @@ class AOTITorchInt4PlainMMTest : public ::testing::Test { }; // MultiGroupRandom: M=1, N=4, K=32, gs=16 -// scale/zero layout: [K//gs=2, N=4] +// scale/zero layout: [N=4, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { int64_t M = 1, K = 32, N = 4, gs = 16; @@ -132,14 +144,17 @@ TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { uint16_t expected[] = {0xBFCC, 0x3FB5, 0x4046, 0xC01E}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -149,7 +164,7 @@ TEST_F(AOTITorchInt4PlainMMTest, MultiGroupRandom) { } // SingleGroup: M=1, N=8, K=32, gs=32 -// scale/zero layout: [K//gs=1, N=8] +// scale/zero layout: [N=8, K//gs=1] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { int64_t M = 1, K = 32, N = 8, gs = 32; @@ -178,14 +193,17 @@ TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { uint16_t expected[] = {0xC031, 0x3BF8, 0x3E81, 0xBF19, 0x3FCB, 0xBF56, 0x4076, 0x3F20}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -195,7 +213,7 @@ TEST_F(AOTITorchInt4PlainMMTest, SingleGroup) { } // PrefillBatch: M=8, N=4, K=64, gs=32 -// scale/zero layout: [K//gs=2, N=4] +// scale/zero layout: [N=4, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { int64_t M = 8, K = 64, N = 4, gs = 32; @@ -224,14 +242,17 @@ TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { uint16_t expected[] = {0x40BD, 0xC0E3, 0x4037, 0x40A9, 0x406F, 0x4116, 0x3F8D, 0xC01F, 0xC039, 0xC043, 0x3F86, 0x410A, 0x3F07, 0xC100, 0x4019, 0x40D7, 0x40A9, 0x40F1, 0xBF89, 0x406F, 0x40FE, 0xBFB8, 0xBF88, 0x406A, 0x4004, 0x3EDE, 0x3E17, 0x4102, 0xC081, 0xC0BA, 0xBFFB, 0x3F25}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -241,7 +262,7 @@ TEST_F(AOTITorchInt4PlainMMTest, PrefillBatch) { } // GroupSize128: M=1, N=2, K=256, gs=128 -// scale/zero layout: [K//gs=2, N=2] +// scale/zero layout: [N=2, K//gs=2] (transposed AOT) TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { int64_t M = 1, K = 256, N = 2, gs = 128; @@ -286,14 +307,17 @@ TEST_F(AOTITorchInt4PlainMMTest, GroupSize128) { uint16_t expected[] = {0xC013, 0xBF05}; // clang-format on + int64_t ng = K / gs; Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, ng}); + Tensor* zero = create_bf16({N, ng}); + auto scale_t = transpose_u16(scale_host, ng, N); + auto zero_t = transpose_u16(zero_host, ng, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -307,8 +331,8 @@ TEST_F(AOTITorchInt4PlainMMTest, NullInputHandling) { Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - Tensor* scale = create_bf16({K / gs, N}); - Tensor* zero = create_bf16({K / gs, N}); + Tensor* scale = create_bf16({N, K / gs}); + Tensor* zero = create_bf16({N, K / gs}); Tensor* output = nullptr; EXPECT_EQ( @@ -357,7 +381,7 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { 0x63, 0x9A, 0x95, 0x78, 0x95, 0x69, 0xF8, 0x58, 0x65, 0x0A, 0x6B, 0x47, 0x9C, 0x5C, 0x6A, 0x35, 0xA2, 0x8A, 0x74, 0x93, 0x28, 0x6D, 0xF0, 0xAB, 0x23, 0xA6, 0xA6, 0x3A}; - // scale/zero are [K//gs, N] = [2, 8] — Int4Tensor's native layout + // scale/zero are [N, K//gs] = [8, 2] — transposed AOT for the coalesced kernel uint16_t scale_host[] = { 0x3E46, 0x3E94, 0x3E8F, 0x3E94, 0x3E94, 0x3E8D, 0x3EA5, 0x3EA5, 0x3E9F, 0x3EAD, 0x3E91, 0x3EA0, 0x3E88, 0x3EB7, 0x3E89, 0x3E92}; @@ -380,13 +404,15 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { Tensor* A = create_bf16({M, K}); Tensor* qdata = create_uint8({N, K / 2}); - // Note: scale/zero shape is [n_groups, N], NOT [N, n_groups] - Tensor* scale = create_bf16({n_groups, N}); - Tensor* zero = create_bf16({n_groups, N}); + // scale/zero shape is [N, n_groups] (transposed AOT) + Tensor* scale = create_bf16({N, n_groups}); + Tensor* zero = create_bf16({N, n_groups}); + auto scale_t = transpose_u16(scale_host, n_groups, N); + auto zero_t = transpose_u16(zero_host, n_groups, N); upload(A, A_host, sizeof(A_host)); upload(qdata, qdata_host, sizeof(qdata_host)); - upload(scale, scale_host, sizeof(scale_host)); - upload(zero, zero_host, sizeof(zero_host)); + upload(scale, scale_t.data(), scale_t.size() * sizeof(uint16_t)); + upload(zero, zero_t.data(), zero_t.size() * sizeof(uint16_t)); Tensor* output = run(A, qdata, scale, zero, gs); ASSERT_NE(output, nullptr); @@ -395,3 +421,25 @@ TEST_F(AOTITorchInt4PlainMMTest, RealInt4TensorLayout) { // W4A8 adds quantization noise vs bf16 reference — use wider tolerance check_bf16_output(output, expected, M * N, 0.5f); } + +// RejectsNativeLayout: scale/zero passed in the un-transposed native +// [n_groups, N] layout (instead of the coalesced [N, n_groups] AOT layout) +// must be rejected gracefully with Error::InvalidArgument, not crash. +// K=64, gs=32 -> n_groups=2, N=8; native scale is [2, 8] while the shim +// expects coalesced [8, 2]. n_groups != N so the shape guard can catch it. +TEST_F(AOTITorchInt4PlainMMTest, RejectsNativeLayout) { + int64_t M = 1, K = 64, N = 8, gs = 32; + int64_t n_groups = K / gs; // 2 + + Tensor* A = create_bf16({M, K}); + Tensor* qdata = create_uint8({N, K / 2}); + // Native torchao layout [n_groups, N] = [2, 8], NOT the coalesced + // [N, n_groups] = [8, 2] the shim expects. + Tensor* scale = create_bf16({n_groups, N}); + Tensor* zero = create_bf16({n_groups, N}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int4_plain_mm(A, qdata, scale, zero, gs, &output), + Error::InvalidArgument); +} diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py index 51d573d33a3..fd748ae8584 100644 --- a/backends/cuda/tests/test_int4_dispatch.py +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -24,13 +24,21 @@ python -m pytest backends/cuda/tests/test_int4_dispatch.py -v """ +import contextlib import unittest +from unittest import mock import executorch.backends.cuda.quantize_op_dispatch.int4_dispatch # noqa: F401 import torch import torch.nn as nn import torch.nn.functional as F -from executorch.examples.models.gemma4_31b.quant.quantize import quantize_weight +from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor +from executorch.backends.cuda.quantize_op_dispatch.int4_dispatch import _dequant_matmul +from executorch.examples.models.gemma4_31b.quant.pack_cuda import pack_linear_for_cuda +from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + quantize_weight, +) from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig @@ -51,8 +59,9 @@ def _make_int4_linear(N, K, group_size=128, symmetric=False, bias=False): ) int4_w = quantize_weight(w_bf16, config) - module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") - module.weight = nn.Parameter(int4_w.cuda(), requires_grad=False) + module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16) + pack_linear_for_cuda(module, {"weight": int4_w}) + module.cuda() return module, w_bf16.cuda() @@ -174,7 +183,7 @@ def test_to_cuda(self): config = QuantConfig(bits=4, group_size=128, symmetric=False, method="min_max") int4_w = quantize_weight(w_bf16, config) module = nn.Linear(512, 256, bias=False) - module.weight = nn.Parameter(int4_w, requires_grad=False) + pack_linear_for_cuda(module, {"weight": int4_w}) module = module.to("cuda") x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") self._check(module(x), F.linear(x, w_bf16.cuda())) @@ -207,5 +216,114 @@ def test_21504x5376_prefill(self): self._check(module(x), F.linear(x, w_ref)) +def _make_int4_tensor(N, K, group_size=128, symmetric=False): + """Build a stock torchao ``Int4Tensor`` (NOT packed/coalesced) on CPU.""" + w = torch.randn(N, K, dtype=torch.bfloat16) + config = QuantConfig( + bits=4, group_size=group_size, symmetric=symmetric, method="min_max" + ) + return quantize_weight(w, config), w + + +@contextlib.contextmanager +def _record_int4_plain_mm(): + """Record calls to the decode custom op without needing a GPU. + + Replaces ``torch.ops.executorch_cuda.int4_plain_mm`` (whose real impl is the + CUDA C shim) with a recorder that computes the result via the eager CPU + dequant, so the dispatch handler still returns a valid tensor. + """ + calls = [] + + def _fake(self, qdata, scale, zero, group_size): + calls.append((tuple(self.shape), group_size)) + return _dequant_matmul(self, qdata, scale, zero, group_size) + + with mock.patch.object(torch.ops.executorch_cuda, "int4_plain_mm", _fake): + yield calls + + +class TestDispatchRouting(unittest.TestCase): + """Type-based routing: only CudaCoalescedInt4Tensor reaches int4_plain_mm. + + These tests run without a GPU by recording calls to the decode custom op + and computing the result with the eager CPU dequant. They guard the + comment-8 refactor: the CUDA decode path must be selected by weight *type*, + not by globally overriding torchao ``Int4Tensor``'s F.linear. + """ + + def setUp(self): + torch.manual_seed(0) + + def _rel_err(self, out, ref): + return ( + (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + ).item() + + def test_stock_int4tensor_does_not_route_to_int4_plain_mm(self): + """A plain torchao Int4Tensor must fall back to torchao's default path.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int4_plain_mm() as calls: + # torchao's default path uses mslk/CUDA and is not exercised on CPU; + # we only assert that our decode op is NOT reached. + with contextlib.suppress(Exception): + F.linear(x, t) + self.assertEqual(calls, []) + + def test_coalesced_tensor_routes_to_int4_plain_mm(self): + """CudaCoalescedInt4Tensor with M<=4 routes to the decode custom op.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(len(calls), 1) + self.assertEqual(out.shape, (1, 16)) + + def test_coalesced_tensor_prefill_uses_dequant(self): + """M>4 uses inline dequant (no custom op) and is numerically correct.""" + t, _ = _make_int4_tensor(16, 64, group_size=32) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + x = torch.randn(8, 64, dtype=torch.bfloat16) # M=8 > 4 (prefill regime) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(calls, []) + ref = F.linear(x, dequantize_weight(t, torch.bfloat16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_square_shape_not_misrouted(self): + """N == n_groups (square scale) stock tensor is still not routed. + + K = group_size * N makes scale square (n_groups == N); the old shape + heuristic could not distinguish this coalesced-looking case. Type-based + routing makes the scale shape irrelevant. + """ + t, _ = _make_int4_tensor(4, 128, group_size=32) + self.assertEqual(tuple(t.scale.shape), (4, 4)) # (n_groups, N), square + x = torch.randn(1, 128, dtype=torch.bfloat16) + with _record_int4_plain_mm() as calls: + with contextlib.suppress(Exception): + F.linear(x, t) + self.assertEqual(calls, []) + + def test_from_int4_tensor_transpose_correct(self): + """from_int4_tensor owns the (n_groups, N) -> (N, n_groups) transpose.""" + t, _ = _make_int4_tensor(24, 192, group_size=64) + c = CudaCoalescedInt4Tensor.from_int4_tensor(t) + n_groups = 192 // 64 + self.assertEqual(tuple(t.scale.shape), (n_groups, 24)) # torchao layout + self.assertEqual(tuple(c.scale.shape), (24, n_groups)) # coalesced layout + self.assertTrue(torch.equal(c.scale, t.scale.t().contiguous())) + self.assertTrue(torch.equal(c.zero_point, t.zero_point.t().contiguous())) + # End-to-end decode result matches a reference dequant of the original. + x = torch.randn(2, 192, dtype=torch.bfloat16) + with _record_int4_plain_mm() as calls: + out = F.linear(x, c) + self.assertEqual(len(calls), 1) + ref = F.linear(x, dequantize_weight(t, torch.bfloat16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 037c3bd8310..655d773e7b3 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -6,8 +6,10 @@ """CUDA packer: assign quantized weights to model modules. -Passes ``Int4Tensor`` and ``IntxUnpackedToInt8Tensor`` through as -``nn.Parameter`` without conversion. The quantize_op_dispatch package +Converts ``Int4Tensor`` weights to the ExecuTorch-internal +``CudaCoalescedInt4Tensor`` (which owns the scale/zero transpose to the +coalesced [N, n_groups] layout) and passes ``IntxUnpackedToInt8Tensor`` through +as ``nn.Parameter`` without conversion. The quantize_op_dispatch package (``int4_dispatch`` / ``int8_dispatch``) handles F.linear at runtime. No CUDA is required for packing. The backend-agnostic ``pack_model`` @@ -28,11 +30,24 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Assign a quantized weight to an ``nn.Linear`` module.""" + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] - if isinstance(w, (Int4Tensor, IntxUnpackedToInt8Tensor)): + if isinstance(w, Int4Tensor): + # Convert to the ExecuTorch-internal CudaCoalescedInt4Tensor, which + # repacks scale/zero from torchao's native [n_groups, N] layout into the + # coalesced [N, n_groups] layout the CUDA decode kernel reads (see + # int4_dispatch.py / int4_plain_mm.cuh). The transpose lives in + # CudaCoalescedInt4Tensor.from_int4_tensor, so it is baked into the + # serialized weight constant and the exported decode graph carries NO + # per-step transpose/clone — AOTInductor (freezing=False) does not + # constant-fold ops on parameters, so the transpose must already live in + # the constant for the coalesced layout to pay off. + w = CudaCoalescedInt4Tensor.from_int4_tensor(w) + module.weight = nn.Parameter(w, requires_grad=False) + elif isinstance(w, IntxUnpackedToInt8Tensor): module.weight = nn.Parameter(w, requires_grad=False) else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") From 888e1fa6ccfbc8257f9b93745ba27f889ba1aed3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sat, 6 Jun 2026 00:11:51 -0700 Subject: [PATCH 02/15] [cuda] decode SDPA: route to split-K via L_kv>=256 (cuda-graph-justified) + benchmark rework Summary: At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding partitions the KV sequence across many more CTAs to fill the GPU. In ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head dims keep the standard kernel). The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call partial-buffer alloc + extra-launch overhead that the graph runtime eliminates; under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward. benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery; run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch, Flash, Efficient, Math) with the PyTorch correctness check, across several decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing every backend through a small self-contained cuda-graph primitive; terminal-only output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4 discarded as warmup; N_RUNS=10, N_WARMUP=4). Test Plan: Exercised against the repo (PYTHONPATH) since the conda env's installed executorch is stale; a lib reinstall is required for the routing to take effect in a real export. backends/cuda/tests/test_sdpa_splitk_replacement.py - L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K; non-pow2 D=96 -> standard. backends/cuda/tests/test_triton_sdpa_splitk.py (14) and backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total. gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0, --cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98 tok/s (+16.0%); prefill within noise. python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 / PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 / ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's per-call overhead (the artifact behind the old 2048). --- backends/cuda/benchmarks/benchmark_sdpa.py | 372 +++++++++--------- .../tests/test_sdpa_splitk_replacement.py | 50 ++- backends/cuda/triton/replacement_pass.py | 23 +- 3 files changed, 235 insertions(+), 210 deletions(-) diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py index 3c117f4574f..0b95f736102 100644 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ b/backends/cuda/benchmarks/benchmark_sdpa.py @@ -6,16 +6,27 @@ # LICENSE file in the root directory of this source tree. """ -Benchmark the Triton SDPA kernel against PyTorch SDPA backends. - -Measures latency across decode shapes matching the Qwen3.5 MoE model -(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA -(2 KV heads), while Flash/Efficient/Math require pre-expanded KV -(16 heads) since they lack native GQA support. - +Benchmark the Triton SDPA kernels against PyTorch SDPA backends at decode. + +Cross-backend latency comparison ("is our kernel competitive vs PyTorch / +Flash?") across a few representative decode configs and the L_kv range, in BOTH +CUDA-graph and plain timing modes. The ET Triton kernels use native GQA; the +Flash/Efficient/Math backends require pre-expanded KV (no native GQA), matching +the test reference. PyTorch (default) is the correctness reference. + +Timing: CUDA-graph mode (capture+replay) is faithful to the deployed +``--cuda_graph`` runtime; plain ``do_bench`` charges each kernel its full +per-call launch/alloc overhead. Run both to see the effect (it is large for ET +split-K, which allocates partial buffers per call). + +Usage: + python benchmark_sdpa.py # both timing modes + python benchmark_sdpa.py --mode cudagraph + python benchmark_sdpa.py --mode plain """ import argparse +import statistics import warnings from functools import partial @@ -23,17 +34,67 @@ import torch.nn.functional as F from executorch.backends.cuda.triton.kernels.sdpa import ( - sdpa as triton_sdpa, - sdpa_decode_splitk as triton_splitk, + sdpa as _triton_sdpa, + sdpa_decode_splitk as _triton_splitk, ) from torch.nn.attention import sdpa_kernel, SDPBackend -from triton.testing import do_bench +from triton.testing import do_bench, do_bench_cudagraph + + +# -- Timing primitive + ET kernel runners (self-contained) ------------------- +# do_bench budgets are millisecond windows (NOT iteration counts). +_WARMUP_MS = 10 +_REP_MS = 50 +# Warmup calls before graph capture so the Triton autotuner has cached a config +# (autotuning cannot run inside graph capture). +_GRAPH_WARMUP_CALLS = 20 + + +def run_standard(q, k, v, attn_mask, enable_gqa): + return _triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def run_splitk(q, k, v, attn_mask, enable_gqa): + return _triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def time_us(fn, cudagraph: bool = True) -> float: + """Median latency (us). cudagraph=True is faithful to the --cuda_graph path. + + Under CUDA-graph the op is captured once (its split-K partial/LSE workspace + is allocated once into the graph's private pool and reused across replays) + and only replay() is timed, so the per-call buffer alloc + launch overhead + is excluded -- exactly as the deployed runtime eliminates it. We warm up + first so the Triton autotuner has cached a config before capture. + """ + if cudagraph: + for _ in range(_GRAPH_WARMUP_CALLS): + fn() + torch.cuda.synchronize() + ms = do_bench_cudagraph(fn, rep=_REP_MS, return_mode="median") + else: + ms = do_bench(fn, warmup=_WARMUP_MS, rep=_REP_MS, return_mode="median") + return ms * 1000.0 + + +# Each reported number repeats the timing primitive N_RUNS times, discards the +# first N_WARMUP as warmup, and reports mean +/- std over the remaining runs. +N_RUNS = 10 +N_WARMUP = 4 + + +def measure_us(fn, cudagraph: bool): + """Repeat time_us N_RUNS times; return (mean, std) over runs[N_WARMUP:].""" + samples = [time_us(fn, cudagraph=cudagraph) for _ in range(N_RUNS)] + kept = samples[N_WARMUP:] + mean = statistics.fmean(kept) + std = statistics.stdev(kept) if len(kept) > 1 else 0.0 + return mean, std # PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. -# We expand KV heads via repeat_interleave so they can run, matching what -# the test reference does. This is fair: it measures the kernel itself, not -# the GQA dispatch overhead. +# We expand KV heads via repeat_interleave so they can run, matching what the +# test reference does. This measures the kernel itself, not GQA dispatch. def _expand_kv(k, v, num_groups): @@ -49,21 +110,9 @@ def _expand_mask(mask, H_q): return mask -def _run_triton(q, k, v, attn_mask, enable_gqa): - return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_splitk(q, k, v, attn_mask, enable_gqa): - return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - enable_gqa=enable_gqa, + q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa ) @@ -75,50 +124,40 @@ def run(q, k, v, attn_mask, enable_gqa): return run -# Flash doesn't support attn_mask at all, only is_causal. -# Our benchmark mask is all-ones, so no mask is equivalent. +# Flash doesn't support attn_mask at all, only is_causal. Our benchmark mask is +# all-ones, so no mask is equivalent. def _run_flash(q, k, v, attn_mask, enable_gqa): with sdpa_kernel(SDPBackend.FLASH_ATTENTION): return F.scaled_dot_product_attention(q, k, v) +# ET Triton kernels reuse the shared helper runners (the real lowered kernels). BACKENDS = { - "triton": ("ET Triton (GQA)", _run_triton), - "splitk": ("ET Split-K (GQA)", _run_splitk), + "triton": ("ET Triton (GQA)", run_standard), + "splitk": ("ET Split-K (GQA)", run_splitk), "pytorch": ("PyTorch", _run_pytorch_default), - "flash": ("Flash (expanded KV)", _run_flash), + "flash": ("Flash (exp KV)", _run_flash), "efficient": ( - "Efficient (expanded KV)", + "Efficient (exp KV)", _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), ), - "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), + "math": ("Math (exp KV)", _make_pytorch_runner(SDPBackend.MATH)), } -# Backends that need KV heads expanded before calling (no native GQA support) +# Backends that need KV heads expanded before calling (no native GQA support). _NEEDS_KV_EXPAND = {"flash", "efficient", "math"} -# -- Shapes ------------------------------------------------------------------ - -# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 -QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} - -DECODE_SHAPES = [ - dict(**QWEN35_BASE, Lq=1, Lk=64), - dict(**QWEN35_BASE, Lq=1, Lk=128), - dict(**QWEN35_BASE, Lq=1, Lk=256), - dict(**QWEN35_BASE, Lq=1, Lk=512), - dict(**QWEN35_BASE, Lq=1, Lk=1024), - dict(**QWEN35_BASE, Lq=1, Lk=2048), - dict(**QWEN35_BASE, Lq=1, Lk=4096), - dict(**QWEN35_BASE, Lq=1, Lk=8192), - dict(**QWEN35_BASE, Lq=1, Lk=16384), +# Representative decode configs (label, B, H_q, H_kv, D). CTA = B * H_kv. +CONFIGS = [ + ("gemma sliding (D=256, CTA=16)", 1, 32, 16, 256), + ("qwen (D=256, CTA=2)", 1, 16, 2, 256), + ("head_dim=128 (D=128, CTA=16)", 1, 32, 16, 128), ] -SCENARIOS = { - "decode": DECODE_SHAPES, -} +L_KV_RANGE = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] -# -- Helpers ----------------------------------------------------------------- +# Cross-backend validation tolerance (bf16 vs bf16). +MAX_ABS_TOL = 1e-2 def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): @@ -128,7 +167,6 @@ def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) enable_gqa = H_q != H_kv num_groups = H_q // H_kv - # Pre-expanded versions for backends without native GQA k_exp, v_exp = _expand_kv(k, v, num_groups) mask_exp = _expand_mask(mask, H_q) return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa @@ -138,170 +176,132 @@ def _max_abs_error(out, ref): return (out.float() - ref.float()).abs().max().item() -# Cross-backend validation tolerance (bf16 vs bf16). -MAX_ABS_TOL = 1e-2 - - -def _bench_us(fn, num_warmup, num_iters): - """Return median latency in microseconds using triton.testing.do_bench.""" - ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - return ms * 1000.0 - - def _try_run(run_fn, q, k, v, mask, enable_gqa): - """Run a backend, returning output or None on failure.""" try: return run_fn(q, k, v, mask, enable_gqa) - except RuntimeError: + except Exception: return None -def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): - """Benchmark a backend, returning median us or None on failure.""" +def _try_bench(run_fn, q, k, v, mask, enable_gqa, cudagraph): + """Benchmark one backend, returning (mean_us, std_us) or None on failure.""" fn = partial(run_fn, q, k, v, mask, enable_gqa) try: run_fn(q, k, v, mask, enable_gqa) - return _bench_us(fn, num_warmup, num_iters) - except RuntimeError: + return measure_us(fn, cudagraph=cudagraph) + except Exception: return None -# -- Main -------------------------------------------------------------------- - - -def _shape_label(shape): - return ( - f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " - f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" - ) - - -def _short_label(shape, scenario="decode"): - return f"Lq={shape['Lq']},Lk={shape['Lk']}" +def _bench_inputs(name, q, k, v, k_exp, v_exp, mask, mask_exp): + """Return the (k, v, mask) a backend should use (expanded or native).""" + if name in _NEEDS_KV_EXPAND: + return k_exp, v_exp, mask_exp + return k, v, mask @torch.inference_mode() -def run_benchmark( - scenario: str = "decode", - num_warmup: int = 25, - num_iters: int = 100, -): - shapes = SCENARIOS[scenario] +def run_benchmark(cudagraph: bool): + """Print a cross-backend decode latency table for each config.""" backends = [(name, *BACKENDS[name]) for name in BACKENDS] + mode = "CUDA-graph (capture+replay)" if cudagraph else "plain do_bench" + device = torch.cuda.get_device_name() + n_sm = torch.cuda.get_device_properties(0).multi_processor_count - device_name = torch.cuda.get_device_name() print() - print("=" * 100) - print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") - print(f" Device: {device_name}") - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Build column specs: (header_text, unit_text, min_width) - # Each column gets width = max(len(header), len(unit), min_width) - max_label = max(len(_short_label(s, scenario)) for s in shapes) - col_specs = [("Shape", "", max(8, max_label))] - for _, label, _ in backends: - col_specs.append((label, "(us)", 8)) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) + print("=" * 124) + print(f"SDPA decode cross-backend benchmark | timing: {mode}") + print(f" device: {device} (n_SM={n_sm}) L_q=1, bf16, all-ones mask") + print(f" backends: {', '.join(label for _, label, _ in backends)}") + print( + f" each cell = mean+/-std us over last {N_RUNS - N_WARMUP} of {N_RUNS} " + f"runs ({N_WARMUP} warmup)" ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for shape in shapes: - q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Validate outputs across backends before benchmarking - outputs = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) - - # Use PyTorch F.sdpa as the trusted reference — never validate - # against our own Triton kernels. - ref_name, ref_out = None, None - if outputs.get("pytorch") is not None: - ref_name, ref_out = "pytorch", outputs["pytorch"] - - if ref_out is not None: - for name, label, _ in backends: - if name == ref_name or outputs[name] is None: - continue - err = _max_abs_error(outputs[name], ref_out) - assert err < MAX_ABS_TOL, ( - f"Output mismatch for {_shape_label(shape)}: " - f"{label} vs {BACKENDS[ref_name][0]}, " - f"max abs error {err:.3e} >= 1e-2" + print("=" * 124) + + for label, B, H_q, H_kv, D in CONFIGS: + print(f"\n{label} [B={B} H_q={H_q} H_kv={H_kv} D={D}]") + col_specs = [("L_kv", "", 6)] + [(lbl, "(us)", 13) for _, lbl, _ in backends] + widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] + header = " | ".join( + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" + for i, ((h, _, _), w) in enumerate(zip(col_specs, widths)) + ) + units = " | ".join( + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" + for i, ((_, u, _), w) in enumerate(zip(col_specs, widths)) + ) + print(" " + header) + print(" " + units) + print(" " + "-" * len(header)) + + for Lk in L_KV_RANGE: + q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors( + B, H_q, H_kv, 1, Lk, D + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Correctness: validate every backend against PyTorch (default). + outputs = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) + ref = outputs.get("pytorch") + if ref is not None: + for name, lbl, _ in backends: + if name == "pytorch" or outputs[name] is None: + continue + err = _max_abs_error(outputs[name], ref) + assert err < MAX_ABS_TOL, ( + f"Output mismatch {label} L_kv={Lk}: {lbl} vs PyTorch, " + f"max abs error {err:.3e} >= {MAX_ABS_TOL}" + ) + del outputs + + times = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + times[name] = _try_bench( + run_fn, q, bk, bv, bmask, enable_gqa, cudagraph ) - del outputs - # Benchmark all backends - times = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp + row = [f"{Lk:<{widths[0]}}"] + for ci, (name, _, _) in enumerate(backends, start=1): + t = times[name] + if t is not None: + cell = f"{t[0]:.1f}\u00b1{t[1]:.1f}" else: - bk, bv, bmask = k, v, mask - times[name] = _try_bench( - run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters - ) - - # Format row using col_widths - ci = 0 - row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del q, k, v, k_exp, v_exp, mask, mask_exp - torch.cuda.empty_cache() - - print("-" * len(header)) + cell = "N/A" + row.append(f"{cell:>{widths[ci]}}") + print(" " + " | ".join(row)) + + del q, k, v, k_exp, v_exp, mask, mask_exp + torch.cuda.empty_cache() print() def main(): parser = argparse.ArgumentParser( - description="Benchmark Triton SDPA vs PyTorch backends" + description="Benchmark Triton SDPA vs PyTorch backends (decode)" ) parser.add_argument( - "--scenario", - choices=list(SCENARIOS.keys()) + ["all"], - default="all", - help="Which shape set to benchmark (default: all)", + "--mode", + choices=["cudagraph", "plain", "both"], + default="both", + help="Timing mode(s) to run (default: both).", ) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) args = parser.parse_args() - scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] - for s in scenarios: - run_benchmark( - scenario=s, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) + if args.mode in ("cudagraph", "both"): + run_benchmark(cudagraph=True) + if args.mode in ("plain", "both"): + run_benchmark(cudagraph=False) if __name__ == "__main__": diff --git a/backends/cuda/tests/test_sdpa_splitk_replacement.py b/backends/cuda/tests/test_sdpa_splitk_replacement.py index 414a1308777..465b0b7ecf4 100644 --- a/backends/cuda/tests/test_sdpa_splitk_replacement.py +++ b/backends/cuda/tests/test_sdpa_splitk_replacement.py @@ -6,9 +6,9 @@ """Test ReplaceEdgeOpWithTritonOpPass split-K SDPA kernel selection. -Exports a minimal model containing F.scaled_dot_product_attention through -the CUDA backend and verifies that the pass routes to split-K for decode -(L_q=1, large L_kv) and standard SDPA otherwise. +Exports a minimal model containing F.scaled_dot_product_attention through the +CUDA backend and verifies that the pass routes to split-K for decode +(L_q==1, L_kv >= 256) and standard SDPA otherwise. """ import logging @@ -106,9 +106,9 @@ class TestSplitKReplacement(unittest.TestCase): def setUp(self): _require_cuda(self) - def test_large_kv_cache_uses_splitk(self): - """L_kv=4096 > threshold → split-K selected for decode.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + def test_below_threshold_uses_standard(self): + """L_kv=128 < threshold (256) -> standard SDPA, no split-K.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=128).to( torch.bfloat16 ) args = ( @@ -119,12 +119,17 @@ def test_large_kv_cache_uses_splitk(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") - self.assertIn("L_kv=4096", splitk[0]) + self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") - def test_small_kv_cache_uses_standard(self): - """L_kv=512 <= threshold → standard SDPA, no split-K.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=512).to( + replaced = [m for m in msgs if "Replaced" in m] + self.assertTrue( + any("1 nodes" in m for m in replaced), + f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + ) + + def test_at_threshold_uses_splitk(self): + """L_kv=256 == threshold -> split-K selected (boundary, inclusive).""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=256).to( torch.bfloat16 ) args = ( @@ -135,16 +140,27 @@ def test_small_kv_cache_uses_standard(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=256", splitk[0]) - replaced = [m for m in msgs if "Replaced" in m] - self.assertTrue( - any("1 nodes" in m for m in replaced), - f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + def test_large_kv_cache_uses_splitk(self): + """L_kv=4096 > threshold -> split-K selected for decode.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + torch.bfloat16 ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=4096", splitk[0]) def test_non_pow2_head_dim_uses_standard(self): - """Non-power-of-2 head_dim → standard SDPA even with large L_kv.""" + """Non-power-of-2 head_dim -> standard SDPA even with large L_kv.""" model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=96, kv_len=8192).to( torch.bfloat16 ) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index 628222e46f7..54c0377dccc 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,7 +27,14 @@ exir_ops.edge.aten.topk.default: triton.topk, } -_SPLITK_LKV_THRESHOLD = 2048 +# Decode (L_q==1) routes to split-K flash-decoding once L_kv >= this threshold. +# At decode, pack-GQA collapses the standard kernel grid to CTA = batch * +# n_kv_heads, which under-occupies the SMs; split-K partitions the KV sequence +# across many more CTAs to fill them. Under faithful CUDA-graph timing (the +# deployed --cuda_graph path) split-K wins ~1.2-20x for L_kv >= 256. The earlier +# 2048 value was overfit to a non-cuda-graph microbenchmark, which charged +# split-K a ~140us per-call alloc+launch overhead that cuda-graph removes. +_SPLITK_LKV_THRESHOLD = 256 class ReplaceEdgeOpWithTritonOpPass(PassBase): @@ -89,11 +96,13 @@ def call(self, graph_module: GraphModule) -> PassResult: def _pick_sdpa_kernel(node: Node): """Choose between standard SDPA and split-K flash-decoding. - Split-K partitions the KV sequence across many CTAs for better GPU - utilization at decode time (L_q=1). It wins when L_kv is large - (full-attention KV caches) but loses to the standard kernel for - small L_kv (sliding-window ring buffers) due to the overhead of - allocating partial buffers and running the reduction kernel. + At decode (L_q==1) the standard pack-GQA kernel's grid collapses to + CTA = batch * n_kv_heads, under-occupying the SMs. Split-K partitions + the KV sequence across many CTAs to fill the GPU. Under CUDA-graph + timing (the deployed --cuda_graph path) split-K wins ~1.2-20x for + L_kv >= 256, so we route decode to split-K whenever + L_kv >= _SPLITK_LKV_THRESHOLD. Prefill (L_q>1) and non-power-of-2 head + dims always use the standard kernel. """ q_shape = node.args[0].meta["val"].shape k_shape = node.args[1].meta["val"].shape @@ -104,7 +113,7 @@ def _pick_sdpa_kernel(node: Node): isinstance(L_q, int) and L_q == 1 and isinstance(L_kv, int) - and L_kv > _SPLITK_LKV_THRESHOLD + and L_kv >= _SPLITK_LKV_THRESHOLD and D > 0 and (D & (D - 1)) == 0 # power of 2 ): From 32e87c2c1d75c1d1674570ee22c04e2363d653fd Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 8 Jun 2026 21:39:21 -0700 Subject: [PATCH 03/15] update comment --- backends/cuda/triton/replacement_pass.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index 54c0377dccc..c55965a00e1 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,13 +27,7 @@ exir_ops.edge.aten.topk.default: triton.topk, } -# Decode (L_q==1) routes to split-K flash-decoding once L_kv >= this threshold. -# At decode, pack-GQA collapses the standard kernel grid to CTA = batch * -# n_kv_heads, which under-occupies the SMs; split-K partitions the KV sequence -# across many more CTAs to fill them. Under faithful CUDA-graph timing (the -# deployed --cuda_graph path) split-K wins ~1.2-20x for L_kv >= 256. The earlier -# 2048 value was overfit to a non-cuda-graph microbenchmark, which charged -# split-K a ~140us per-call alloc+launch overhead that cuda-graph removes. + _SPLITK_LKV_THRESHOLD = 256 @@ -96,13 +90,14 @@ def call(self, graph_module: GraphModule) -> PassResult: def _pick_sdpa_kernel(node: Node): """Choose between standard SDPA and split-K flash-decoding. - At decode (L_q==1) the standard pack-GQA kernel's grid collapses to - CTA = batch * n_kv_heads, under-occupying the SMs. Split-K partitions - the KV sequence across many CTAs to fill the GPU. Under CUDA-graph - timing (the deployed --cuda_graph path) split-K wins ~1.2-20x for - L_kv >= 256, so we route decode to split-K whenever - L_kv >= _SPLITK_LKV_THRESHOLD. Prefill (L_q>1) and non-power-of-2 head - dims always use the standard kernel. + Split-K partitions the KV sequence across many CTAs for better GPU + utilization at decode time (L_q=1). It wins when L_kv is large + (full-attention KV caches) but loses to the standard kernel for + small L_kv (sliding-window ring buffers) due to the overhead of + allocating partial buffers and running the reduction kernel. + + TODO(gasoonjia): Benchmarking to determine the optimal + implmentation for each shape. """ q_shape = node.args[0].meta["val"].shape k_shape = node.args[1].meta["val"].shape From 83f94ce673a57ba1c5385303423455229d99fbec Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sat, 6 Jun 2026 00:58:58 -0700 Subject: [PATCH 04/15] [cuda] int4 W4A8 matvec: vectorized activation load (16B-aligned Q8Block) The decode-only int4_plain_mm matvec was bound by activation load-instruction throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation loads + the same per-block scale d reloaded 4x. Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B halves are 16B-aligned, then load a whole activation block with two vectorized uint4 loads + one d load (~4x fewer activation loads). dp4a math and accumulation order are bit-identical; the int8 activation values and scale are unchanged. gemma4_31b decode (long-ctx harness, stacked on optimize_1): decode 43.98 -> 46.79 tok/s (+6.4%) prefill 1193 -> 1186 (noise; int4_plain_mm is decode-only) nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged. Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128). --- backends/cuda/runtime/shims/int4_plain_mm.cuh | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 31214bc0bf6..db54da91687 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -55,7 +55,11 @@ __host__ __forceinline__ int32_t log2_pow2(int32_t v) { // blocks) // --------------------------------------------------------------------------- -struct Q8Block { +// alignas(16) pads sizeof(Q8Block) to 48 so each block (and its qs_even/qs_odd +// 16-byte halves) is 16-byte aligned. This lets the matvec load a whole block's +// int8 activations with two vectorized uint4 loads instead of eight scalar +// int32 loads, cutting activation load instructions ~4x. +struct alignas(16) Q8Block { int8_t qs_even[Q8_BLOCK_SIZE / 2]; int8_t qs_odd[Q8_BLOCK_SIZE / 2]; float d; // scale @@ -149,6 +153,18 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t k_base = i * 32; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (32 weights) maps to exactly one Q8 activation block (32 + // activations), i.e. q8_block_idx == i. Load the whole block with two + // vectorized uint4 loads (+ one scale load) instead of eight scalar int32 + // loads. ae.{x,y,z,w} == qs_even[0:4],[4:8],[8:12],[12:16] == a_even for + // w=0..3 (same for ao/qs_odd) -> bit-identical to the scalar path. + const Q8Block* qb = &q8_row[i]; + uint4 ae = *reinterpret_cast(qb->qs_even); + uint4 ao = *reinterpret_cast(qb->qs_odd); + float a_scale = qb->d; + const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w}; + const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { uint32_t packed = words[w]; @@ -164,22 +180,11 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t vi_lo = packed & 0x0F0F0F0F; int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; - int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; - int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; - const Q8Block* qb = &q8_row[q8_block_idx]; - - int32_t a_even = - *reinterpret_cast(qb->qs_even + q8_half_offset); - int32_t a_odd = - *reinterpret_cast(qb->qs_odd + q8_half_offset); - - int32_t dp = __dp4a(vi_lo, a_even, 0); - dp = __dp4a(vi_hi, a_odd, dp); - - float a_scale = qb->d; + int32_t dp = __dp4a(vi_lo, static_cast(a_even[w]), 0); + dp = __dp4a(vi_hi, static_cast(a_odd[w]), dp); - int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); - a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + int32_t a_sum8 = __dp4a(0x01010101, static_cast(a_even[w]), 0); + a_sum8 = __dp4a(0x01010101, static_cast(a_odd[w]), a_sum8); sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum8)); From 457a316ba9600247f128c501b17d5ee11d2e4244 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 9 Jun 2026 09:42:09 -0700 Subject: [PATCH 05/15] int8 vec support --- backends/cuda/runtime/shims/int8_plain_mm.cuh | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/backends/cuda/runtime/shims/int8_plain_mm.cuh b/backends/cuda/runtime/shims/int8_plain_mm.cuh index 2c478854644..8458c7680b5 100644 --- a/backends/cuda/runtime/shims/int8_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int8_plain_mm.cuh @@ -58,7 +58,11 @@ __host__ __forceinline__ int32_t log2_pow2_i8(int32_t v) { // blocks, NATURAL order — qs[k] holds the quantized value for element k). // --------------------------------------------------------------------------- -struct Q8BlockNat { +// alignas(16) pads sizeof(Q8BlockNat) 36->48 so each block (and its two 16-byte +// qs halves) is 16-byte aligned. This lets the matvec load 16 int8 activations +// with one vectorized uint4 load instead of four scalar int32 loads, cutting +// activation load instructions ~4x. +struct alignas(16) Q8BlockNat { int8_t qs[Q8_NAT_BLOCK_SIZE]; float d; // scale }; @@ -135,6 +139,17 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( int32_t k_base = i * 16; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (16 int8 weights) maps to exactly one 16-byte half of a Q8 + // activation block (16 activations): block i>>1, byte offset 0 (i even) or + // 16 (i odd). Load those 16 int8 activations with a single vectorized uint4 + // load (+ one scale load) instead of four scalar int32 loads + four scale + // reloads. av.{x,y,z,w} == qs[off+0:4],[4:8],[8:12],[12:16] == a_word for + // w=0..3 -> bit-identical to the scalar path. + const Q8BlockNat* qb = &q8_row[i >> 1]; + uint4 av = *reinterpret_cast(qb->qs + ((i & 1) ? 16 : 0)); + float a_scale = qb->d; + const uint32_t a_words[4] = {av.x, av.y, av.z, av.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { int32_t k_word = k_base + w * 4; // 4 int8 weights start here @@ -147,15 +162,10 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( } int32_t w_word = static_cast(words[w]); - - int32_t q8_block_idx = k_word / Q8_NAT_BLOCK_SIZE; - int32_t q8_offset = k_word % Q8_NAT_BLOCK_SIZE; - const Q8BlockNat* qb = &q8_row[q8_block_idx]; - int32_t a_word = *reinterpret_cast(qb->qs + q8_offset); + int32_t a_word = static_cast(a_words[w]); int32_t dp = __dp4a(w_word, a_word, 0); int32_t a_sum = __dp4a(0x01010101, a_word, 0); - float a_scale = qb->d; sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum)); From ec3863c9da98e5d61c763cd935fdf8e37d5fa759 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 9 Jun 2026 22:03:05 -0700 Subject: [PATCH 06/15] [cuda][prefill] window-aware SDPA: skip fully-masked KV blocks (idea #1) Block-sparse early-exit in _sdpa_fwd_kernel_body: skip KV blocks that are entirely masked (sliding-window via HAS_MASK sum==0, causal via start_n>max_seq_pos). Exact (skipped blocks are x1,+0 no-ops). Prefill +46-88% all lengths; decode safe; SDPA nsys 58.1%->18.5%. Numerically bf16-exact vs dense+mask (unit test). --- backends/cuda/triton/kernels/sdpa.py | 102 ++++++++++++++++----------- 1 file changed, 62 insertions(+), 40 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 9f42a474b36..fb665e538bf 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body( offs_n_init = tl.arange(0, BLOCK_N) + # Window-aware early-exit. A KV block that is fully masked (sliding-window + # or causal) contributes nothing to the online softmax — every entry is + # -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up + # front and skip their K/V loads and both matmuls. This is exact: it only + # skips work the mask would have zeroed out anyway. At seq=2048 the 50 + # sliding-window(1024) layers and the 10 causal layers each leave roughly + # half (or more) of their KV blocks fully masked, so this is a large cut to + # the dominant prefill cost. The skip condition is a CTA-wide reduction, so + # the branch is uniform and turns into a real skip (not predication). + if IS_CAUSAL: + max_seq_pos = tl.max(seq_pos) + for start_n in tl.range(0, Lk, BLOCK_N): offs_n = start_n + offs_n_init - # K load: uniform (single KV head, shared across all Q heads in tile) - k_ptrs = K_ptr + ( - b * stride_kb - + h_kv * stride_kh - + (offs_n[:, None] * stride_kn) - + (offs_d[None, :] * stride_kd) - ) - k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) - k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - - qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) - + # Decide whether any row in this tile actually attends to this KV block. if HAS_MASK: mask_ptrs = Mask_ptr + ( b * stride_mb @@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body( ) mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) - qk = tl.where( - mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) + block_active = tl.sum(mask_block.to(tl.int32)) > 0 + elif IS_CAUSAL: + # Block is entirely in the future for every row -> skip. + block_active = start_n <= max_seq_pos + else: + block_active = True + + if block_active: + # K load: uniform (single KV head, shared across Q heads in tile) + k_ptrs = K_ptr + ( + b * stride_kb + + h_kv * stride_kh + + (offs_n[:, None] * stride_kn) + + (offs_d[None, :] * stride_kd) ) + k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - if IS_CAUSAL: - causal = offs_n[None, :] > seq_pos[:, None] - qk = tl.where( - causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk - ) + qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) - m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) - safe_diff = tl.where( - m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") - ) - p_f32 = tl.exp(safe_diff).to(tl.float32) - l_ij = tl.sum(p_f32, axis=1).to(tl.float32) - safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) - alpha = tl.exp(safe_alpha_diff).to(tl.float32) + if HAS_MASK: + qk = tl.where( + mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) + ) - # V load: uniform (single KV head) - v_ptrs = V_ptr + ( - b * stride_vb - + h_kv * stride_vh - + (offs_n[:, None] * stride_vn) - + (offs_d[None, :] * stride_vd) - ) - v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) - v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + if IS_CAUSAL: + causal = offs_n[None, :] > seq_pos[:, None] + qk = tl.where( + causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk + ) - p_bf16 = p_f32.to(tl.bfloat16) - acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) - l_i = (l_i * alpha + l_ij).to(tl.float32) - m_i = m_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) + safe_diff = tl.where( + m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") + ) + p_f32 = tl.exp(safe_diff).to(tl.float32) + l_ij = tl.sum(p_f32, axis=1).to(tl.float32) + safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) + alpha = tl.exp(safe_alpha_diff).to(tl.float32) + + # V load: uniform (single KV head) + v_ptrs = V_ptr + ( + b * stride_vb + + h_kv * stride_vh + + (offs_n[:, None] * stride_vn) + + (offs_d[None, :] * stride_vd) + ) + v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + + p_bf16 = p_f32.to(tl.bfloat16) + acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) + l_i = (l_i * alpha + l_ij).to(tl.float32) + m_i = m_ij inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) acc = acc * inv_l_i[:, None] From 79d5cdfbbaf4e8bdb9bdb3e3289ead89a3eeaceb Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 11 Jun 2026 21:30:16 -0700 Subject: [PATCH 07/15] [cuda] GGUF Q6_K real packed INT6 (W6A8 dp4a) + GGUF CI export Add a genuine 6-bit packed weight path for GGUF Q6_K on the CUDA backend, parallel to the int4/int8 plain_mm paths: - int6_plain_mm CUDA shim (W6A8 dp4a; ql/qh planes; spread2; -32 symmetric offset) - CudaPackedInt6Tensor (ql/qh + per-group bf16 scale; symmetric, no zero tensor) - int6_dispatch: F.linear routing (M<=4 -> executorch_cuda::int6_plain_mm op, M>4 -> dequant) - backend fallback-kernel + custom_ops_to_c_shims registration; CMake build - GGUF Q6_K: gguf_loader returns the native torchao IntxUnpackedToInt8Tensor and the backend packer (pack_cuda.pack_linear_for_cuda) repacks a symmetric Q6_K weight into CudaPackedInt6Tensor -- mirroring Int4Tensor -> CudaCoalescedInt4Tensor, so the loader stays backend-agnostic; dequantize_weight handles the tied embedding - tests: int6 gtest, test_int6_dispatch.py, pack round-trip; fix stale int4/int6 type asserts CI (export_model_artifact.sh, gemma4_31b): download the Q4_K_M GGUF from unsloth/gemma-4-31B-it-GGUF (tokenizer from unsloth/gemma-4-31B-it) and run the inference sanity check + export via the GGUF loader (--gguf) instead of the prequantized HF checkpoint. Signed-off-by: gasoonjia --- .ci/scripts/export_model_artifact.sh | 20 +- backends/cuda/CMakeLists.txt | 1 + backends/cuda/cuda_backend.py | 7 + backends/cuda/packed_int6_tensor.py | 209 +++++++++++ .../cuda/quantize_op_dispatch/__init__.py | 5 +- .../cuda/quantize_op_dispatch/_library.py | 7 +- .../quantize_op_dispatch/int6_dispatch.py | 116 ++++++ backends/cuda/runtime/shims/int6_plain_mm.cu | 81 ++++ backends/cuda/runtime/shims/int6_plain_mm.cuh | 353 ++++++++++++++++++ backends/cuda/runtime/shims/int6_plain_mm.h | 61 +++ .../cuda/runtime/shims/tests/CMakeLists.txt | 5 +- .../test_aoti_torch_cuda_int6_plain_mm.cpp | 306 +++++++++++++++ backends/cuda/tests/test_int6_dispatch.py | 226 +++++++++++ examples/models/gemma4_31b/gguf_loader.py | 12 +- examples/models/gemma4_31b/quant/pack_cuda.py | 39 +- examples/models/gemma4_31b/quant/quantize.py | 7 + .../gemma4_31b/quant/tests/test_pack_cuda.py | 104 ++++++ .../gemma4_31b/tests/test_cuda_pipeline.py | 11 +- 18 files changed, 1545 insertions(+), 25 deletions(-) create mode 100644 backends/cuda/packed_int6_tensor.py create mode 100644 backends/cuda/quantize_op_dispatch/int6_dispatch.py create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.cu create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.cuh create mode 100644 backends/cuda/runtime/shims/int6_plain_mm.h create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp create mode 100644 backends/cuda/tests/test_int6_dispatch.py diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index db447bb907f..e9218dce625 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -467,21 +467,27 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then exit 0 fi -# Gemma 4 31B uses a prequantized checkpoint and custom export script +# Gemma 4 31B: download the Q4_K_M GGUF and export via the GGUF loader if [ "$MODEL_NAME" = "gemma4_31b" ]; then pip install safetensors huggingface_hub gguf - # Download prequantized model outside OUTPUT_DIR to avoid uploading on failure + # Download GGUF + tokenizer outside OUTPUT_DIR to avoid uploading on failure. + # The unsloth GGUF repo ships the .gguf but no tokenizer.json, so the tokenizer + # is fetched from the (non-GGUF) unsloth/gemma-4-31B-it repo. LOCAL_MODEL_DIR=$(mktemp -d) INDUCTOR_CACHE=$(mktemp -d) trap 'rm -rf "$LOCAL_MODEL_DIR" "$INDUCTOR_CACHE"' EXIT - python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')" + GGUF_FILE="gemma-4-31B-it-Q4_K_M.gguf" + python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it-GGUF', '${GGUF_FILE}', local_dir='${LOCAL_MODEL_DIR}')" + python -c "from huggingface_hub import hf_hub_download; hf_hub_download('unsloth/gemma-4-31B-it', 'tokenizer.json', local_dir='${LOCAL_MODEL_DIR}')" + GGUF_PATH="${LOCAL_MODEL_DIR}/${GGUF_FILE}" - # Sanity check: run inference on the prequantized model + # Sanity check: run inference on the GGUF model echo "::group::Inference sanity check" INFERENCE_OUTPUT=$(python -m executorch.examples.models.gemma4_31b.inference \ - --prequantized "$LOCAL_MODEL_DIR" \ + --gguf "$GGUF_PATH" \ + --tokenizer-path "${LOCAL_MODEL_DIR}/tokenizer.json" \ --prompt "What is the capital of France?" \ --max-new-tokens 32 \ --temperature 0 \ @@ -494,13 +500,13 @@ if [ "$MODEL_NAME" = "gemma4_31b" ]; then echo "::endgroup::" # Copy tokenizer for the runner - cp "$LOCAL_MODEL_DIR/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" + cp "${LOCAL_MODEL_DIR}/tokenizer.json" "${OUTPUT_DIR}/tokenizer.json" # Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues) echo "::group::Export" TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ python -m executorch.examples.models.gemma4_31b.export \ - --prequantized "$LOCAL_MODEL_DIR" \ + --gguf "$GGUF_PATH" \ --output-dir "${OUTPUT_DIR}" echo "::endgroup::" diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 4668e48b91e..2d522f33e28 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -114,6 +114,7 @@ if(CMAKE_CUDA_COMPILER) _aoti_cuda_shim_sources runtime/shims/int4mm.cu runtime/shims/int4_plain_mm.cu + runtime/shims/int6_plain_mm.cu runtime/shims/int8_plain_mm.cu runtime/shims/sort.cu runtime/shims/rand.cu diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index c07cc29b102..f9f23a842f9 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -231,6 +231,8 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "aoti_torch_cuda_randint_low_out": None, "executorch_cuda::int4_plain_mm": None, "aoti_torch_cuda_int4_plain_mm": None, + "executorch_cuda::int6_plain_mm": None, + "aoti_torch_cuda_int6_plain_mm": None, "executorch_cuda::int8_plain_mm": None, "aoti_torch_cuda_int8_plain_mm": None, } @@ -314,6 +316,11 @@ def get_aoti_compile_options( "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " "AtenTensorHandle, int64_t, AtenTensorHandle*)" ], + torch.ops.executorch_cuda.int6_plain_mm.default: [ + "AOTITorchError aoti_torch_cuda_int6_plain_mm(" + "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " + "AtenTensorHandle, int64_t, AtenTensorHandle*)" + ], torch.ops.executorch_cuda.int8_plain_mm.default: [ "AOTITorchError aoti_torch_cuda_int8_plain_mm(" "AtenTensorHandle, AtenTensorHandle, AtenTensorHandle, " diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/packed_int6_tensor.py new file mode 100644 index 00000000000..104ed5bbfa0 --- /dev/null +++ b/backends/cuda/packed_int6_tensor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""ExecuTorch-internal packed-INT6 tensor for the CUDA W6A8 dp4a decode kernel. + +``CudaPackedInt6Tensor`` is an ExecuTorch-internal tensor subclass that stores a +genuine 6-bit packed weight (0.75 B/elem), used for GGUF Q6_K weights. Unlike +the int8 path (``IntxUnpackedToInt8Tensor``, one int8 per 6-bit value), this +format wastes no bits and carries no zero tensor — Q6_K is symmetric. + +The stored value is ``u = q + 32`` in ``[0, 63]`` (``q`` in ``[-32, 31]``); the +constant ``-32`` offset is applied in the decode kernel. The 6 bits are split +into two planes that mirror the INT4 nibble layout so the kernel can reuse the +INT4 even/odd extraction verbatim: + + ql : (N, K/2) uint8 — low-nibble plane, nibble-packed even/odd + (``ql[:, j] = lo[:, 2j] | (lo[:, 2j+1] << 4)``, ``lo = u & 0xF``). + qh : (N, K/4) uint8 — high-2-bit plane, 4 values/byte, arranged per + 32-weight chunk as ``hi_even_packed[4]`` then ``hi_odd_packed[4]``; + each byte holds the four 2-bit highs (``hi = (u >> 4) & 0x3``) of one + 8-weight dp4a word, bit field ``j`` (bits ``2j..2j+1``) = the high 2 + bits of that word's ``j``-th even/odd weight. + scale : (N, K/gs) bf16 — per-group scales, row-major (already coalesced; the + decode kernel reads it row-for-row, no transpose). + +The pack/unpack helpers (:func:`pack_int6`, :func:`unpack_int6`) must stay in +lockstep with ``int6_plain_mm.cuh`` (the decode kernel) — the per-32-weight +``hi_even``/``hi_odd`` byte order is the single most error-prone detail and is +covered by the pack round-trip and the C++ gtest. +""" + +from typing import List, Optional, Tuple + +import torch +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "CudaPackedInt6Tensor", + "pack_int6", + "unpack_int6", +] + + +def pack_int6(q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Pack symmetric Q6_K int values into the (ql, qh) planes. + + Args: + q: (N, K) integer tensor with values in ``[-32, 31]``. + + Returns: + ``(ql, qh)`` where ``ql`` is ``(N, K/2)`` uint8 and ``qh`` is + ``(N, K/4)`` uint8 (see the module docstring for the layout). + """ + if q.dim() != 2: + raise ValueError(f"pack_int6 expects a 2-D tensor, got shape {tuple(q.shape)}") + N, K = int(q.shape[0]), int(q.shape[1]) + if K % 32 != 0: + raise ValueError(f"K={K} must be a multiple of 32 for INT6 packing") + + # All intermediates are uint8 (values fit in a byte) to keep peak memory low + # — important for the ~1.4B-element tied token embedding. + u = (q.to(torch.int16) + 32).to(torch.uint8) # [0, 63] + lo = u & 0xF # low nibble (uint8) + hi = (u >> 4) & 0x3 # high 2 bits (uint8) + + # ql: nibble-pack the low plane even/odd, exactly like the INT4 path. + ql = lo[:, 0::2] | (lo[:, 1::2] << 4) # (N, K/2) uint8 + + # qh: per 32-weight chunk -> [hi_even_packed[4], hi_odd_packed[4]]; each byte + # packs the four 2-bit highs of one 8-weight dp4a word, field j at bits 2j. + chunks = K // 32 + hw = hi.reshape(N, chunks, 4, 8) # (N, chunk, word, pos-in-word) + even = hw[..., 0::2] # (N, chunk, 4, 4) positions 0,2,4,6 + odd = hw[..., 1::2] # (N, chunk, 4, 4) positions 1,3,5,7 + # Explicit OR (not sum) keeps the result uint8 (torch.sum would promote). + hi_even_byte = ( + even[..., 0] | (even[..., 1] << 2) | (even[..., 2] << 4) | (even[..., 3] << 6) + ) # (N, chunk, 4) uint8 + hi_odd_byte = ( + odd[..., 0] | (odd[..., 1] << 2) | (odd[..., 2] << 4) | (odd[..., 3] << 6) + ) + qh = torch.cat([hi_even_byte, hi_odd_byte], dim=-1) # (N, chunk, 8) uint8 + qh = qh.reshape(N, K // 4) + return ql.contiguous(), qh.contiguous() + + +def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Tensor: + """Inverse of :func:`pack_int6`. Returns ``(N, K)`` int16 q in ``[-32, 31]``. + + Intermediates are uint8 to keep peak memory low; only the final ``- 32`` shift + (which produces negatives) widens to int16. + """ + qlu = ql.to(torch.uint8) + lo_even = qlu & 0xF # low nibble -> even weights + lo_odd = (qlu >> 4) & 0xF # high nibble -> odd weights + lo = torch.stack([lo_even, lo_odd], dim=-1).reshape(N, K) # uint8 + + chunks = K // 32 + qhu = qh.to(torch.uint8).reshape(N, chunks, 8) + hi_even_byte = qhu[:, :, 0:4] # (N, chunk, 4) word w + hi_odd_byte = qhu[:, :, 4:8] # (N, chunk, 4) + hi_even = torch.stack( + [(hi_even_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 + ) # (N, chunk, 4, 4) uint8 + hi_odd = torch.stack( + [(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 + ) + hi = torch.empty(N, chunks, 4, 8, dtype=torch.uint8, device=ql.device) + hi[..., 0::2] = hi_even + hi[..., 1::2] = hi_odd + hi = hi.reshape(N, K) + + u = lo | (hi << 4) # [0, 63] uint8 + return u.to(torch.int16) - 32 + + +class CudaPackedInt6Tensor(TorchAOBaseTensor): + """Packed 6-bit weight (ql/qh planes + per-group scale), symmetric. + + ExecuTorch-internal; see the module docstring. The CUDA decode/prefill + dispatch (``int6_dispatch.py``) is selected by *type* — it is registered on + this class only. + """ + + tensor_data_names = ["ql", "qh", "scale"] + tensor_attribute_names = ["block_size", "shape"] + + def __new__( + cls, + ql: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + kwargs = {} + kwargs["device"] = ql.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + ql: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + shape: torch.Size, + ): + super().__init__() + self.ql = ql + self.qh = qh + self.scale = scale + self.block_size = block_size + + def _quantization_type(self): + return ( + f"shape={self.shape}, block_size={self.block_size}, " + f"device={self.device}" + ) + + @classmethod + def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": + """Build from a torchao ``IntxUnpackedToInt8Tensor`` decoded from Q6_K. + + The source is symmetric (zero_point == 0), ``qdata`` is int8 in + ``[-32, 31]`` and ``scale`` is ``(N, K/16)``. The ql/qh bit-pack is baked + into the serialized weight constant here, once at pack time. + """ + q = t.qdata + if not bool(torch.all(t.zero_point == 0)): + raise ValueError( + "CudaPackedInt6Tensor.from_intx_int8 requires symmetric Q6_K " + "weights (zero_point == 0)" + ) + q_min, q_max = int(q.min()), int(q.max()) + if q_min < -32 or q_max > 31: + raise ValueError( + f"Q6_K values must be in [-32, 31], got [{q_min}, {q_max}]" + ) + ql, qh = pack_int6(q) + return cls( + ql, + qh, + t.scale.contiguous(), + list(t.block_size), + t.shape, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize to a dense tensor (symmetric: ``w = q * scale``). + + Used for the tied lm_head / token embedding (which can't gather a packed + tensor) and as the numerical reference. + """ + dtype = output_dtype if output_dtype is not None else self.scale.dtype + N, K = int(self.shape[0]), int(self.shape[1]) + gs = self.block_size[-1] + q = unpack_int6(self.ql, self.qh, N, K).to(dtype) + scale = self.scale.to(dtype).repeat_interleave(gs, dim=-1) + return (q * scale).to(dtype) + + +# Allow a model with CudaPackedInt6Tensor weights to be loaded with +# `weights_only=True` (mirrors torchao quantized tensors). +torch.serialization.add_safe_globals([CudaPackedInt6Tensor]) diff --git a/backends/cuda/quantize_op_dispatch/__init__.py b/backends/cuda/quantize_op_dispatch/__init__.py index 005c2b6e7c7..bc45b3906f9 100644 --- a/backends/cuda/quantize_op_dispatch/__init__.py +++ b/backends/cuda/quantize_op_dispatch/__init__.py @@ -11,9 +11,11 @@ dequant logic instead of torchao's defaults. It registers: * INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm`` + * INT6 (``CudaPackedInt6Tensor``) → ``executorch_cuda::int6_plain_mm`` * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` -See ``int4_dispatch`` and ``int8_dispatch`` for the per-dtype details. +See ``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` for the per-dtype +details. Import this package before using nn.Linear with quantized weights:: @@ -22,5 +24,6 @@ from executorch.backends.cuda.quantize_op_dispatch import ( # noqa: F401 int4_dispatch, + int6_dispatch, int8_dispatch, ) diff --git a/backends/cuda/quantize_op_dispatch/_library.py b/backends/cuda/quantize_op_dispatch/_library.py index c256e856c2c..2308ecf7102 100644 --- a/backends/cuda/quantize_op_dispatch/_library.py +++ b/backends/cuda/quantize_op_dispatch/_library.py @@ -6,9 +6,10 @@ """Shared torch.library handle for the ``executorch_cuda`` op namespace. -``int4_dispatch`` and ``int8_dispatch`` both register custom ops into the same -``executorch_cuda`` namespace, so they must share a single ``DEF`` library -instance — PyTorch allows only one ``DEF`` per namespace per process. +``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` all register custom +ops into the same ``executorch_cuda`` namespace, so they must share a single +``DEF`` library instance — PyTorch allows only one ``DEF`` per namespace per +process. """ from torch.library import Library diff --git a/backends/cuda/quantize_op_dispatch/int6_dispatch.py b/backends/cuda/quantize_op_dispatch/int6_dispatch.py new file mode 100644 index 00000000000..a26814ded1e --- /dev/null +++ b/backends/cuda/quantize_op_dispatch/int6_dispatch.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CudaPackedInt6Tensor F.linear dispatch for CUDA — eager / export trace time. + +This module registers an F.linear dispatch on ``CudaPackedInt6Tensor`` (an +ExecuTorch-internal subclass, see ``packed_int6_tensor.py``) so that +torch.export traces through our custom op and dequant logic. Routing is by +*type*: only GGUF Q6_K weights (converted to ``CudaPackedInt6Tensor``) take the +packed-int6 path; genuine INT8 weights stay on the int8 path. The code here runs +during eager inference and AOTI export tracing — it does NOT run at .pte runtime. + +At .pte runtime, the captured graph is executed by the AOTI-generated .so: + - The custom op ``executorch_cuda::int6_plain_mm`` maps to a C shim that runs + the W6A8 dp4a matvec kernel (backends/cuda/runtime/shims/int6_plain_mm.*). + - The inline dequant + F.linear is compiled by inductor into fused Triton + dequant + cuBLAS matmul kernels. + +Dispatch strategy (determines what gets captured in the export graph): + Decode (M<=4): Custom op ``executorch_cuda::int6_plain_mm`` + Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops) + +The packed-int6 weight is symmetric (no zero point): ``w = q * scale`` with +``q`` in ``[-32, 31]`` stored as the ql/qh planes. The op signature mirrors +int4_plain_mm / int8_plain_mm but takes two weight planes (ql, qh) instead of +one, and no zero tensor. + +Importing the parent ``quantize_op_dispatch`` package registers this dispatch +override (along with the INT4 / INT8 ones):: + + import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 +""" + +import torch +import torch.nn.functional as F +from executorch.backends.cuda.packed_int6_tensor import ( + CudaPackedInt6Tensor, + unpack_int6, +) +from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib +from torch.library import impl + +# --------------------------------------------------------------------------- +# Custom op for INT6 decode (M<=4): W6A8 dp4a matvec in C shim. +# --------------------------------------------------------------------------- + +_lib.define( + "int6_plain_mm(Tensor self, Tensor ql, Tensor qh, Tensor scale, int group_size) -> Tensor" +) + + +@impl(_lib, "int6_plain_mm", "Meta") +def _meta_int6(self, ql, qh, scale, group_size): + return torch.empty(self.shape[0], ql.shape[0], dtype=self.dtype, device=self.device) + + +@impl(_lib, "int6_plain_mm", "CUDA") +def _cuda_int6(self, ql, qh, scale, group_size): + return _dequant_matmul_int6(self, ql, qh, scale, group_size) + + +def _dequant_matmul_int6(x, ql, qh, scale, group_size): + """Dequant packed-INT6 weights to input dtype and call F.linear. + + ql [N, K/2] / qh [N, K/4] pack symmetric Q6_K values q in [-32, 31]; + scale [N, K//gs]. Dequant: w[n, k] = q[n, k] * scale[n, k//gs]. + """ + N = ql.shape[0] + K = ql.shape[1] * 2 + n_groups = K // group_size + dtype = x.dtype + + q = unpack_int6(ql, qh, N, K).to(dtype).reshape(N, n_groups, group_size) + s = scale.to(dtype).reshape(N, n_groups, 1) + w_deq = (q * s).reshape(N, K) + + return F.linear(x, w_deq) + + +# --------------------------------------------------------------------------- +# CudaPackedInt6Tensor F.linear dispatch (W6A8 dp4a for decode) +# --------------------------------------------------------------------------- + +aten = torch.ops.aten +_implements_i6 = CudaPackedInt6Tensor.implements +_implements_torch_function_i6 = CudaPackedInt6Tensor.implements_torch_function + + +@_implements_i6([aten.linear.default]) +@_implements_torch_function_i6([F.linear]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + + orig_shape = input_tensor.shape + x_2d = input_tensor.reshape(-1, orig_shape[-1]) + + ql = weight_tensor.ql + qh = weight_tensor.qh + scale = weight_tensor.scale + gs = weight_tensor.block_size[-1] + + M = x_2d.shape[0] + if M <= 4: + out = torch.ops.executorch_cuda.int6_plain_mm(x_2d, ql, qh, scale, gs) + else: + out = _dequant_matmul_int6(x_2d, ql, qh, scale, gs) + + out = out.reshape(*orig_shape[:-1], -1) + if bias is not None: + out = out + bias + return out diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cu b/backends/cuda/runtime/shims/int6_plain_mm.cu new file mode 100644 index 00000000000..dd068a5766b --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_int6_plain_mm( + Tensor* self, + Tensor* ql, + Tensor* qh, + Tensor* scale, + int64_t group_size, + Tensor** ret0) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + ql != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: ql is null"); + + ET_CHECK_OR_RETURN_ERROR( + qh != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: qh is null"); + + ET_CHECK_OR_RETURN_ERROR( + scale != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: scale is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_int6_plain_mm: ret0 is null"); + + int32_t M = self->size(0); + int32_t N = ql->size(0); + Tensor* C = nullptr; + std::array c_shape = {M, N}; + std::array c_stride = {N, 1}; + aoti_torch_empty_strided( + 2, + c_shape.data(), + c_stride.data(), + static_cast( + executorch::backends::aoti::slim::c10::ScalarType::BFloat16), + static_cast( + executorch::backends::aoti::slim::c10::DeviceType::CUDA), + 0, + &C); + + _int6_plain_mm_cuda(*self, *ql, *qh, *scale, group_size, C); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + + *ret0 = C; + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cuh b/backends/cuda/runtime/shims/int6_plain_mm.cuh new file mode 100644 index 00000000000..a1c7206e6a7 --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.cuh @@ -0,0 +1,353 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// W6A8 dp4a matvec for packed INT6 decode (M <= 4), used for GGUF Q6_K weights. +// +// Reads a genuine 6-bit packed weight (CudaPackedInt6Tensor format), split into +// two planes: +// ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd exactly +// like the INT4 path (ql[:,j] = lo[:,2j] | (lo[:,2j+1] << 4)). +// qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte, arranged per +// 32-weight chunk as hi_even_packed[4] then hi_odd_packed[4] (each +// byte holds the four 2-bit highs of one dp4a word in even/odd order). +// scale : [N, K/gs] bf16 — per-group scales, row-major (coalesced; no zero). +// The stored 6-bit value is u = q + 32 in [0, 63] (q in [-32, 31]); the constant +// -32 offset is applied in the kernel, so Q6_K's symmetry means NO zero tensor. +// +// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks, even/odd +// order, identical to the INT4 path), reconstructs full 6-bit weight bytes per +// dp4a word (vfull = vi_lo | (spread2(hi_byte) << 4)), and uses dp4a for fused +// int6xint8 dot products with vectorized weight loads and warp-cooperative +// quantization. +// +// Symbol names are suffixed _i6 / distinct from int4_plain_mm.cuh and +// int8_plain_mm.cuh so all three translation units can be linked together +// without ODR conflicts. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::Tensor; +namespace c10 = executorch::backends::aoti::slim::c10; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +constexpr int32_t MV6_NWARPS = 8; +constexpr int32_t MV6_WARP_SIZE = 32; +constexpr int32_t MV6_THREADS = MV6_NWARPS * MV6_WARP_SIZE; +constexpr int32_t Q8_BLOCK_SIZE_I6 = 32; + +__host__ __forceinline__ int32_t log2_pow2_i6(int32_t v) { + int32_t r = 0; + while (v > 1) { + v >>= 1; + r++; + } + return r; +} + +// Expand a byte's four 2-bit fields into four byte lanes (each in bits 0-1): +// in : b = [.. b7 b6 | b5 b4 | b3 b2 | b1 b0] +// out : lane0=[b1 b0], lane1=[b3 b2], lane2=[b5 b4], lane3=[b7 b6] +// ~6 ALU ops; verified by truth-table. Used to place the high 2 bits of each +// weight into bits 4-5 of the corresponding dp4a byte lane. +__device__ __forceinline__ uint32_t spread2_i6(uint32_t b) { + uint32_t t = (b | (b << 12)) & 0x000F000F; + uint32_t r = (t | (t << 6)) & 0x03030303; + return r; +} + +// --------------------------------------------------------------------------- +// Activation quantization: bf16 -> int8 (warp-cooperative, per-32-element +// blocks, EVEN/ODD order — identical to the INT4 path's Q8Block). +// --------------------------------------------------------------------------- + +// alignas(16) pads sizeof(Q8Block_i6) to 48 so each block (and its qs_even/qs_odd +// 16-byte halves) is 16-byte aligned, allowing two vectorized uint4 loads of a +// block's int8 activations instead of eight scalar int32 loads. +struct alignas(16) Q8Block_i6 { + int8_t qs_even[Q8_BLOCK_SIZE_I6 / 2]; + int8_t qs_odd[Q8_BLOCK_SIZE_I6 / 2]; + float d; // scale +}; + +__global__ void quantize_activations_q8_i6_kernel( + const __nv_bfloat16* __restrict__ A, + Q8Block_i6* __restrict__ q8, + int32_t K) { + const int32_t m = blockIdx.y; + const int32_t block_id = blockIdx.x * blockDim.y + threadIdx.y; + const int32_t n_blocks = K / Q8_BLOCK_SIZE_I6; + if (block_id >= n_blocks) + return; + + const int32_t lane = threadIdx.x; + const __nv_bfloat16* src = + A + static_cast(m) * K + block_id * Q8_BLOCK_SIZE_I6; + Q8Block_i6* dst = q8 + static_cast(m) * n_blocks + block_id; + + float val = __bfloat162float(src[lane]); + + float amax = fabsf(val); + for (int offset = 16; offset > 0; offset >>= 1) + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, offset)); + + float d = amax / 127.0f; + float id = (d > 0.0f) ? 1.0f / d : 0.0f; + int32_t q = __float2int_rn(val * id); + q = max(-128, min(127, q)); + + if (lane % 2 == 0) + dst->qs_even[lane / 2] = static_cast(q); + else + dst->qs_odd[lane / 2] = static_cast(q); + + if (lane == 0) + dst->d = d; +} + +// --------------------------------------------------------------------------- +// W6A8 dp4a matvec kernel +// +// dp4a is linear, so reconstructing v = lo + (hi<<4) and dotting once is +// equivalent to two separate dp4a passes. We reconstruct the full 6-bit byte +// (vfull = vi_lo | (spread2(hi_byte) << 4)) so a single dp4a per even/odd half +// covers the whole weight. The per-group zero is the constant 32 (in u-space), +// applied as out += scale * a_scale * (dp - 32 * a_sum) — no zero load. +// --------------------------------------------------------------------------- + +__global__ void __launch_bounds__(MV6_THREADS) int6_w6a8_matvec_kernel( + const uint8_t* __restrict__ ql, // [N, K/2] + const uint8_t* __restrict__ qh, // [N, K/4] + const __nv_bfloat16* __restrict__ w_scale, // [N, n_groups] + const Q8Block_i6* __restrict__ q8, + __nv_bfloat16* __restrict__ out, + int32_t N, + int32_t K, + int32_t gs_shift, + int32_t n_groups) { + const int32_t n = blockIdx.x * MV6_NWARPS + threadIdx.y; + const int32_t m = blockIdx.y; + if (n >= N) + return; + + const int32_t K_half = K / 2; + const int32_t K_quarter = K / 4; + const int32_t lane_id = threadIdx.x; + const int32_t n_q8_blocks = K / Q8_BLOCK_SIZE_I6; + + const uint8_t* qlrow = ql + static_cast(n) * K_half; + const uint8_t* qhrow = qh + static_cast(n) * K_quarter; + const __nv_bfloat16* scale_row = w_scale + static_cast(n) * n_groups; + const Q8Block_i6* q8_row = q8 + static_cast(m) * n_q8_blocks; + + // Vectorized loads: one uint4 of ql (32 weights) + one uint2 of qh (the + // 8 high-bit bytes for the same 32-weight chunk) per iteration. + const uint4* qlrow16 = reinterpret_cast(qlrow); + const uint2* qhrow8 = reinterpret_cast(qhrow); + const int32_t K_half_16 = K_half / 16; + + float sum = 0.0f; + + int32_t prev_g = -1; + float ws = 0.0f; + + for (int32_t i = lane_id; i < K_half_16; i += MV6_WARP_SIZE) { + uint4 packed16 = __ldg(&qlrow16[i]); + uint2 qh_chunk = __ldg(&qhrow8[i]); + int32_t k_base = i * 32; + uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // qh_chunk.x bytes = hi_even_packed[0..3], qh_chunk.y = hi_odd_packed[0..3]. + uint32_t hi_even_word = qh_chunk.x; + uint32_t hi_odd_word = qh_chunk.y; + + // One uint4 (32 weights) maps to exactly one Q8 activation block (32 + // activations), i.e. q8_block_idx == i. Load the whole block with two + // vectorized uint4 loads (+ one scale load). + const Q8Block_i6* qb = &q8_row[i]; + uint4 ae = *reinterpret_cast(qb->qs_even); + uint4 ao = *reinterpret_cast(qb->qs_odd); + float a_scale = qb->d; + const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w}; + const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w}; + +#pragma unroll + for (int32_t w = 0; w < 4; w++) { + uint32_t packed = words[w]; + int32_t k_word = k_base + w * 8; + int32_t g = k_word >> gs_shift; + + if (g != prev_g) { + ws = __bfloat162float(__ldg(&scale_row[g])); + prev_g = g; + } + + int32_t vi_lo = static_cast(packed & 0x0F0F0F0F); + int32_t vi_hi = static_cast((packed >> 4) & 0x0F0F0F0F); + + uint32_t hi_even_byte = (hi_even_word >> (w * 8)) & 0xFF; + uint32_t hi_odd_byte = (hi_odd_word >> (w * 8)) & 0xFF; + + // Reconstruct full 6-bit weight bytes (u in [0, 63]). + int32_t vfull_even = + vi_lo | static_cast(spread2_i6(hi_even_byte) << 4); + int32_t vfull_odd = + vi_hi | static_cast(spread2_i6(hi_odd_byte) << 4); + + int32_t dp = __dp4a(vfull_even, static_cast(a_even[w]), 0); + dp = __dp4a(vfull_odd, static_cast(a_odd[w]), dp); + + int32_t a_sum = __dp4a(0x01010101, static_cast(a_even[w]), 0); + a_sum = __dp4a(0x01010101, static_cast(a_odd[w]), a_sum); + + // q = u - 32, so the -32 offset replaces the per-group zero point. + sum += ws * a_scale * + (static_cast(dp) - 32.0f * static_cast(a_sum)); + } + } + + for (int offset = MV6_WARP_SIZE / 2; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset); + + if (lane_id == 0) + out[static_cast(m) * N + n] = __float2bfloat16(sum); +} + +// --------------------------------------------------------------------------- +// Persistent Q8 buffer (lazy init, not thread-safe — single-stream only). +// Freed at process exit via a static guard so leak detectors stay quiet; the +// CUDA runtime would otherwise reclaim it on teardown anyway. +// --------------------------------------------------------------------------- + +static Q8Block_i6* g_q8_buf_i6 = nullptr; +static size_t g_q8_buf_i6_size = 0; + +namespace { +struct Q8BufferGuardI6 { + ~Q8BufferGuardI6() { + if (g_q8_buf_i6) { + // Ignore errors: during process teardown the CUDA context may already be + // gone (cudaErrorCudartUnloading), which is harmless here. + cudaFree(g_q8_buf_i6); + g_q8_buf_i6 = nullptr; + g_q8_buf_i6_size = 0; + } + } +}; +Q8BufferGuardI6 g_q8_buf_i6_guard; +} // namespace + +static Q8Block_i6* get_q8_buffer_i6(size_t needed) { + if (g_q8_buf_i6_size < needed) { + if (g_q8_buf_i6) + cudaFree(g_q8_buf_i6); + cudaError_t err = cudaMalloc(&g_q8_buf_i6, needed); + ET_CHECK_MSG( + err == cudaSuccess, + "cudaMalloc failed for Q8 buffer (int6): %s", + cudaGetErrorString(err)); + g_q8_buf_i6_size = needed; + } + return g_q8_buf_i6; +} + +// --------------------------------------------------------------------------- +// Main entry point +// --------------------------------------------------------------------------- + +inline void _int6_plain_mm_cuda( + const Tensor& A, // [M, K] bf16 + const Tensor& ql, // [N, K/2] uint8 + const Tensor& qh, // [N, K/4] uint8 + const Tensor& scale, // [N, K/gs] bf16 + int64_t group_size, + Tensor* output) { // [M, N] bf16, pre-allocated + int32_t M = A.size(0); + int32_t K = A.size(1); + int32_t N = ql.size(0); + + ET_CHECK(A.dtype() == c10::ScalarType::BFloat16); + ET_CHECK( + ql.dtype() == c10::ScalarType::Byte || + ql.dtype() == c10::ScalarType::Char); + ET_CHECK( + qh.dtype() == c10::ScalarType::Byte || + qh.dtype() == c10::ScalarType::Char); + ET_CHECK(scale.dtype() == c10::ScalarType::BFloat16); + ET_CHECK(A.dim() == 2); + ET_CHECK(ql.dim() == 2); + ET_CHECK(ql.size(1) == K / 2); + ET_CHECK(qh.dim() == 2); + ET_CHECK(qh.size(1) == K / 4); + ET_CHECK(scale.dim() == 2); + ET_CHECK(scale.size(0) == N); + + int32_t gs = static_cast(group_size); + ET_CHECK_MSG( + gs > 0 && (gs & (gs - 1)) == 0, "group_size=%d must be a power of 2", gs); + // group_size must be a multiple of 8 (the dp4a word stride) so a word never + // straddles a group boundary; gs=16 covers GGUF Q6_K. + ET_CHECK_MSG( + gs % 8 == 0, + "group_size=%d must be a multiple of 8 (e.g. 16 for GGUF Q6_K)", + gs); + ET_CHECK_MSG( + K >= Q8_BLOCK_SIZE_I6 && K % Q8_BLOCK_SIZE_I6 == 0, + "K=%d must be a positive multiple of %d for dp4a int6 kernel", + K, + Q8_BLOCK_SIZE_I6); + + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_MSG(stream_result.ok(), "Failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + int32_t gs_shift = log2_pow2_i6(gs); + + // Quantize activations to INT8 (even/odd order) + int32_t n_q8_blocks = K / Q8_BLOCK_SIZE_I6; + size_t q8_bytes = static_cast(M) * n_q8_blocks * sizeof(Q8Block_i6); + Q8Block_i6* q8_buf = get_q8_buffer_i6(q8_bytes); + + constexpr int32_t Q8_WARPS = 8; + int32_t blocks_per_m = (n_q8_blocks + Q8_WARPS - 1) / Q8_WARPS; + dim3 q8_grid(blocks_per_m, M); + dim3 q8_block(MV6_WARP_SIZE, Q8_WARPS); + quantize_activations_q8_i6_kernel<<>>( + reinterpret_cast(A.data_ptr()), q8_buf, K); + + // dp4a matvec + dim3 grid((N + MV6_NWARPS - 1) / MV6_NWARPS, M); + dim3 block(MV6_WARP_SIZE, MV6_NWARPS); + + int32_t n_groups = static_cast(scale.size(1)); + int6_w6a8_matvec_kernel<<>>( + reinterpret_cast(ql.data_ptr()), + reinterpret_cast(qh.data_ptr()), + reinterpret_cast(scale.data_ptr()), + q8_buf, + reinterpret_cast<__nv_bfloat16*>(output->data_ptr()), + N, + K, + gs_shift, + n_groups); +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/int6_plain_mm.h b/backends/cuda/runtime/shims/int6_plain_mm.h new file mode 100644 index 00000000000..e093fb9f055 --- /dev/null +++ b/backends/cuda/runtime/shims/int6_plain_mm.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Packed INT6 matrix multiplication for GGUF Q6_K weights (symmetric). + * + * The 6-bit weight is split into two planes plus a per-group scale; there is + * NO zero tensor — Q6_K is symmetric and the stored value is u = q + 32 in + * [0, 63] (q in [-32, 31]), with the constant -32 offset applied in the kernel. + * + * Weight format: + * ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd + * (ql[:,j] = (u[:,2j] & 0xF) | ((u[:,2j+1] & 0xF) << 4)). + * qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte, arranged per + * 32-weight chunk as hi_even_packed[4] then hi_odd_packed[4]; each + * byte holds the four 2-bit highs of one dp4a word, bit field j + * (bits 2j..2j+1) = high 2 bits of that word's j-th even/odd weight. + * scale : [N, K//group_size] bf16 per-group scales (row-major). + * W6A8 dp4a matvec: dynamically quantizes activations to INT8, reconstructs + * full 6-bit weight bytes, then uses dp4a for fused int6xint8 dot products. + * + * @param self Input activation [M, K] bf16 + * @param ql Low-nibble plane [N, K/2] uint8 + * @param qh High-2-bit plane [N, K/4] uint8 + * @param scale Per-group scales [N, K//group_size] bf16 + * @param group_size Quantization group size (multiple of 8; e.g. 16 for Q6_K) + * @param ret0 Output [M, N] bf16 + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_int6_plain_mm( + Tensor* self, + Tensor* ql, + Tensor* qh, + Tensor* scale, + int64_t group_size, + Tensor** ret0); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index 62e9180d603..072e4effad4 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -49,8 +49,9 @@ set(CUDA_SHIM_TESTS ) # CUDA-specific tests requiring GPU kernels -set(CUDA_KERNEL_TESTS test_aoti_torch_cuda__weight_int4pack_mm - test_aoti_torch_cuda_int4_plain_mm +set(CUDA_KERNEL_TESTS + test_aoti_torch_cuda__weight_int4pack_mm test_aoti_torch_cuda_int4_plain_mm + test_aoti_torch_cuda_int6_plain_mm ) enable_testing() diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp new file mode 100644 index 00000000000..43d3946294a --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using executorch::backends::cuda::aoti_torch_cuda_int6_plain_mm; +using executorch::backends::cuda::aoti_torch_empty_strided; +using executorch::backends::cuda::AOTITorchError; +using executorch::runtime::Error; +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +using Tensor = executorch::backends::aoti::slim::SlimTensor; + +// W6A8 dp4a matvec shim for packed-INT6 decode (CudaPackedInt6Tensor layout, +// GGUF Q6_K). The 6-bit weight is split into two planes plus a per-group scale; +// there is NO zero tensor (Q6_K is symmetric, the -32 offset is applied in the +// kernel): +// ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd +// qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte (per 32-weight +// chunk: hi_even_packed[4] then hi_odd_packed[4]) +// scale : [N, K//gs] bf16 — per-group scales (row-major) +// +// Expected outputs are generated from the export-path reference +// _dequant_matmul_int6 (backends/cuda/quantize_op_dispatch/int6_dispatch.py): +// w[n, k] = q[n, k] * scale[n, k//gs]; out = A @ w^T (q symmetric, in +// [-32,31]). +// The kernel runs W6A8 (it also quantizes activations to int8), so a 0.5 atol +// absorbs the activation-quant noise (same tolerance as the int4/int8 tests). +class AOTITorchInt6PlainMMTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available"; + } + } + + Tensor* create_tensor( + const std::vector& sizes, + slim_c10::ScalarType dtype) { + Tensor* tensor; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + nullptr, + static_cast(dtype), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + return (error == Error::Ok) ? tensor : nullptr; + } + + Tensor* create_bf16(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::BFloat16); + } + + // ql / qh are uint8 (ScalarType::Byte) packed planes. + Tensor* create_uint8(const std::vector& sizes) { + return create_tensor(sizes, slim_c10::ScalarType::Byte); + } + + // Upload raw bytes to a CUDA tensor. + void upload(Tensor* t, const void* host_data, size_t bytes) { + cudaMemcpy(t->data_ptr(), host_data, bytes, cudaMemcpyHostToDevice); + } + + // Download CUDA tensor to host buffer. + void download(const Tensor* t, void* host_data, size_t bytes) { + cudaMemcpy(host_data, t->data_ptr(), bytes, cudaMemcpyDeviceToHost); + } + + // Run the shim and return the output tensor (asserts success). + Tensor* + run(Tensor* A, Tensor* ql, Tensor* qh, Tensor* scale, int64_t group_size) { + Tensor* output = nullptr; + AOTITorchError error = + aoti_torch_cuda_int6_plain_mm(A, ql, qh, scale, group_size, &output); + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(output, nullptr); + return output; + } + + // Check output bf16 values against expected, with absolute tolerance. + void check_bf16_output( + Tensor* output, + const uint16_t* expected_data, + int64_t count, + float atol = 0.5f) { + std::vector actual(count); + download(output, actual.data(), count * sizeof(uint16_t)); + cudaDeviceSynchronize(); + + for (int64_t i = 0; i < count; i++) { + // Convert bf16 raw bits to float: bf16 is the upper 16 bits of float32. + uint32_t actual_bits = static_cast(actual[i]) << 16; + uint32_t expected_bits = static_cast(expected_data[i]) << 16; + float actual_f, expected_f; + memcpy(&actual_f, &actual_bits, sizeof(float)); + memcpy(&expected_f, &expected_bits, sizeof(float)); + + EXPECT_NEAR(actual_f, expected_f, atol) + << "Mismatch at index " << i << ": actual=" << actual_f + << " expected=" << expected_f; + } + } + + // Upload data and run the shim. ql/qh are uint8; scale/A are bf16. + Tensor* setup_and_run( + int64_t M, + int64_t N, + int64_t K, + int64_t gs, + const uint8_t* ql_host, + const uint8_t* qh_host, + const uint16_t* scale_host, + const uint16_t* A_host) { + int64_t ng = K / gs; + Tensor* A = create_bf16({M, K}); + Tensor* ql = create_uint8({N, K / 2}); + Tensor* qh = create_uint8({N, K / 4}); + Tensor* scale = create_bf16({N, ng}); + EXPECT_NE(A, nullptr); + EXPECT_NE(ql, nullptr); + EXPECT_NE(qh, nullptr); + EXPECT_NE(scale, nullptr); + + upload(A, A_host, static_cast(M) * K * sizeof(uint16_t)); + upload(ql, ql_host, static_cast(N) * (K / 2) * sizeof(uint8_t)); + upload(qh, qh_host, static_cast(N) * (K / 4) * sizeof(uint8_t)); + upload(scale, scale_host, static_cast(N) * ng * sizeof(uint16_t)); + + return run(A, ql, qh, scale, gs); + } +}; + +// Q6KGroupSize16: M=2, N=4, K=64, gs=16, symmetric (no zero), q in [-32,31]. +// The canonical GGUF Q6_K shape (group_size=16). +TEST_F(AOTITorchInt6PlainMMTest, Q6KGroupSize16) { + int64_t M = 2, N = 4, K = 64, gs = 16; + + // clang-format off + uint8_t ql_host[] = { + 249, 176, 113, 205, 113, 130, 205, 208, 208, 220, 36, 28, 90, 117, 20, 139, + 24, 99, 43, 2, 253, 112, 107, 185, 154, 203, 229, 119, 15, 8, 139, 95, + 117, 50, 27, 48, 120, 65, 40, 224, 147, 165, 182, 177, 210, 160, 239, 192, + 136, 20, 241, 201, 43, 56, 64, 34, 219, 104, 39, 103, 79, 70, 196, 157, + 193, 90, 70, 26, 31, 78, 234, 55, 53, 19, 198, 24, 26, 71, 88, 181, + 205, 210, 95, 167, 16, 80, 183, 76, 106, 66, 44, 124, 17, 197, 49, 227, + 46, 51, 2, 185, 46, 243, 128, 59, 39, 121, 45, 252, 221, 98, 155, 170, + 27, 31, 108, 91, 235, 129, 177, 104, 44, 22, 110, 142, 169, 226, 255, 217 + }; + uint8_t qh_host[] = { + 21, 230, 10, 92, 55, 212, 46, 90, 227, 91, 52, 88, 49, 132, 203, 60, + 255, 132, 109, 173, 8, 49, 181, 163, 130, 224, 227, 13, 216, 86, 234, 219, + 180, 142, 137, 139, 87, 161, 244, 72, 109, 20, 107, 165, 31, 47, 99, 59, + 215, 173, 1, 159, 180, 83, 227, 190, 15, 222, 95, 108, 117, 157, 225, 105 + }; + uint16_t scale_host[] = { + 0x3C6E, 0xBCF6, 0x3CC3, 0xBB88, 0xBD0C, 0x3D5A, 0x3B40, 0x3D43, 0xBB71, 0xBD6A, 0x3D16, 0xBCC3, + 0xBC1E, 0x3D2A, 0xBCC3, 0xBD37 + }; + uint16_t A_host[] = { + 0x3F5C, 0xBF3E, 0x0000, 0xBC33, 0xBE9A, 0x3CAA, 0x3F7A, 0xBF94, 0xC016, 0xBFF6, 0x0000, 0x3E71, + 0xBFD3, 0x3F5E, 0xBF96, 0x3E2A, 0x4023, 0x3EC0, 0x3E90, 0xC00C, 0x3F84, 0xBEEA, 0xBE32, 0x3F71, + 0x0000, 0x3EC9, 0xBEE2, 0x3EE8, 0x3F30, 0xBECB, 0x3F1F, 0xBF2F, 0xBF2A, 0x3F01, 0x3F11, 0x3F88, + 0xBF6A, 0x3FD4, 0xBDD5, 0x3F8F, 0xBF5F, 0xBEBA, 0xBF24, 0xBF45, 0xBF3F, 0x3E51, 0xBE7D, 0xBF35, + 0x3E73, 0x3F1B, 0x3F34, 0x3EA2, 0xBF13, 0xBF4F, 0xBEE2, 0x4006, 0x3F37, 0x3EC5, 0x3F9F, 0xBD79, + 0x3F21, 0xBF0C, 0xBEA9, 0x3FF2, 0x3F55, 0x3FD6, 0x3FAB, 0x3F89, 0xBDA1, 0x3EDD, 0xBF8D, 0xBE4F, + 0xC005, 0xBFBD, 0xBF59, 0x3CD7, 0x3E07, 0xBEEA, 0x3EAC, 0x4038, 0x3F7E, 0xBE4B, 0xBE3A, 0xBF99, + 0xBFCC, 0x3EF0, 0xBF84, 0xBEE8, 0xBF6E, 0xBC97, 0xBF57, 0xBF3F, 0x3FD7, 0xBFB5, 0x3F0C, 0x3E3F, + 0x3F77, 0xBE45, 0x3FAA, 0x3FE1, 0x3D9C, 0x3F8F, 0xBF38, 0xBF1F, 0xBF07, 0xBE94, 0xBF58, 0xBF85, + 0x3FCE, 0x3F2A, 0x3EAC, 0xBF45, 0x3DC4, 0x3E9E, 0xBF9C, 0x3F0A, 0x3E8F, 0x3EA7, 0xBEFB, 0xBE65, + 0xBFB1, 0xBF58, 0xBF88, 0x3EC2, 0xC008, 0x3F7C, 0xBFFC, 0xBF66 + }; + uint16_t expected[] = { + 0x3F46, 0xC02B, 0x40C5, 0xBED9, 0xBECA, 0x4098, 0x3F96, 0x3F19 + }; + // clang-format on + + Tensor* output = + setup_and_run(M, N, K, gs, ql_host, qh_host, scale_host, A_host); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +// LargeKGroupSize16: M=1, N=2, K=256, gs=16, symmetric — a larger-K decode case +// (16 groups) exercising the multi-iteration warp loop on the gs=16 path. +TEST_F(AOTITorchInt6PlainMMTest, LargeKGroupSize16) { + int64_t M = 1, N = 2, K = 256, gs = 16; + + // clang-format off + uint8_t ql_host[] = { + 69, 12, 100, 182, 132, 79, 45, 206, 141, 218, 39, 249, 136, 245, 75, 210, + 18, 150, 51, 178, 183, 119, 174, 151, 235, 77, 75, 247, 29, 241, 55, 154, + 12, 189, 29, 93, 92, 153, 20, 52, 67, 219, 12, 178, 99, 207, 12, 151, + 5, 133, 30, 141, 56, 234, 26, 101, 93, 150, 46, 101, 80, 30, 33, 153, + 240, 83, 103, 193, 72, 152, 248, 85, 69, 52, 240, 168, 4, 81, 134, 98, + 101, 106, 122, 199, 212, 244, 190, 139, 33, 62, 6, 147, 243, 106, 105, 196, + 120, 49, 123, 17, 38, 205, 200, 90, 10, 248, 177, 182, 9, 195, 90, 9, + 127, 194, 250, 109, 105, 141, 182, 53, 35, 162, 151, 192, 134, 134, 246, 198, + 202, 191, 86, 93, 221, 185, 60, 230, 242, 167, 247, 189, 35, 210, 188, 146, + 8, 218, 95, 120, 119, 39, 177, 110, 158, 144, 0, 36, 69, 219, 134, 94, + 29, 25, 81, 213, 207, 185, 206, 89, 113, 1, 50, 59, 238, 29, 69, 128, + 97, 97, 229, 181, 211, 253, 157, 118, 71, 232, 63, 21, 171, 62, 115, 78, + 3, 109, 188, 187, 172, 5, 144, 190, 60, 214, 171, 194, 232, 6, 192, 189, + 136, 45, 201, 26, 110, 239, 63, 229, 197, 85, 25, 121, 147, 63, 227, 20, + 30, 66, 228, 231, 197, 90, 65, 116, 255, 50, 51, 88, 142, 60, 112, 10, + 18, 192, 52, 144, 148, 19, 197, 32, 3, 157, 152, 52, 176, 31, 38, 242 + }; + uint8_t qh_host[] = { + 235, 21, 174, 144, 160, 216, 229, 90, 25, 104, 128, 211, 93, 165, 189, 219, + 87, 210, 115, 144, 79, 31, 166, 108, 199, 41, 50, 92, 21, 45, 124, 158, + 142, 126, 0, 139, 23, 77, 180, 181, 218, 246, 98, 252, 50, 141, 10, 82, + 82, 31, 128, 233, 230, 216, 156, 120, 193, 161, 94, 122, 62, 85, 233, 8, + 199, 237, 102, 124, 105, 252, 43, 58, 34, 218, 77, 242, 219, 85, 16, 221, + 102, 49, 77, 226, 23, 30, 142, 36, 110, 63, 97, 59, 164, 214, 221, 103, + 253, 67, 106, 140, 18, 75, 207, 144, 21, 18, 108, 84, 110, 217, 45, 114, + 180, 170, 6, 111, 131, 171, 200, 246, 55, 206, 40, 185, 16, 114, 54, 62 + }; + uint16_t scale_host[] = { + 0x3CF1, 0x3C5B, 0x3B89, 0x3B53, 0x3865, 0xBD3E, 0x3D2F, 0x3AD1, 0x3CC6, 0x3D06, 0xBCFE, 0x3BDD, + 0x3D60, 0x3BD0, 0xBD1A, 0x3D1F, 0xBBBA, 0x3D58, 0x3CD5, 0xBCD3, 0x3BB7, 0x3CF3, 0x3D05, 0x3D0B, + 0x3D42, 0xBBF0, 0x3CC5, 0xBC17, 0xBD73, 0xBC09, 0xBC01, 0xBD24 + }; + uint16_t A_host[] = { + 0xBF33, 0xBF48, 0xBE27, 0x3F25, 0xBFF5, 0x3F5C, 0xBFCE, 0xBF36, 0x3DFA, 0x3EE3, 0xBF64, 0x3E14, + 0xBF41, 0x3E5C, 0x3ED3, 0xBF93, 0xBF45, 0x3BC7, 0xBEF0, 0x3D95, 0xBF20, 0x3E4D, 0xBEA8, 0xBF49, + 0x3F65, 0xBF75, 0xBEA2, 0x3F35, 0x3DE0, 0xBDB1, 0xBEA7, 0xBF5B, 0x3F7F, 0x3F47, 0x3FA4, 0x3FB6, + 0xBE20, 0xBFDE, 0xBD38, 0xBFC6, 0x3F22, 0xBF91, 0xBEA8, 0xBFEA, 0x3FA0, 0xBFAB, 0x3F78, 0xBFAC, + 0x3EA4, 0x3FB3, 0xBF88, 0xBF3B, 0xBEA4, 0x3EDF, 0x3F01, 0x3E7A, 0xBF5F, 0xBD3E, 0x3FA3, 0xBF68, + 0xBF32, 0x3EC0, 0xBF59, 0x3EE9, 0xBEB9, 0xBEC4, 0x3F1E, 0xBE8A, 0x3FBE, 0x3F19, 0x3FC2, 0xC00B, + 0xBEF4, 0xBF45, 0xBEC8, 0x3FC7, 0x3F09, 0x3F97, 0x3F43, 0xBF47, 0x3FCF, 0x3E26, 0x3E10, 0xBEA9, + 0x3EA2, 0x3FAE, 0x3F3F, 0x3E93, 0xBFB6, 0x3FCA, 0x3F70, 0x3FD6, 0x3E58, 0xBF17, 0x3FB2, 0xBE16, + 0x4006, 0x3FC1, 0x3F7D, 0x3F3E, 0xBE03, 0x3ED5, 0x3F0A, 0xBE95, 0xBE89, 0x3F8E, 0x3EF0, 0x3FBB, + 0x3F83, 0xBFCB, 0x3E18, 0x3FA8, 0x3F60, 0x3F1D, 0xBFB4, 0x3FB8, 0xBDB3, 0xBF77, 0xBEBC, 0x3E68, + 0x3EAC, 0x3F54, 0x3F72, 0xC01B, 0x3E4C, 0x3FA9, 0xBDCC, 0xBE59, 0xBF8D, 0xBE29, 0x3E80, 0x3FB9, + 0xBFD0, 0x3E11, 0xBF42, 0xBECE, 0xBE42, 0x4016, 0x3C98, 0x3E5B, 0x3F43, 0x3FB1, 0x3F30, 0xBE69, + 0x3F2C, 0x3F4A, 0x3F43, 0x3FAB, 0x3E4C, 0xBF9C, 0xBEF7, 0xBF87, 0x3DA9, 0x3F2E, 0xBEA8, 0xBF4A, + 0x3F80, 0xBF1E, 0xBE81, 0x3EA5, 0x3F0E, 0xBF50, 0x3EA4, 0x3FD3, 0xBE3C, 0x3F8D, 0xBF38, 0xBEB3, + 0x3E86, 0x3F79, 0xBF77, 0x3E26, 0x3F6E, 0x3DDF, 0xBCB2, 0x3F92, 0xBE11, 0xBF0E, 0xBFFE, 0xBF6A, + 0x3FA0, 0x0000, 0xBF84, 0x3FA7, 0x3F23, 0x3F8F, 0xBF90, 0xBF2F, 0x3F8A, 0x0000, 0xBDA4, 0x3F6A, + 0x3E9D, 0x3FAB, 0xBEDB, 0x3F06, 0x3EFB, 0xBF86, 0x3DAD, 0xBE1C, 0xBF85, 0x3F65, 0xBF5C, 0xBE89, + 0x3EC4, 0x3F85, 0x3EF7, 0x3C47, 0x3E98, 0x3EFB, 0x3DC9, 0x3D1B, 0xBECD, 0x4007, 0x3ED0, 0xBF28, + 0x3F99, 0x3E9F, 0xBF7A, 0x3EBD, 0xBEEE, 0xBF1C, 0xBED0, 0xBF01, 0x3F76, 0xBE8A, 0xBF8C, 0x3EDD, + 0x3FE6, 0x3ECA, 0x3F45, 0xBF64, 0xBE8F, 0x3FC7, 0x3FD4, 0xBF2D, 0x3F0C, 0x3F58, 0x3F45, 0x3E8B, + 0x3A08, 0x3F9E, 0x4004, 0x3F9D, 0xBFDE, 0xBF69, 0xBF8E, 0xBF0B, 0x3F89, 0x3DFA, 0xBF91, 0xC019, + 0x3DAA, 0x3F09, 0x3F69, 0x3F3E + }; + uint16_t expected[] = { + 0xC196, 0x40F7 + }; + // clang-format on + + Tensor* output = + setup_and_run(M, N, K, gs, ql_host, qh_host, scale_host, A_host); + ASSERT_NE(output, nullptr); + EXPECT_EQ(output->size(0), M); + EXPECT_EQ(output->size(1), N); + check_bf16_output(output, expected, M * N, 0.5f); +} + +TEST_F(AOTITorchInt6PlainMMTest, NullInputHandling) { + int64_t M = 2, K = 128, N = 64, gs = 16; + + Tensor* A = create_bf16({M, K}); + Tensor* ql = create_uint8({N, K / 2}); + Tensor* qh = create_uint8({N, K / 4}); + Tensor* scale = create_bf16({N, K / gs}); + Tensor* output = nullptr; + + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(nullptr, ql, qh, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, nullptr, qh, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, nullptr, scale, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, qh, nullptr, gs, &output), + Error::InvalidArgument); + EXPECT_EQ( + aoti_torch_cuda_int6_plain_mm(A, ql, qh, scale, gs, nullptr), + Error::InvalidArgument); +} diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py new file mode 100644 index 00000000000..63602618b3a --- /dev/null +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for CudaPackedInt6Tensor F.linear dispatch via int6_dispatch. + +These tests validate the eager / trace-time dispatch path — the same code that +torch.export traces through when building the AOTI graph. They do NOT test the +.pte runtime C shim (W6A8 dp4a kernel); that is covered by +test_aoti_torch_cuda_int6_plain_mm.cpp (C++ unit tests). + +The API contract: after importing int6_dispatch, F.linear / nn.Linear with a +CudaPackedInt6Tensor weight produce numerically correct results, routed by +batch size (decode M<=4 -> custom op, prefill M>4 -> inline dequant). Routing +tests run without a GPU by recording calls to the decode custom op. + +Usage: + python -m pytest backends/cuda/tests/test_int6_dispatch.py -v +""" + +import contextlib +import unittest +from unittest import mock + +import executorch.backends.cuda.quantize_op_dispatch.int6_dispatch # noqa: F401 +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor, pack_int6 +from executorch.backends.cuda.quantize_op_dispatch.int6_dispatch import ( + _dequant_matmul_int6, +) + + +def _require_cuda(tc: unittest.TestCase) -> None: + if not torch.cuda.is_available(): + tc.skipTest("CUDA required") + + +def _make_int6_tensor(N, K, group_size=16): + """Build a CudaPackedInt6Tensor (symmetric Q6_K) and return (tensor, q, scale). + + ``q`` (int8 in [-32, 31]) and ``scale`` are the originals, so tests can + measure against the exact dequant reference ``w = q * scale``. + """ + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // group_size) * 0.1 + 0.01).to(torch.bfloat16) + ql, qh = pack_int6(q) + t = CudaPackedInt6Tensor(ql, qh, scale, [1, group_size], torch.Size([N, K])) + return t, q, scale + + +def _ref_weight(q, scale, group_size, dtype=torch.bfloat16): + """Exact dequant reference: w[n, k] = q[n, k] * scale[n, k//gs].""" + N, K = q.shape + ng = K // group_size + w = q.to(dtype).reshape(N, ng, group_size) * scale.to(dtype).reshape(N, ng, 1) + return w.reshape(N, K) + + +@contextlib.contextmanager +def _record_int6_plain_mm(): + """Record calls to the decode custom op without needing a GPU. + + Replaces ``torch.ops.executorch_cuda.int6_plain_mm`` (whose real impl is the + CUDA C shim) with a recorder that computes the result via the eager CPU + dequant, so the dispatch handler still returns a valid tensor. + """ + calls = [] + + def _fake(self, ql, qh, scale, group_size): + calls.append((tuple(self.shape), group_size)) + return _dequant_matmul_int6(self, ql, qh, scale, group_size) + + with mock.patch.object(torch.ops.executorch_cuda, "int6_plain_mm", _fake): + yield calls + + +class TestDispatchRouting(unittest.TestCase): + """Type-based routing: M<=4 -> int6_plain_mm op, M>4 -> inline dequant. + + Runs without a GPU by recording calls to the decode custom op and computing + the result with the eager CPU dequant. + """ + + def setUp(self): + torch.manual_seed(0) + + def _rel_err(self, out, ref): + return ( + (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + ).item() + + def test_decode_routes_to_int6_plain_mm(self): + """M<=4 routes to the decode custom op.""" + t, _, _ = _make_int6_tensor(16, 64) + x = torch.randn(1, 64, dtype=torch.bfloat16) # M=1 (decode regime) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + self.assertEqual(out.shape, (1, 16)) + + def test_prefill_uses_dequant(self): + """M>4 uses inline dequant (no custom op) and is numerically correct.""" + t, q, scale = _make_int6_tensor(16, 64) + x = torch.randn(8, 64, dtype=torch.bfloat16) # M=8 > 4 (prefill regime) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(calls, []) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_decode_result_matches_reference(self): + """The decode op (eager -> dequant) is numerically correct.""" + t, q, scale = _make_int6_tensor(24, 128) + x = torch.randn(2, 128, dtype=torch.bfloat16) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_with_bias(self): + """Bias is added after the matmul on the decode path.""" + t, q, scale = _make_int6_tensor(16, 64) + bias = torch.randn(16, dtype=torch.bfloat16) + x = torch.randn(1, 64, dtype=torch.bfloat16) + with _record_int6_plain_mm(): + out = F.linear(x, t, bias) + ref = F.linear(x, _ref_weight(q, scale, 16), bias) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_3d_batched_input(self): + """3D input is flattened and the output shape is restored.""" + t, q, scale = _make_int6_tensor(16, 64) + x = torch.randn(2, 8, 64, dtype=torch.bfloat16) # flattened M=16 > 4 + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(calls, []) # prefill regime + self.assertEqual(out.shape, (2, 8, 16)) + ref = F.linear(x, _ref_weight(q, scale, 16)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_from_intx_int8_roundtrip(self): + """from_intx_int8 packs a symmetric int8 tensor and dispatch is correct.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K, gs = 16, 64, 16 + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + intx = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=torch.zeros_like(scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + t = CudaPackedInt6Tensor.from_intx_int8(intx) + x = torch.randn(1, K, dtype=torch.bfloat16) + with _record_int6_plain_mm() as calls: + out = F.linear(x, t) + self.assertEqual(len(calls), 1) + ref = F.linear(x, _ref_weight(q, scale, gs)) + self.assertLess(self._rel_err(out, ref), 0.02) + + def test_from_intx_int8_rejects_asymmetric(self): + """A non-zero zero_point (not Q6_K) is rejected.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K, gs = 8, 64, 16 + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + intx = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=torch.ones_like(scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + with self.assertRaises(ValueError): + CudaPackedInt6Tensor.from_intx_int8(intx) + + +class TestFLinearDispatchCuda(unittest.TestCase): + """F.linear with a CudaPackedInt6Tensor weight on CUDA (eager -> dequant).""" + + def setUp(self): + _require_cuda(self) + torch.manual_seed(42) + + def _check(self, out, ref, tol=0.02): + rel_err = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_err.item(), tol) + + def _linear(self, N, K, gs=16): + t, q, scale = _make_int6_tensor(N, K, gs) + module = nn.Linear(K, N, bias=False, dtype=torch.bfloat16) + module.weight = nn.Parameter(t, requires_grad=False) + module.cuda() + return module, _ref_weight(q, scale, gs).cuda() + + def test_decode_m1(self): + module, w_ref = self._linear(256, 512) + x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_prefill_m64(self): + module, w_ref = self._linear(256, 512) + x = torch.randn(64, 512, dtype=torch.bfloat16, device="cuda") + self._check(module(x), F.linear(x, w_ref)) + + def test_dequantize_matches_reference(self): + t, q, scale = _make_int6_tensor(32, 128) + ref = _ref_weight(q, scale, 16) + self.assertTrue(torch.equal(t.dequantize().cpu(), ref)) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 5d7c5ec540d..1cd9c0db8b0 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -16,9 +16,11 @@ by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share the one quantized tensor. -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor`` (native + torchao tensors; the backend packer in ``quant/pack_cuda.py`` repacks them into + ``CudaCoalescedInt4Tensor`` / the genuine 6-bit ``CudaPackedInt6Tensor``). ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to - bf16 (``Int4Tensor`` can't gather), so they are untied. + bf16 (the packed tensors can't gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -91,7 +93,11 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": return gtensor - # CUDA: native torchao quantized tensors. + # CUDA: native torchao quantized tensors. Q4_K -> Int4Tensor; Q6_K (and any + # other quant type) -> IntxUnpackedToInt8Tensor. The backend packer in + # quant/pack_cuda.py repacks these into the ExecuTorch-internal CUDA layouts + # (CudaCoalescedInt4Tensor / CudaPackedInt6Tensor), so the loader itself stays + # backend-agnostic and carries no backends/cuda dependency. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() return gtensor.to_intx_unpacked_to_int8_tensor() diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index 655d773e7b3..e22e99789b6 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -6,11 +6,17 @@ """CUDA packer: assign quantized weights to model modules. -Converts ``Int4Tensor`` weights to the ExecuTorch-internal -``CudaCoalescedInt4Tensor`` (which owns the scale/zero transpose to the -coalesced [N, n_groups] layout) and passes ``IntxUnpackedToInt8Tensor`` through -as ``nn.Parameter`` without conversion. The quantize_op_dispatch package -(``int4_dispatch`` / ``int8_dispatch``) handles F.linear at runtime. +Repacks native torchao quantized tensors into the ExecuTorch-internal CUDA +layouts read by the decode kernels: + + * ``Int4Tensor`` -> ``CudaCoalescedInt4Tensor`` (bakes the scale/zero transpose + into the coalesced [N, n_groups] layout). + * symmetric Q6_K ``IntxUnpackedToInt8Tensor`` -> ``CudaPackedInt6Tensor`` (the + genuine 6-bit ql/qh planes). + +A genuine INT8 ``IntxUnpackedToInt8Tensor`` is left unchanged for the int8 path. +The quantize_op_dispatch package (``int4_dispatch`` / ``int6_dispatch`` / +``int8_dispatch``) handles F.linear at runtime. No CUDA is required for packing. The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. @@ -28,9 +34,26 @@ # Per-module packers +def _is_symmetric_q6k(w) -> bool: + """True if ``w`` is a symmetric Q6_K ``IntxUnpackedToInt8Tensor``. + + GGUF Q6_K decodes (``gguf.to_intx_unpacked_to_int8_tensor``) to a symmetric + int8 tensor with 16-wide groups and values in ``[-32, 31]``. Those three + properties together distinguish it from a genuine INT8 weight (wider groups + and/or the full int8 range), so the int8 path is never misrouted into the + 6-bit packer. + """ + if tuple(int(b) for b in w.block_size) != (1, 16): + return False + if not bool(torch.all(w.zero_point == 0)): + return False + return int(w.qdata.min()) >= -32 and int(w.qdata.max()) <= 31 + + def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Assign a quantized weight to an ``nn.Linear`` module.""" from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor @@ -48,6 +71,12 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> w = CudaCoalescedInt4Tensor.from_int4_tensor(w) module.weight = nn.Parameter(w, requires_grad=False) elif isinstance(w, IntxUnpackedToInt8Tensor): + # GGUF Q6_K decodes to a symmetric int8 tensor; repack it into the genuine + # 6-bit CudaPackedInt6Tensor (ql/qh planes, 0.75 B/elem) for the W6A8 dp4a + # decode kernel — the bit-pack is baked into the weight constant here, + # once. A genuine INT8 weight is left unchanged for the int8 path. + if _is_symmetric_q6k(w): + w = CudaPackedInt6Tensor.from_intx_int8(w) module.weight = nn.Parameter(w, requires_grad=False) else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index ade85efd788..1baf65a1c3e 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -272,6 +272,13 @@ def dequantize_weight( zero = weight.zero_point.float().repeat_interleave(gs, dim=-1) return ((weight.qdata.float() - zero) * scale).to(dtype) + # CudaPackedInt6Tensor (GGUF Q6_K on CUDA) carries its own dequant (symmetric, + # ql/qh planes). Imported lazily to avoid a hard backends/cuda dependency. + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + + if isinstance(weight, CudaPackedInt6Tensor): + return weight.dequantize(dtype) + raise TypeError(f"Cannot dequantize {type(weight).__name__}") diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index e4f68fce43c..38eca18f5b8 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -18,6 +18,11 @@ import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 import torch import torch.nn as nn +from executorch.backends.cuda.packed_int6_tensor import ( + CudaPackedInt6Tensor, + pack_int6, + unpack_int6, +) from executorch.examples.models.gemma4_31b.quant.pack import pack_one from executorch.examples.models.gemma4_31b.quant.pack_cuda import ( DEFAULT_CUDA_PACKERS, @@ -124,6 +129,105 @@ def test_unsupported_type_raises(self): pack_linear_for_cuda(module, {"weight": torch.randn(32, 64)}) +class TestPackLinearInt6(unittest.TestCase): + """pack_linear_for_cuda converts a symmetric Q6_K IntxUnpackedToInt8Tensor + (the gguf_loader output) into a CudaPackedInt6Tensor. + + The pack/unpack round-trip is lossless and dequantize() == q * scale (no + CUDA required); the F.linear correctness check is CUDA-only. A genuine INT8 + weight is left on the int8 path. + """ + + def setUp(self): + torch.manual_seed(0) + + def _make_int6(self, N, K, gs=16): + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + ql, qh = pack_int6(q) + t = CudaPackedInt6Tensor(ql, qh, scale, [1, gs], torch.Size([N, K])) + return t, q, scale + + def _make_q6k_intx(self, N, K, gs=16): + """Build a symmetric Q6_K IntxUnpackedToInt8Tensor (mirrors gguf.py).""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + q = torch.randint(-32, 32, (N, K), dtype=torch.int8) + scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) + zero = torch.zeros(N, K // gs, dtype=torch.int8) + t = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=zero, + target_dtype=torch.int8, + block_size=(1, gs), + dtype=torch.bfloat16, + activation_quantization=None, + ) + return t, q, scale + + def test_pack_unpack_roundtrip(self): + q = torch.randint(-32, 32, (64, 128), dtype=torch.int8) + ql, qh = pack_int6(q) + self.assertEqual(tuple(ql.shape), (64, 64)) # [N, K/2] + self.assertEqual(tuple(qh.shape), (64, 32)) # [N, K/4] + q_rt = unpack_int6(ql, qh, 64, 128).to(torch.int8) + self.assertTrue(torch.equal(q_rt, q)) + + def test_dequantize_equals_q_scale(self): + t, q, scale = self._make_int6(32, 128, gs=16) + ref = q.to(torch.bfloat16) * scale.to(torch.bfloat16).repeat_interleave( + 16, dim=-1 + ) + self.assertTrue(torch.equal(t.dequantize(), ref)) + + def test_pack_linear_converts_q6k(self): + t, _, _ = self._make_q6k_intx(32, 128) + with torch.device("meta"): + module = nn.Linear(128, 32, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + self.assertEqual(module.weight.shape, torch.Size([32, 128])) + + def test_pack_linear_real_int8_passthrough(self): + """A genuine INT8 weight (wide groups, full range) is NOT repacked.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + q = torch.randint(-128, 128, (32, 128), dtype=torch.int8) + scale = (torch.rand(32, 128 // 32) * 0.1 + 0.01).to(torch.bfloat16) + zero = torch.zeros(32, 128 // 32, dtype=torch.int8) + t = IntxUnpackedToInt8Tensor( + qdata=q, + scale=scale, + zero_point=zero, + target_dtype=torch.int8, + block_size=(1, 32), + dtype=torch.bfloat16, + activation_quantization=None, + ) + with torch.device("meta"): + module = nn.Linear(128, 32, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + + def test_matmul_correct(self): + _require_cuda(self) + t, q, scale = self._make_q6k_intx(256, 128, gs=16) + module = nn.Linear(128, 256, bias=False) + pack_linear_for_cuda(module, {"weight": t}) + self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + module.cuda() + x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + w_ref = ( + q.to(torch.bfloat16) + * scale.to(torch.bfloat16).repeat_interleave(16, dim=-1) + ).cuda() + ref = torch.nn.functional.linear(x, w_ref) + out = module(x) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + class TestPackEmbedding(unittest.TestCase): """pack_embedding_for_cuda with INT8 per-axis weights.""" diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 0e31a50f37b..4cee363a123 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -245,12 +245,12 @@ def _load(self, tmp): return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG) def test_load_converts_weights(self): - """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> IntxUnpacked, + """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> CudaPackedInt6Tensor, embedding bf16.""" from executorch.backends.cuda.coalesced_int4_tensor import ( CudaCoalescedInt4Tensor, ) - from torchao.quantization import IntxUnpackedToInt8Tensor + from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor with tempfile.TemporaryDirectory() as tmp: model, _ = self._load(tmp) @@ -259,9 +259,12 @@ def test_load_converts_weights(self): model.layers[0].self_attn.q_proj.weight.data, CudaCoalescedInt4Tensor ) self.assertIsInstance( - model.layers[0].mlp.down_proj.weight.data, IntxUnpackedToInt8Tensor + model.layers[0].mlp.down_proj.weight.data, CudaPackedInt6Tensor ) - # Token embedding is dequantized to bf16 (Int4/Intx can't gather). + # Tied lm_head is repacked to int6 by pack_cuda (it keeps quantization, + # unlike the token embedding which is dequantized for the gather). + self.assertIsInstance(model.lm_head.weight.data, CudaPackedInt6Tensor) + # Token embedding is dequantized to bf16 (Int4/packed-int6 can't gather). self.assertEqual(model.embed_tokens.weight.dtype, torch.bfloat16) def test_generate(self): From 3ce6acc079fd8edd1ad861b24ee332a7d9cb99ec Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 12 Jun 2026 13:41:07 -0700 Subject: [PATCH 08/15] remove comment --- examples/models/gemma4_31b/gguf_loader.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 1cd9c0db8b0..6606ccaa524 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -16,11 +16,10 @@ by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share the one quantized tensor. -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor`` (native - torchao tensors; the backend packer in ``quant/pack_cuda.py`` repacks them into - ``CudaCoalescedInt4Tensor`` / the genuine 6-bit ``CudaPackedInt6Tensor``). - ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to - bf16 (the packed tensors can't gather), so they are untied. +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``CudaPackedInt6Tensor`` (a genuine + 6-bit packed weight, lossless, symmetric); ``lm_head`` keeps the quantized + tensor but the token embedding is dequantized to bf16 (the packed tensors can't + gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -93,11 +92,7 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": return gtensor - # CUDA: native torchao quantized tensors. Q4_K -> Int4Tensor; Q6_K (and any - # other quant type) -> IntxUnpackedToInt8Tensor. The backend packer in - # quant/pack_cuda.py repacks these into the ExecuTorch-internal CUDA layouts - # (CudaCoalescedInt4Tensor / CudaPackedInt6Tensor), so the loader itself stays - # backend-agnostic and carries no backends/cuda dependency. + # CUDA: native torchao quantized tensors. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() return gtensor.to_intx_unpacked_to_int8_tensor() From a2f90a7a1e84ccf5b96470822fa3c6500310f3a3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 12 Jun 2026 13:45:45 -0700 Subject: [PATCH 09/15] lin --- backends/cuda/packed_int6_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/packed_int6_tensor.py index 104ed5bbfa0..06582df197f 100644 --- a/backends/cuda/packed_int6_tensor.py +++ b/backends/cuda/packed_int6_tensor.py @@ -105,9 +105,7 @@ def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Ten hi_even = torch.stack( [(hi_even_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 ) # (N, chunk, 4, 4) uint8 - hi_odd = torch.stack( - [(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1 - ) + hi_odd = torch.stack([(hi_odd_byte >> s) & 0x3 for s in (0, 2, 4, 6)], dim=-1) hi = torch.empty(N, chunks, 4, 8, dtype=torch.uint8, device=ql.device) hi[..., 0::2] = hi_even hi[..., 1::2] = hi_odd From a8ad1d914b939e3e114f2861f8d75a8356bd6b19 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 15 Jun 2026 15:09:15 -0700 Subject: [PATCH 10/15] [cuda] int6: pack from native ExportableGGUFTensor (drop _is_symmetric_q6k heuristic) Route the Q6_K CUDA path on the native ExportableGGUFTensor (ggml_type == "q6_k") instead of an int8 intermediate. pack_linear_for_cuda now repacks the raw GGUF tensor via CudaPackedInt6Tensor.from_exportable_gguf, which REUSES the shared Q6_K block decode in gguf.py (to_intx_unpacked_to_int8_tensor) then bakes the ql/qh bit-pack -- the decode is not duplicated. This removes the brittle _is_symmetric_q6k heuristic and makes the int8 passthrough unambiguous. - packed_int6_tensor: add from_exportable_gguf (keeps from_intx_int8 low-level packer) - gguf_loader._convert_weight: q6_k returns the raw ExportableGGUFTensor (like MLX); q4_k unchanged - quantize.dequantize_weight: add ExportableGGUFTensor branch (tied token embedding -> bf16) - pack_cuda.pack_linear_for_cuda: route Int4Tensor / ExportableGGUFTensor(q6_k) / IntxUnpackedToInt8Tensor; drop heuristic - tests: feed synthetic q6_k ExportableGGUFTensor; cover from_exportable_gguf Python-only refactor; .cu kernel and serialized CudaPackedInt6Tensor unchanged. Can be squashed into the int6 commit (390238ee00) later. --- backends/cuda/packed_int6_tensor.py | 23 ++++++++ backends/cuda/tests/test_int6_dispatch.py | 48 ++++++++++++++- examples/models/gemma4_31b/gguf_loader.py | 7 ++- examples/models/gemma4_31b/quant/pack_cuda.py | 47 +++++++-------- examples/models/gemma4_31b/quant/quantize.py | 9 +++ .../gemma4_31b/quant/tests/test_pack_cuda.py | 58 ++++++++++--------- 6 files changed, 136 insertions(+), 56 deletions(-) diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/packed_int6_tensor.py index 06582df197f..adef6984ed9 100644 --- a/backends/cuda/packed_int6_tensor.py +++ b/backends/cuda/packed_int6_tensor.py @@ -11,6 +11,12 @@ the int8 path (``IntxUnpackedToInt8Tensor``, one int8 per 6-bit value), this format wastes no bits and carries no zero tensor — Q6_K is symmetric. +Build one with :meth:`from_exportable_gguf` (from a native Q6_K +``ExportableGGUFTensor`` — it reuses the shared Q6_K block decode in +``extension/llm/export/gguf.py``) or with the low-level :meth:`from_intx_int8` +(from an already-decoded symmetric int8 tensor). Both feed the same ql/qh packer; +this class owns only the 6-bit pack, never the Q6_K block decode. + The stored value is ``u = q + 32`` in ``[0, 63]`` (``q`` in ``[-32, 31]``); the constant ``-32`` offset is applied in the decode kernel. The 6 bits are split into two planes that mirror the INT4 nibble layout so the kernel can reuse the @@ -188,6 +194,23 @@ def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": t.shape, ) + @classmethod + def from_exportable_gguf(cls, gt) -> "CudaPackedInt6Tensor": + """Build from a native Q6_K ``ExportableGGUFTensor``. + + Reuses the shared Q6_K block decode in + ``extension/llm/export/gguf.py`` (``to_intx_unpacked_to_int8_tensor`` -> + a symmetric int8 tensor in ``[-32, 31]``), then bit-packs into the ql/qh + planes via :meth:`from_intx_int8`. The Q6_K decode lives in one place; + this class only owns the 6-bit pack. + """ + if gt.ggml_type != "q6_k": + raise ValueError( + "CudaPackedInt6Tensor.from_exportable_gguf requires a q6_k " + f"ExportableGGUFTensor, got {gt.ggml_type!r}" + ) + return cls.from_intx_int8(gt.to_intx_unpacked_to_int8_tensor()) + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a dense tensor (symmetric: ``w = q * scale``). diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py index 63602618b3a..0914810a4d8 100644 --- a/backends/cuda/tests/test_int6_dispatch.py +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -29,7 +29,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor, pack_int6 +from executorch.backends.cuda.packed_int6_tensor import ( + CudaPackedInt6Tensor, + pack_int6, + unpack_int6, +) from executorch.backends.cuda.quantize_op_dispatch.int6_dispatch import ( _dequant_matmul_int6, ) @@ -187,6 +191,48 @@ def test_from_intx_int8_rejects_asymmetric(self): with self.assertRaises(ValueError): CudaPackedInt6Tensor.from_intx_int8(intx) + def test_from_exportable_gguf(self): + """from_exportable_gguf reuses the gguf.py Q6_K decode then packs losslessly.""" + from executorch.extension.llm.export.gguf import ( + _Q6_K_BLOCK_BYTES, + ExportableGGUFTensor, + ) + + N, nb = 8, 1 # K = nb * 256 + g = torch.Generator().manual_seed(0) + blk = torch.randint( + 0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) + blk[:, 192:208] = 0x10 # fixed non-zero int8 sub-scales + blk[:, 208:210] = torch.tensor([0.01], dtype=torch.float16).view( + torch.uint8 + ) # super-block scale d + raw = blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) + gt = ExportableGGUFTensor.from_raw(raw, "q6_k") + + t = CudaPackedInt6Tensor.from_exportable_gguf(gt) + self.assertIsInstance(t, CudaPackedInt6Tensor) + self.assertEqual(tuple(t.shape), (N, nb * 256)) + + # The packer must reuse the shared Q6_K int8 decode (no duplication) and + # bit-pack it losslessly: the unpacked q and the scale match the int8 path. + intx = gt.to_intx_unpacked_to_int8_tensor() + q_rt = unpack_int6(t.ql, t.qh, N, nb * 256).to(torch.int8) + self.assertTrue(torch.equal(q_rt, intx.qdata)) + self.assertTrue(torch.equal(t.scale, intx.scale)) + + def test_from_exportable_gguf_rejects_non_q6k(self): + """A non-q6_k ExportableGGUFTensor is rejected before any decode.""" + from executorch.extension.llm.export.gguf import ( + _Q4_K_BLOCK_BYTES, + ExportableGGUFTensor, + ) + + raw = torch.zeros(4, _Q4_K_BLOCK_BYTES, dtype=torch.uint8) + gt = ExportableGGUFTensor.from_raw(raw, "q4_k") + with self.assertRaises(ValueError): + CudaPackedInt6Tensor.from_exportable_gguf(gt) + class TestFLinearDispatchCuda(unittest.TestCase): """F.linear with a CudaPackedInt6Tensor weight on CUDA (eager -> dequant).""" diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 6606ccaa524..150e4fd0241 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -92,10 +92,13 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": return gtensor - # CUDA: native torchao quantized tensors. + # CUDA: Q4_K -> torchao Int4Tensor. Q6_K stays the raw ExportableGGUFTensor + # (like MLX) -- the CUDA packer repacks it into CudaPackedInt6Tensor via + # CudaPackedInt6Tensor.from_exportable_gguf, so the Q6_K block decode is + # owned by gguf.py and reused, not duplicated here. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() - return gtensor.to_intx_unpacked_to_int8_tensor() + return gtensor def _resolve_tied_lm_head(model, lm_head_weight, packers): diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index e22e99789b6..c884f2b957c 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -11,10 +11,11 @@ * ``Int4Tensor`` -> ``CudaCoalescedInt4Tensor`` (bakes the scale/zero transpose into the coalesced [N, n_groups] layout). - * symmetric Q6_K ``IntxUnpackedToInt8Tensor`` -> ``CudaPackedInt6Tensor`` (the - genuine 6-bit ql/qh planes). + * Q6_K ``ExportableGGUFTensor`` -> ``CudaPackedInt6Tensor`` (the genuine 6-bit + ql/qh planes; the Q6_K block decode is reused from gguf.py, not duplicated). -A genuine INT8 ``IntxUnpackedToInt8Tensor`` is left unchanged for the int8 path. +A genuine INT8 ``IntxUnpackedToInt8Tensor`` is left unchanged for the int8 path +(Q6_K no longer arrives as an int8 tensor, so the routing is unambiguous). The quantize_op_dispatch package (``int4_dispatch`` / ``int6_dispatch`` / ``int8_dispatch``) handles F.linear at runtime. @@ -34,26 +35,16 @@ # Per-module packers -def _is_symmetric_q6k(w) -> bool: - """True if ``w`` is a symmetric Q6_K ``IntxUnpackedToInt8Tensor``. +def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Assign a quantized weight to an ``nn.Linear`` module. - GGUF Q6_K decodes (``gguf.to_intx_unpacked_to_int8_tensor``) to a symmetric - int8 tensor with 16-wide groups and values in ``[-32, 31]``. Those three - properties together distinguish it from a genuine INT8 weight (wider groups - and/or the full int8 range), so the int8 path is never misrouted into the - 6-bit packer. + Routes by weight type: ``Int4Tensor`` -> coalesced INT4, Q6_K + ``ExportableGGUFTensor`` -> packed INT6, genuine ``IntxUnpackedToInt8Tensor`` + -> int8 passthrough. """ - if tuple(int(b) for b in w.block_size) != (1, 16): - return False - if not bool(torch.all(w.zero_point == 0)): - return False - return int(w.qdata.min()) >= -32 and int(w.qdata.max()) <= 31 - - -def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: - """Assign a quantized weight to an ``nn.Linear`` module.""" from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + from executorch.extension.llm.export.gguf import ExportableGGUFTensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor @@ -69,17 +60,19 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> # constant-fold ops on parameters, so the transpose must already live in # the constant for the coalesced layout to pay off. w = CudaCoalescedInt4Tensor.from_int4_tensor(w) - module.weight = nn.Parameter(w, requires_grad=False) + elif isinstance(w, ExportableGGUFTensor) and w.ggml_type == "q6_k": + # GGUF Q6_K: repack the native ExportableGGUFTensor into the genuine 6-bit + # CudaPackedInt6Tensor (ql/qh planes, 0.75 B/elem) for the W6A8 dp4a decode + # kernel. from_exportable_gguf reuses the shared Q6_K decode (gguf.py) then + # bakes the bit-pack into the weight constant, once. + w = CudaPackedInt6Tensor.from_exportable_gguf(w) elif isinstance(w, IntxUnpackedToInt8Tensor): - # GGUF Q6_K decodes to a symmetric int8 tensor; repack it into the genuine - # 6-bit CudaPackedInt6Tensor (ql/qh planes, 0.75 B/elem) for the W6A8 dp4a - # decode kernel — the bit-pack is baked into the weight constant here, - # once. A genuine INT8 weight is left unchanged for the int8 path. - if _is_symmetric_q6k(w): - w = CudaPackedInt6Tensor.from_intx_int8(w) - module.weight = nn.Parameter(w, requires_grad=False) + # Genuine INT8 weight: left unchanged for the int8 path. Q6_K never reaches + # here (it arrives as an ExportableGGUFTensor), so this is unambiguous. + pass else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") + module.weight = nn.Parameter(w, requires_grad=False) def pack_embedding_for_cuda( diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index 1baf65a1c3e..b04a1891323 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -272,6 +272,15 @@ def dequantize_weight( zero = weight.zero_point.float().repeat_interleave(gs, dim=-1) return ((weight.qdata.float() - zero) * scale).to(dtype) + # ExportableGGUFTensor (native GGUF Q4_K/Q6_K) carries its own gguf-package + # dequant. The tied CUDA token embedding keeps the raw GGUF tensor and is + # dequantized to bf16 here for the gather. Imported lazily to avoid a hard + # extension/llm dependency. + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + if isinstance(weight, ExportableGGUFTensor): + return weight.dequantize(dtype) + # CudaPackedInt6Tensor (GGUF Q6_K on CUDA) carries its own dequant (symmetric, # ql/qh planes). Imported lazily to avoid a hard backends/cuda dependency. from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index 38eca18f5b8..f562cfa642e 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -130,12 +130,12 @@ def test_unsupported_type_raises(self): class TestPackLinearInt6(unittest.TestCase): - """pack_linear_for_cuda converts a symmetric Q6_K IntxUnpackedToInt8Tensor - (the gguf_loader output) into a CudaPackedInt6Tensor. + """pack_linear_for_cuda converts a native Q6_K ExportableGGUFTensor (the + gguf_loader output) into a CudaPackedInt6Tensor. The pack/unpack round-trip is lossless and dequantize() == q * scale (no CUDA required); the F.linear correctness check is CUDA-only. A genuine INT8 - weight is left on the int8 path. + IntxUnpackedToInt8Tensor is left on the int8 path. """ def setUp(self): @@ -148,23 +148,27 @@ def _make_int6(self, N, K, gs=16): t = CudaPackedInt6Tensor(ql, qh, scale, [1, gs], torch.Size([N, K])) return t, q, scale - def _make_q6k_intx(self, N, K, gs=16): - """Build a symmetric Q6_K IntxUnpackedToInt8Tensor (mirrors gguf.py).""" - from torchao.quantization import IntxUnpackedToInt8Tensor + def _make_q6k_gguf(self, N, nb=1): + """Build a synthetic q6_k ExportableGGUFTensor (see test_gguf.py). - q = torch.randint(-32, 32, (N, K), dtype=torch.int8) - scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) - zero = torch.zeros(N, K // gs, dtype=torch.int8) - t = IntxUnpackedToInt8Tensor( - qdata=q, - scale=scale, - zero_point=zero, - target_dtype=torch.int8, - block_size=(1, gs), - dtype=torch.bfloat16, - activation_quantization=None, + ``K = nb * 256``; fixed non-zero sub-scales + a small super-block scale + keep dequantized magnitudes O(1). + """ + from executorch.extension.llm.export.gguf import ( + _Q6_K_BLOCK_BYTES, + ExportableGGUFTensor, ) - return t, q, scale + + g = torch.Generator().manual_seed(0) + blk = torch.randint( + 0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) + blk[:, 192:208] = 0x10 # fixed non-zero int8 sub-scales + blk[:, 208:210] = torch.tensor([0.01], dtype=torch.float16).view( + torch.uint8 + ) # super-block scale d + raw = blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) + return ExportableGGUFTensor.from_raw(raw, "q6_k") def test_pack_unpack_roundtrip(self): q = torch.randint(-32, 32, (64, 128), dtype=torch.int8) @@ -182,12 +186,12 @@ def test_dequantize_equals_q_scale(self): self.assertTrue(torch.equal(t.dequantize(), ref)) def test_pack_linear_converts_q6k(self): - t, _, _ = self._make_q6k_intx(32, 128) + gt = self._make_q6k_gguf(32, nb=1) # (32, 256) with torch.device("meta"): - module = nn.Linear(128, 32, bias=False) - pack_linear_for_cuda(module, {"weight": t}) + module = nn.Linear(256, 32, bias=False) + pack_linear_for_cuda(module, {"weight": gt}) self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) - self.assertEqual(module.weight.shape, torch.Size([32, 128])) + self.assertEqual(module.weight.shape, torch.Size([32, 256])) def test_pack_linear_real_int8_passthrough(self): """A genuine INT8 weight (wide groups, full range) is NOT repacked.""" @@ -212,12 +216,14 @@ def test_pack_linear_real_int8_passthrough(self): def test_matmul_correct(self): _require_cuda(self) - t, q, scale = self._make_q6k_intx(256, 128, gs=16) - module = nn.Linear(128, 256, bias=False) - pack_linear_for_cuda(module, {"weight": t}) + gt = self._make_q6k_gguf(256, nb=1) # (256, 256) + intx = gt.to_intx_unpacked_to_int8_tensor() + q, scale = intx.qdata, intx.scale + module = nn.Linear(256, 256, bias=False) + pack_linear_for_cuda(module, {"weight": gt}) self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) module.cuda() - x = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + x = torch.randn(1, 256, dtype=torch.bfloat16, device="cuda") w_ref = ( q.to(torch.bfloat16) * scale.to(torch.bfloat16).repeat_interleave(16, dim=-1) From 5082a309035fa7cf367b82b6c044a0124c7065d6 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 15 Jun 2026 16:04:11 -0700 Subject: [PATCH 11/15] [cuda] rename CudaPackedInt6Tensor -> CudaDp4aPlanarInt6Tensor Rename the int6 tensor subclass and its module file to reflect the dp4a-planar (ql/qh split bit-plane) layout: backends/cuda/packed_int6_tensor.py -> backends/cuda/dp4a_planar_int6_tensor.py class CudaPackedInt6Tensor -> CudaDp4aPlanarInt6Tensor Update all references (imports, type dispatch, CUDA packer, quantize dequant branch, gguf_loader, tests, kernel comments) and the torch.serialization.add_safe_globals registration so exported models round-trip under the new qualified name. Classmethods from_exportable_gguf/from_intx_int8 and helpers pack_int6/unpack_int6 are unchanged; the runtime op int6_plain_mm and the .cu/.cuh kernel are untouched. --- ...6_tensor.py => dp4a_planar_int6_tensor.py} | 26 ++++++++++--------- .../cuda/quantize_op_dispatch/__init__.py | 2 +- .../quantize_op_dispatch/int6_dispatch.py | 18 ++++++------- backends/cuda/runtime/shims/int6_plain_mm.cuh | 2 +- .../test_aoti_torch_cuda_int6_plain_mm.cpp | 8 +++--- backends/cuda/tests/test_int6_dispatch.py | 24 ++++++++--------- examples/models/gemma4_31b/gguf_loader.py | 6 ++--- examples/models/gemma4_31b/quant/pack_cuda.py | 17 +++++++----- examples/models/gemma4_31b/quant/quantize.py | 11 +++++--- .../gemma4_31b/quant/tests/test_pack_cuda.py | 12 ++++----- .../gemma4_31b/tests/test_cuda_pipeline.py | 10 ++++--- 11 files changed, 73 insertions(+), 63 deletions(-) rename backends/cuda/{packed_int6_tensor.py => dp4a_planar_int6_tensor.py} (89%) diff --git a/backends/cuda/packed_int6_tensor.py b/backends/cuda/dp4a_planar_int6_tensor.py similarity index 89% rename from backends/cuda/packed_int6_tensor.py rename to backends/cuda/dp4a_planar_int6_tensor.py index adef6984ed9..4daeef7efcd 100644 --- a/backends/cuda/packed_int6_tensor.py +++ b/backends/cuda/dp4a_planar_int6_tensor.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""ExecuTorch-internal packed-INT6 tensor for the CUDA W6A8 dp4a decode kernel. +"""ExecuTorch-internal dp4a-planar INT6 tensor for the CUDA W6A8 dp4a decode kernel. -``CudaPackedInt6Tensor`` is an ExecuTorch-internal tensor subclass that stores a -genuine 6-bit packed weight (0.75 B/elem), used for GGUF Q6_K weights. Unlike +``CudaDp4aPlanarInt6Tensor`` is an ExecuTorch-internal tensor subclass that stores a +genuine 6-bit packed weight (0.75 B/elem) in a dp4a-friendly *planar* layout: the +6 bits are split into two bit-planes (``ql``/``qh``) so the decode kernel can run +the W6A8 dp4a matvec directly. Used for GGUF Q6_K weights. Unlike the int8 path (``IntxUnpackedToInt8Tensor``, one int8 per 6-bit value), this format wastes no bits and carries no zero tensor — Q6_K is symmetric. @@ -44,7 +46,7 @@ from torchao.utils import TorchAOBaseTensor __all__ = [ - "CudaPackedInt6Tensor", + "CudaDp4aPlanarInt6Tensor", "pack_int6", "unpack_int6", ] @@ -121,8 +123,8 @@ def unpack_int6(ql: torch.Tensor, qh: torch.Tensor, N: int, K: int) -> torch.Ten return u.to(torch.int16) - 32 -class CudaPackedInt6Tensor(TorchAOBaseTensor): - """Packed 6-bit weight (ql/qh planes + per-group scale), symmetric. +class CudaDp4aPlanarInt6Tensor(TorchAOBaseTensor): + """Dp4a-planar 6-bit weight (ql/qh split bit-planes + per-group scale), symmetric. ExecuTorch-internal; see the module docstring. The CUDA decode/prefill dispatch (``int6_dispatch.py``) is selected by *type* — it is registered on @@ -167,7 +169,7 @@ def _quantization_type(self): ) @classmethod - def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": + def from_intx_int8(cls, t: torch.Tensor) -> "CudaDp4aPlanarInt6Tensor": """Build from a torchao ``IntxUnpackedToInt8Tensor`` decoded from Q6_K. The source is symmetric (zero_point == 0), ``qdata`` is int8 in @@ -177,7 +179,7 @@ def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": q = t.qdata if not bool(torch.all(t.zero_point == 0)): raise ValueError( - "CudaPackedInt6Tensor.from_intx_int8 requires symmetric Q6_K " + "CudaDp4aPlanarInt6Tensor.from_intx_int8 requires symmetric Q6_K " "weights (zero_point == 0)" ) q_min, q_max = int(q.min()), int(q.max()) @@ -195,7 +197,7 @@ def from_intx_int8(cls, t: torch.Tensor) -> "CudaPackedInt6Tensor": ) @classmethod - def from_exportable_gguf(cls, gt) -> "CudaPackedInt6Tensor": + def from_exportable_gguf(cls, gt) -> "CudaDp4aPlanarInt6Tensor": """Build from a native Q6_K ``ExportableGGUFTensor``. Reuses the shared Q6_K block decode in @@ -206,7 +208,7 @@ def from_exportable_gguf(cls, gt) -> "CudaPackedInt6Tensor": """ if gt.ggml_type != "q6_k": raise ValueError( - "CudaPackedInt6Tensor.from_exportable_gguf requires a q6_k " + "CudaDp4aPlanarInt6Tensor.from_exportable_gguf requires a q6_k " f"ExportableGGUFTensor, got {gt.ggml_type!r}" ) return cls.from_intx_int8(gt.to_intx_unpacked_to_int8_tensor()) @@ -225,6 +227,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor return (q * scale).to(dtype) -# Allow a model with CudaPackedInt6Tensor weights to be loaded with +# Allow a model with CudaDp4aPlanarInt6Tensor weights to be loaded with # `weights_only=True` (mirrors torchao quantized tensors). -torch.serialization.add_safe_globals([CudaPackedInt6Tensor]) +torch.serialization.add_safe_globals([CudaDp4aPlanarInt6Tensor]) diff --git a/backends/cuda/quantize_op_dispatch/__init__.py b/backends/cuda/quantize_op_dispatch/__init__.py index bc45b3906f9..7ba444ebd0c 100644 --- a/backends/cuda/quantize_op_dispatch/__init__.py +++ b/backends/cuda/quantize_op_dispatch/__init__.py @@ -11,7 +11,7 @@ dequant logic instead of torchao's defaults. It registers: * INT4 (``CudaCoalescedInt4Tensor``) → ``executorch_cuda::int4_plain_mm`` - * INT6 (``CudaPackedInt6Tensor``) → ``executorch_cuda::int6_plain_mm`` + * INT6 (``CudaDp4aPlanarInt6Tensor``) → ``executorch_cuda::int6_plain_mm`` * INT8 (``IntxUnpackedToInt8Tensor``) → ``executorch_cuda::int8_plain_mm`` See ``int4_dispatch``, ``int6_dispatch`` and ``int8_dispatch`` for the per-dtype diff --git a/backends/cuda/quantize_op_dispatch/int6_dispatch.py b/backends/cuda/quantize_op_dispatch/int6_dispatch.py index a26814ded1e..b98a2c3ab80 100644 --- a/backends/cuda/quantize_op_dispatch/int6_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int6_dispatch.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""CudaPackedInt6Tensor F.linear dispatch for CUDA — eager / export trace time. +"""CudaDp4aPlanarInt6Tensor F.linear dispatch for CUDA — eager / export trace time. -This module registers an F.linear dispatch on ``CudaPackedInt6Tensor`` (an -ExecuTorch-internal subclass, see ``packed_int6_tensor.py``) so that +This module registers an F.linear dispatch on ``CudaDp4aPlanarInt6Tensor`` (an +ExecuTorch-internal subclass, see ``dp4a_planar_int6_tensor.py``) so that torch.export traces through our custom op and dequant logic. Routing is by -*type*: only GGUF Q6_K weights (converted to ``CudaPackedInt6Tensor``) take the +*type*: only GGUF Q6_K weights (converted to ``CudaDp4aPlanarInt6Tensor``) take the packed-int6 path; genuine INT8 weights stay on the int8 path. The code here runs during eager inference and AOTI export tracing — it does NOT run at .pte runtime. @@ -36,8 +36,8 @@ import torch import torch.nn.functional as F -from executorch.backends.cuda.packed_int6_tensor import ( - CudaPackedInt6Tensor, +from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, unpack_int6, ) from executorch.backends.cuda.quantize_op_dispatch._library import lib as _lib @@ -81,12 +81,12 @@ def _dequant_matmul_int6(x, ql, qh, scale, group_size): # --------------------------------------------------------------------------- -# CudaPackedInt6Tensor F.linear dispatch (W6A8 dp4a for decode) +# CudaDp4aPlanarInt6Tensor F.linear dispatch (W6A8 dp4a for decode) # --------------------------------------------------------------------------- aten = torch.ops.aten -_implements_i6 = CudaPackedInt6Tensor.implements -_implements_torch_function_i6 = CudaPackedInt6Tensor.implements_torch_function +_implements_i6 = CudaDp4aPlanarInt6Tensor.implements +_implements_torch_function_i6 = CudaDp4aPlanarInt6Tensor.implements_torch_function @_implements_i6([aten.linear.default]) diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cuh b/backends/cuda/runtime/shims/int6_plain_mm.cuh index a1c7206e6a7..69007561422 100644 --- a/backends/cuda/runtime/shims/int6_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int6_plain_mm.cuh @@ -8,7 +8,7 @@ // W6A8 dp4a matvec for packed INT6 decode (M <= 4), used for GGUF Q6_K weights. // -// Reads a genuine 6-bit packed weight (CudaPackedInt6Tensor format), split into +// Reads a genuine 6-bit packed weight (CudaDp4aPlanarInt6Tensor format), split into // two planes: // ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd exactly // like the INT4 path (ql[:,j] = lo[:,2j] | (lo[:,2j+1] << 4)). diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp index 43d3946294a..1359118b997 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda_int6_plain_mm.cpp @@ -26,10 +26,10 @@ namespace slim_c10 = executorch::backends::aoti::slim::c10; using Tensor = executorch::backends::aoti::slim::SlimTensor; -// W6A8 dp4a matvec shim for packed-INT6 decode (CudaPackedInt6Tensor layout, -// GGUF Q6_K). The 6-bit weight is split into two planes plus a per-group scale; -// there is NO zero tensor (Q6_K is symmetric, the -32 offset is applied in the -// kernel): +// W6A8 dp4a matvec shim for packed-INT6 decode (CudaDp4aPlanarInt6Tensor +// layout, GGUF Q6_K). The 6-bit weight is split into two planes plus a +// per-group scale; there is NO zero tensor (Q6_K is symmetric, the -32 offset +// is applied in the kernel): // ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd // qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte (per 32-weight // chunk: hi_even_packed[4] then hi_odd_packed[4]) diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py index 0914810a4d8..1d105212242 100644 --- a/backends/cuda/tests/test_int6_dispatch.py +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Tests for CudaPackedInt6Tensor F.linear dispatch via int6_dispatch. +"""Tests for CudaDp4aPlanarInt6Tensor F.linear dispatch via int6_dispatch. These tests validate the eager / trace-time dispatch path — the same code that torch.export traces through when building the AOTI graph. They do NOT test the @@ -13,7 +13,7 @@ test_aoti_torch_cuda_int6_plain_mm.cpp (C++ unit tests). The API contract: after importing int6_dispatch, F.linear / nn.Linear with a -CudaPackedInt6Tensor weight produce numerically correct results, routed by +CudaDp4aPlanarInt6Tensor weight produce numerically correct results, routed by batch size (decode M<=4 -> custom op, prefill M>4 -> inline dequant). Routing tests run without a GPU by recording calls to the decode custom op. @@ -29,8 +29,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from executorch.backends.cuda.packed_int6_tensor import ( - CudaPackedInt6Tensor, +from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, pack_int6, unpack_int6, ) @@ -45,7 +45,7 @@ def _require_cuda(tc: unittest.TestCase) -> None: def _make_int6_tensor(N, K, group_size=16): - """Build a CudaPackedInt6Tensor (symmetric Q6_K) and return (tensor, q, scale). + """Build a CudaDp4aPlanarInt6Tensor (symmetric Q6_K) and return (tensor, q, scale). ``q`` (int8 in [-32, 31]) and ``scale`` are the originals, so tests can measure against the exact dequant reference ``w = q * scale``. @@ -53,7 +53,7 @@ def _make_int6_tensor(N, K, group_size=16): q = torch.randint(-32, 32, (N, K), dtype=torch.int8) scale = (torch.rand(N, K // group_size) * 0.1 + 0.01).to(torch.bfloat16) ql, qh = pack_int6(q) - t = CudaPackedInt6Tensor(ql, qh, scale, [1, group_size], torch.Size([N, K])) + t = CudaDp4aPlanarInt6Tensor(ql, qh, scale, [1, group_size], torch.Size([N, K])) return t, q, scale @@ -164,7 +164,7 @@ def test_from_intx_int8_roundtrip(self): dtype=torch.bfloat16, activation_quantization=None, ) - t = CudaPackedInt6Tensor.from_intx_int8(intx) + t = CudaDp4aPlanarInt6Tensor.from_intx_int8(intx) x = torch.randn(1, K, dtype=torch.bfloat16) with _record_int6_plain_mm() as calls: out = F.linear(x, t) @@ -189,7 +189,7 @@ def test_from_intx_int8_rejects_asymmetric(self): activation_quantization=None, ) with self.assertRaises(ValueError): - CudaPackedInt6Tensor.from_intx_int8(intx) + CudaDp4aPlanarInt6Tensor.from_intx_int8(intx) def test_from_exportable_gguf(self): """from_exportable_gguf reuses the gguf.py Q6_K decode then packs losslessly.""" @@ -210,8 +210,8 @@ def test_from_exportable_gguf(self): raw = blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) gt = ExportableGGUFTensor.from_raw(raw, "q6_k") - t = CudaPackedInt6Tensor.from_exportable_gguf(gt) - self.assertIsInstance(t, CudaPackedInt6Tensor) + t = CudaDp4aPlanarInt6Tensor.from_exportable_gguf(gt) + self.assertIsInstance(t, CudaDp4aPlanarInt6Tensor) self.assertEqual(tuple(t.shape), (N, nb * 256)) # The packer must reuse the shared Q6_K int8 decode (no duplication) and @@ -231,11 +231,11 @@ def test_from_exportable_gguf_rejects_non_q6k(self): raw = torch.zeros(4, _Q4_K_BLOCK_BYTES, dtype=torch.uint8) gt = ExportableGGUFTensor.from_raw(raw, "q4_k") with self.assertRaises(ValueError): - CudaPackedInt6Tensor.from_exportable_gguf(gt) + CudaDp4aPlanarInt6Tensor.from_exportable_gguf(gt) class TestFLinearDispatchCuda(unittest.TestCase): - """F.linear with a CudaPackedInt6Tensor weight on CUDA (eager -> dequant).""" + """F.linear with a CudaDp4aPlanarInt6Tensor weight on CUDA (eager -> dequant).""" def setUp(self): _require_cuda(self) diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 150e4fd0241..e95581dc95d 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -16,7 +16,7 @@ by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share the one quantized tensor. -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``CudaPackedInt6Tensor`` (a genuine +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``CudaDp4aPlanarInt6Tensor`` (a genuine 6-bit packed weight, lossless, symmetric); ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to bf16 (the packed tensors can't gather), so they are untied. @@ -93,8 +93,8 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): if backend == "mlx": return gtensor # CUDA: Q4_K -> torchao Int4Tensor. Q6_K stays the raw ExportableGGUFTensor - # (like MLX) -- the CUDA packer repacks it into CudaPackedInt6Tensor via - # CudaPackedInt6Tensor.from_exportable_gguf, so the Q6_K block decode is + # (like MLX) -- the CUDA packer repacks it into CudaDp4aPlanarInt6Tensor via + # CudaDp4aPlanarInt6Tensor.from_exportable_gguf, so the Q6_K block decode is # owned by gguf.py and reused, not duplicated here. if gtensor.ggml_type == "q4_k": return gtensor.to_int4_tensor() diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index c884f2b957c..a96d585d655 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -11,8 +11,9 @@ * ``Int4Tensor`` -> ``CudaCoalescedInt4Tensor`` (bakes the scale/zero transpose into the coalesced [N, n_groups] layout). - * Q6_K ``ExportableGGUFTensor`` -> ``CudaPackedInt6Tensor`` (the genuine 6-bit - ql/qh planes; the Q6_K block decode is reused from gguf.py, not duplicated). + * Q6_K ``ExportableGGUFTensor`` -> ``CudaDp4aPlanarInt6Tensor`` (the genuine 6-bit + ql/qh split bit-planes; the Q6_K block decode is reused from gguf.py, not + duplicated). A genuine INT8 ``IntxUnpackedToInt8Tensor`` is left unchanged for the int8 path (Q6_K no longer arrives as an int8 tensor, so the routing is unambiguous). @@ -43,7 +44,9 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> -> int8 passthrough. """ from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor - from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, + ) from executorch.extension.llm.export.gguf import ExportableGGUFTensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor @@ -62,10 +65,10 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> w = CudaCoalescedInt4Tensor.from_int4_tensor(w) elif isinstance(w, ExportableGGUFTensor) and w.ggml_type == "q6_k": # GGUF Q6_K: repack the native ExportableGGUFTensor into the genuine 6-bit - # CudaPackedInt6Tensor (ql/qh planes, 0.75 B/elem) for the W6A8 dp4a decode - # kernel. from_exportable_gguf reuses the shared Q6_K decode (gguf.py) then - # bakes the bit-pack into the weight constant, once. - w = CudaPackedInt6Tensor.from_exportable_gguf(w) + # CudaDp4aPlanarInt6Tensor (ql/qh split bit-planes, 0.75 B/elem) for the + # W6A8 dp4a decode kernel. from_exportable_gguf reuses the shared Q6_K + # decode (gguf.py) then bakes the bit-pack into the weight constant, once. + w = CudaDp4aPlanarInt6Tensor.from_exportable_gguf(w) elif isinstance(w, IntxUnpackedToInt8Tensor): # Genuine INT8 weight: left unchanged for the int8 path. Q6_K never reaches # here (it arrives as an ExportableGGUFTensor), so this is unambiguous. diff --git a/examples/models/gemma4_31b/quant/quantize.py b/examples/models/gemma4_31b/quant/quantize.py index b04a1891323..09829daf1e8 100644 --- a/examples/models/gemma4_31b/quant/quantize.py +++ b/examples/models/gemma4_31b/quant/quantize.py @@ -281,11 +281,14 @@ def dequantize_weight( if isinstance(weight, ExportableGGUFTensor): return weight.dequantize(dtype) - # CudaPackedInt6Tensor (GGUF Q6_K on CUDA) carries its own dequant (symmetric, - # ql/qh planes). Imported lazily to avoid a hard backends/cuda dependency. - from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + # CudaDp4aPlanarInt6Tensor (GGUF Q6_K on CUDA) carries its own dequant + # (symmetric, ql/qh split bit-planes). Imported lazily to avoid a hard + # backends/cuda dependency. + from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, + ) - if isinstance(weight, CudaPackedInt6Tensor): + if isinstance(weight, CudaDp4aPlanarInt6Tensor): return weight.dequantize(dtype) raise TypeError(f"Cannot dequantize {type(weight).__name__}") diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index f562cfa642e..3bcc77808ba 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -18,8 +18,8 @@ import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 import torch import torch.nn as nn -from executorch.backends.cuda.packed_int6_tensor import ( - CudaPackedInt6Tensor, +from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, pack_int6, unpack_int6, ) @@ -131,7 +131,7 @@ def test_unsupported_type_raises(self): class TestPackLinearInt6(unittest.TestCase): """pack_linear_for_cuda converts a native Q6_K ExportableGGUFTensor (the - gguf_loader output) into a CudaPackedInt6Tensor. + gguf_loader output) into a CudaDp4aPlanarInt6Tensor. The pack/unpack round-trip is lossless and dequantize() == q * scale (no CUDA required); the F.linear correctness check is CUDA-only. A genuine INT8 @@ -145,7 +145,7 @@ def _make_int6(self, N, K, gs=16): q = torch.randint(-32, 32, (N, K), dtype=torch.int8) scale = (torch.rand(N, K // gs) * 0.1 + 0.01).to(torch.bfloat16) ql, qh = pack_int6(q) - t = CudaPackedInt6Tensor(ql, qh, scale, [1, gs], torch.Size([N, K])) + t = CudaDp4aPlanarInt6Tensor(ql, qh, scale, [1, gs], torch.Size([N, K])) return t, q, scale def _make_q6k_gguf(self, N, nb=1): @@ -190,7 +190,7 @@ def test_pack_linear_converts_q6k(self): with torch.device("meta"): module = nn.Linear(256, 32, bias=False) pack_linear_for_cuda(module, {"weight": gt}) - self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + self.assertIsInstance(module.weight.data, CudaDp4aPlanarInt6Tensor) self.assertEqual(module.weight.shape, torch.Size([32, 256])) def test_pack_linear_real_int8_passthrough(self): @@ -221,7 +221,7 @@ def test_matmul_correct(self): q, scale = intx.qdata, intx.scale module = nn.Linear(256, 256, bias=False) pack_linear_for_cuda(module, {"weight": gt}) - self.assertIsInstance(module.weight.data, CudaPackedInt6Tensor) + self.assertIsInstance(module.weight.data, CudaDp4aPlanarInt6Tensor) module.cuda() x = torch.randn(1, 256, dtype=torch.bfloat16, device="cuda") w_ref = ( diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 4cee363a123..caf0a44e03b 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -245,12 +245,14 @@ def _load(self, tmp): return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG) def test_load_converts_weights(self): - """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> CudaPackedInt6Tensor, + """GGUF -> CUDA: Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> CudaDp4aPlanarInt6Tensor, embedding bf16.""" from executorch.backends.cuda.coalesced_int4_tensor import ( CudaCoalescedInt4Tensor, ) - from executorch.backends.cuda.packed_int6_tensor import CudaPackedInt6Tensor + from executorch.backends.cuda.dp4a_planar_int6_tensor import ( + CudaDp4aPlanarInt6Tensor, + ) with tempfile.TemporaryDirectory() as tmp: model, _ = self._load(tmp) @@ -259,11 +261,11 @@ def test_load_converts_weights(self): model.layers[0].self_attn.q_proj.weight.data, CudaCoalescedInt4Tensor ) self.assertIsInstance( - model.layers[0].mlp.down_proj.weight.data, CudaPackedInt6Tensor + model.layers[0].mlp.down_proj.weight.data, CudaDp4aPlanarInt6Tensor ) # Tied lm_head is repacked to int6 by pack_cuda (it keeps quantization, # unlike the token embedding which is dequantized for the gather). - self.assertIsInstance(model.lm_head.weight.data, CudaPackedInt6Tensor) + self.assertIsInstance(model.lm_head.weight.data, CudaDp4aPlanarInt6Tensor) # Token embedding is dequantized to bf16 (Int4/packed-int6 can't gather). self.assertEqual(model.embed_tokens.weight.dtype, torch.bfloat16) From 650ddeb2dfdeacd78c951669a43d32f939e5d96f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 15 Jun 2026 16:19:32 -0700 Subject: [PATCH 12/15] [ci] gemma4_31b: rename CI selector to the unsloth GGUF source The gemma4_31b CUDA CI selector still keyed off the stale prequant HF repo SocialLocalMobile/gemma-4-31B-it-HQQ-INT4, but export_model_artifact.sh already downloads the weights from unsloth/gemma-4-31B-it-GGUF. Rename the CI identifier to match the actual source: SocialLocalMobile/gemma-4-31B-it-HQQ-INT4 -> unsloth/gemma-4-31B-it-GGUF Updated the case selectors + help text in export_model_artifact.sh and test_model_e2e.sh, and the matrix entries, exclude rules, and the A100 runner-selection conditionals in both the export and e2e jobs of cuda.yml. The executorch registry MODEL_NAME stays gemma4_31b; qwen3_5_moe's SocialLocalMobile HQQ entry is left unchanged. --- .ci/scripts/export_model_artifact.sh | 4 ++-- .ci/scripts/test_model_e2e.sh | 4 ++-- .github/workflows/cuda.yml | 28 ++++++++++++++-------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index e9218dce625..1e7b9646cfd 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -195,7 +195,7 @@ case "$HF_MODEL" in PREPROCESSOR_FEATURE_SIZE="" PREPROCESSOR_OUTPUT="" ;; - SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + unsloth/gemma-4-31B-it-GGUF) MODEL_NAME="gemma4_31b" TASK="" MAX_SEQ_LEN="" @@ -205,7 +205,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, unsloth/gemma-4-31B-it-GGUF" exit 1 ;; esac diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index e1ba976b0cc..503bd381a8d 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -228,7 +228,7 @@ case "$HF_MODEL" in AUDIO_FILE="" IMAGE_PATH="" ;; - SocialLocalMobile/gemma-4-31B-it-HQQ-INT4) + unsloth/gemma-4-31B-it-GGUF) MODEL_NAME="gemma4_31b" RUNNER_TARGET="gemma4_31b_runner" RUNNER_PATH="gemma4_31b" @@ -242,7 +242,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/gemma-4-31B-it-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, unsloth/gemma-4-31B-it-GGUF" exit 1 ;; esac diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index d0da13e5733..948e58d389e 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -258,8 +258,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" - - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + - repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -281,12 +281,12 @@ jobs: quant: "quantized-int4-weight-only" # Gemma 4 31B uses a prequantized checkpoint, only tile-packed - model: - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: "non-quantized" - model: - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: @@ -342,7 +342,7 @@ jobs: with: timeout: 150 secrets-env: EXECUTORCH_HF_TOKEN - runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-GGUF') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }} gpu-arch-type: cuda gpu-arch-version: "13.0" use-custom-docker-registry: false @@ -424,8 +424,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" - - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + - repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -447,12 +447,12 @@ jobs: quant: "quantized-int4-weight-only" # Gemma 4 31B uses a prequantized checkpoint, only tile-packed - model: - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: "non-quantized" - model: - repo: "SocialLocalMobile" - name: "gemma-4-31B-it-HQQ-INT4" + repo: "unsloth" + name: "gemma-4-31B-it-GGUF" quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: @@ -502,7 +502,7 @@ jobs: quant: "non-quantized" with: timeout: 90 - runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-HQQ-INT4') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'gemma-4-31B-it-GGUF') && 'mt-l-x86iavx512-11-125-a100' || 'mt-l-x86aavx2-29-113-a10g' }} gpu-arch-type: cuda gpu-arch-version: "13.0" use-custom-docker-registry: false From f1bb112cf83e375a1b3d492133ae5437f0e0bef0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 15 Jun 2026 17:16:29 -0700 Subject: [PATCH 13/15] [cuda] int6: make the low-level packer private (from_intx_int8 -> _from_intx_int8) It has no external production caller (from_exportable_gguf is the only entry; pack_cuda routes Q6_K there); keep it as the internal, unit-tested ql/qh packer. Signed-off-by: gasoonjia --- backends/cuda/dp4a_planar_int6_tensor.py | 16 ++++++++-------- backends/cuda/tests/test_int6_dispatch.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/backends/cuda/dp4a_planar_int6_tensor.py b/backends/cuda/dp4a_planar_int6_tensor.py index 4daeef7efcd..8ecbb9283b1 100644 --- a/backends/cuda/dp4a_planar_int6_tensor.py +++ b/backends/cuda/dp4a_planar_int6_tensor.py @@ -15,9 +15,9 @@ Build one with :meth:`from_exportable_gguf` (from a native Q6_K ``ExportableGGUFTensor`` — it reuses the shared Q6_K block decode in -``extension/llm/export/gguf.py``) or with the low-level :meth:`from_intx_int8` -(from an already-decoded symmetric int8 tensor). Both feed the same ql/qh packer; -this class owns only the 6-bit pack, never the Q6_K block decode. +``extension/llm/export/gguf.py`` and then feeds the internal ql/qh packer +:meth:`_from_intx_int8`). This class owns only the 6-bit pack, never the Q6_K +block decode. The stored value is ``u = q + 32`` in ``[0, 63]`` (``q`` in ``[-32, 31]``); the constant ``-32`` offset is applied in the decode kernel. The 6 bits are split @@ -169,8 +169,8 @@ def _quantization_type(self): ) @classmethod - def from_intx_int8(cls, t: torch.Tensor) -> "CudaDp4aPlanarInt6Tensor": - """Build from a torchao ``IntxUnpackedToInt8Tensor`` decoded from Q6_K. + def _from_intx_int8(cls, t: torch.Tensor) -> "CudaDp4aPlanarInt6Tensor": + """Internal ql/qh packer: build from a symmetric int8 ``IntxUnpackedToInt8Tensor`` decoded from Q6_K. The source is symmetric (zero_point == 0), ``qdata`` is int8 in ``[-32, 31]`` and ``scale`` is ``(N, K/16)``. The ql/qh bit-pack is baked @@ -179,7 +179,7 @@ def from_intx_int8(cls, t: torch.Tensor) -> "CudaDp4aPlanarInt6Tensor": q = t.qdata if not bool(torch.all(t.zero_point == 0)): raise ValueError( - "CudaDp4aPlanarInt6Tensor.from_intx_int8 requires symmetric Q6_K " + "CudaDp4aPlanarInt6Tensor._from_intx_int8 requires symmetric Q6_K " "weights (zero_point == 0)" ) q_min, q_max = int(q.min()), int(q.max()) @@ -203,7 +203,7 @@ def from_exportable_gguf(cls, gt) -> "CudaDp4aPlanarInt6Tensor": Reuses the shared Q6_K block decode in ``extension/llm/export/gguf.py`` (``to_intx_unpacked_to_int8_tensor`` -> a symmetric int8 tensor in ``[-32, 31]``), then bit-packs into the ql/qh - planes via :meth:`from_intx_int8`. The Q6_K decode lives in one place; + planes via :meth:`_from_intx_int8`. The Q6_K decode lives in one place; this class only owns the 6-bit pack. """ if gt.ggml_type != "q6_k": @@ -211,7 +211,7 @@ def from_exportable_gguf(cls, gt) -> "CudaDp4aPlanarInt6Tensor": "CudaDp4aPlanarInt6Tensor.from_exportable_gguf requires a q6_k " f"ExportableGGUFTensor, got {gt.ggml_type!r}" ) - return cls.from_intx_int8(gt.to_intx_unpacked_to_int8_tensor()) + return cls._from_intx_int8(gt.to_intx_unpacked_to_int8_tensor()) def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a dense tensor (symmetric: ``w = q * scale``). diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py index 1d105212242..d7de974f1d7 100644 --- a/backends/cuda/tests/test_int6_dispatch.py +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -149,7 +149,7 @@ def test_3d_batched_input(self): self.assertLess(self._rel_err(out, ref), 0.02) def test_from_intx_int8_roundtrip(self): - """from_intx_int8 packs a symmetric int8 tensor and dispatch is correct.""" + """_from_intx_int8 packs a symmetric int8 tensor and dispatch is correct.""" from torchao.quantization import IntxUnpackedToInt8Tensor N, K, gs = 16, 64, 16 @@ -164,7 +164,7 @@ def test_from_intx_int8_roundtrip(self): dtype=torch.bfloat16, activation_quantization=None, ) - t = CudaDp4aPlanarInt6Tensor.from_intx_int8(intx) + t = CudaDp4aPlanarInt6Tensor._from_intx_int8(intx) x = torch.randn(1, K, dtype=torch.bfloat16) with _record_int6_plain_mm() as calls: out = F.linear(x, t) @@ -189,7 +189,7 @@ def test_from_intx_int8_rejects_asymmetric(self): activation_quantization=None, ) with self.assertRaises(ValueError): - CudaDp4aPlanarInt6Tensor.from_intx_int8(intx) + CudaDp4aPlanarInt6Tensor._from_intx_int8(intx) def test_from_exportable_gguf(self): """from_exportable_gguf reuses the gguf.py Q6_K decode then packs losslessly.""" From 8f39a33ed497389773c463a8fb0b3ece93a721fa Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 15 Jun 2026 20:10:50 -0700 Subject: [PATCH 14/15] lint --- backends/cuda/runtime/shims/int6_plain_mm.cuh | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/backends/cuda/runtime/shims/int6_plain_mm.cuh b/backends/cuda/runtime/shims/int6_plain_mm.cuh index 69007561422..61554424d05 100644 --- a/backends/cuda/runtime/shims/int6_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int6_plain_mm.cuh @@ -8,22 +8,24 @@ // W6A8 dp4a matvec for packed INT6 decode (M <= 4), used for GGUF Q6_K weights. // -// Reads a genuine 6-bit packed weight (CudaDp4aPlanarInt6Tensor format), split into -// two planes: +// Reads a genuine 6-bit packed weight (CudaDp4aPlanarInt6Tensor format), split +// into two planes: // ql : [N, K/2] uint8 — low-nibble plane, nibble-packed even/odd exactly // like the INT4 path (ql[:,j] = lo[:,2j] | (lo[:,2j+1] << 4)). // qh : [N, K/4] uint8 — high-2-bit plane, 4 values/byte, arranged per // 32-weight chunk as hi_even_packed[4] then hi_odd_packed[4] (each -// byte holds the four 2-bit highs of one dp4a word in even/odd order). +// byte holds the four 2-bit highs of one dp4a word in even/odd +// order). // scale : [N, K/gs] bf16 — per-group scales, row-major (coalesced; no zero). -// The stored 6-bit value is u = q + 32 in [0, 63] (q in [-32, 31]); the constant -// -32 offset is applied in the kernel, so Q6_K's symmetry means NO zero tensor. +// The stored 6-bit value is u = q + 32 in [0, 63] (q in [-32, 31]); the +// constant -32 offset is applied in the kernel, so Q6_K's symmetry means NO +// zero tensor. // -// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks, even/odd -// order, identical to the INT4 path), reconstructs full 6-bit weight bytes per -// dp4a word (vfull = vi_lo | (spread2(hi_byte) << 4)), and uses dp4a for fused -// int6xint8 dot products with vectorized weight loads and warp-cooperative -// quantization. +// Dynamically quantizes bf16 activations to INT8 (per-32-element blocks, +// even/odd order, identical to the INT4 path), reconstructs full 6-bit weight +// bytes per dp4a word (vfull = vi_lo | (spread2(hi_byte) << 4)), and uses dp4a +// for fused int6xint8 dot products with vectorized weight loads and +// warp-cooperative quantization. // // Symbol names are suffixed _i6 / distinct from int4_plain_mm.cuh and // int8_plain_mm.cuh so all three translation units can be linked together @@ -80,9 +82,10 @@ __device__ __forceinline__ uint32_t spread2_i6(uint32_t b) { // blocks, EVEN/ODD order — identical to the INT4 path's Q8Block). // --------------------------------------------------------------------------- -// alignas(16) pads sizeof(Q8Block_i6) to 48 so each block (and its qs_even/qs_odd -// 16-byte halves) is 16-byte aligned, allowing two vectorized uint4 loads of a -// block's int8 activations instead of eight scalar int32 loads. +// alignas(16) pads sizeof(Q8Block_i6) to 48 so each block (and its +// qs_even/qs_odd 16-byte halves) is 16-byte aligned, allowing two vectorized +// uint4 loads of a block's int8 activations instead of eight scalar int32 +// loads. struct alignas(16) Q8Block_i6 { int8_t qs_even[Q8_BLOCK_SIZE_I6 / 2]; int8_t qs_odd[Q8_BLOCK_SIZE_I6 / 2]; @@ -175,7 +178,8 @@ __global__ void __launch_bounds__(MV6_THREADS) int6_w6a8_matvec_kernel( uint2 qh_chunk = __ldg(&qhrow8[i]); int32_t k_base = i * 32; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; - // qh_chunk.x bytes = hi_even_packed[0..3], qh_chunk.y = hi_odd_packed[0..3]. + // qh_chunk.x bytes = hi_even_packed[0..3], qh_chunk.y = + // hi_odd_packed[0..3]. uint32_t hi_even_word = qh_chunk.x; uint32_t hi_odd_word = qh_chunk.y; From f1c6087ec4faaf0c57c2bec0b6c2d802a9853df7 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 16 Jun 2026 17:49:50 -0700 Subject: [PATCH 15/15] lint --- .../quantize_op_dispatch/int6_dispatch.py | 2 +- backends/cuda/tests/test_int6_dispatch.py | 18 +++++ examples/models/gemma4_31b/quant/pack_cuda.py | 9 ++- .../gemma4_31b/quant/tests/test_pack_cuda.py | 74 +++++++++++++++++++ 4 files changed, 100 insertions(+), 3 deletions(-) diff --git a/backends/cuda/quantize_op_dispatch/int6_dispatch.py b/backends/cuda/quantize_op_dispatch/int6_dispatch.py index b98a2c3ab80..df373212d1c 100644 --- a/backends/cuda/quantize_op_dispatch/int6_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int6_dispatch.py @@ -94,7 +94,7 @@ def _dequant_matmul_int6(x, ql, qh, scale, group_size): def _(func, types, args, kwargs): input_tensor = args[0] weight_tensor = args[1] - bias = args[2] if len(args) > 2 else None + bias = args[2] if len(args) > 2 else kwargs.get("bias", None) orig_shape = input_tensor.shape x_2d = input_tensor.reshape(-1, orig_shape[-1]) diff --git a/backends/cuda/tests/test_int6_dispatch.py b/backends/cuda/tests/test_int6_dispatch.py index d7de974f1d7..7db34099b48 100644 --- a/backends/cuda/tests/test_int6_dispatch.py +++ b/backends/cuda/tests/test_int6_dispatch.py @@ -137,6 +137,24 @@ def test_with_bias(self): ref = F.linear(x, _ref_weight(q, scale, 16), bias) self.assertLess(self._rel_err(out, ref), 0.02) + def test_with_bias_kwarg(self): + """Bias passed as a keyword (F.linear(x, w, bias=b)) is applied, not dropped.""" + t, q, scale = _make_int6_tensor(16, 64) + bias = torch.randn(16, dtype=torch.bfloat16) + x = torch.randn(1, 64, dtype=torch.bfloat16) + with _record_int6_plain_mm(): + out = F.linear(x, t, bias=bias) + ref = F.linear(x, _ref_weight(q, scale, 16), bias) + self.assertLess(self._rel_err(out, ref), 0.02) + # Guard against a regression to dropping the keyword bias: the no-bias + # result must differ from the bias result by exactly the bias. + with _record_int6_plain_mm(): + out_no_bias = F.linear(x, t) + self.assertTrue( + torch.allclose(out, out_no_bias + bias, atol=1e-2), + "keyword bias was not applied", + ) + def test_3d_batched_input(self): """3D input is flattened and the output shape is restored.""" t, q, scale = _make_int6_tensor(16, 64) diff --git a/examples/models/gemma4_31b/quant/pack_cuda.py b/examples/models/gemma4_31b/quant/pack_cuda.py index a96d585d655..079553337a2 100644 --- a/examples/models/gemma4_31b/quant/pack_cuda.py +++ b/examples/models/gemma4_31b/quant/pack_cuda.py @@ -70,8 +70,13 @@ def pack_linear_for_cuda(module: nn.Module, weights: dict[str, torch.Tensor]) -> # decode (gguf.py) then bakes the bit-pack into the weight constant, once. w = CudaDp4aPlanarInt6Tensor.from_exportable_gguf(w) elif isinstance(w, IntxUnpackedToInt8Tensor): - # Genuine INT8 weight: left unchanged for the int8 path. Q6_K never reaches - # here (it arrives as an ExportableGGUFTensor), so this is unambiguous. + # Genuine INT8 weight: left unchanged for the int8 dp4a path. The + # mixed-precision HQQ-INT4 ("sensitive") checkpoint reaches this branch + # for its int8 tensors — edge-layer v_proj/down_proj are quantized to + # INT8 while the rest is INT4 (see GEMMA4_31B_SENSITIVE_RECIPE in + # quantize_and_save.py). Q6_K never reaches here (it arrives as an + # ExportableGGUFTensor, handled above), so int4 vs int6 vs int8 routing + # stays unambiguous. pass else: raise ValueError(f"Unsupported weight type: {type(w).__name__}") diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py index 3bcc77808ba..bde61a0b34f 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_cuda.py @@ -233,6 +233,80 @@ def test_matmul_correct(self): rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() self.assertLess(rel_error.item(), 0.02) + def test_e2e_q6k_export_lower_decode(self): + """Q6_K -> pack -> export + CUDA lower -> run, vs the Q6_K dequant reference. + + Builds a synthetic Q6_K ExportableGGUFTensor, packs it into a + CudaDp4aPlanarInt6Tensor, exports a decode-shaped (M=1) nn.Linear, and + asserts: + * the exported graph captured ``executorch_cuda.int6_plain_mm`` (the + decode custom op chosen for M<=4), + * lowering through the CUDA backend produces an ``executorch_call_delegate``, + * running the exported graph matches the Q6_K dequant reference. + + The lowered .pte is not executed here (that needs the built C-shim + runtime); the eager exported graph already exercises the int6 decode op + through its registered CUDA impl. + """ + _require_cuda(self) + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import export + + gt = self._make_q6k_gguf(64, nb=1) # (64, 256) + # Reference: the shared Q6_K int8 decode, dequantized (w = q * scale). + intx = gt.to_intx_unpacked_to_int8_tensor() + w_ref = ( + intx.qdata.to(torch.bfloat16) + * intx.scale.to(torch.bfloat16).repeat_interleave(16, dim=-1) + ).cuda() + + module = nn.Linear(256, 64, bias=False) + pack_linear_for_cuda(module, {"weight": gt}) + self.assertIsInstance(module.weight.data, CudaDp4aPlanarInt6Tensor) + module = module.cuda().eval() + + x = torch.randn(1, 256, dtype=torch.bfloat16, device="cuda") # decode M=1 + with torch.no_grad(): + ep = export(module, (x,), strict=True) + + # The decode (M<=4) path must capture the int6 decode custom op. + targets = [str(n.target) for n in ep.graph.nodes if n.op == "call_function"] + self.assertTrue( + any("int6_plain_mm" in t for t in targets), + f"int6_plain_mm not found in exported graph: {targets}", + ) + + # Run the exported graph and compare against the Q6_K dequant reference. + with torch.no_grad(): + out = ep.module()(x) + ref = torch.nn.functional.linear(x, w_ref) + rel_error = (out.float() - ref.float()).abs().mean() / ref.float().abs().mean() + self.assertLess(rel_error.item(), 0.02) + + # Lower through the CUDA backend: the int6 weight + decode op must land in + # an executorch_call_delegate. + lowered = to_edge_transform_and_lower( + ep, + partitioner=[ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("forward")] + ) + ], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=True + ), + ) + lowered_ep = lowered.exported_program() + self.assertTrue( + any( + n.op == "call_function" and "executorch_call_delegate" in str(n.target) + for n in lowered_ep.graph.nodes + ), + "CUDA lowering produced no delegate call", + ) + class TestPackEmbedding(unittest.TestCase): """pack_embedding_for_cuda with INT8 per-axis weights."""