Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 120 additions & 41 deletions benchmarks/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from quack.gemm import gemm as quack_gemm

"""
GEMM benchmark using quack.gemm.gemm() (dense path) or the SM100 blockscaled
path (MXFP8 / MXFP4 / NVFP4) via --blockscaled.
GEMM benchmark using quack.gemm.gemm() (dense path) or the blockscaled
path (MXFP8 / MXFP4 / NVFP4). The blockscaled path is selected by passing
--sf_dtype and/or --sf_vec_size.

Usage (dense):
python benchmarks/benchmark_gemm.py --mnkl 512,7168,2048,256 \
Expand All @@ -17,18 +18,22 @@

Usage (blockscaled MXFP8, with cuBLAS comparison):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --init quant --compare_cublas
--ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --sf_vec_size 32

Usage (blockscaled MXFP4):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --d_dtype Float32
--ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \
--sf_vec_size 32 --d_dtype BFloat16

Usage (blockscaled NVFP4):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype Float32
--ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype BFloat16

Usage (blockscaled NVFP4 fast SM120 path):
python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \
--ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \
--sf_vec_size 16 --d_dtype BFloat16 --sm120_nvfp4_path fast
"""


Expand Down Expand Up @@ -124,6 +129,22 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument("--use_tma_gather", action="store_true", help="Use TMA gather4 for A")
parser.add_argument("--max_swizzle_size", type=int, default=8, help="Max swizzle size")
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
parser.add_argument(
"--sm120_nvfp4_init",
choices=("tensorfill", "ones"),
default="tensorfill",
help="SM120 NVFP4 input initialization. tensorfill uses bounded random non-zero "
"FP4/scales close to CUTLASS 79a TensorFillRandomUniform; ones preserves the "
"old all-ones microbenchmark.",
)
parser.add_argument(
"--sm120_nvfp4_path",
choices=("validated", "fast"),
default="validated",
help="SM120 NVFP4 kernel policy. validated uses the conservative direct-store "
"static-scheduler path; fast uses the CLC/full-grid scheduler and delayed TMA "
"epilogue path.",
)
# Dtype flags. Blockscaled path is selected automatically when --sf_dtype is passed.
parser.add_argument(
"--ab_dtype",
Expand Down Expand Up @@ -202,12 +223,20 @@ def _run_blockscaled(args):
)
from quack.cute_dsl_utils import get_device_capacity
from quack.gemm_default_epi import GemmDefaultSm100
from quack.gemm_sm120 import GemmSm120
from quack.sm120_blockscaled_utils import (
create_sm120_nvfp4_ab_tensor,
create_sm120_nvfp4_scale_tensor,
create_sm120_nvfp4_tensorfill_like_ab_tensor,
create_sm120_nvfp4_tensorfill_like_scale_tensor,
)

sm_major = get_device_capacity(torch.device("cuda"))[0]
assert sm_major in (10, 11), (
f"Blockscaled GEMM requires SM100 (B200/B300) or SM110; got SM{sm_major}x. "
"MXFP8/MXFP4/NVFP4 use tcgen05 UMMA which is SM100+."
)
if sm_major not in (10, 11, 12):
raise RuntimeError(
f"Blockscaled GEMM requires SM100/SM110 or SM120; got SM{sm_major}x. "
"SM120 currently supports only the narrow NVFP4 path."
)

