Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AI/varlen_blockscaled_sf_layout.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ a tile-unit offset integer.

## Tests

`tests/test_gemm_sm100_blockscaled.py`:
`tests/test_gemm_blockscaled.py`:
- `test_blockscaled_mxfp8_varlen_m_nonaligned` — 4 seqlen patterns × 2 B-majors = 8 cases.
Patterns include `[128, 128, 128]`, `[100, 200, 150]`, `[30, 300, 64, 200]`,
`[1, 128, 127, 129]`.
Expand Down
128 changes: 90 additions & 38 deletions benchmarks/benchmark_gemm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
import time

import torch
Expand All @@ -7,8 +8,9 @@
from quack.gemm import gemm as quack_gemm

"""
GEMM benchmark using quack.gemm.gemm() (dense path) or the SM100 blockscaled
path (MXFP8 / MXFP4 / NVFP4) via --blockscaled.
GEMM benchmark using quack.gemm.gemm() (dense path) or the blockscaled
path (MXFP8 / MXFP4 / NVFP4). The blockscaled path is selected by passing
--sf_dtype and/or --sf_vec_size.

Usage (dense):
python benchmarks/benchmark_gemm.py --mnkl 512,7168,2048,256 \
Expand All @@ -17,18 +19,24 @@

Usage (blockscaled MXFP8, with cuBLAS comparison):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --init quant --compare_cublas
--ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --sf_vec_size 32

Usage (blockscaled MXFP4):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --d_dtype Float32
--ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --d_dtype BFloat16

Usage (blockscaled NVFP4):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype Float32
--ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype BFloat16

Usage (SM120 packed-LDSM NVFP4 performance path):
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--tile_shape_mnk 64,64,128 --cluster_shape_mnk 1,1,1 \
--ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype BFloat16 --skip_ref_check
"""


Expand Down Expand Up @@ -194,19 +202,21 @@ def _run_blockscaled(args):
compile_blockscaled_gemm_tvm_ffi,
create_blockscaled_operand_quantized,
create_blockscaled_operand_tensor,
create_sm120_blockscaled_scale_tensor,
create_blockscaled_varlen_m_operands,
scale_blocked_for_cublas,
torch_dtype_for_cutlass,
)
from quack.cute_dsl_utils import get_device_capacity
from quack.gemm_default_epi import GemmDefaultSm100
from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120

