Skip to content

Commit db2aaa9

Browse files
kwyss-nvidiatimmoon10pre-commit-ci[bot]
authored
Subchannel Block quantized GEMM (#1545)
* Add GEMM logic for blockwise quantized tensors. GEMM test cases included in pytorch integration. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Update NVTE_BLOCK_SCALING for GEMM. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Gate feature on CUDA 12.9 Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Gemm typo. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Remove unecessary type converter change. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Reflect epilogue availability and test supported epilogues. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * GEMM simplifications from recipe branch. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Format py code. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Update GEMM DGelu tests to match support depending on output dtype. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Force pow2Scales in GEMM Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Add GEMM test to pytorch test suite. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Add copyright to GEMM test. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Update import for GEMM test. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Add license. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Update test gemm supported predicate. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Use sgemm like interfaces and naming. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Rewrite GEMM comment. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * MR Feedback. Signed-off-by: Keith Wyss <kwyss@nvidia.com> * Refactor GEMM param canonicalization Configure A and B matrices separately. Have separate code path for each scaling mode. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Prune number of tests. Signed-off-by: Keith Wyss <kwyss@nvidia.com> --------- Signed-off-by: Keith Wyss <kwyss@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent b362a6e commit db2aaa9

7 files changed

Lines changed: 1445 additions & 146 deletions

File tree

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "
3232
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
3333
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
3434
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
35+
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
3536
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
3637
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
3738
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
from typing import Tuple
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
11+
12+
@triton.jit
13+
def fused_fma_kernel(y_ptr, x_ptr, s_ptr, M, N, y_str0, y_str1, BLOCK: tl.constexpr = 128):
14+
pid = tl.program_id(0)
15+
idx = pid * BLOCK + tl.arange(0, BLOCK)
16+
mask = idx < M * N
17+
18+
row = idx // N
19+
col = idx % N
20+
21+
y_offset = row * y_str0 + col * y_str1
22+
x_offset = row * N + col
23+
s_offset = row * N + col
24+
25+
y = tl.load(y_ptr + y_offset, mask=mask)
26+
x = tl.load(x_ptr + x_offset, mask=mask)
27+
s = tl.load(s_ptr + s_offset, mask=mask)
28+
29+
tl.store(y_ptr + y_offset, tl.fma(x, s, y), mask=mask)
30+
31+
32+
def fused_fma(y, x, s, BLOCK=128):
33+
"""
34+
Fused multiply-add operation (y = y + x * s).
35+
36+
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
37+
This function also supports cases where 'y' is non-contiguous in memory.
38+
"""
39+
40+
assert (
41+
y.shape == x.shape == s.shape and y.dim() == 2
42+
), "All tensors must be 2D with the same shape"
43+
assert x.is_contiguous() and s.is_contiguous(), "x and s must be contiguous"
44+
45+
M, N = y.shape
46+
grid = ((M * N + BLOCK - 1) // BLOCK,)
47+
48+
fused_fma_kernel[grid](y, x, s, M, N, *y.stride(), BLOCK)
49+
50+
return y
51+
52+
53+
class CuBLASRefBlockwiseGemm:
54+
"""
55+
A cuBLAS compatible reference implementation of subchannel GEMM.
56+
"""
57+
58+
def qgemm(
59+
self,
60+
qx: torch.Tensor,
61+
qw: torch.Tensor,
62+
out_dtype: torch.dtype,
63+
demunged_sx: torch.Tensor,
64+
demunged_sw: torch.Tensor,
65+
quant_tile_shape_x: Tuple[int, int],
66+
quant_tile_shape_w: Tuple[int, int],
67+
bias: torch.Tensor | None = None,
68+
out: torch.Tensor | None = None,
69+
accumulate: bool = False,
70+
use_split_accumulator: bool = False,
71+
) -> torch.Tensor:
72+
# demunge scale shapes for cuBLAS
73+
is_a_1d_scaled = quant_tile_shape_x[0] == 1
74+
is_b_1d_scaled = quant_tile_shape_w[0] == 1
75+
M, K = qx.shape
76+
N, K = qw.shape
77+
78+
# mm_tile_shape = (tile_m, tile_n, tile_k)
79+
mm_tile_shape = (
80+
quant_tile_shape_x[0],
81+
quant_tile_shape_w[0],
82+
quant_tile_shape_w[1],
83+
)
84+
if bias is not None and bias.numel():
85+
# To match cuBLAS more closely when bias is applied,
86+
# the reference accumulates into float32, and cast to
87+
# bfloat16 is deferred until after the GEMM.
88+
out_dtype_for_ref = torch.float32
89+
else:
90+
out_dtype_for_ref = out_dtype
91+
y = self.qgemm_blockwise_2d(
92+
qx,
93+
qw,
94+
out_dtype_for_ref,
95+
demunged_sx,
96+
demunged_sw,
97+
mm_tile_shape,
98+
use_split_accumulator,
99+
is_a_1d_scaled,
100+
is_b_1d_scaled,
101+
)
102+
if bias is not None and bias.numel():
103+
y += bias
104+
y = y.to(dtype=out_dtype)
105+
# cublas accumulation first convert to output dtype, then accumulate.
106+
if accumulate:
107+
assert out is not None
108+
y = y + out
109+
else:
110+
assert out is None, "Output tensor should be None when accumulate is False."
111+
112+
return y
113+
114+
@classmethod
115+
def qgemm_blockwise_2d(
116+
cls,
117+
qx: torch.Tensor,
118+
qw: torch.Tensor,
119+
out_dtype: torch.dtype,
120+
sx: torch.Tensor,
121+
sw: torch.Tensor,
122+
mm_tile_shape: Tuple[int, int, int],
123+
use_split_accumulator: bool,
124+
is_a_1d_scaled: bool,
125+
is_b_1d_scaled: bool,
126+
) -> torch.Tensor:
127+
"""
128+
Difference between cuBLAS and CUTLASS GEMM implementations:
129+
- cuBLAS accumulation equation: use different equation for each scaling mode.
130+
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
131+
"""
132+
133+
M, K = qx.shape
134+
N, K_w = qw.shape
135+
assert K == K_w, "K dimension mismatch between qx and qw"
136+
137+
tile_len = 128
138+
# Calculate grid sizes without padding
139+
grid_m = (M + tile_len - 1) // tile_len
140+
grid_n = (N + tile_len - 1) // tile_len
141+
grid_k = (K + tile_len - 1) // tile_len
142+
143+
block_m, block_n, block_k = mm_tile_shape
144+
scale_m_per_tile = tile_len // block_m
145+
scale_n_per_tile = tile_len // block_n
146+
assert block_k == tile_len, "block_k must be equal to tile_len"
147+
148+
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
149+
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
150+
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
151+
y = torch.zeros(M, N, dtype=torch.float32, device=qx.device)
152+
153+
# Validate shapes of sx and sw
154+
scale_m_per_tensor = (M + block_m - 1) // block_m
155+
scale_n_per_tensor = (N + block_n - 1) // block_n
156+
assert sx.shape == (
157+
scale_m_per_tensor,
158+
grid_k,
159+
), f"sx shape mismatch: expected ({scale_m_per_tensor}, {grid_k}), got {sx.shape}"
160+
assert sw.shape == (
161+
scale_n_per_tensor,
162+
grid_k,
163+
), f"sw shape mismatch: expected ({scale_n_per_tensor}, {grid_k}), got {sw.shape}"
164+
165+
for i in range(grid_m):
166+
m_start = i * tile_len
167+
m_end = min(m_start + tile_len, M)
168+
m_size = m_end - m_start
169+
170+
for j in range(grid_n):
171+
n_start = j * tile_len
172+
n_end = min(n_start + tile_len, N)
173+
n_size = n_end - n_start
174+
175+
y_block = y[m_start:m_end, n_start:n_end]
176+
177+
for k in range(grid_k):
178+
k_start = k * tile_len
179+
k_end = min(k_start + tile_len, K)
180+
k_size = k_end - k_start
181+
182+
qx_block = (
183+
qx[m_start:m_end, k_start:k_end].clone().contiguous()
184+
) # Shape: [m_size, k_size]
185+
qw_block = (
186+
qw[n_start:n_end, k_start:k_end].clone().contiguous()
187+
) # Shape: [n_size, k_size]
188+
189+
# Extract scaling factors for the current blocks
190+
sx_block = sx[i * scale_m_per_tile : (i + 1) * scale_m_per_tile, k].unsqueeze(
191+
-1
192+
)
193+
sw_block = sw[j * scale_n_per_tile : (j + 1) * scale_n_per_tile, k].unsqueeze(0)
194+
195+
# Perform qgemm with scaling factors fused in the GEMM
196+
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
197+
one = torch.tensor(1.0, dtype=torch.float32, device=qx.device)
198+
y_partial = torch._scaled_mm(
199+
qx_block,
200+
qw_block.t(),
201+
scale_a=one,
202+
scale_b=one,
203+
out_dtype=torch.float32,
204+
use_fast_accum=not use_split_accumulator,
205+
)
206+
207+
# Accumulate the partial result
208+
if is_a_1d_scaled and is_b_1d_scaled:
209+
# 1Dx1D
210+
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
211+
y_partial = y_partial * sx_block
212+
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
213+
# y_block.add_(y_partial, alpha=scale.item())
214+
fused_fma(
215+
y_block,
216+
y_partial,
217+
sw_block.expand_as(y_partial).contiguous(),
218+
)
219+
elif not is_a_1d_scaled and is_b_1d_scaled:
220+
# 2Dx1D
221+
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
222+
y_partial = y_partial * sw_block
223+
fused_fma(
224+
y_block,
225+
y_partial,
226+
sx_block.expand_as(y_partial).contiguous(),
227+
)
228+
elif is_a_1d_scaled and not is_b_1d_scaled:
229+
# 1Dx2D
230+
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
231+
y_partial = y_partial * sx_block
232+
fused_fma(
233+
y_block,
234+
y_partial,
235+
sw_block.expand_as(y_partial).contiguous(),
236+
)
237+
else:
238+
scale = sx_block * sw_block
239+
fused_fma(y_block, y_partial, scale.expand_as(y_partial).contiguous())
240+
241+
y = y.to(out_dtype)
242+
return y

tests/pytorch/references/blockwise_quantizer_reference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor:
4949
s_t = _pad_inner_to_align(unmunged.scale_t, transpose=tile_shape[0] == 1)
5050
return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t)
5151

52+
@classmethod
5253
def demunge_scale_shape_from_backend(
5354
cls,
5455
qtensor_shape: Tuple[int, int],

0 commit comments

Comments
 (0)