if args.varlen_k or args.gather_A or args.pingpong:
raise NotImplementedError(
Expand All @@ -217,6 +246,7 @@ def _run_blockscaled(args):

m, n, k, l = args.mnkl
mma_tiler_mnk = args.tile_shape_mnk
mma_tiler_mn = mma_tiler_mnk[:2]
cluster_shape_mnk = args.cluster_shape_mnk
cluster_shape_mn = cluster_shape_mnk[:2]
if cluster_shape_mnk[2] != 1:
Expand Down Expand Up @@ -257,7 +287,23 @@ def _run_blockscaled(args):
raise ValueError(
f"MXFP4/NVFP4 require K-major for both A and B; got a_major={a_major}, b_major={b_major}"
)
if not GemmDefaultSm100.can_implement_blockscaled(
is_sm120_nvfp4 = (
sm_major == 12
and ab_dtype == cutlass.Float4E2M1FN
and sf_dtype == cutlass.Float8E4M3FN
and sf_vec_size == 16
)
can_implement = (
GemmSm120.can_implement_blockscaled
if is_sm120_nvfp4
else GemmDefaultSm100.can_implement_blockscaled
)
if sm_major == 12 and not is_sm120_nvfp4:
raise TypeError(
"SM120 blockscaled benchmark currently supports only NVFP4 "
"(Float4E2M1FN A/B, Float8E4M3FN scales, sf_vec_size=16)"
)
if not can_implement(
ab_dtype,
sf_dtype,
sf_vec_size,
Expand Down Expand Up @@ -309,54 +355,74 @@ def _run_blockscaled(args):
sf_dtype,
sf_vec_size,
d_dtype,
mma_tiler_mnk,
mma_tiler_mn,
cluster_shape_mn,
mA,
mB,
mD,
mSFA,
mSFB,
varlen_m=True,
sm120_nvfp4_path=args.sm120_nvfp4_path,
)

def fn():
runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m)
else:
a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized(
l,
m,
k,
a_major == "m",
sf_vec_size,
ab_dtype,
sf_dtype,
)
b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized(
l,
n,
k,
b_major == "n",
sf_vec_size,
ab_dtype,
sf_dtype,
)
# (l, rm, rk, 512) contig scale — consumed directly by the kernel.
mSFA, mSFB = a_sc_contig, b_sc_contig
sfa_ref = torch.ones_like(a_ref)
sfb_ref = torch.ones_like(b_ref)
if is_sm120_nvfp4:
if args.sm120_nvfp4_init == "ones":
mA = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22)
mB = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22)
a_ref = torch.ones((m, k, l), device="cuda", dtype=torch.float32)
b_ref = torch.ones((n, k, l), device="cuda", dtype=torch.float32)
_, mSFA = create_sm120_nvfp4_scale_tensor(l, m, k)
_, mSFB = create_sm120_nvfp4_scale_tensor(l, n, k)
mSFA.fill_(1.0)
mSFB.fill_(1.0)
sfa_ref = torch.ones_like(a_ref)
sfb_ref = torch.ones_like(b_ref)
else:
a_ref, mA = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, m, k)
b_ref, mB = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, n, k)
sfa_ref, mSFA = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, m, k)
sfb_ref, mSFB = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, n, k)
else:
a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized(
l,
m,
k,
a_major == "m",
sf_vec_size,
ab_dtype,
sf_dtype,
)
b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized(
l,
n,
k,
b_major == "n",
sf_vec_size,
ab_dtype,
sf_dtype,
)
# (l, rm, rk, 512) contig scale — consumed directly by the SM100 kernel.
mSFA, mSFB = a_sc_contig, b_sc_contig
sfa_ref = torch.ones_like(a_ref)
sfb_ref = torch.ones_like(b_ref)
_, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty")
runner = compile_blockscaled_gemm_tvm_ffi(
ab_dtype,
sf_dtype,
sf_vec_size,
d_dtype,
mma_tiler_mnk,
mma_tiler_mn,
cluster_shape_mn,
mA,
mB,
mD,
mSFA,
mSFB,
sm120_nvfp4_path=args.sm120_nvfp4_path,
)

def fn():
Expand All @@ -365,29 +431,42 @@ def fn():
if not args.skip_ref_check:
fn()
torch.cuda.synchronize()
tol = 5e-3 if d_dtype != cutlass.Float32 else 5e-4
tol = (
0.25
if is_sm120_nvfp4 and args.sm120_nvfp4_init == "tensorfill"
else 5e-3
if d_dtype != cutlass.Float32
else 5e-4
)
rtol = 2e-2 if is_sm120_nvfp4 and args.sm120_nvfp4_init == "tensorfill" else 1e-3
if args.varlen_m:
# Per-expert matmul reference using dequantized operands
ref = torch.cat(
[a_ref_dq[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] @ b_ref_dq[i].T for i in range(l)]
)
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3)
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=rtol)
else:
ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref)
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3)
torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=rtol)
print("Ref check PASSED")

print("Running SM100 Blockscaled GEMM with:")
print(f"Running SM{sm_major} Blockscaled GEMM with:")
print(f"mnkl: {args.mnkl}")
print(f"tile_shape_mnk: {mma_tiler_mnk}, cluster_shape_mnk: {cluster_shape_mnk}")
print(
f"ab_dtype: {ab_dtype}, sf_dtype: {sf_dtype}, sf_vec_size: {sf_vec_size}, d_dtype: {args.d_dtype}"
)
print(f"a_major: {a_major}, b_major: {b_major}")
if is_sm120_nvfp4:
print(f"sm120_nvfp4_init: {args.sm120_nvfp4_init}")
print(f"sm120_nvfp4_path: {args.sm120_nvfp4_path}")

flops = 2 * m * n * k * l
timing = _bench_and_report("quack ", fn, flops, args.warmup_iterations, args.iterations)

if is_sm120_nvfp4:
print("(skipping cuBLAS: benchmark uses SM120 native NVFP4 scale storage)")
return
if args.varlen_m:
print("(skipping cuBLAS: varlen_m not supported)")
return
Expand Down
Loading
Loading