sm_major = get_device_capacity(torch.device("cuda"))[0]
assert sm_major in (10, 11), (
f"Blockscaled GEMM requires SM100 (B200/B300) or SM110; got SM{sm_major}x. "
"MXFP8/MXFP4/NVFP4 use tcgen05 UMMA which is SM100+."
assert sm_major in (10, 11, 12), (
f"Blockscaled GEMM requires SM100/SM110 or SM120; got SM{sm_major}x."
)

if sm_major == 12 and (args.varlen_m or args.varlen_k):
raise NotImplementedError("SM120 blockscaled benchmark path does not support varlen")
if args.varlen_k or args.gather_A or args.pingpong:
raise NotImplementedError(
"blockscaled + varlen_k/gather/pingpong is not wired up yet. "
Expand Down Expand Up @@ -255,12 +265,28 @@ def _run_blockscaled(args):
raise ValueError(
f"MXFP4/NVFP4 require K-major for both A and B; got a_major={a_major}, b_major={b_major}"
)
if not GemmDefaultSm100.can_implement_blockscaled(
GemmBlockscaledCls = GemmDefaultSm120 if sm_major == 12 else GemmDefaultSm100
mma_tiler_for_validation = (
tuple(mma_tiler_mnk) if len(mma_tiler_mnk) == 3 or sm_major != 12 else (*mma_tiler_mnk, 64)
)
if (
sm_major == 12
and len(mma_tiler_for_validation) == 3
and mma_tiler_for_validation[2] == 128
and os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") != "1"
):
raise NotImplementedError(
"SM120 blockscaled tile_K=128 requires the packed ldmatrix path. "
"Set QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1, for example:\n"
" QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a "
"python benchmarks/benchmark_gemm.py --tile_shape_mnk 64,64,128 ..."
)
if not GemmBlockscaledCls.can_implement_blockscaled(
ab_dtype,
sf_dtype,
sf_vec_size,
d_dtype,
mma_tiler_mnk,
mma_tiler_for_validation,
cluster_shape_mn,
m,
n,
Expand Down Expand Up @@ -320,28 +346,43 @@ def _run_blockscaled(args):
def fn():
runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m)
else:
a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized(
l,
m,
k,
a_major == "m",
sf_vec_size,
ab_dtype,
sf_dtype,
)
b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized(
l,
n,
k,
b_major == "n",
sf_vec_size,
ab_dtype,
sf_dtype,
)
# (l, rm, rk, 512) contig scale — consumed directly by the kernel.
mSFA, mSFB = a_sc_contig, b_sc_contig
sfa_ref = torch.ones_like(a_ref)
sfb_ref = torch.ones_like(b_ref)
if sm_major == 12:
if ab_dtype is not cutlass.Float4E2M1FN or d_dtype is not cutlass.BFloat16:
raise TypeError(
"SM120 blockscaled benchmark currently supports FP4 inputs and BF16 D"
)
_, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype, init="empty")
_, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype, init="empty")
mA.view(torch.uint8).fill_(0x22)
mB.view(torch.uint8).fill_(0x22)
a_ref = torch.ones((m, k, l), device="cuda", dtype=torch.float32)
b_ref = torch.ones((n, k, l), device="cuda", dtype=torch.float32)
sfa_ref, mSFA = create_sm120_blockscaled_scale_tensor(l, m, k, sf_vec_size, sf_dtype)
sfb_ref, mSFB = create_sm120_blockscaled_scale_tensor(l, n, k, sf_vec_size, sf_dtype)
a_sc_contig = b_sc_contig = None
else:
a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized(
l,
m,
k,
a_major == "m",
sf_vec_size,
ab_dtype,
sf_dtype,
)
b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized(
l,
n,
k,
b_major == "n",
sf_vec_size,
ab_dtype,
sf_dtype,
)
# (l, rm, rk, 512) contig scale — consumed directly by the kernel.
mSFA, mSFB = a_sc_contig, b_sc_contig
sfa_ref = torch.ones_like(a_ref)
sfb_ref = torch.ones_like(b_ref)
_, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty")
runner = compile_blockscaled_gemm_tvm_ffi(
ab_dtype,
Expand Down Expand Up @@ -372,10 +413,12 @@ def fn():
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3)
else:
ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref)
if d_dtype != cutlass.Float32:
ref = ref.to(torch_dtype_for_cutlass(d_dtype)).float()
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3)
print("Ref check PASSED")

