diff --git a/benchmarks/benchmark_gemm_autotuned.py b/benchmarks/benchmark_gemm_autotuned.py index 36d2480f..5f8b8f94 100644 --- a/benchmarks/benchmark_gemm_autotuned.py +++ b/benchmarks/benchmark_gemm_autotuned.py @@ -34,6 +34,11 @@ from quack.autotuner import default_cache_dir from quack.cache_utils import get_cache_path +from quack.gemm_blockscaled_interface import ( + mxfp8_gemm_act, + quantize_act_sm90, + quantize_weight_sm90, +) from quack.gemm_config import GemmConfig from quack.gemm_interface import ( act_to_pytorch_fn_map, @@ -245,6 +250,90 @@ def benchmark_gemm_dgated( return ms, tf +def benchmark_mxfp8_gemm_act( + m, + n, + k, + activation="swiglu", + dtype=torch.bfloat16, + repeats=30, + trace_path=None, +): + """Benchmark fused MXFP8 GEMM + gated activation (SM90 blockscaled path). + + Quantizes A (bf16 -> fp8_e4m3fn + 1x128 scales) and W (bf16 -> fp8_e4m3fn + + 128x128 scales) once outside the timed loop, then measures the fused + GEMM+gated-activation kernel. + + Baseline matches benchmark_gemm_act: torch.compile(F.linear + gated_activation) + on bf16. The reported speedup conflates fusion gains with the lower-precision + MMA throughput, so it overstates pure fusion benefit relative to a hypothetical + bf16 fused kernel. + """ + is_gated = activation in gated_to_pytorch_fn_map + if not is_gated: + raise ValueError(f"benchmark_mxfp8_gemm_act expects a gated activation; got {activation!r}") + + a_bf16 = torch.randn(m, k, device="cuda", dtype=dtype) + # W: (2*N, K) for gated; quantize then build a (K, 2*N) K-contig view for B. + b_n = 2 * n + w_bf16 = torch.randn(b_n, k, device="cuda", dtype=dtype) / math.sqrt(k) + + a_q, a_sc = quantize_act_sm90(a_bf16) + w_q, w_sc = quantize_weight_sm90(w_bf16) + b_q, b_sc = w_q.mT, w_sc.mT + + nflops = 2 * m * b_n * k + + fn = lambda: mxfp8_gemm_act( + a_q, + b_q, + a_sc, + b_sc, + activation=None, + out_dtype=dtype, + postact_dtype=dtype, + store_preact=False, + tuned=False, + ) + fn() # warmup + + if trace_path is not None: + for _ in range(3): + fn() + torch.cuda.synchronize() + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) as prof: + for _ in range(5): + fn() + torch.cuda.synchronize() + prof.export_chrome_trace(trace_path) + print(f" saved kineto trace: {trace_path}") + + time.sleep(0.5) + ms = do_bench(fn, warmup=5, rep=repeats) + tf = tflops(nflops, ms) + + # Baseline: torch.compile(GEMM + gated activation) on bf16. + ref_fn = torch.compile( + lambda: _torch_gated_act(gated_to_pytorch_fn_map[activation], a_bf16, w_bf16) + ) + ref_fn() # compile warmup + ref_fn() + time.sleep(0.5) + ms_pt = do_bench(ref_fn, warmup=5, rep=repeats) + tf_pt = tflops(nflops, ms_pt) + + print(f" quack mxfp8: {ms:.3f}ms {tf:.1f} TFLOPS") + print(f" cuBLAS bf16 + torch.compile: {ms_pt:.3f}ms {tf_pt:.1f} TFLOPS") + print(f" speedup vs bf16 baseline: {ms_pt / ms:.2f}x") + return ms, tf + + def forced_config_from_args(args): if args.config_tile_n is None: return None @@ -293,6 +382,17 @@ def main(): default=None, help="Restrict the FFN gated backward benchmark to one activation", ) + parser.add_argument( + "--only-mxfp8-gated", + action="store_true", + help="Only run the SM90 MXFP8 FFN gated GEMM benchmark", + ) + parser.add_argument( + "--mxfp8-gated-activation", + choices=sorted(gated_to_pytorch_fn_map), + default=None, + help="Restrict the MXFP8 FFN gated benchmark to one activation", + ) parser.add_argument( "--untuned", action="store_true", @@ -307,6 +407,12 @@ def main(): parser.add_argument("--config-pingpong", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--config-swap-ab", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("--cold", action="store_true", help="Clear .so and autotuning caches first") + parser.add_argument( + "--trace", + type=str, + default=None, + help="Export a kineto Chrome trace of the quack kernel to this path (mxfp8-gated only)", + ) args = parser.parse_args() if args.cold: @@ -332,11 +438,22 @@ def main(): "swiglu-tanh", ] ) + mxfp8_gated_activations = ( + [args.mxfp8_gated_activation] + if args.mxfp8_gated_activation + else [ + "swiglu", + "geglu", + ] + ) forced_config = forced_config_from_args(args) ffn = int(args.dim * 3.5) # Llama-3 ratio - if args.only_gated and args.only_dgated: - raise ValueError("--only-gated and --only-dgated are mutually exclusive") + only_flags = [args.only_gated, args.only_dgated, args.only_mxfp8_gated] + if sum(only_flags) > 1: + raise ValueError( + "--only-gated, --only-dgated, and --only-mxfp8-gated are mutually exclusive" + ) if args.only_gated: print( @@ -384,6 +501,32 @@ def main(): ) return + if args.only_mxfp8_gated: + if torch.cuda.get_device_properties(0).major != 9: + raise RuntimeError("--only-mxfp8-gated requires SM90") + print( + f"MXFP8 GEMM gated activation benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})" + ) + print(f" batch={args.batch}, dim={args.dim}, ffn={ffn}, dtype={args.dtype}") + for activation in mxfp8_gated_activations: + print( + f"\n FFN up + {activation} (mxfp8): ({args.batch}, {args.dim}) x ({args.dim}, {2 * ffn})" + ) + trace_path = args.trace + if trace_path is not None and len(mxfp8_gated_activations) > 1: + root, ext = os.path.splitext(trace_path) + trace_path = f"{root}.{activation}{ext or '.json'}" + benchmark_mxfp8_gemm_act( + args.batch, + ffn, + args.dim, + activation, + dtype, + repeats=args.repeats, + trace_path=trace_path, + ) + return + print(f"GEMM autotuning demo (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})") print(f" M={M}, N={N}, K={K}, dtype={args.dtype}") if forced_config is not None: diff --git a/benchmarks/benchmark_quack_vs_deepgemm_fp8.py b/benchmarks/benchmark_quack_vs_deepgemm_fp8.py new file mode 100644 index 00000000..bf6ce8d1 --- /dev/null +++ b/benchmarks/benchmark_quack_vs_deepgemm_fp8.py @@ -0,0 +1,368 @@ +"""Benchmark quack SM90 mxfp8 GEMM vs deep_gemm fp8 GEMM across modes. + +Modes (selectable via --mode): + dense : standard GEMM (FFN-up shape). Quack vs deep_gemm.fp8_gemm_nt. + batched : grouped GEMM with G experts × M_per_expert tokens (same M per expert). + Quack 3D batched vs deep_gemm.m_grouped_fp8_gemm_nt_masked. + varlen : grouped GEMM with G experts × variable M per expert. + Quack varlen (cu_seqlens_m) vs deep_gemm.m_grouped_fp8_gemm_nt_contiguous. + all : run all three (default). + +Models are MoE configurations from real systems; per-expert GEMM shape is +(M_per_expert, hidden, expert_dim). FFN gating is omitted (single matmul only), +so reported TFLOPS counts the up-proj only. + +Run on H100: + python benchmarks/benchmark_quack_vs_deepgemm_fp8.py + python benchmarks/benchmark_quack_vs_deepgemm_fp8.py --mode varlen --batch 4096 +""" + +import argparse +import math +from typing import List, Tuple + +import torch +from triton.testing import do_bench + +import deep_gemm +from deep_gemm.utils import get_m_alignment_for_contiguous_layout +from deep_gemm.utils.math import per_token_cast_to_fp8, per_block_cast_to_fp8 + +from quack.gemm_blockscaled_interface import ( + mxfp8_gemm_act, + quantize_act_sm90, + quantize_weight_sm90, +) + + +# MoE model shapes (name, hidden, expert_dim, num_experts, active_per_token) +# Active counts are typical top-k values for the model family. +MOE_MODELS: List[Tuple[str, int, int, int, int]] = [ + ("Qwen3 3a30b", 2048, 768, 128, 8), + ("Qwen3 22a235b", 4096, 1536, 128, 8), + ("Qwen3.5 3a35b", 2048, 512, 256, 8), + ("Qwen3.5 17a397b", 4096, 1024, 512, 8), + # DeepSeek-V4: hidden=7168, expert_dim=3072, num_experts=384. Disabled by + # default because the (384, 3072, 7168) weight stack needs ~35 GB peak + # during quantization (bf16 + contiguous copy) and doesn't fit on a single + # 80 GB H100 alongside other allocations. Re-enable once we have a + # per-expert quantization path that avoids the bf16 working copy. + # ("DeepSeek-V4", 7168, 3072, 384, 8), +] + +# Standalone dense shapes (FFN up/down at common sizes). +DENSE_SHAPES: List[Tuple[int, int, int]] = [ + (8192, 4096, 14336), + (8192, 14336, 4096), +] + + +def _tflops(m_total: int, n: int, k: int, ms: float) -> float: + return 2.0 * m_total * n * k / (ms * 1e-3) / 1e12 + + +def _pt_cast_3d(x_3d: torch.Tensor): + """deep_gemm's per_token_cast_to_fp8 asserts 2D; loop over batch dim.""" + qs, sfs = zip(*[ + per_token_cast_to_fp8(x_3d[i].contiguous(), use_ue8m0=True, gran_k=128) + for i in range(x_3d.shape[0]) + ]) + return torch.stack(qs), torch.stack(sfs) + + +def _pb_cast_3d(x_3d: torch.Tensor): + """deep_gemm's per_block_cast_to_fp8 asserts 2D; loop over batch dim.""" + qs, sfs = zip(*[ + per_block_cast_to_fp8(x_3d[i].contiguous(), use_ue8m0=True, gran_k=128) + for i in range(x_3d.shape[0]) + ]) + return torch.stack(qs), torch.stack(sfs) + + +# --------------------------------------------------------------------------- +# Dense GEMM (single matmul) +# --------------------------------------------------------------------------- + + +def bench_dense_quack(M: int, K: int, N: int, repeats: int) -> Tuple[float, float]: + A = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + W = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + A_q, A_sc = quantize_act_sm90(A) + W_q, W_sc = quantize_weight_sm90(W) + B_q, B_sc = W_q.mT, W_sc.mT + out = torch.empty(M, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + mxfp8_gemm_act( + A_q, B_q, A_sc, B_sc, + activation=None, + preact_out=None, + postact_out=out, + store_preact=False, + tuned=False, + ) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + return ms, _tflops(M, N, K, ms) + + +def bench_dense_dgemm(M: int, K: int, N: int, repeats: int) -> Tuple[float, float]: + A = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + W = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + A_fp8, A_sf = per_token_cast_to_fp8(A, use_ue8m0=True, gran_k=128) + B_fp8, B_sf = per_block_cast_to_fp8(W, use_ue8m0=True, gran_k=128) + out = torch.empty(M, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + deep_gemm.fp8_gemm_nt((A_fp8, A_sf), (B_fp8, B_sf), out) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + return ms, _tflops(M, N, K, ms) + + +# --------------------------------------------------------------------------- +# Batched-grouped GEMM (G experts, same M per expert) +# --------------------------------------------------------------------------- + + +def bench_batched_quack(G: int, M_per_expert: int, K: int, N: int, repeats: int) -> Tuple[float, float]: + # 3D (L=G, M, K) and (L=G, N, K). + A = torch.randn(G, M_per_expert, K, device="cuda", dtype=torch.bfloat16) + W = torch.randn(G, N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + A_q, A_sc = quantize_act_sm90(A) + W_q, W_sc = quantize_weight_sm90(W) + del A, W # free bf16 — quantization is done + B_q, B_sc = W_q.mT, W_sc.mT + out = torch.empty(G, M_per_expert, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + mxfp8_gemm_act( + A_q, B_q, A_sc, B_sc, + activation=None, + preact_out=None, + postact_out=out, + store_preact=False, + tuned=False, + ) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + M_total = G * M_per_expert + return ms, _tflops(M_total, N, K, ms) + + +def bench_batched_dgemm(G: int, M_per_expert: int, K: int, N: int, repeats: int) -> Tuple[float, float]: + # deep_gemm masked: (G, M_max, K) tokens with per-group valid count. + A = torch.randn(G, M_per_expert, K, device="cuda", dtype=torch.bfloat16) + W = torch.randn(G, N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + A_fp8, A_sf = _pt_cast_3d(A) + B_fp8, B_sf = _pb_cast_3d(W) + del A, W # free bf16; only fp8 needed from here + masked_m = torch.full((G,), M_per_expert, dtype=torch.int32, device="cuda") + out = torch.empty(G, M_per_expert, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + deep_gemm.m_grouped_fp8_gemm_nt_masked( + (A_fp8, A_sf), (B_fp8, B_sf), out, masked_m, M_per_expert, + ) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + M_total = G * M_per_expert + return ms, _tflops(M_total, N, K, ms) + + +# --------------------------------------------------------------------------- +# Varlen-grouped GEMM (G experts, variable M per expert) +# --------------------------------------------------------------------------- + + +def _make_seqlens(G: int, M_total_target: int, seed: int = 0) -> List[int]: + """Generate G per-expert M values summing to ~M_total_target. + + Uses a mildly imbalanced distribution: ~30% variance around the mean, + aligned to multiples of 8 (quack constraint per-segment) and 128 + (deep_gemm padding alignment) — we pick 128 as the common multiple. + """ + gen = torch.Generator(device="cpu").manual_seed(seed) + mean = M_total_target / G + raw = torch.empty(G).uniform_(0.7, 1.3, generator=gen) * mean + align = 128 + seqlens = [max(align, (int(x) // align) * align) for x in raw.tolist()] + return seqlens + + +def bench_varlen_quack(seqlens: List[int], K: int, N: int, repeats: int) -> Tuple[float, float]: + G = len(seqlens) + M_total = sum(seqlens) + A = torch.randn(M_total, K, device="cuda", dtype=torch.bfloat16) + W = torch.randn(G, N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + cu = torch.tensor([0] + list(torch.tensor(seqlens).cumsum(0).tolist()), + dtype=torch.int32, device="cuda") + A_q, A_sc = quantize_act_sm90(A) + W_q, W_sc = quantize_weight_sm90(W) + del A, W + B_q, B_sc = W_q.mT, W_sc.mT + out = torch.empty(M_total, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + mxfp8_gemm_act( + A_q, B_q, A_sc, B_sc, + activation=None, + preact_out=None, + postact_out=out, + store_preact=False, + cu_seqlens_m=cu, + tuned=False, + ) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + return ms, _tflops(M_total, N, K, ms) + + +def bench_varlen_dgemm(seqlens: List[int], K: int, N: int, repeats: int) -> Tuple[float, float]: + """deep_gemm contiguous: flat 128-row-padded A, per-row grouped_layout with -1 padding.""" + G = len(seqlens) + align = get_m_alignment_for_contiguous_layout() + aligned = [((m + align - 1) // align) * align for m in seqlens] + M_total_padded = sum(aligned) + # Construct padded A in expert-permuted order; the unused rows hold zeros. + A_padded = torch.zeros(M_total_padded, K, device="cuda", dtype=torch.bfloat16) + grouped_layout = torch.empty(M_total_padded, device="cuda", dtype=torch.int32) + row = 0 + for g, (a_m, p_m) in enumerate(zip(seqlens, aligned)): + A_padded[row : row + a_m] = torch.randn(a_m, K, device="cuda", dtype=torch.bfloat16) + grouped_layout[row : row + a_m] = g + grouped_layout[row + a_m : row + p_m] = -1 + row += p_m + + W = torch.randn(G, N, K, device="cuda", dtype=torch.bfloat16) / math.sqrt(K) + A_fp8, A_sf = per_token_cast_to_fp8(A_padded, use_ue8m0=True, gran_k=128) + B_fp8, B_sf = _pb_cast_3d(W) + del A_padded, W + out = torch.empty(M_total_padded, N, dtype=torch.bfloat16, device="cuda") + + def fn(): + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + (A_fp8, A_sf), (B_fp8, B_sf), out, grouped_layout, + ) + + fn() + ms = do_bench(fn, warmup=5, rep=repeats) + # Report TFLOPS on the *unpadded* work — the padded rows do compute but aren't useful. + M_total = sum(seqlens) + return ms, _tflops(M_total, N, K, ms) + + +# --------------------------------------------------------------------------- +# Main / printing +# --------------------------------------------------------------------------- + + +def _print_header(extra: str) -> None: + print(f" {extra:<40} {'quack ms':>10} {'quack TF':>9} " + f"{'dgemm ms':>10} {'dgemm TF':>9} {'speedup':>8}") + print("-" * 100) + + +def _print_row(label: str, q_ms, q_tf, d_ms, d_tf) -> None: + speedup = d_ms / q_ms if q_ms > 0 else 0.0 + print(f" {label:<40} {q_ms:>10.4f} {q_tf:>9.1f} " + f"{d_ms:>10.4f} {d_tf:>9.1f} {speedup:>7.2f}x") + + +def run_dense(repeats: int) -> None: + print("\n=== Dense GEMM (no grouping) ===") + _print_header(f"{'M':>6} {'K':>6} {'N':>6}") + for M, K, N in DENSE_SHAPES: + q_ms, q_tf = bench_dense_quack(M, K, N, repeats) + d_ms, d_tf = bench_dense_dgemm(M, K, N, repeats) + _print_row(f"{M:>6} {K:>6} {N:>6}", q_ms, q_tf, d_ms, d_tf) + + +def _safe_run(label: str, fn): + """Run a bench fn; return (ms, tflops) or (None, None) on OOM/error, after logging.""" + try: + return fn() + except torch.cuda.OutOfMemoryError: + print(f" {label} SKIP: OOM") + torch.cuda.empty_cache() + return None, None + except Exception as e: + print(f" {label} SKIP: {type(e).__name__}: {e}") + torch.cuda.empty_cache() + return None, None + + +def run_batched(repeats: int, m_per_expert: int) -> None: + print(f"\n=== Batched-grouped GEMM (G experts × M_per_expert={m_per_expert}) ===") + _print_header(f"{'model':<25}{'G':>5} {'M_per_expert':>13}") + for name, hidden, expert_dim, G, _active in MOE_MODELS: + M_per_expert = m_per_expert + K, N = hidden, expert_dim + label = f"{name:<25}{G:>5} {M_per_expert:>13}" + q_ms, q_tf = _safe_run(label + " (quack)", lambda: bench_batched_quack(G, M_per_expert, K, N, repeats)) + if q_ms is None: + torch.cuda.empty_cache() + continue + d_ms, d_tf = _safe_run(label + " (dgemm)", lambda: bench_batched_dgemm(G, M_per_expert, K, N, repeats)) + if d_ms is None: + torch.cuda.empty_cache() + continue + _print_row(label, q_ms, q_tf, d_ms, d_tf) + torch.cuda.empty_cache() + + +def run_varlen(repeats: int, m_per_expert: int) -> None: + print(f"\n=== Varlen-grouped GEMM (G experts × variable M, mean={m_per_expert}) ===") + _print_header(f"{'model':<25}{'G':>5} {'M_tot':>8}") + for name, hidden, expert_dim, G, _active in MOE_MODELS: + M_total_target = G * m_per_expert + seqlens = _make_seqlens(G, M_total_target) + K, N = hidden, expert_dim + label = f"{name:<25}{G:>5} {sum(seqlens):>8}" + q_ms, q_tf = _safe_run(label + " (quack)", lambda: bench_varlen_quack(seqlens, K, N, repeats)) + if q_ms is None: + torch.cuda.empty_cache() + continue + d_ms, d_tf = _safe_run(label + " (dgemm)", lambda: bench_varlen_dgemm(seqlens, K, N, repeats)) + if d_ms is None: + torch.cuda.empty_cache() + continue + _print_row(label, q_ms, q_tf, d_ms, d_tf) + torch.cuda.empty_cache() + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--repeats", type=int, default=30) + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["dense", "batched", "varlen", "all"], + ) + parser.add_argument( + "--m-per-expert", type=int, default=1024, + help="Per-expert M (tokens routed to each expert) for grouped GEMM benches", + ) + args = parser.parse_args() + + cap = torch.cuda.get_device_properties(0).major + if cap != 9: + raise SystemExit(f"requires SM90 (H100); current device is sm_{cap}0") + + torch.manual_seed(0) + if args.mode in ("dense", "all"): + run_dense(args.repeats) + if args.mode in ("batched", "all"): + run_batched(args.repeats, args.m_per_expert) + if args.mode in ("varlen", "all"): + run_varlen(args.repeats, args.m_per_expert) + + +if __name__ == "__main__": + main() diff --git a/quack/cute_dsl_utils.py b/quack/cute_dsl_utils.py index 2090fb97..3b999439 100644 --- a/quack/cute_dsl_utils.py +++ b/quack/cute_dsl_utils.py @@ -59,6 +59,8 @@ def _patched_convert_single_arg(arg, arg_name, arg_type, ctx): torch.float32: Float32, torch.int32: Int32, torch.int64: Int64, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e8m0fnu: cutlass.Float8E8M0FNU, } diff --git a/quack/gemm_blockscaled_interface.py b/quack/gemm_blockscaled_interface.py index 60714cb4..4c8b44ff 100644 --- a/quack/gemm_blockscaled_interface.py +++ b/quack/gemm_blockscaled_interface.py @@ -1,56 +1,184 @@ # Copyright (c) 2026, Tri Dao. -"""PyTorch-friendly interface for the SM100 MXFP8 blockscaled GEMM. -Shape / layout conventions (matches torch.matmul, torch._scaled_mm, cuBLAS): +"""PyTorch-friendly interface for the MXFP8 blockscaled GEMM (SM90 + SM100). + +Layout overview: A: (M, K) or (L, M, K) dtype float8_e4m3fn, K-contiguous (row-major) B: (K, N) or (L, K, N) dtype float8_e4m3fn, K-contiguous (col-major) - A_scale: (M, K/32) or (L, M, K/32) dtype float8_e8m0fnu, K-contiguous - B_scale: (K/32, N) or (L, K/32, N) dtype float8_e8m0fnu, K-contiguous + A_scale: (M, K/V) or (L, M, K/V) dtype float8_e8m0fnu/float32, K-contiguous + B_scale: (K/V, N) or (L, K/V, N) dtype float8_e8m0fnu/float32, K-contiguous out: (M, N) or (L, M, N) dtype bfloat16/float16, contiguous +K-block size V differs by target: + - SM100 (Blackwell): V = 32 (true MX granularity) + - SM90 (Hopper): V = 128 (coarser, faster on H100) + "K-contiguous" means stride 1 on the K axis. This matches how torchao/cuBLAS -use `torch._scaled_mm(a, b.t(), ...)`: - - you store a weight as nn.Linear-style `W` of shape `(N, K)` row-major - - you pass `W.mT` (a zero-copy view of shape (K, N) with K-contig) as B -The interface applies `.mT` internally to reach the `(N, K) K-major` layout -the quack kernel consumes. No data is copied. +use `torch._scaled_mm(a, b.t(), ...)`: weight stored as `(N, K)` row-major, +pass `W.mT` (zero-copy view of shape `(K, N)` with K-contig) as B. The +interface applies `.mT` internally to reach the `(N, K) K-major` layout the +kernels consume; no data is copied. + +File organization (top to bottom): + 1. Helpers + 2. Quantization (`quantize_act`, `quantize_*_sm{90,100}`, `quantize_weight_sm90`) + 3. SM100 GEMM (`mxfp8_gemm_sm100`, etc.) + 4. SM90 GEMM (`mxfp8_gemm_act_sm90`, etc.) + 5. Unified dispatch (`mxfp8_gemm`, `mxfp8_gemm_act`) + 6. Reference / utility """ -from functools import lru_cache +from functools import lru_cache, partial from typing import Optional, Tuple import torch from torch import Tensor import cutlass +import cutlass.cute as cute +from quack.activation import act_fn_map, gate_fn_map from quack.blockscaled_gemm_utils import ( + _make_compile_tensor_like, ceil_div, compile_blockscaled_gemm_tvm_ffi, pack_scale_2d_to_blocked_contig, scale_blocked_for_cublas, scale_view_for_kernel, ) +from quack.cache_utils import COMPILE_ONLY, jit_cache +from quack.compile_utils import make_fake_tensor as fake_tensor +from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters, torch2cute_dtype_map +from quack.gemm_act import GemmActMixin, GemmActSm90, GemmActSm100, GemmGatedSm90, GemmGatedSm100 +from quack.gemm_config import GemmConfig from quack.gemm_default_epi import GemmDefaultSm100 -from quack.mx_utils import to_mx +from quack.gemm_interface import ( + Activation, + GatedActivation, + _concat_interleave_bias, + _empty_k_matmul_into, + gated_to_pytorch_fn_map, +) +from quack.gemm_tvm_ffi_utils import ( + compile_gemm_kernel, + div_for_dtype, + get_major, + make_fake_gemm_tensors, + make_fake_scheduler_args, + make_fake_varlen_args, + make_scheduler_args, + make_varlen_args, + perm3d_single, +) +from quack.mx_utils import to_mx, to_mx_2d -_SF_VEC_SIZE = 32 -_TORCH_TO_CUTLASS_D = { - torch.bfloat16: cutlass.BFloat16, - torch.float16: cutlass.Float16, - torch.float32: cutlass.Float32, -} +_SF_VEC_SIZE_SM100 = 32 # SM100 K-block size +_SF_VEC_SIZE_SM90 = 128 # SM90 K-block size (activations and weights) +_WEIGHT_BLOCK_N_SM90 = 128 # SM90 N-block size for weight scales -def _default_tiler_cluster(m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int]]: +def default_config(device): + cap = get_device_capacity(device)[0] + if cap == 9: + return GemmConfig( + tile_m=256, + tile_n=128, + cluster_m=1, + cluster_n=2, + pingpong=False, + is_dynamic_persistent=False, + ) + else: + raise NotImplementedError("Currently only Hopper is supported") + + +def _default_tiler_cluster_sm100(m: int, n: int) -> Tuple[Tuple[int, int], Tuple[int, int]]: """Pick a reasonable default (mma_tiler_mn, cluster_shape_mn).""" if m >= 512 and n >= 128: return (256, 128), (2, 1) return (128, 128), (1, 1) +# --------------------------------------------------------------------------- +# Quantization +# --------------------------------------------------------------------------- + + +def _f32_to_e8m0(scale_f32: torch.Tensor) -> torch.Tensor: + """Convert float32 power-of-2 scales to E8M0 bytes. + + Extracts the biased exponent byte: (f32_bits >> 23) & 0xFF. + """ + e8m0_byte = ((scale_f32.contiguous().view(torch.int32) >> 23) & 0xFF).to(torch.uint8) + return e8m0_byte.view(torch.float8_e8m0fnu) + + +def _e8m0_to_f32(scale_e8m0: torch.Tensor) -> torch.Tensor: + """E8M0 (float8_e8m0fnu viewed as uint8) → float32 power-of-2 scale.""" + bits = scale_e8m0.contiguous().view(torch.uint8).to(torch.int32) << 23 + return (bits & 0x7F000000).view(torch.float32) + + +def quantize_act_sm100(x: Tensor) -> Tuple[Tensor, Tensor]: + """SM100 activation quantization: 32-element K-blocks, per-row. + Returns (qdata: float8_e4m3fn, scale: float8_e8m0fnu) in torchao layout.""" + assert x.shape[-1] % _SF_VEC_SIZE_SM100 == 0, ( + f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE_SM100}" + ) + return to_mx(x.contiguous(), _SF_VEC_SIZE_SM100) + + +def quantize_act_sm90(x: Tensor) -> Tuple[Tensor, Tensor]: + """SM90 activation quantization: 128-element K-blocks, per-row. Scales as float32.""" + assert x.shape[-1] % _SF_VEC_SIZE_SM90 == 0, ( + f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE_SM90}" + ) + qdata, scale_e8m0 = to_mx(x.contiguous(), _SF_VEC_SIZE_SM90) + return qdata, _e8m0_to_f32(scale_e8m0).mT.contiguous().mT + + +def quantize_weight_sm90(w: Tensor) -> Tuple[Tensor, Tensor]: + """SM90 weight quantization: 128×128 block size (per-block, not per-row). + + Args: + w: (..., N, K) bf16/fp32, N % 128 == 0, K % 128 == 0. + Returns: + qdata: float8_e4m3fn, same shape as w. + scale: float32, shape (..., N // 128, K // 128). One scale per 128×128 tile. + """ + assert w.shape[-1] % _SF_VEC_SIZE_SM90 == 0, ( + f"last dim K ({w.shape[-1]}) must be divisible by {_SF_VEC_SIZE_SM90}" + ) + assert w.shape[-2] % _WEIGHT_BLOCK_N_SM90 == 0, ( + f"second-to-last dim N ({w.shape[-2]}) must be divisible by {_WEIGHT_BLOCK_N_SM90}" + ) + *batch_shape, N, K = w.shape + qdata_2d, scale_2d = to_mx_2d( + w.reshape(-1, K).contiguous(), _WEIGHT_BLOCK_N_SM90, _SF_VEC_SIZE_SM90 + ) + qdata = qdata_2d.reshape(*batch_shape, N, K) + scale = scale_2d.reshape(*batch_shape, N // _WEIGHT_BLOCK_N_SM90, K // _SF_VEC_SIZE_SM90) + return qdata, scale + + +def quantize_act(x: Tensor) -> Tuple[Tensor, Tensor]: + """Auto-dispatch activation quantization based on device capability.""" + cap = torch.cuda.get_device_capability(torch.cuda.current_device())[0] + if cap == 9: + return quantize_act_sm90(x) + elif cap == 10: + return quantize_act_sm100(x) + else: + raise NotImplementedError(f"sm_{cap}0 not supported") + + +# --------------------------------------------------------------------------- +# SM100 GEMM (Blackwell) +# --------------------------------------------------------------------------- + + @lru_cache(maxsize=64) -def _compile_cached( +def _compile_cached_sm100( m: int, n: int, k: int, @@ -65,7 +193,7 @@ def _compile_cached( dev = torch.device("cuda") rm = ceil_div(m, 128) rn = ceil_div(n, 128) - rk = ceil_div(k // _SF_VEC_SIZE, 4) + rk = ceil_div(k // _SF_VEC_SIZE_SM100, 4) # K-major: (l, m, k) contiguous, viewed as (m, k, l) strides (k, 1, m*k) fake_mA = torch.empty(l, m, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0) fake_mB = torch.empty(l, n, k, dtype=torch.float8_e4m3fn, device=dev).permute(1, 2, 0) @@ -73,13 +201,13 @@ def _compile_cached( fake_mD = torch.empty(l, m, n, dtype=out_torch_dtype, device=dev).permute(1, 2, 0) fake_sc_A = torch.empty(l, rm, rk, 512, dtype=torch.float8_e8m0fnu, device=dev) fake_sc_B = torch.empty(l, rn, rk, 512, dtype=torch.float8_e8m0fnu, device=dev) - fake_mSFA = scale_view_for_kernel(fake_sc_A, m, k // _SF_VEC_SIZE, l) - fake_mSFB = scale_view_for_kernel(fake_sc_B, n, k // _SF_VEC_SIZE, l) + fake_mSFA = scale_view_for_kernel(fake_sc_A, m, k // _SF_VEC_SIZE_SM100, l) + fake_mSFB = scale_view_for_kernel(fake_sc_B, n, k // _SF_VEC_SIZE_SM100, l) return compile_blockscaled_gemm_tvm_ffi( ab_dtype_cutlass, sf_dtype_cutlass, - _SF_VEC_SIZE, - _TORCH_TO_CUTLASS_D[out_torch_dtype], + _SF_VEC_SIZE_SM100, + torch2cute_dtype_map[out_torch_dtype], mma_tiler_mn, cluster_shape_mn, fake_mA, @@ -97,7 +225,7 @@ def _as_3d(x: Tensor, ndim_in: int) -> Tensor: return x -def _to_kernel_layout( +def _to_kernel_layout_sm100( A: Tensor, B: Tensor, A_scale: Tensor, @@ -111,8 +239,12 @@ def _to_kernel_layout( """ assert A.dtype == torch.float8_e4m3fn, f"A dtype must be float8_e4m3fn, got {A.dtype}" assert B.dtype == torch.float8_e4m3fn, f"B dtype must be float8_e4m3fn, got {B.dtype}" - assert A_scale.dtype == torch.float8_e8m0fnu - assert B_scale.dtype == torch.float8_e8m0fnu + assert A_scale.dtype in (torch.float8_e8m0fnu, torch.float32), ( + f"A_scale dtype must be float8_e8m0fnu or float32, got {A_scale.dtype}" + ) + assert B_scale.dtype in (torch.float8_e8m0fnu, torch.float32), ( + f"B_scale dtype must be float8_e8m0fnu or float32, got {B_scale.dtype}" + ) was_2d = A.dim() == 2 # Flip B from (K,N) to (N,K) via .mT (zero-copy). User's B K-contig → .mT K-contig. A3 = _as_3d(A, A.dim()) # (l, m, k) K-contig row-major expected @@ -121,12 +253,12 @@ def _to_kernel_layout( l2, n, k2 = B3.shape assert l == l2, f"batch mismatch: A={l}, B={l2}" assert k == k2, f"K mismatch: A K={k}, B K={k2}" - assert k % _SF_VEC_SIZE == 0, f"K ({k}) must be divisible by {_SF_VEC_SIZE}" + assert k % _SF_VEC_SIZE_SM100 == 0, f"K ({k}) must be divisible by {_SF_VEC_SIZE_SM100}" assert A3.stride(-1) == 1, "A must be K-contiguous (stride 1 on K)" assert B3.stride(-1) == 1, ( "B must be K-contiguous on its K axis (pass .mT of an (N,K) row-major tensor)" ) - sf_k = k // _SF_VEC_SIZE + sf_k = k // _SF_VEC_SIZE_SM100 as3 = _as_3d(A_scale, A_scale.dim()) # expected (l, m, sf_k) K-contig row-major bs3 = _as_3d(B_scale, B_scale.dim()).mT # (l, n, sf_k) K-contig (view) from (l, sf_k, n) assert as3.stride(-1) == 1, "A_scale must be K-contiguous" @@ -153,7 +285,7 @@ def _to_kernel_layout( return m, n, k, l, mA_mkl, mB_nkl, sc_contig_A, sc_contig_B, sfa_view, sfb_view, was_2d -def mxfp8_gemm_out( +def mxfp8_gemm_out_sm100( A: Tensor, B: Tensor, A_scale: Tensor, @@ -164,9 +296,11 @@ def mxfp8_gemm_out( cluster_shape_mn: Optional[Tuple[int, int]] = None, ) -> None: """MXFP8 blockscaled GEMM with pre-allocated output. See module doc for shape conventions.""" - m, n, k, l, mA, mB, _scA, _scB, sfa, sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale) + m, n, k, l, mA, mB, _scA, _scB, sfa, sfb, was_2d = _to_kernel_layout_sm100( + A, B, A_scale, B_scale + ) out_dtype = out.dtype - assert out_dtype in _TORCH_TO_CUTLASS_D, f"unsupported out dtype: {out_dtype}" + assert out_dtype in torch2cute_dtype_map, f"unsupported out dtype: {out_dtype}" expected_out_shape = (m, n) if was_2d else (l, m, n) assert tuple(out.shape) == expected_out_shape, ( f"out shape {tuple(out.shape)} != expected {expected_out_shape}" @@ -176,14 +310,14 @@ def mxfp8_gemm_out( out_3d = out.unsqueeze(0) if was_2d else out # (l, m, n) mD = out_3d.permute(1, 2, 0) # (m, n, l), strides (n, 1, m*n) if mma_tiler_mn is None or cluster_shape_mn is None: - tlr, clu = _default_tiler_cluster(m, n) + tlr, clu = _default_tiler_cluster_sm100(m, n) mma_tiler_mn = mma_tiler_mn or tlr cluster_shape_mn = cluster_shape_mn or clu if not GemmDefaultSm100.can_implement_blockscaled( cutlass.Float8E4M3FN, cutlass.Float8E8M0FNU, - _SF_VEC_SIZE, - _TORCH_TO_CUTLASS_D[out_dtype], + _SF_VEC_SIZE_SM100, + torch2cute_dtype_map[out_dtype], mma_tiler_mn, cluster_shape_mn, m, @@ -198,7 +332,7 @@ def mxfp8_gemm_out( f"unsupported config: m={m}, n={n}, k={k}, l={l}, " f"tiler={mma_tiler_mn}, cluster={cluster_shape_mn}" ) - runner = _compile_cached( + runner = _compile_cached_sm100( m, n, k, @@ -212,7 +346,7 @@ def mxfp8_gemm_out( runner(mA, mB, mD, sfa, sfb) -def mxfp8_gemm( +def mxfp8_gemm_sm100( A: Tensor, B: Tensor, A_scale: Tensor, @@ -231,7 +365,7 @@ def mxfp8_gemm( else: out_shape = (A.shape[0], A.shape[1], B.shape[2]) out = torch.empty(out_shape, dtype=out_dtype, device=A.device) - mxfp8_gemm_out( + mxfp8_gemm_out_sm100( A, B, A_scale, @@ -243,16 +377,751 @@ def mxfp8_gemm( return out -def mxfp8_quantize(x: Tensor) -> Tuple[Tensor, Tensor]: - """Quantize a (..., K) bf16/fp32 tensor to MXFP8. Returns (qdata, scale_2d) - in torchao-convention layout. Last dim (K) must be divisible by 32.""" - assert x.shape[-1] % _SF_VEC_SIZE == 0, ( - f"last dim ({x.shape[-1]}) must be divisible by {_SF_VEC_SIZE}" +# --------------------------------------------------------------------------- +# SM90 GEMM (Hopper) +# --------------------------------------------------------------------------- + + +@jit_cache +def _compile_mxfp8_gemm_act_sm90( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + is_dynamic_persistent, + activation, + rowvec_dtype, + colvec_dtype, + colvec_ndim, + varlen_m, + gather_A, + concat_layout, + device_capacity, + sr_seed_mode=0, + use_tma_gather=False, +): + is_gated = activation in gate_fn_map + sm = device_capacity[0] + if sm == 9: + GemmCls = GemmGatedSm90 if is_gated else GemmActSm90 + else: + GemmCls = GemmGatedSm100 if is_gated else GemmActSm100 + + mA, mB, mD, mC, m, n, k, l = make_fake_gemm_tensors( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + a_major, + b_major, + d_major, + c_major, + varlen_m=varlen_m, + gather_A=gather_A, + ) + + pa_leading = 1 if postact_major == "n" else 0 + pa_n = cute.sym_int() if is_gated else n + div_pa = div_for_dtype(postact_dtype) + pa_leading_dim = 1 if is_gated else pa_leading + pa_shape = (m, pa_n) if varlen_m else (m, pa_n, l) + mAuxOut = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading_dim, divisibility=div_pa) + + mRowVec = ( + fake_tensor(rowvec_dtype, (l, n), leading_dim=1, divisibility=4) if rowvec_dtype else None + ) + if colvec_ndim == 2: + mColVec = ( + fake_tensor(colvec_dtype, (l, m), leading_dim=1, divisibility=4) + if colvec_dtype + else None + ) + elif colvec_ndim == 1: + mColVec = ( + fake_tensor(colvec_dtype, (m,), leading_dim=0, divisibility=4) if colvec_dtype else None + ) + else: + mColVec = None + + from cutlass import Int32 + from cutlass.cute.runtime import make_ptr + + act_fn = gate_fn_map[activation] if is_gated else act_fn_map[activation] + + def fake_scalar(mode): + if mode == 0: + return None + elif mode == 1: + return Int32(0) + else: + return make_ptr(Int32, 0, cute.AddressSpace.gmem, assumed_align=4) + + epi_args = GemmCls.EpilogueArguments( + mAuxOut, + act_fn, + mRowVecBroadcast=mRowVec, + mColVecBroadcast=mColVec, + rounding_mode=0, # RoundingMode.RN, Constexpr baked at compile time + sr_seed=fake_scalar(sr_seed_mode), + ) + scheduler_args = make_fake_scheduler_args((is_dynamic_persistent and sm == 9), False, l) + varlen_args = make_fake_varlen_args(varlen_m, False, gather_A, m if varlen_m else None) + + if sm == 9: + # SM90 blockscaled: float32 scales. + # A scale: dispatch produces (m, sf_k, l) (or (m, sf_k) varlen) with M innermost + # so the TMA atom can do a single contiguous (BLOCK_M, 1) burst per K-stage. + # B scale: (l, n_blocks, sf_k) — read directly from gmem in math warps (no TMA). + sf_k_sym = cute.sym_int() + n_blocks_sym = cute.sym_int() + if varlen_m: + fake_sfa = fake_tensor(cutlass.Float32, (m, sf_k_sym), leading_dim=0, divisibility=1) + else: + fake_sfa = fake_tensor(cutlass.Float32, (m, sf_k_sym, l), leading_dim=0, divisibility=1) + fake_sfb = fake_tensor( + cutlass.Float32, (l, n_blocks_sym, sf_k_sym), leading_dim=2, divisibility=1 + ) + return compile_gemm_kernel( + partial(GemmCls, sf_vec_size=_SF_VEC_SIZE_SM90, weight_n_block=_WEIGHT_BLOCK_N_SM90), + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + mSFA=fake_sfa, + mSFB=fake_sfb, + use_tma_gather=use_tma_gather, + concat_layout=concat_layout or None, + ) + + # SM100/SM110: blockscaled path — inject sf_vec_size and pass fake scale tensors. + # Layout is (l, rm, rk, 512) contiguous; dynamic_layout=True lets TVM FFI + # accept any concrete shape at runtime. + sc_fake = torch.empty(1, 1, 1, 512, dtype=torch.float8_e8m0fnu, device="cuda") + mSFA = _make_compile_tensor_like(sc_fake, cutlass.Float8E8M0FNU, dynamic_layout=True) + mSFB = _make_compile_tensor_like(sc_fake, cutlass.Float8E8M0FNU, dynamic_layout=True) + return compile_gemm_kernel( + partial(GemmCls, sf_vec_size=_SF_VEC_SIZE_SM100), + a_dtype, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + gather_A, + is_dynamic_persistent, + device_capacity, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + mSFA=mSFA, + mSFB=mSFB, + use_tma_gather=use_tma_gather, + concat_layout=concat_layout or None, + ) + + +def mxfp8_gemm_act_dispatch_sm90( + A: Tensor, # (l, m, k) K-contig + B: Tensor, # (l, n, k) K-contig + A_scale: Tensor, # (l, m, k/32) K-contig + B_scale: Tensor, # (l, n, k/32) K-contig + D: Optional[Tensor], # (l, m, n) or None (preact_out) + C: Optional[Tensor], # (l, m, n) or None + PostAct: Tensor, # (l, m, n//2) for gated + tile_count_semaphore: Optional[Tensor], + activation: str, + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + tile_K: int | None = None, + pingpong: bool = False, + persistent: bool = True, + is_dynamic_persistent: bool = False, + max_swizzle_size: int = 8, + rowvec_bias: Optional[Tensor] = None, + colvec_bias: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + use_tma_gather: bool = False, + concat_layout: tuple | None = None, +) -> None: + varlen_m = cu_seqlens_m is not None + gather_A = A_idx is not None + + A_p = perm3d_single(A, varlen_m) + B_p = perm3d_single(B) + D_p = perm3d_single(D, varlen_m) if D is not None else None + C_p = perm3d_single(C, varlen_m) if C is not None else None + PostAct_p = perm3d_single(PostAct, varlen_m) + + a_major = get_major(A_p, "m", "k") + b_major = get_major(B_p, "n", "k") + d_major = get_major(D_p, "m", "n") if D_p is not None else None + c_major = get_major(C_p, "m", "n") if C_p is not None else None + postact_major = get_major(PostAct_p, "m", "n") + + a_dtype = torch2cute_dtype_map[A.dtype] + b_dtype = torch2cute_dtype_map[B.dtype] + d_dtype = torch2cute_dtype_map[D.dtype] if D is not None else None + c_dtype = torch2cute_dtype_map[C.dtype] if C is not None else None + postact_dtype = torch2cute_dtype_map[PostAct.dtype] + colvec_ndim = colvec_bias.ndim if colvec_bias is not None else 0 + + device_capacity = get_device_capacity(A.device) + sm = device_capacity[0] + assert sm in (9, 10, 11), "mxfp8_gemm_act_dispatch_sm90 requires SM90, SM100, or SM110" + + if sm == 9 and not GemmActSm90.is_valid_dtypes( + a_dtype, b_dtype, cutlass.Float32, d_dtype, a_major, b_major + ): + raise ValueError( + f"unsupported SM90 mxfp8 config: a_dtype={a_dtype}, b_dtype={b_dtype}, " + f"d_dtype={d_dtype}, a_major={a_major}, b_major={b_major} " + f"(SM90 fp8 requires K-major A and B)" + ) + + concat_layout_key = tuple(sorted(concat_layout)) if concat_layout else () + compiled_fn = _compile_mxfp8_gemm_act_sm90( + a_dtype, + b_dtype, + d_dtype, + c_dtype, + postact_dtype, + a_major, + b_major, + d_major, + c_major, + postact_major, + (tile_M, tile_N, tile_K) if tile_K is not None else (tile_M, tile_N), + (cluster_M, cluster_N, 1), + pingpong, + persistent, + is_dynamic_persistent, + activation, + torch2cute_dtype_map[rowvec_bias.dtype] if rowvec_bias is not None else None, + torch2cute_dtype_map[colvec_bias.dtype] if colvec_bias is not None else None, + colvec_ndim, + varlen_m, + gather_A, + concat_layout_key, + device_capacity, + use_tma_gather=use_tma_gather, + ) + + if COMPILE_ONLY: + return + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + epi_args = GemmActMixin.EpilogueArguments( + PostAct_p, + None, # act_fn is Constexpr, baked at compile time + mRowVecBroadcast=rowvec_bias, + mColVecBroadcast=colvec_bias, + rounding_mode=None, # Constexpr, baked at compile time + sr_seed=None, + ) + scheduler_args = make_scheduler_args( + max_active_clusters, max_swizzle_size, tile_count_semaphore ) - return to_mx(x.contiguous(), _SF_VEC_SIZE) + varlen_args = make_varlen_args(cu_seqlens_m, None, A_idx) + + if sm == 9: + # Scales are float32. + # SFA: kernel sees logical shape (m, sf_k, l) (or (m, sf_k) for varlen) with + # *M as the innermost contiguous dim* — matches the DeepGEMM a.cu reference + # convention. TMA loads (BLOCK_M, 1) per K-stage as a single 512B burst from + # M-contiguous memory; a K-major (sf_k stride 1) layout would force TMA to + # do a strided gather and reads wrong values. The transpose+contiguous below + # rematerializes A_scale as (..., sf_k, m) then transposes back to a + # (..., m, sf_k) view with M innermost. + if varlen_m: + sfa_sm90 = A_scale + else: + # sfa_2d_to_3d = A_scale.transpose(-2, -1).contiguous() # (l, sf_k, m) + sfa_sm90 = A_scale.permute(1, 2, 0) # view: (m, sf_k, l), M innermost + sfb_sm90 = B_scale.contiguous() # (l, n_blocks, sf_k); read directly from gmem + compiled_fn( + A_p, + B_p, + D_p, + C_p, + epi_args, + scheduler_args, + varlen_args, + sfa_sm90, + sfb_sm90, + None, + ) + else: + # SM100/SM110: pack scales and pass to blockscaled kernel. + # Scales may be float32 (from quantize_act_*) — convert to E8M0 first. + k = A.shape[-1] + l = B.shape[0] + n = B.shape[1] + sf_k = k // _SF_VEC_SIZE_SM100 + a_scale_e8m0 = _f32_to_e8m0(A_scale) if A_scale.dtype == torch.float32 else A_scale + b_scale_e8m0 = _f32_to_e8m0(B_scale) if B_scale.dtype == torch.float32 else B_scale + if varlen_m: + # A_scale: (total_m, sf_k) — dQaccum-padded layout. Each expert's rows + # start at the next 128-row tile boundary after the previous expert, with + # one extra tile of slack per expert boundary so the kernel's + # VarlenManager.offset_batch_SFA decodes offsets correctly. + total_m = A.shape[0] + seqlens_m = (cu_seqlens_m[1:] - cu_seqlens_m[:-1]).cpu().tolist() + tile = 128 + total_padded_rm = (total_m + tile - 1) // tile + (l - 1) + total_padded_m = total_padded_rm * tile + sa_padded = torch.zeros( + total_padded_m, sf_k, dtype=torch.float8_e8m0fnu, device=A_scale.device + ) + row = 0 + for i, m_i in enumerate(seqlens_m): + row_padded = (row // tile + i) * tile + sa_padded[row_padded : row_padded + m_i] = a_scale_e8m0[row : row + m_i] + row += m_i + sc_contig_A = pack_scale_2d_to_blocked_contig(sa_padded.view(1, total_padded_m, sf_k)) + sfa = scale_view_for_kernel(sc_contig_A, total_padded_m, sf_k, 1) + else: + m = A.shape[1] + sc_contig_A = pack_scale_2d_to_blocked_contig(a_scale_e8m0.contiguous()) + sfa = scale_view_for_kernel(sc_contig_A, m, sf_k, l) + sc_contig_B = pack_scale_2d_to_blocked_contig(b_scale_e8m0.contiguous()) + sfb = scale_view_for_kernel(sc_contig_B, n, sf_k, l) + compiled_fn( + A_p, + B_p, + D_p, + C_p, + epi_args, + scheduler_args, + varlen_args, + sfa, + sfb, + None, + ) -def mxfp8_gemm_quantize( +# @autotune( +# configs=[AutotuneConfig(config=c) for c in get_all_configs("gated")], +# key=["activation", "dynamic_scheduler"], +# prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +# ) +def mxfp8_gemm_gated_tuned_sm90( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + A_scale: Tensor, + B_scale: Tensor, + # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact + preact_out: Optional[Tensor], + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: GatedActivation = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + A_scale = A_scale.unsqueeze(0) + B, B_scale = B.mT, B_scale.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + B_scale = B_scale.unsqueeze(0) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if preact_out is not None and preact_out.ndim == 2 and not varlen_m: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + if concat_layout and "bias" in concat_layout: + if bias is not None and bias.dtype.itemsize >= 4: + bias_key = "mColVecBroadcast" if config.swap_ab else "mRowVecBroadcast" + concat_layout = tuple(bias_key if k == "bias" else k for k in concat_layout) + else: + concat_layout = tuple(k for k in concat_layout if k != "bias") + if bias is not None: + bias = _concat_interleave_bias(bias) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + mxfp8_gemm_act_dispatch_sm90( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + A_scale if not config.swap_ab else B_scale, + B_scale if not config.swap_ab else A_scale, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + tile_K=config.tile_k, + pingpong=config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + use_tma_gather=config.use_tma_gather, + concat_layout=concat_layout, + ) + + +def mxfp8_gemm_act_tuned_sm90( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + A_scale: Tensor, + B_scale: Tensor, + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Activation = None, + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + A_scale = A_scale.unsqueeze(0) + B, B_scale = B.mT, B_scale.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + B_scale = B_scale.unsqueeze(0) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if preact_out is not None and preact_out.ndim == 2 and not varlen_m: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) + if dynamic_scheduler and get_device_capacity(A.device)[0] == 9 + else None + ) + mxfp8_gemm_act_dispatch_sm90( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + A_scale if not config.swap_ab else B_scale, + B_scale if not config.swap_ab else A_scale, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + tile_K=config.tile_k, + pingpong=config.pingpong, + persistent=True, + is_dynamic_persistent=dynamic_scheduler, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + use_tma_gather=config.use_tma_gather, + ) + + +def mxfp8_gemm_gated_out_sm90( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + A_scale: Tensor, + B_scale: Tensor, + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: GatedActivation = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + tuned: bool = True, + concat_layout: Optional[str] = None, +) -> None: + """GEMM with gated activation and pre-allocated output tensors.""" + # TODO: add tuning + assert not tuned, "currently tuning is not available" + fn = mxfp8_gemm_gated_tuned_sm90 if tuned else partial(mxfp8_gemm_gated_tuned_sm90, config=None) + fn( + A, + B, + A_scale, + B_scale, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + concat_layout=tuple(concat_layout.split(",")) if concat_layout else None, + ) + + +def mxfp8_gemm_act_out_sm90( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + A_scale: Tensor, + B_scale: Tensor, + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Activation = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """GEMM with activation and pre-allocated output tensors.""" + # TODO: add tuning + tuned = False + fn = mxfp8_gemm_act_tuned_sm90 if tuned else partial(mxfp8_gemm_act_tuned_sm90, config=None) + fn( + A, + B, + A_scale, + B_scale, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + ) + + +def mxfp8_gemm_act_sm90( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + A_scale: Tensor, + B_scale: Tensor, + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Activation = None, + preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + store_preact: bool = True, + dynamic_scheduler: bool = False, + tuned: bool = True, + concat_layout: tuple | None = None, # tensors whose non-contiguous dim is concat [gate; up] +) -> Tuple[Optional[Tensor], Tensor]: + """GEMM with activation (or gated activation) and optional output tensors.""" + is_gated = activation in gated_to_pytorch_fn_map + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + varlen_m = cu_seqlens_m is not None + if not varlen_m: + # SM90 constraints (non-varlen): M % 8, and the "1d2d" block-scaled quant scheme: + # A_scale: (..., M, K // 128) from quantize_act_sm90 (1 × 128) + # B_scale: (..., K // 128, N // 128) from quantize_weight_sm90 (128 × 128), passed as .mT + m_dim = A.shape[-2] # works for both 2D (M, K) and 3D (L, M, K) + k_dim = A.shape[-1] + n_dim = B.shape[-1] + assert m_dim % 8 == 0, f"SM90 mxfp8 GEMM requires M % 8 == 0; got M={m_dim}" + assert A_scale.shape[-2:] == (m_dim, k_dim // 128), ( + f"SM90 expects A_scale from quantize_act_sm90 (1x128): " + f"shape (..., M={m_dim}, K/128={k_dim // 128}); got {tuple(A_scale.shape)}" + ) + assert B_scale.shape[-2:] == (k_dim // 128, n_dim // 128), ( + f"SM90 expects B_scale from quantize_weight_sm90 (128x128, passed as .mT): " + f"shape (..., K/128={k_dim // 128}, N/128={n_dim // 128}); got {tuple(B_scale.shape)}" + ) + # Determine output shape based on gather_A + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) if is_gated else out_shape + if preact_out is None and store_preact: + preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + # Empty-input fast path. For M=0 or N=0 the outputs are empty; for K=0 + # (A@B == 0) the no-bias / no-C surface yields preact=0 and act(0)=0 for + # every supported activation, so both outputs are zero. + if postact_out.numel() == 0 or A.numel() == 0: + if preact_out is not None: + _empty_k_matmul_into(preact_out) + _empty_k_matmul_into(postact_out) + return preact_out, postact_out + concat_str = ",".join(concat_layout) if concat_layout else None + if is_gated: + mxfp8_gemm_gated_out_sm90( + A, + B, + A_scale, + B_scale, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + concat_layout=concat_str, + ) + else: + mxfp8_gemm_act_out_sm90( + A, + B, + A_scale, + B_scale, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + return preact_out, postact_out + + +# --------------------------------------------------------------------------- +# Unified dispatch entry points +# --------------------------------------------------------------------------- +# Top-level public API. These detect the device capability at runtime and +# route to the SM-specific implementation. Callers that don't care about the +# underlying hardware should use these; callers that need to pin a specific +# implementation can still call the *_sm90 / *_sm100 functions directly. + + +def mxfp8_gemm( + A: Tensor, + B: Tensor, + A_scale: Tensor, + B_scale: Tensor, + out: Optional[Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + **kwargs, +) -> Tensor: + """MXFP8 blockscaled GEMM. Dispatches to SM90 or SM100 by device capability. + + SM100 path expects 32-element K-block scales (`quantize_act_sm100`). + SM90 path expects 128-element K-block scales (`quantize_act_sm90`, + `quantize_weight_sm90`). Caller is responsible for using the matching + quantization for the target device. + """ + cap = torch.cuda.get_device_capability(torch.cuda.current_device())[0] + if cap == 10: + return mxfp8_gemm_sm100(A, B, A_scale, B_scale, out=out, out_dtype=out_dtype, **kwargs) + if cap == 9: + # SM90 has no plain-GEMM entry; use the act path with activation=None. + if out is None: + out_shape = (A.shape[0], B.shape[-1]) if A.dim() == 2 else (*A.shape[:-1], B.shape[-1]) + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + mxfp8_gemm_act_sm90( + A, + B, + A_scale, + B_scale, + activation=None, + preact_out=None, + postact_out=out, + store_preact=False, + **kwargs, + ) + return out + raise NotImplementedError(f"mxfp8_gemm: sm_{cap}0 not supported") + + +def mxfp8_gemm_act(*args, **kwargs) -> Tuple[Optional[Tensor], Tensor]: + """GEMM + (optionally gated) activation. SM90-only at present. + + See `mxfp8_gemm_act_sm90` for the full signature. + """ + cap = torch.cuda.get_device_capability(torch.cuda.current_device())[0] + if cap == 9: + return mxfp8_gemm_act_sm90(*args, **kwargs) + raise NotImplementedError(f"mxfp8_gemm_act: sm_{cap}0 not supported (SM90 only)") + + +# --------------------------------------------------------------------------- +# Reference / utility +# --------------------------------------------------------------------------- + + +def mxfp8_gemm_quantize_sm100( A: Tensor, B: Tensor, out: Optional[Tensor] = None, @@ -264,11 +1133,11 @@ def mxfp8_gemm_quantize( """High-level: quantize bf16 A, B_as_NK to MXFP8, then run C = A @ B_as_NK.mT. Inputs: A=(M,K)/(L,M,K), B_as_NK=(N,K)/(L,N,K) bf16/fp32. Quantization scales along the last (K) dim. Returned output has shape (M,N)/(L,M,N).""" - A_q, A_sc = mxfp8_quantize(A) - B_q, B_sc = mxfp8_quantize(B) + A_q, A_sc = quantize_act(A) + B_q, B_sc = quantize_act(B) # B_q, B_sc are (..., N, K) / (..., N, K/32). Flip to (..., K, N) / (..., K/32, N) # K-contig zero-copy views to match the interface convention. - return mxfp8_gemm( + return mxfp8_gemm_sm100( A_q, B_q.mT, A_sc, @@ -280,7 +1149,7 @@ def mxfp8_gemm_quantize( ) -def mxfp8_gemm_cublas( +def mxfp8_gemm_cublas_sm100( A: Tensor, B: Tensor, A_scale: Tensor, @@ -288,13 +1157,15 @@ def mxfp8_gemm_cublas( out_dtype: torch.dtype = torch.bfloat16, ) -> Tensor: """Reference path via torch._scaled_mm. Requires l=1 (or 2D inputs).""" - m, n, k, l, _mA, _mB, sc_A, sc_B, _sfa, _sfb, was_2d = _to_kernel_layout(A, B, A_scale, B_scale) + m, n, k, l, _mA, _mB, sc_A, sc_B, _sfa, _sfb, was_2d = _to_kernel_layout_sm100( + A, B, A_scale, B_scale + ) assert l == 1, "torch._scaled_mm MXFP8 path is 2D only; pass 2D inputs or l=1" # torch._scaled_mm: A=(M,K) row-major, B=(K,N) col-major (both K-contig) -- same layout user gave us. a2d = A if A.dim() == 2 else A.squeeze(0) b2d = B if B.dim() == 2 else B.squeeze(0) - sca = scale_blocked_for_cublas(sc_A, m, k // _SF_VEC_SIZE, 0) - scb = scale_blocked_for_cublas(sc_B, n, k // _SF_VEC_SIZE, 0) + sca = scale_blocked_for_cublas(sc_A, m, k // _SF_VEC_SIZE_SM100, 0) + scb = scale_blocked_for_cublas(sc_B, n, k // _SF_VEC_SIZE_SM100, 0) out = torch._scaled_mm( a2d, b2d, @@ -305,7 +1176,7 @@ def mxfp8_gemm_cublas( return out if was_2d else out.unsqueeze(0) -def mxfp8_gemm_ref( +def mxfp8_gemm_ref_sm100( A: Tensor, B: Tensor, A_scale: Tensor, @@ -320,7 +1191,7 @@ def mxfp8_gemm_ref( B3 = _as_3d(B, B.dim()).mT.contiguous().float() as3 = _as_3d(A_scale, A_scale.dim()).float() bs3 = _as_3d(B_scale, B_scale.dim()).mT.contiguous().float() - a_dq = A3 * as3.repeat_interleave(_SF_VEC_SIZE, dim=-1) - b_dq = B3 * bs3.repeat_interleave(_SF_VEC_SIZE, dim=-1) + a_dq = A3 * as3.repeat_interleave(_SF_VEC_SIZE_SM100, dim=-1) + b_dq = B3 * bs3.repeat_interleave(_SF_VEC_SIZE_SM100, dim=-1) out3 = torch.einsum("lmk,lnk->lmn", a_dq, b_dq).to(out_dtype) return out3.squeeze(0) if was_2d else out3 diff --git a/quack/gemm_sm90.py b/quack/gemm_sm90.py index 423e5d9a..4bc26cd0 100644 --- a/quack/gemm_sm90.py +++ b/quack/gemm_sm90.py @@ -123,6 +123,8 @@ def __init__( use_clc_persistence: bool = False, concat_layout: tuple | None = None, use_pdl: bool = True, + sf_vec_size: Optional[int] = None, + weight_n_block: Optional[int] = None, ): """ Initializes the configuration for a Hopper dense GEMM kernel. @@ -140,6 +142,9 @@ def __init__( self.acc_dtype = acc_dtype self.pingpong = pingpong + self.sf_vec_size = sf_vec_size + self.weight_n_block = weight_n_block + self.blockscaled = sf_vec_size is not None self.is_persistent = is_persistent self.use_clc_persistence = use_clc_persistence if self.use_clc_persistence: @@ -147,7 +152,7 @@ def __init__( self.use_pdl = use_pdl if self.pingpong: assert self.is_persistent, "Pingpong gemm requires persistent scheduler" - self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8 + self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8 or self.blockscaled self.gather_A = gather_A self.concat_layout = concat_layout or () if gather_A: @@ -228,8 +233,16 @@ def __init__( regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // ( math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group ) + self.fp8_partial_acc = False if self.fp8_slow_accum: - regs_per_thread *= 2 + if self.blockscaled: + assert tile_M % 128 == 0 + if tile_M >= 256: + assert math.prod(self.atom_layout_mnk[1:]) == 1 + regs_per_thread += regs_per_thread // 2 + self.fp8_partial_acc = True + else: + regs_per_thread *= 2 if not self.gather_A: if self.mma_warp_groups == 3: self.num_regs_load, self.num_regs_mma = 32, 160 @@ -290,6 +303,8 @@ def _setup_tiled_mma(self): tile_k = ( self.cta_tile_shape_mnk[2] if self.cta_tile_shape_mnk[2] > 0 else mma_inst_shape_k * 4 ) + if self.blockscaled: + assert self.sf_vec_size == tile_k == self.cta_tile_shape_mnk[1] assert tile_k > 0, "CTA tile K must be positive" assert tile_k % mma_inst_shape_k == 0, ( f"CTA tile K ({tile_k}) must be divisible by MMA instruction K ({mma_inst_shape_k})" @@ -325,7 +340,11 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments): ) self.epi_tile_shape = cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile) - # Compute stage before compute smem layout + # Compute stage before compute smem layout. SFA staging (mxfp8) needs + # BLOCK_M * 4 extra bytes per stage so reduce ab_stage accordingly. + sfa_bytes_per_stage = ( + self.cta_tile_shape_mnk[0] * 4 if (self.blockscaled and not self.gather_A) else 0 + ) self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages( self.cta_tile_shape_mnk, self.epi_tile, @@ -337,6 +356,7 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments): cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity self.occupancy, self.epi_smem_warp_shape_mnk(), + sfa_bytes_per_stage=sfa_bytes_per_stage, ) self.sched_stage = 2 if self.pingpong else 1 @@ -345,6 +365,7 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments): self.b_smem_layout_staged, self.epi_smem_layout_staged, self.epi_c_smem_layout_staged, + self.sfa_smem_layout_staged, ) = self._make_smem_layouts( self.cta_tile_shape_mnk, self.epi_tile, @@ -359,6 +380,7 @@ def _setup_attributes(self, epilogue_args: EpilogueArguments): self.c_dtype, self.c_layout, self.epi_c_stage, + sfa_staged=self.blockscaled and not self.gather_A, ) @cute.jit @@ -372,6 +394,8 @@ def __call__( scheduler_args: TileSchedulerOptions, varlen_args: Optional[VarlenArguments], stream: cuda.CUstream, + mSFA: Optional[cute.Tensor] = None, + mSFB: Optional[cute.Tensor] = None, trace_ptr: Optional[cutlass.Int64] = None, ): """Execute the GEMM operation in steps: @@ -434,6 +458,22 @@ def __call__( if const_expr(not self.gather_A): self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout) + # SFA TMA atom: stage activation scales into smem alongside A/B per K-block. + # mSFA arrives as (M, SF_K, L) (non-varlen, permuted in dispatch) or + # (total_M, SF_K) (varlen). The atom's smem_tile is (BLOCK_M, 1) — one column + # of sf_k per K-stage. mcast=1 since per-CTA SFA is small. + tma_atom_sfa, tma_tensor_sfa = None, None + sfa_smem_layout = None + if const_expr(self.blockscaled and not self.gather_A): + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = self._make_tma_atoms_and_tensors( + mSFA, + sfa_smem_layout, + (self.cta_tile_shape_mnk[0], 1), + self.cluster_shape_mnk[1], + ) + self.num_tma_load_bytes += cute.size_in_bytes(Float32, sfa_smem_layout) + tma_atom_d, tma_tensor_d, tma_atom_c, tma_tensor_c = ( self.make_tma_epilogue_atoms_and_tensors(mD, mC, epilogue_args, varlen_m) ) @@ -462,6 +502,16 @@ def __call__( epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + sfa_smem_size = ( + cute.cosize(self.sfa_smem_layout_staged) + if (self.blockscaled and not self.gather_A) + else 0 + ) + # Fixed upper bound: 512 floats = 2KB. Covers K up to 64K with BLOCK_K=128. + # Used to stage mSFB into smem once per output block (DeepGEMM-style), + # eliminating the per-k_tile gmem dependency on the wgmma issue path. + SFB_SMEM_MAX = 512 + sfb_smem_size = SFB_SMEM_MAX if (self.blockscaled and not self.pingpong) else 0 @cute.struct class SharedStorage: @@ -490,6 +540,14 @@ class SharedStorage: cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)], self.buffer_align_bytes, ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[Float32, sfa_smem_size], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[Float32, sfb_smem_size], + self.buffer_align_bytes, + ] self.shared_storage = SharedStorage @@ -511,8 +569,12 @@ class SharedStorage: self.b_smem_layout_staged, self.epi_smem_layout_staged, self.epi_c_smem_layout_staged, + self.sfa_smem_layout_staged, tile_sched_params, TileSchedulerCls, + tma_atom_sfa, + tma_tensor_sfa if (self.blockscaled) else mSFA, + mSFB, trace_ptr, ).launch( grid=grid, @@ -544,8 +606,12 @@ def kernel( b_smem_layout: cute.ComposedLayout, epi_smem_layout: cute.ComposedLayout, epi_c_smem_layout: cute.ComposedLayout, + sfa_smem_layout: Optional[cute.Layout], tile_sched_params, TileSchedulerCls: cutlass.Constexpr[Callable], + tma_atom_sfa: Optional[cute.CopyAtom] = None, + mSFA: Optional[cute.Tensor] = None, + mSFB: Optional[cute.Tensor] = None, trace_ptr: Optional[cutlass.Int64] = None, ): """ @@ -627,6 +693,14 @@ def kernel( # Generate smem tensor A/B sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) + sSFA = None + if const_expr(sfa_smem_layout is not None): + # Plain layout (no swizzle) — get_tensor without the swizzle kwarg. + sSFA = storage.sSFA.get_tensor(sfa_smem_layout) + # Per-output-block scale_b staged into smem (non-pingpong only for now). + sSFB = None + if const_expr(self.blockscaled and not self.pingpong): + sSFB = storage.sSFB.get_tensor(cute.make_layout((512,))) sD = None if const_expr(has_D): sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) @@ -729,11 +803,40 @@ def kernel( dst_tensor=sB, mcast_mask=b_mcast_mask, ) + # SFA: TMA-load activation scales staged alongside A/B per K-block. + # mSFA is (M, SF_K, L) for non-varlen (permuted in dispatch) or + # (total_M, SF_K) for varlen. After offset_batch_SFA we have (M, SF_K). + copy_SFA = None + if const_expr(sSFA is not None): + if const_expr(varlen_m): + mSFA_mk = cute.domain_offset( + (varlen_params.cu_seqlens_m[batch_idx], 0), mSFA + ) + else: + mSFA_mk = mSFA[None, None, batch_idx] + gSFA_mk = cute.local_tile( + mSFA_mk, + (self.cta_tile_shape_mnk[0], 1), + (tile_coord_mnkl[0], None), + ) + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=block_in_cluster_coord_mnk[1], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_mnk, (0, None, 0)).shape + ), + src_tensor=gSFA_mk, + dst_tensor=sSFA, + mcast_mask=a_mcast_mask, + ) len_k = varlen_manager.len_k(batch_idx) k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(not self.gather_A): + copy_fns = [copy_A, copy_B] + if const_expr(copy_SFA is not None): + copy_fns.append(copy_SFA) ab_producer_state = self.load_tma( - ab_pipeline, ab_producer_state, [copy_A, copy_B], k_tile_cnt + ab_pipeline, ab_producer_state, copy_fns, k_tile_cnt ) else: ab_producer_state = self.load_AB_gather_A( @@ -785,7 +888,9 @@ def kernel( acc_slow = None if const_expr(self.fp8_slow_accum): acc_slow = cute.make_rmem_tensor(acc.shape, self.acc_dtype) - mma_fn = partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc, tCrA, tCrB) + if const_expr(self.fp8_partial_acc): + partial_acc_shape = (acc.shape[0], acc.shape[1] // 2, acc.shape[2]) + acc_slow = cute.make_rmem_tensor(partial_acc_shape, self.acc_dtype) if const_expr(self.pingpong): if warp_group_idx == 0: @@ -829,10 +934,44 @@ def kernel( k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) if const_expr(self.pingpong): self.pingpong_barrier_sync(warp_group_idx, stage="mma") + # Stage scale_b for this output block into smem (non-pingpong only). + # Math threads cooperatively load mSFB[batch, n_tile, 0..k_tile_cnt-1] + # into sSFB, then sync the math WGs via the epilogue_barrier (256 threads). + if const_expr(self.blockscaled and not self.pingpong): + threads_per_math = self.mma_warp_groups * self.num_threads_per_warp_group + # Each pass covers `threads_per_math` scales; 2 passes covers k_tile_cnt up to 512. + NUM_SFB_PASSES = const_expr(512 // threads_per_math) + for pass_idx in cutlass.range_constexpr(NUM_SFB_PASSES): + kk = tidx + pass_idx * threads_per_math + if kk < k_tile_cnt: + sSFB[kk] = mSFB[batch_idx, tile_coord_mnkl[1], kk] + self.epilogue_barrier.arrive_and_wait() tctx.b("mma") - ab_read_state = self.mma( - ab_pipeline, ab_read_state, mma_fn, acc, acc_slow, k_tile_cnt, warp_group_idx - ) + if const_expr(self.blockscaled): + ab_read_state = self.mma_blockscaled( + ab_pipeline, + ab_read_state, + partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc_slow, tCrB=tCrB), + acc, + acc_slow, + tCrA, + tile_coord_mnkl[1], + k_tile_cnt, + warp_group_idx, + sSFA=sSFA, + sSFB=sSFB, + ) + else: + mma_fn = partial(quack_sm90_utils.gemm_w_idx, tiled_mma, acc, tCrA, tCrB) + ab_read_state = self.mma( + ab_pipeline, + ab_read_state, + mma_fn, + acc, + acc_slow, + k_tile_cnt, + warp_group_idx, + ) if const_expr(varlen_k): if k_tile_cnt == 0: acc.fill(0.0) @@ -1130,6 +1269,85 @@ def mma( acc.store(acc_slow.load()) return ab_read_state + @cute.jit + def mma_blockscaled( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_read_state: cutlass.pipeline.PipelineState, + mma_fn: Callable, + acc: cute.Tensor, + acc_slow: cute.Tensor, + tCrA: cute.Tensor, + n_tile_coord: Int32, + k_tile_cnt: Int32, + warp_group_idx: Int32, + sSFA: Optional[cute.Tensor] = None, + sSFB: Optional[cute.Tensor] = None, + ) -> cutlass.pipeline.PipelineState: + tidx, _, _ = cute.arch.thread_idx() + tidx_wg = tidx % Int32(self.num_threads_per_warp_group) + + # acc contains the output of a single warp group + # each warp is responsible for 16 rows + # See: https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-register-fragment-wgmma-64n32 + thread_m_offset = Int32(16) * (tidx_wg // Int32(32)) + ((tidx_wg % Int32(32)) // Int32(4)) + + MMA_M = const_expr(acc.shape[1]) + # Sanity check that MMA_M of acc is the same as the MMA_M of tCrA + assert MMA_M == tCrA.shape[1] + wgmma_m = const_expr(cute.size(self.tiled_mma.shape_mnk, mode=[0])) + if const_expr(self.atom_layout_mnk[0] > 1 and not self.pingpong): + wg_m_offset = warp_group_idx * Int32(wgmma_m) + else: + wg_m_offset = Int32(0) + m0 = wg_m_offset + thread_m_offset + m1 = wg_m_offset + thread_m_offset + Int32(8) + + ab_release_state = ab_read_state.clone() + acc.fill(0.0) + + scales = cute.make_rmem_tensor(cute.make_layout((2,)), acc.dtype) + for k_tile in cutlass.range(k_tile_cnt, unroll=8): + peek_full = ab_pipeline.consumer_try_wait(ab_read_state) + scale_b = sSFB[k_tile] + ab_pipeline.consumer_wait(ab_read_state, peek_full) + for m_idx in cutlass.range_constexpr(MMA_M): + curr_tCrA = layout_utils.expand(tCrA[None, m_idx, None, None], dim=1, size=1) + curr_acc = layout_utils.expand(acc[None, m_idx, None], dim=1, size=1) + m_off = Int32(m_idx * self.atom_layout_mnk[0] * wgmma_m) + stage = ab_read_state.index + scales[0] = sSFA[m0 + m_off, 0, stage] * scale_b + scales[1] = sSFA[m1 + m_off, 0, stage] * scale_b + mma_fn( + tCrA=curr_tCrA, + A_idx=stage, + B_idx=stage, + zero_init=True, + ) + if const_expr(m_idx == MMA_M - 1): + ab_pipeline.consumer_release(ab_release_state) + # scale_b = Float32(1.0) + # Each m_idx iteration is one tiled-MMA step further down M. + # Stride = atom_layout_m * atom_m (= 2*64 = 128 for tile_m=256, atom_layout_m=2). + + warpgroup.wait_group(0) + + # Broadcast layout: ((2,2,16), MMA_M, MMA_N) -> offset = b + 2*m + # b chooses scale_a_0 vs scale_a_1 within an atom; m chooses the atom. + scales_bcast = cute.make_tensor( + scales.iterator, + cute.make_layout( + acc_slow.shape, + stride=((0, 1, 0), 0, 0), + ), + ) + curr_acc.store(curr_acc.load() + acc_slow.load() * scales_bcast.load()) + ab_read_state.advance() + + ab_release_state.advance() + + return ab_read_state + def epi_retile_acc(self, acc, tRS_rD, tiled_copy_r2s): """Retile accumulator for epilogue subtile access.""" acc_reshaped = layout_utils.reshape_acc_to_frgA(acc) # ((2, 2, 2), MMA_M, MMA_N) @@ -1254,6 +1472,7 @@ def _compute_stages( smem_capacity: int, occupancy: int, warp_shape_mnk: Tuple[int, int, int] | None = None, + sfa_bytes_per_stage: int = 0, ) -> Tuple[int, int]: """Computes the number of stages for A/B/C operands based on heuristics. @@ -1293,7 +1512,9 @@ def _compute_stages( a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None)) b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None)) ab_bytes_per_stage = ( - cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8 + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + + sfa_bytes_per_stage # mxfp8: BLOCK_M floats of activation scale per stage ) mbar_helpers_bytes = 1024 @@ -1363,8 +1584,13 @@ def _make_smem_layouts( c_dtype: Optional[Type[cutlass.Numeric]], c_layout: Optional[LayoutEnum], epi_c_stage: int, + sfa_staged: bool = False, ) -> Tuple[ - cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout] + cute.ComposedLayout, + cute.ComposedLayout, + cute.ComposedLayout, + Optional[cute.ComposedLayout], + Optional[cute.ComposedLayout], ]: """Create shared memory layouts for A, B, and C tensors. @@ -1433,11 +1659,20 @@ def _make_smem_layouts( c_dtype, c_layout, epi_tile, epi_c_stage ) + sfa_smem_layout_staged = None + if sfa_staged: + BM = cta_tile_shape_mnk[0] + sfa_smem_layout_staged = cute.make_layout( + (BM, 1, ab_stage), + stride=(1, BM, BM), + ) + return ( a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged, epi_c_smem_layout_staged, + sfa_smem_layout_staged, ) @staticmethod diff --git a/quack/gemm_tvm_ffi_utils.py b/quack/gemm_tvm_ffi_utils.py index 8f687df4..061ffbb5 100644 --- a/quack/gemm_tvm_ffi_utils.py +++ b/quack/gemm_tvm_ffi_utils.py @@ -214,7 +214,7 @@ def compile_gemm_kernel( if post_init: post_init(gemm_obj) stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - sf_args = () if device_capacity[0] in (8, 9, 12) else (mSFA, mSFB) + sf_args = () if device_capacity[0] in (8, 12) else (mSFA, mSFB) # Trace pointer: Optional[Int64]. Compile with Int64(0) when tracing is # requested, None otherwise. TVM-FFI caches each variant separately. trace_ptr = Int64(0) if has_trace_ptr else None diff --git a/quack/mx_utils.py b/quack/mx_utils.py index 5184bc92..f06949bb 100644 --- a/quack/mx_utils.py +++ b/quack/mx_utils.py @@ -232,11 +232,49 @@ def to_nvfp4(x: torch.Tensor, block_size: int = 16, per_tensor_scale=None): return data_lp, block_scale_fp8, returned_pts +def to_mx_2d(data_hp: torch.Tensor, block_rows: int = 128, block_cols: int = 128): + """MXFP8-e4m3 quantization with 2D (block_rows × block_cols) FLOOR scaling. + + Each (block_rows, block_cols) tile shares one E8M0 scale. + + Args: + data_hp: (N, K) bf16 or fp32, contiguous. N % block_rows == 0, K % block_cols == 0. + Returns: + qdata: (N, K) float8_e4m3fn + scale: (N // block_rows, K // block_cols) float32 + """ + assert data_hp.dtype in (torch.bfloat16, torch.float32) + assert data_hp.ndim == 2, "to_mx_2d requires a 2D (N, K) input" + N, K = data_hp.shape + assert N % block_rows == 0 and K % block_cols == 0 + assert data_hp.is_contiguous() + + # Reshape to (N//block_rows, block_rows, K//block_cols, block_cols) and + # reduce max over the two inner (tile) dims. + blocked = data_hp.float().reshape(N // block_rows, block_rows, K // block_cols, block_cols) + max_abs = torch.amax(torch.abs(blocked), dim=(1, 3), keepdim=True) # (nr, 1, nk, 1) + + scale_e8m0_biased = _compute_e8m0_scale_floor(max_abs, F8E4M3_MAX_POW2) # (nr, 1, nk, 1) + scale_fp32 = (torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32)).view( + torch.float32 + ) + scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) + + data_lp = blocked / scale_fp32 + if not torch.compiler.is_compiling(): + data_lp = torch.clamp(data_lp, min=-F8E4M3_MAX, max=F8E4M3_MAX) + + qdata = data_lp.to(torch.float8_e4m3fn).reshape(N, K) + scale = scale_fp32.squeeze(1).squeeze(-1) # (nr, nk) + return qdata, scale + + # --------------------------------------------------------------------------- # torch.compile-wrapped fast paths. Generates fused Triton quant kernels via # Inductor. dynamic=True avoids recompilation on shape changes. # --------------------------------------------------------------------------- to_mx_compiled = torch.compile(to_mx, dynamic=True) +to_mx_2d_compiled = torch.compile(to_mx_2d, dynamic=True) to_mxfp4_compiled = torch.compile(to_mxfp4, dynamic=True) to_nvfp4_compiled = torch.compile(to_nvfp4, dynamic=True) diff --git a/tests/test_gemm_sm100_blockscaled.py b/tests/test_gemm_sm100_blockscaled.py index 315697d8..49d74814 100644 --- a/tests/test_gemm_sm100_blockscaled.py +++ b/tests/test_gemm_sm100_blockscaled.py @@ -571,10 +571,10 @@ def test_mxfp8_interface(shape_mnk, batched): _skip_if_not_sm100() from quack.gemm_blockscaled_interface import ( mxfp8_gemm, - mxfp8_gemm_cublas, - mxfp8_gemm_ref, - mxfp8_gemm_quantize, - mxfp8_quantize, + mxfp8_gemm_cublas_sm100, + mxfp8_gemm_ref_sm100, + mxfp8_gemm_quantize_sm100, + quantize_act_sm100, ) M, N, K = shape_mnk @@ -586,8 +586,8 @@ def test_mxfp8_interface(shape_mnk, batched): A_hp = torch.randn(*shape_A, device="cuda", dtype=torch.bfloat16) * K**-0.5 W_hp = torch.randn(*shape_W, device="cuda", dtype=torch.bfloat16) * K**-0.5 - A_q, A_sc = mxfp8_quantize(A_hp) - W_q, W_sc = mxfp8_quantize(W_hp) # (..., N, K), (..., N, K/32) + A_q, A_sc = quantize_act_sm100(A_hp) + W_q, W_sc = quantize_act_sm100(W_hp) # (..., N, K), (..., N, K/32) assert A_q.dtype == torch.float8_e4m3fn and A_sc.dtype == torch.float8_e8m0fnu B_q = W_q.mT # (..., K, N) K-contig view @@ -597,17 +597,17 @@ def test_mxfp8_interface(shape_mnk, batched): assert out.shape == ((L, M, N) if batched else (M, N)) assert out.dtype == torch.bfloat16 - ref = mxfp8_gemm_ref(A_q, B_q, A_sc, B_sc) + ref = mxfp8_gemm_ref_sm100(A_q, B_q, A_sc, B_sc) err = (out.float() - ref.float()).abs().max().item() assert err < 5e-3, f"quack vs ref max_err={err}" # cuBLAS comparison only for 2D / L=1 if not batched: - out_cublas = mxfp8_gemm_cublas(A_q, B_q, A_sc, B_sc) + out_cublas = mxfp8_gemm_cublas_sm100(A_q, B_q, A_sc, B_sc) assert torch.equal(out, out_cublas), "quack interface != cuBLAS" # High-level quantize+gemm convenience fn - out2 = mxfp8_gemm_quantize(A_hp, W_hp) + out2 = mxfp8_gemm_quantize_sm100(A_hp, W_hp) assert torch.equal(out, out2) @@ -889,14 +889,14 @@ def test_blockscaled_mxfp8_strided_sf(rk_pad): def test_mxfp8_interface_preallocated_out(): _skip_if_not_sm100() - from quack.gemm_blockscaled_interface import mxfp8_gemm, mxfp8_quantize + from quack.gemm_blockscaled_interface import mxfp8_gemm, quantize_act_sm100 M, N, K = 256, 256, 256 torch.manual_seed(0) A_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 W_hp = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 - A_q, A_sc = mxfp8_quantize(A_hp) - W_q, W_sc = mxfp8_quantize(W_hp) + A_q, A_sc = quantize_act_sm100(A_hp) + W_q, W_sc = quantize_act_sm100(W_hp) B_q, B_sc = W_q.mT, W_sc.mT out_alloc = mxfp8_gemm(A_q, B_q, A_sc, B_sc) diff --git a/tests/test_gemm_sm90_mxfp8.py b/tests/test_gemm_sm90_mxfp8.py new file mode 100644 index 00000000..e4b01b66 --- /dev/null +++ b/tests/test_gemm_sm90_mxfp8.py @@ -0,0 +1,212 @@ +import math + +import pytest +import torch + +from quack.gemm_blockscaled_interface import ( + _SF_VEC_SIZE_SM90 as SF, + _WEIGHT_BLOCK_N_SM90 as BN, + mxfp8_gemm_act, + quantize_act_sm90, + quantize_weight_sm90, +) +from quack.gemm_interface import gemm_gated_ref + + +def _skip_if_not_sm90(): + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip("SM90 required") + + +def deepseek_calc_diff(x: torch.Tensor, y: torch.Tensor) -> float: + """Cosine similarity. Copied from DeepGEMM + https://github.com/deepseek-ai/DeepGEMM/blob/891d57b4db1071624b5c8fa0d1e51cb317fa709f/deep_gemm/testing/numeric.py#L5 + """ + x, y = x.double(), y.double() + denom = (x * x + y * y).sum() + return 0.0 if denom == 0 else float(1 - 2 * (x * y).sum() / denom) + + +def _fp8_dequant_ref(A_q, A_sc, W_q, W_sc): + """Exact float32 matmul of dequantized FP8 tensors.""" + A_dq = A_q.float() * A_sc.float().repeat_interleave(SF, dim=-1) + W_sc_exp = W_sc.float().repeat_interleave(BN, dim=-2).repeat_interleave(SF, dim=-1) + W_dq = W_q.float() * W_sc_exp + return A_dq, W_dq.mT if W_q.ndim > 2 else W_dq.T + + +def _assert_cos_close(kernel_out, ref, tag, tol=0.001): + """Cosine-complement check only (DeepGEMM style).""" + cos = deepseek_calc_diff(kernel_out.float(), ref) + assert cos < tol, f"{tag}: cosine_diff={cos:.6f} >= {tol}" + + +def _make_varied_scale_inputs(M, K, N, *, dtype=torch.bfloat16, device="cuda"): + """Build (A, W) bf16 tensors whose MXFP8 scales differ across every relevant + indexing axis. Plain randn quantizes to nearly-uniform power-of-2 scales, masking + indexing bugs (wrong-row / wrong-K-block / wrong-m_block reads still produce + numerically plausible results). + + Scaling scheme — designed to defeat *every* indexing aliasing pattern: + - Per-row factor: 2^(i % 4) cycles every 4 rows (catches per-row off-by-N bugs) + - Per-64-row-chunk factor: 2^((i // 64) % 4) (catches off-by-BLOCK_M aliasing, + since row 0 and row 128 land in different chunks → different factors) + - Per-K-block factor: 2^((k * 3) % 4) (catches wrong-stage / wrong-K reads; + stride 3 coprime with 4 spreads scales) + - Same scheme for W = (2*N, K) at the 128-row N-block granularity that + matches `_WEIGHT_BLOCK_N_SM90`. + + Power-of-2 ranges chosen so the dequantized inputs stay inside fp8_e4m3fn + dynamic range and the K-summed outputs stay bf16-representable. + """ + sf_k = K // SF + base_a = torch.randn(M, K, device=device, dtype=dtype) / math.sqrt(K) + base_w = torch.randn(2 * N, K, device=device, dtype=dtype) / math.sqrt(K) + + rows = torch.arange(M, device=device) + # Combined exponent = (row % 4) + (row // 64) % 4: 4*4 = 16 combinations, + # max factor 2^6 = 64 → max input magnitude ~64 * 1/sqrt(K), still safe for fp8. + a_row = (2.0 ** ((rows % 4) + ((rows // 64) % 4))).to(dtype) + ks = torch.arange(sf_k, device=device) + a_kb = (2.0 ** ((ks * 3) % 4)).to(dtype) + n_blk = (2 * N) // BN + n_idx = torch.arange(n_blk, device=device) + # Per-N-block factor uses BLOCK_N=128 granularity; combine outer chunk to + # break aliasing across N-tiles when N > BLOCK_N. + w_nb = (2.0 ** ((n_idx % 4) + ((n_idx // 4) % 4))).to(dtype) + w_kb = (2.0 ** ((ks * 5) % 4)).to(dtype) + + a = base_a.view(M, sf_k, SF) * a_kb.view(1, sf_k, 1) + a = a * a_row.view(M, 1, 1) + a = a.reshape(M, K) + + w = base_w.view(n_blk, BN, sf_k, SF) * w_kb.view(1, 1, sf_k, 1) + w = w * w_nb.view(n_blk, 1, 1, 1) + w = w.reshape(2 * N, K) + return a, w + + +# --------------------------------------------------------------------------- +# Batched (no varlen) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("store_preact", [True, False]) +@pytest.mark.parametrize("activation", ["swiglu", "geglu"]) +@pytest.mark.parametrize( + "M, K, N", + [ + (512, 2048, 1024), + # TODO: M=1 fails with varied scales (cosine_diff ~0.87). Real kernel bug + # at M