print("Running SM100 Blockscaled GEMM with:")
print(f"Running SM{sm_major}0 Blockscaled GEMM with:")
print(f"mnkl: {args.mnkl}")
print(f"tile_shape_mnk: {mma_tiler_mnk}, cluster_shape_mnk: {cluster_shape_mnk}")
print(
Expand All @@ -395,6 +438,12 @@ def fn():
# batch would be an unfair comparison (hides batching potential), so skip.
print("(skipping cuBLAS: batched blockscaled mm not supported via a single call)")
return
if sm_major == 12:
print(
"(skipping cuBLAS comparison: SM120 benchmark currently builds QuACK's "
"padded row-major scale tensors, not the cuBLAS/PyTorch scaled_mm scale layout)"
)
return
if a_major != "k" or b_major != "k":
# F.scaled_mm requires A (M,K) row-major and B (K,N) col-major —
# i.e. both operands K-contiguous. Skip for m/n-major to avoid an
Expand Down Expand Up @@ -630,5 +679,8 @@ def fn():

if __name__ == "__main__":
args = parse_arguments()
run(args)
try:
run(args)
except (NotImplementedError, TypeError, ValueError) as exc:
raise SystemExit(f"benchmark_gemm.py: error: {exc}") from None
print("PASS")
116 changes: 116 additions & 0 deletions docs/sm120_blockscaled_perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SM120 Blockscaled Performance Notes

This note explains the first SM120 blockscaled performance path in
`GemmSm120`. It is intentionally narrower than the experimental branch: the PR
keeps the proven per-atom TMA mechanism and adds only the packed shared-memory
consumer path plus a `64x64x128` tile.

## Supported Scope

The performance path is opt-in:

```bash
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1
```

Current intended benchmark shape:

```bash
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \
python benchmarks/benchmark_gemm.py \
--mnkl 4096,4096,4096,1 \
--tile_shape_mnk 64,64,128 \
--cluster_shape_mnk 1,1,1 \
--ab_dtype Float4E2M1FN \
--sf_dtype Float8E4M3FN \
--sf_vec_size 16 \
--d_dtype BFloat16 \
--warmup_iterations 5 \
--iterations 30 \
--skip_ref_check
```

The benchmark is a launch and timing harness. Numerical coverage lives in
`tests/test_gemm_blockscaled.py`, including asymmetric FP4 values, poisoned
scale padding, K-page crossing, and PTX checks.

Targeted correctness gate for this path:

```bash
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \
pytest -q tests/test_gemm_blockscaled.py -k "sm120 and packed_ldsm" -n 16 -s -rs
```

Supported formats for this SM120 path are:

- NVFP4: `Float4E2M1FN` A/B, `Float8E4M3FN` scales, `sf_vec_size=16`
- MXFP4: `Float4E2M1FN` A/B, `Float8E8M0FNU` scales, `sf_vec_size=32`
- BF16 output, `C is None`, `beta=0`
- cluster shape `(1, 1, 1)`

The packed performance path supports `64x64x64` and `64x64x128` CTA tiles. For
`tile_K=128`, logical K must be divisible by 128.

## Why Packed LDSM

The correctness-first SM120 path expanded compact FP4 bytes into the padded
`.b4x16_p64` ldmatrix shared-memory format. That path was useful for proving
the tuple MMA, scale mapping, and padded scale TMA, but profiling showed a large
shared-memory load bottleneck around the generated `b4x16_p64` ldmatrix
instruction.

The packed path instead stages FP4 into a swizzled packed shared-memory layout
and consumes it with:

```text
ldmatrix.sync.aligned.m8n8.x4.shared.b16
mma.sync.aligned.m16n8k64.kind::mxf4nvf4
```

This direction is based on the local CUTLASS GeForce reference in
`examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu`,
while this PR keeps QuACK's narrower per-atom TMA producer path.

Tests assert that the packed path does not regress back to `b4x16_p64`,
`m8n16`, multicast TMA, or `shared::cluster`.

## Why 64x64x128 First

The packed `64x64x64` path removed most of the original shared-load wavefront
excess, but the kernel was still producer/barrier heavy. Moving to
`64x64x128` keeps the same accumulator ownership and correctness surface while
doubling the K work per producer/barrier cycle. Local Nsight Compute runs on a
noisy workstation showed the expected direction:

- shared-load excessive wavefronts stayed near the packed-path level
- tensor pipe active increased materially
- barrier and MIO stall samples per issued instruction dropped
- runtime improved over `64x64x64`

Treat these numbers as direction, not a stable benchmark claim. The following
benchmark runs were taken on an RTX 5060 workstation with reference checking
skipped because the pytest suite owns numerical validation:

```text
base sm120-blockscaled, correctness-first path:
4096x4096x4096 NVFP4 -> BF16, 128x128x64: 60.571 ms, 2.3 TFLOP/s

this PR, packed-LDSM path:
4096x4096x4096 NVFP4 -> BF16, 64x64x64: 2.988 ms, 46.0 TFLOP/s
4096x4096x4096 NVFP4 -> BF16, 64x64x128: 1.329 ms, 103.4 TFLOP/s
```

## Why Not Full-Tile TMA In This PR

The natural next architecture is full-tile or grouped A/B TMA into the final
packed/swizzled shared-memory layout. Local experiments were not clean enough
for this PR:

- raw subbyte swizzled full-tile TMA failed CuTe DSL legalization
- byte-addressable recast layouts compiled but produced many tiny static TMA
sites and could hang at runtime
- nested grouped raw FP4 layouts hit compile/codegen timeouts

Those findings point to a separate minimal layout-lowering repro/upstream issue.
This PR keeps production on the proven per-atom TMA path and uses `tile_K=128`
as the low-risk amortization step.
Loading
Loading