diff --git a/AI/varlen_blockscaled_sf_layout.md b/AI/varlen_blockscaled_sf_layout.md index 7208c895..ec8ab597 100644 --- a/AI/varlen_blockscaled_sf_layout.md +++ b/AI/varlen_blockscaled_sf_layout.md @@ -152,7 +152,7 @@ a tile-unit offset integer. ## Tests -`tests/test_gemm_sm100_blockscaled.py`: +`tests/test_gemm_blockscaled.py`: - `test_blockscaled_mxfp8_varlen_m_nonaligned` — 4 seqlen patterns × 2 B-majors = 8 cases. Patterns include `[128, 128, 128]`, `[100, 200, 150]`, `[30, 300, 64, 200]`, `[1, 128, 127, 129]`. diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index e71a923d..ffb9acd7 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -1,4 +1,5 @@ import argparse +import os import time import torch @@ -7,8 +8,9 @@ from quack.gemm import gemm as quack_gemm """ -GEMM benchmark using quack.gemm.gemm() (dense path) or the SM100 blockscaled -path (MXFP8 / MXFP4 / NVFP4) via --blockscaled. +GEMM benchmark using quack.gemm.gemm() (dense path) or the blockscaled +path (MXFP8 / MXFP4 / NVFP4). The blockscaled path is selected by passing +--sf_dtype and/or --sf_vec_size. Usage (dense): python benchmarks/benchmark_gemm.py --mnkl 512,7168,2048,256 \ @@ -17,18 +19,24 @@ Usage (blockscaled MXFP8, with cuBLAS comparison): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --init quant --compare_cublas + --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --sf_vec_size 32 Usage (blockscaled MXFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ + --sf_vec_size 32 --d_dtype BFloat16 Usage (blockscaled NVFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ - --sf_vec_size 16 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 + +Usage (SM120 packed-LDSM NVFP4 performance path): + QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ + python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ + --tile_shape_mnk 64,64,128 --cluster_shape_mnk 1,1,1 \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 --skip_ref_check """ @@ -194,19 +202,21 @@ def _run_blockscaled(args): compile_blockscaled_gemm_tvm_ffi, create_blockscaled_operand_quantized, create_blockscaled_operand_tensor, + create_sm120_blockscaled_scale_tensor, create_blockscaled_varlen_m_operands, scale_blocked_for_cublas, torch_dtype_for_cutlass, ) from quack.cute_dsl_utils import get_device_capacity - from quack.gemm_default_epi import GemmDefaultSm100 + from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 sm_major = get_device_capacity(torch.device("cuda"))[0] - assert sm_major in (10, 11), ( - f"Blockscaled GEMM requires SM100 (B200/B300) or SM110; got SM{sm_major}x. " - "MXFP8/MXFP4/NVFP4 use tcgen05 UMMA which is SM100+." + assert sm_major in (10, 11, 12), ( + f"Blockscaled GEMM requires SM100/SM110 or SM120; got SM{sm_major}x." ) + if sm_major == 12 and (args.varlen_m or args.varlen_k): + raise NotImplementedError("SM120 blockscaled benchmark path does not support varlen") if args.varlen_k or args.gather_A or args.pingpong: raise NotImplementedError( "blockscaled + varlen_k/gather/pingpong is not wired up yet. " @@ -255,12 +265,28 @@ def _run_blockscaled(args): raise ValueError( f"MXFP4/NVFP4 require K-major for both A and B; got a_major={a_major}, b_major={b_major}" ) - if not GemmDefaultSm100.can_implement_blockscaled( + GemmBlockscaledCls = GemmDefaultSm120 if sm_major == 12 else GemmDefaultSm100 + mma_tiler_for_validation = ( + tuple(mma_tiler_mnk) if len(mma_tiler_mnk) == 3 or sm_major != 12 else (*mma_tiler_mnk, 64) + ) + if ( + sm_major == 12 + and len(mma_tiler_for_validation) == 3 + and mma_tiler_for_validation[2] == 128 + and os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") != "1" + ): + raise NotImplementedError( + "SM120 blockscaled tile_K=128 requires the packed ldmatrix path. " + "Set QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1, for example:\n" + " QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a " + "python benchmarks/benchmark_gemm.py --tile_shape_mnk 64,64,128 ..." + ) + if not GemmBlockscaledCls.can_implement_blockscaled( ab_dtype, sf_dtype, sf_vec_size, d_dtype, - mma_tiler_mnk, + mma_tiler_for_validation, cluster_shape_mn, m, n, @@ -320,28 +346,43 @@ def _run_blockscaled(args): def fn(): runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m) else: - a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( - l, - m, - k, - a_major == "m", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( - l, - n, - k, - b_major == "n", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - # (l, rm, rk, 512) contig scale — consumed directly by the kernel. - mSFA, mSFB = a_sc_contig, b_sc_contig - sfa_ref = torch.ones_like(a_ref) - sfb_ref = torch.ones_like(b_ref) + if sm_major == 12: + if ab_dtype is not cutlass.Float4E2M1FN or d_dtype is not cutlass.BFloat16: + raise TypeError( + "SM120 blockscaled benchmark currently supports FP4 inputs and BF16 D" + ) + _, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype, init="empty") + _, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype, init="empty") + mA.view(torch.uint8).fill_(0x22) + mB.view(torch.uint8).fill_(0x22) + a_ref = torch.ones((m, k, l), device="cuda", dtype=torch.float32) + b_ref = torch.ones((n, k, l), device="cuda", dtype=torch.float32) + sfa_ref, mSFA = create_sm120_blockscaled_scale_tensor(l, m, k, sf_vec_size, sf_dtype) + sfb_ref, mSFB = create_sm120_blockscaled_scale_tensor(l, n, k, sf_vec_size, sf_dtype) + a_sc_contig = b_sc_contig = None + else: + a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( + l, + m, + k, + a_major == "m", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( + l, + n, + k, + b_major == "n", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + # (l, rm, rk, 512) contig scale — consumed directly by the kernel. + mSFA, mSFB = a_sc_contig, b_sc_contig + sfa_ref = torch.ones_like(a_ref) + sfb_ref = torch.ones_like(b_ref) _, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty") runner = compile_blockscaled_gemm_tvm_ffi( ab_dtype, @@ -372,10 +413,12 @@ def fn(): torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) else: ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + if d_dtype != cutlass.Float32: + ref = ref.to(torch_dtype_for_cutlass(d_dtype)).float() torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) print("Ref check PASSED") - print("Running SM100 Blockscaled GEMM with:") + print(f"Running SM{sm_major}0 Blockscaled GEMM with:") print(f"mnkl: {args.mnkl}") print(f"tile_shape_mnk: {mma_tiler_mnk}, cluster_shape_mnk: {cluster_shape_mnk}") print( @@ -395,6 +438,12 @@ def fn(): # batch would be an unfair comparison (hides batching potential), so skip. print("(skipping cuBLAS: batched blockscaled mm not supported via a single call)") return + if sm_major == 12: + print( + "(skipping cuBLAS comparison: SM120 benchmark currently builds QuACK's " + "padded row-major scale tensors, not the cuBLAS/PyTorch scaled_mm scale layout)" + ) + return if a_major != "k" or b_major != "k": # F.scaled_mm requires A (M,K) row-major and B (K,N) col-major — # i.e. both operands K-contiguous. Skip for m/n-major to avoid an @@ -630,5 +679,8 @@ def fn(): if __name__ == "__main__": args = parse_arguments() - run(args) + try: + run(args) + except (NotImplementedError, TypeError, ValueError) as exc: + raise SystemExit(f"benchmark_gemm.py: error: {exc}") from None print("PASS") diff --git a/docs/sm120_blockscaled_perf.md b/docs/sm120_blockscaled_perf.md new file mode 100644 index 00000000..0891d0f4 --- /dev/null +++ b/docs/sm120_blockscaled_perf.md @@ -0,0 +1,116 @@ +# SM120 Blockscaled Performance Notes + +This note explains the first SM120 blockscaled performance path in +`GemmSm120`. It is intentionally narrower than the experimental branch: the PR +keeps the proven per-atom TMA mechanism and adds only the packed shared-memory +consumer path plus a `64x64x128` tile. + +## Supported Scope + +The performance path is opt-in: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 +``` + +Current intended benchmark shape: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ +python benchmarks/benchmark_gemm.py \ + --mnkl 4096,4096,4096,1 \ + --tile_shape_mnk 64,64,128 \ + --cluster_shape_mnk 1,1,1 \ + --ab_dtype Float4E2M1FN \ + --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 \ + --d_dtype BFloat16 \ + --warmup_iterations 5 \ + --iterations 30 \ + --skip_ref_check +``` + +The benchmark is a launch and timing harness. Numerical coverage lives in +`tests/test_gemm_blockscaled.py`, including asymmetric FP4 values, poisoned +scale padding, K-page crossing, and PTX checks. + +Targeted correctness gate for this path: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ +pytest -q tests/test_gemm_blockscaled.py -k "sm120 and packed_ldsm" -n 16 -s -rs +``` + +Supported formats for this SM120 path are: + +- NVFP4: `Float4E2M1FN` A/B, `Float8E4M3FN` scales, `sf_vec_size=16` +- MXFP4: `Float4E2M1FN` A/B, `Float8E8M0FNU` scales, `sf_vec_size=32` +- BF16 output, `C is None`, `beta=0` +- cluster shape `(1, 1, 1)` + +The packed performance path supports `64x64x64` and `64x64x128` CTA tiles. For +`tile_K=128`, logical K must be divisible by 128. + +## Why Packed LDSM + +The correctness-first SM120 path expanded compact FP4 bytes into the padded +`.b4x16_p64` ldmatrix shared-memory format. That path was useful for proving +the tuple MMA, scale mapping, and padded scale TMA, but profiling showed a large +shared-memory load bottleneck around the generated `b4x16_p64` ldmatrix +instruction. + +The packed path instead stages FP4 into a swizzled packed shared-memory layout +and consumes it with: + +```text +ldmatrix.sync.aligned.m8n8.x4.shared.b16 +mma.sync.aligned.m16n8k64.kind::mxf4nvf4 +``` + +This direction is based on the local CUTLASS GeForce reference in +`examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu`, +while this PR keeps QuACK's narrower per-atom TMA producer path. + +Tests assert that the packed path does not regress back to `b4x16_p64`, +`m8n16`, multicast TMA, or `shared::cluster`. + +## Why 64x64x128 First + +The packed `64x64x64` path removed most of the original shared-load wavefront +excess, but the kernel was still producer/barrier heavy. Moving to +`64x64x128` keeps the same accumulator ownership and correctness surface while +doubling the K work per producer/barrier cycle. Local Nsight Compute runs on a +noisy workstation showed the expected direction: + +- shared-load excessive wavefronts stayed near the packed-path level +- tensor pipe active increased materially +- barrier and MIO stall samples per issued instruction dropped +- runtime improved over `64x64x64` + +Treat these numbers as direction, not a stable benchmark claim. The following +benchmark runs were taken on an RTX 5060 workstation with reference checking +skipped because the pytest suite owns numerical validation: + +```text +base sm120-blockscaled, correctness-first path: + 4096x4096x4096 NVFP4 -> BF16, 128x128x64: 60.571 ms, 2.3 TFLOP/s + +this PR, packed-LDSM path: + 4096x4096x4096 NVFP4 -> BF16, 64x64x64: 2.988 ms, 46.0 TFLOP/s + 4096x4096x4096 NVFP4 -> BF16, 64x64x128: 1.329 ms, 103.4 TFLOP/s +``` + +## Why Not Full-Tile TMA In This PR + +The natural next architecture is full-tile or grouped A/B TMA into the final +packed/swizzled shared-memory layout. Local experiments were not clean enough +for this PR: + +- raw subbyte swizzled full-tile TMA failed CuTe DSL legalization +- byte-addressable recast layouts compiled but produced many tiny static TMA + sites and could hang at runtime +- nested grouped raw FP4 layouts hit compile/codegen timeouts + +Those findings point to a separate minimal layout-lowering repro/upstream issue. +This PR keeps production on the proven per-atom TMA path and uses `tile_K=128` +as the low-risk amortization step. diff --git a/quack/blockscaled_gemm_utils.py b/quack/blockscaled_gemm_utils.py index 479c78ff..ed9580ea 100644 --- a/quack/blockscaled_gemm_utils.py +++ b/quack/blockscaled_gemm_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, Tri Dao. import itertools +import os from functools import partial from typing import Callable, Optional, Type, Tuple @@ -11,7 +12,7 @@ from quack.compile_utils import make_fake_tensor as fake_tensor from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters -from quack.gemm_default_epi import GemmDefaultSm100 +from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 from quack.gemm_tvm_ffi_utils import div_for_dtype, make_scheduler_args from quack.mx_utils import ( to_mx_compiled, @@ -235,6 +236,53 @@ def create_blockscaled_scale_tensor( return ref, packed +def create_sm120_blockscaled_scale_tensor( + l: int, + mn: int, + k: int, + sf_vec_size: int, + dtype: Type[cutlass.Numeric], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create row-major padded SM120 blockscaled scale tensors. + + SM120 TMA copies scale factors as 2D row-major pages. Physical scale columns + are padded to a multiple of 16; columns beyond ``ceil(k / sf_vec_size)`` are + padding and ignored by the kernel. + """ + if k % 64 != 0: + raise ValueError("SM120 blockscaled scale helper requires K divisible by 64") + if sf_vec_size not in (16, 32): + raise ValueError("SM120 blockscaled scale helper supports sf_vec_size 16 or 32") + expected_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + if dtype is not expected_dtype: + raise ValueError(f"SM120 sf_vec_size={sf_vec_size} requires {expected_dtype}, got {dtype}") + sf_k = ceil_div(k, sf_vec_size) + physical_sf_k = ceil_div(sf_k, 16) * 16 + if dtype == cutlass.Float8E8M0FNU: + exponents = torch.randint(0, 2, (mn, sf_k, l), device="cuda", dtype=torch.int32) + ref_blocks = torch.pow(2.0, exponents.float()) + else: + ref_blocks = torch.randint(1, 4, (mn, sf_k, l), device="cuda", dtype=torch.int32).float() + packed = torch.empty( + (mn, physical_sf_k, l), device="cuda", dtype=torch_dtype_for_cutlass(dtype) + ) + packed[:, :sf_k, :].copy_(ref_blocks.to(packed.dtype)) + if physical_sf_k > sf_k: + poison = torch.tensor([0.5, 1.5, 2.0, 3.0], device="cuda", dtype=torch.float32).to( + packed.dtype + ) + cols = torch.arange(physical_sf_k - sf_k, device="cuda") + packed[:, sf_k:, :] = poison[cols % poison.numel()][None, :, None] + ref = ( + ref_blocks.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(1, 2, 0) + )[:, :k, :] + return ref, packed + + def pack_scale_2d_to_blocked_contig(scale_2d: torch.Tensor) -> torch.Tensor: """Rearrange a (l, mn, sf_k) or (mn, sf_k) e8m0 scale tensor into the contiguous (l, rm, rk, 512) blocked layout shared by the quack kernel and @@ -608,8 +656,9 @@ def compile_blockscaled_gemm_tvm_ffi( use_clc_persistence: bool = True, varlen_m: bool = False, varlen_k: bool = False, + compile_options: str = "--enable-tvm-ffi", ) -> Callable: - """Compile the SM100 blockscaled GEMM. + """Compile the blockscaled GEMM. When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major, mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor. @@ -617,15 +666,67 @@ def compile_blockscaled_gemm_tvm_ffi( run(...) takes an extra cu_seqlens_k tensor. """ device_capacity = get_device_capacity(mA.device) - if device_capacity[0] not in (10, 11): - raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110") + if device_capacity[0] not in (10, 11, 12): + raise RuntimeError("Blockscaled GEMM requires SM100/SM110 or SM120") assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k" - gemm = partial( - GemmDefaultSm100, - sf_vec_size=sf_vec_size, - use_clc_persistence=use_clc_persistence, - )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1)) + mma_tiler_mn_only = mma_tiler_mn[:2] + mma_tiler_k = mma_tiler_mn[2] if len(mma_tiler_mn) == 3 else 64 + + if device_capacity[0] == 12: + if varlen_m or varlen_k: + raise NotImplementedError("SM120 blockscaled GEMM does not support varlen") + if mma_tiler_k not in (64, 128): + raise NotImplementedError("SM120 blockscaled GEMM requires tile_K in {64,128}") + if mma_tiler_k == 128 and os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") != "1": + raise NotImplementedError( + "SM120 blockscaled tile_K=128 requires QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1" + ) + if len(mA.shape) != 3 or len(mB.shape) != 3 or len(mD.shape) != 3: + raise ValueError("SM120 blockscaled GEMM requires rank-3 A/B/D tensors") + logical_k = mA.shape[1] * (2 if ab_dtype is cutlass.Float4E2M1FN else 1) + physical_scale_cols = ceil_div(ceil_div(logical_k, sf_vec_size), 16) * 16 + expected_sfa_shape = (mA.shape[0], physical_scale_cols, mA.shape[2]) + expected_sfb_shape = (mB.shape[0], physical_scale_cols, mB.shape[2]) + if tuple(mSFA.shape) != expected_sfa_shape: + raise ValueError(f"SFA shape must be {expected_sfa_shape}, got {tuple(mSFA.shape)}") + if tuple(mSFB.shape) != expected_sfb_shape: + raise ValueError(f"SFB shape must be {expected_sfb_shape}, got {tuple(mSFB.shape)}") + if not GemmDefaultSm120.can_implement_blockscaled( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + (*mma_tiler_mn_only, mma_tiler_k), + cluster_shape_mn, + mA.shape[0], + mB.shape[0], + logical_k, + mA.shape[2], + "k", + "k", + "n", + ): + raise RuntimeError( + "Unsupported SM120 blockscaled config; supported formats are MXFP4 " + "(Float4E2M1FN + Float8E8M0FNU + vec32) and NVFP4 " + "(Float4E2M1FN + Float8E4M3FN + vec16)" + ) + gemm = GemmDefaultSm120( + cutlass.Float32, + ab_dtype, + (*mma_tiler_mn_only, mma_tiler_k), + (*cluster_shape_mn, 1), + is_persistent=False, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + ) + else: + gemm = partial( + GemmDefaultSm100, + sf_vec_size=sf_vec_size, + use_clc_persistence=use_clc_persistence, + )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1)) compile_epi_args = gemm.EpilogueArguments() scheduler_args = make_scheduler_args( get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]), @@ -709,7 +810,16 @@ def runner( varlen_args, stream, ): - gemm(a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None) + if cutlass.const_expr(device_capacity[0] == 12): + # SM120 FFI has already validated packed torch storage and compiles + # logical fake CuTe tensor views. Do not route through + # blockscaled_call, whose host validation expects logical class-call + # tensors rather than packed torch.float4_e2m1fn_x2 storage. + gemm._blockscaled_call_jit(a, b, d, varlen_args, stream, sfa, sfb, None) + else: + gemm( + a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None + ) compiled = cute.compile( runner, @@ -720,7 +830,7 @@ def runner( _make_compile_tensor_like(mSFB, sf_dtype, dynamic_layout=True), varlen_args_fake, stream, - options="--enable-tvm-ffi", + options=compile_options, ) if varlen_m or varlen_k: @@ -736,6 +846,10 @@ def run(a, b, d, sfa, sfb, cu_seqlens): def run(a, b, d, sfa, sfb): compiled(a, b, d, sfa, sfb, VarlenArguments()) + for attr in ("__ptx__", "__cubin__"): + if hasattr(compiled, attr): + setattr(run, attr, getattr(compiled, attr)) + return run diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py index 1c15f3f4..a279fe00 100644 --- a/quack/gemm_sm120.py +++ b/quack/gemm_sm120.py @@ -8,6 +8,7 @@ # This is a work in progress and not very optimized. import math +import os from typing import Tuple, Type, Callable, Optional from functools import partial @@ -17,14 +18,100 @@ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.cute.nvgpu import cpasync, warp from cutlass import Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum -from quack.varlen_utils import VarlenManager +from quack.varlen_utils import VarlenArguments, VarlenManager from quack.pipeline import make_pipeline_state from quack import copy_utils from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm from quack import sm80_utils +def _round_up(x: int, multiple: int) -> int: + return ((x + multiple - 1) // multiple) * multiple + + +@cute.jit +def _sm120_blockscaled_scale_fragment( + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + sf_vec_size: cutlass.Constexpr[int], + is_sfa: cutlass.Constexpr[bool], +): + if const_expr(sf_vec_size == 16): + if const_expr(is_sfa): + return warp.make_mxf4nvf4_sfa_fragment(dtype) + return warp.make_mxf4nvf4_sfb_fragment(dtype) + return cute.make_rmem_tensor(cute.make_layout(((32, 2),), stride=((0, 1),)), dtype) + + +@cute.jit +def _load_sm120_blockscaled_selector0_scale_fragments( + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + k_scale_base: cutlass.Int32, + sf_vec_size: cutlass.Constexpr[int], + sf_dtype: cutlass.Constexpr[Type[cutlass.Numeric]], +): + """Load selector-zero SM120 FP4 blockscaled scale packets from SMEM. + + The tuple-lowered SM120 blockscaled MMA uses byte-id-a=byte-id-b=0. For + selector zero, SFA provider lanes map tid 0 to row group and tid 1 to row + group+8; SFB provider lanes map group 0..7 to logical N columns 0..7. + """ + lane = cute.arch.lane_idx() + group = lane >> 2 + tid = lane & 3 + sfa_row = m_atom_base + group + 8 * (tid & 1) + sfb_col = n_atom_base + group + sfa = _sm120_blockscaled_scale_fragment(sf_dtype, sf_vec_size, True) + sfb = _sm120_blockscaled_scale_fragment(sf_dtype, sf_vec_size, False) + compact_sfa = cute.filter_zeros(sfa) + compact_sfb = cute.filter_zeros(sfb) + for kb in cutlass.range_constexpr(64 // sf_vec_size): + compact_sfa[kb] = sSFA[sfa_row, k_scale_base + kb, stage] + compact_sfb[kb] = sSFB[sfb_col, k_scale_base + kb, stage] + return sfa, sfb + + +@cute.jit +def _make_sm120_fp4_ldmatrix_smem_view( + smem: cute.Tensor, + mn: cutlass.Constexpr[int], +): + return cute.make_tensor( + smem.iterator, + cute.make_layout((mn, (8, 4)), stride=(64, (1, 16))), + ) + + +@cute.jit +def _expand_compact_fp4_to_sm120_ldmatrix_smem( + compact: cute.Tensor, + padded: cute.Tensor, + mn: cutlass.Constexpr[int], + thread_count: cutlass.Constexpr[int] = 32, +): + """Expand packed FP4 bytes into the padded SM120 ldmatrix SMEM layout.""" + tidx, _, _ = cute.arch.thread_idx() + + for i in cutlass.range((mn * 64 + thread_count - 1) // thread_count, unroll_full=True): + flat = tidx + i * thread_count + if flat < mn * 64: + padded[flat // 64, flat % 64] = cutlass.Int8(0) + + for i in cutlass.range((mn * 32 + thread_count - 1) // thread_count, unroll_full=True): + flat = tidx + i * thread_count + if flat < mn * 32: + row = flat // 32 + packed_k = flat - row * 32 + group = packed_k // 8 + in_group = packed_k - group * 8 + padded[row, group * 16 + in_group] = compact[row, packed_k] + + class GemmSm120(GemmSm90): """SM120-style GEMM using warp-level MMA instead of WGMMA. @@ -49,6 +136,8 @@ def __init__( gather_A: bool = False, concat_layout: tuple | None = None, use_pdl: bool = True, + sf_vec_size: int | None = None, + sf_dtype: Type[cutlass.Numeric] | None = None, ): # Don't call super().__init__ — we set up our own config self.acc_dtype = acc_dtype @@ -59,6 +148,27 @@ def __init__( self.fp8_slow_accum = False self.gather_A = gather_A self.concat_layout = concat_layout or () + self.blockscaled = sf_vec_size is not None + self.sf_vec_size = sf_vec_size + self.sf_dtype = sf_dtype + if self.blockscaled: + if a_dtype is not cutlass.Float4E2M1FN: + raise ValueError("SM120 blockscaled path currently supports Float4E2M1FN A/B only") + if acc_dtype is not cutlass.Float32: + raise ValueError("SM120 blockscaled path requires Float32 accumulation") + if sf_vec_size not in (16, 32): + raise ValueError("SM120 blockscaled path supports sf_vec_size 16 or 32") + expected_sf_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + if sf_dtype is not expected_sf_dtype: + raise ValueError( + f"SM120 blockscaled sf_vec_size={sf_vec_size} requires {expected_sf_dtype}" + ) + if pingpong: + raise NotImplementedError("SM120 blockscaled pingpong is not implemented") + if gather_A: + raise NotImplementedError("SM120 blockscaled gather_A is not implemented") + if cluster_shape_mnk != (1, 1, 1): + raise ValueError("SM120 blockscaled path requires cluster_shape_mnk=(1,1,1)") if self.pingpong: assert self.is_persistent, "Pingpong gemm requires persistent scheduler" if gather_A: @@ -74,7 +184,7 @@ def __init__( # Pingpong: 2 warp groups each with (2,2,1) atom layout # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout - self.mma_inst_mnk = (16, 8, 16) + self.mma_inst_mnk = (16, 8, 64) if self.blockscaled else (16, 8, 16) self.atom_layout_mnk = (4, 2, 1) if not self.pingpong else (2, 2, 1) # num_mma_warps = total warps doing MMA (both warp groups in pingpong) self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) @@ -119,6 +229,8 @@ def __init__( self.epi_m_major = True self.a_smem_layout_staged = None self.b_smem_layout_staged = None + self.a_smem_store_layout_staged = None + self.b_smem_store_layout_staged = None self.epi_smem_layout_staged = None self.epi_tile = None self.shared_storage = None @@ -129,7 +241,14 @@ def epi_smem_warp_shape_mnk(self): def _setup_tiled_mma(self): """Set up warp-level MMA (MmaF16BF16Op) and tile K dimension.""" - op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk) + if const_expr(self.blockscaled): + if self.sf_vec_size == 16: + op = warp.MmaMXF4NVF4Op(self.a_dtype, self.acc_dtype, self.sf_dtype) + else: + op = warp.MmaMXF4Op(self.a_dtype, self.acc_dtype, self.sf_dtype) + self.mma_inst_mnk = (16, 8, 64) + else: + op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk) tC = cute.make_layout(self.atom_layout_mnk) atom_m, atom_n, atom_k = self.atom_layout_mnk # We want each warp to have 16 consecutive elements in the N direction, for STSM @@ -152,8 +271,1743 @@ def _setup_tiled_mma(self): ) self.cta_tile_shape_mnk = (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], tile_k) - # __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, - # make_sched_pipeline, epilogue are all inherited from GemmSm90. + # Dense __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, + # make_sched_pipeline, epilogue are inherited from GemmSm90. + + @staticmethod + def padded_blockscale_cols(k: int, sf_vec_size: int) -> int: + if k % 64 != 0: + raise ValueError("SM120 blockscaled GEMM requires K divisible by 64") + return _round_up(k // sf_vec_size, 16) + + @staticmethod + def can_implement_blockscaled( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mnk: Tuple[int, int] | Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + del d_major + if ab_dtype is not cutlass.Float4E2M1FN: + return False + if sf_vec_size == 16: + if sf_dtype is not cutlass.Float8E4M3FN: + return False + elif sf_vec_size == 32: + if sf_dtype is not cutlass.Float8E8M0FNU: + return False + else: + return False + if d_dtype is not cutlass.BFloat16: + return False + if a_major != "k" or b_major != "k": + return False + if cluster_shape_mn != (1, 1): + return False + if len(mma_tiler_mnk) == 3 and mma_tiler_mnk[2] not in (64, 128): + return False + tile_m, tile_n = mma_tiler_mnk[:2] + tile_k = mma_tiler_mnk[2] if len(mma_tiler_mnk) == 3 else 64 + supported_tiles = {(128, 128, 64), (64, 64, 64)} + if os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") == "1": + supported_tiles.add((64, 64, 128)) + if (tile_m, tile_n, tile_k) not in supported_tiles: + return False + return m % tile_m == 0 and n % tile_n == 0 and k % tile_k == 0 and l == 1 + + @staticmethod + def _shape_tuple(tensor: cute.Tensor) -> tuple[int, ...]: + return tuple(int(dim) for dim in tensor.shape) + + @staticmethod + def _is_empty_varlen(varlen_args: VarlenArguments) -> bool: + return ( + varlen_args.mCuSeqlensM is None + and varlen_args.mCuSeqlensK is None + and varlen_args.mAIdx is None + ) + + def blockscaled_call( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: Optional[cute.Tensor], + mC: Optional[cute.Tensor], + epilogue_args: tuple, + scheduler_args, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: Optional[cute.Tensor] = None, + mSFB: Optional[cute.Tensor] = None, + trace_ptr: Optional[cutlass.Int64] = None, + ): + varlen_args = self._validate_blockscaled_call( + mA, mB, mD, mC, mSFA, mSFB, epilogue_args, scheduler_args, varlen_args, trace_ptr + ) + return self._blockscaled_call_jit( + mA, + mB, + mD, + varlen_args, + stream, + mSFA, + mSFB, + trace_ptr, + ) + + @cute.jit + def _blockscaled_call_jit( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: cute.Tensor, + mSFB: cute.Tensor, + trace_ptr: Optional[cutlass.Int64] = None, + ): + if const_expr(varlen_args is None): + varlen_args = VarlenArguments() + return self._call_blockscaled( + mA, + mB, + mD, + varlen_args, + stream, + mSFA, + mSFB, + trace_ptr, + ) + + def _validate_blockscaled_call( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + mC: Optional[cute.Tensor], + mSFA: Optional[cute.Tensor], + mSFB: Optional[cute.Tensor], + epilogue_args: tuple, + scheduler_args, + varlen_args: Optional[VarlenArguments], + trace_ptr: Optional[cutlass.Int64], + ) -> VarlenArguments: + if mSFA is None or mSFB is None: + raise ValueError("SM120 blockscaled GEMM requires SFA and SFB scale tensors") + if mD is None: + raise ValueError("SM120 blockscaled GEMM requires an output tensor") + if mC is not None: + raise NotImplementedError("SM120 blockscaled C/beta path is not implemented") + if epilogue_args not in (None, ()): + beta = getattr(epilogue_args, "beta", None) + add_to_output = getattr(epilogue_args, "add_to_output", False) + if beta not in (None, 0) or add_to_output: + raise NotImplementedError("SM120 blockscaled C/beta path is not implemented") + del scheduler_args + if trace_ptr is not None: + raise NotImplementedError("SM120 blockscaled trace path is not implemented") + tile_m, tile_n, tile_k = self.cta_tile_shape_mnk + supported_tiles = {(128, 128, 64), (64, 64, 64)} + if os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") == "1": + supported_tiles.add((64, 64, 128)) + if (tile_m, tile_n, tile_k) not in supported_tiles: + raise NotImplementedError( + "SM120 blockscaled path currently supports tile shapes " + "(128,128,64), (64,64,64), and opt-in (64,64,128)" + ) + if self.cluster_shape_mnk != (1, 1, 1): + raise NotImplementedError("SM120 blockscaled path requires cluster_shape_mnk=(1,1,1)") + if ( + mA.element_type is not cutlass.Float4E2M1FN + or mB.element_type is not cutlass.Float4E2M1FN + ): + raise TypeError("SM120 blockscaled path requires Float4E2M1FN A/B") + if mSFA.element_type is not self.sf_dtype or mSFB.element_type is not self.sf_dtype: + raise TypeError(f"SM120 blockscaled path requires {self.sf_dtype} SFA/SFB") + if mD.element_type is not cutlass.BFloat16: + raise NotImplementedError("SM120 blockscaled path currently supports only BF16 D") + + if varlen_args is None: + varlen_args = VarlenArguments() + if not self._is_empty_varlen(varlen_args): + raise NotImplementedError("SM120 blockscaled varlen GEMM is not implemented") + + a_shape = self._shape_tuple(mA) + b_shape = self._shape_tuple(mB) + d_shape = self._shape_tuple(mD) + if len(a_shape) != 3 or len(b_shape) != 3 or len(d_shape) != 3: + raise ValueError("SM120 blockscaled tensors must use logical rank-3 shapes") + m, k, l = a_shape + n, kb, lb = b_shape + if k != kb or l != lb: + raise ValueError("SM120 blockscaled A/B K and L extents must match") + if k % tile_k != 0: + if k * 2 % tile_k == 0: + raise ValueError( + "SM120 blockscaled class call expects logical Float4E2M1FN K extent; " + "use compile_blockscaled_gemm_tvm_ffi for packed torch.float4_e2m1fn_x2 " + "storage" + ) + raise ValueError("SM120 blockscaled path requires logical K to be divisible by tile_K") + if d_shape != (m, n, l): + raise ValueError(f"SM120 blockscaled D shape must be {(m, n, l)}, got {d_shape}") + if m % tile_m != 0 or n % tile_n != 0: + raise NotImplementedError("SM120 blockscaled path requires M/N multiples of CTA M/N") + if l != 1: + raise NotImplementedError("SM120 blockscaled path currently supports L=1") + scale_cols = self.padded_blockscale_cols(k, self.sf_vec_size) + if self._shape_tuple(mSFA) != (m, scale_cols, l): + raise ValueError( + f"SFA shape must be {(m, scale_cols, l)}, got {self._shape_tuple(mSFA)}" + ) + if self._shape_tuple(mSFB) != (n, scale_cols, l): + raise ValueError( + f"SFB shape must be {(n, scale_cols, l)}, got {self._shape_tuple(mSFB)}" + ) + return varlen_args + + @cute.jit + def _call_blockscaled( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: cute.Tensor, + mSFB: cute.Tensor, + trace_ptr: Optional[cutlass.Int64] = None, + ): + del stream, trace_ptr + + self.a_dtype = mA.element_type + self.b_dtype = mB.element_type + self.d_dtype = mD.element_type + self.c_dtype = None + self.sf_dtype = mSFA.element_type + self.a_layout = LayoutEnum.from_tensor(mA) + self.b_layout = LayoutEnum.from_tensor(mB) + self.d_layout = LayoutEnum.from_tensor(mD) + self.c_layout = None + self._setup_attributes(()) + self.ab_stage = 2 if self.cta_tile_shape_mnk in ((64, 64, 64), (64, 64, 128)) else 1 + # Split independent 16x8 output atoms across four consumer warps. + # Warp 4 owns TMA production; warp 0 owns pipeline wait/release. + self.blockscaled_consumer_warps = 4 + self.blockscaled_producer_warp = 4 + packed_ldsm_override = os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") + if const_expr(packed_ldsm_override is not None and packed_ldsm_override not in ("0", "1")): + raise ValueError("QUACK_SM120_BLOCKSCALED_PACKED_LDSM must be 0 or 1") + tile_m, tile_n, tile_k = self.cta_tile_shape_mnk + self.blockscaled_packed_ldsm = packed_ldsm_override == "1" and (tile_m, tile_n, tile_k) in ( + (64, 64, 64), + (64, 64, 128), + ) + if const_expr(not self.blockscaled_packed_ldsm and tile_k != 64): + raise NotImplementedError( + "SM120 blockscaled tile_K=128 currently requires " + "QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1" + ) + if const_expr( + self.blockscaled_packed_ldsm + and (tile_m, tile_n, tile_k) == (64, 64, 128) + and self.blockscaled_consumer_warps != 4 + ): + raise NotImplementedError( + "SM120 packed-subbyte ldmatrix tile 64x64x128 currently requires " + "four consumer warps" + ) + + if const_expr(self.blockscaled_packed_ldsm): + # Direct TMA uses a visible 16x64 / 8x64 FP4 tile composed with + # the same swizzle that the packed ldmatrix consumer partitions. + # The separate 8x128 atom view below is only for the consumer. + self.a_smem_store_layout_staged = cute.make_layout( + (8, 128, tile_m // 16, self.ab_stage), + stride=(128, 1, 8 * 128, (tile_m // 16) * 8 * 128), + ) + self.b_smem_store_layout_staged = cute.make_layout( + (8, 128, tile_n // 8, self.ab_stage), + stride=(128, 1, 8 * 128, (tile_n // 8) * 8 * 128), + ) + # The consumer view is typed as 4-bit elements, so the outer atom + # and stage strides are expressed in FP4 element units. They are + # twice the raw Uint8 store strides above to advance by the same + # physical byte distance. + self.a_smem_layout_staged = cute.make_layout( + (8, 128, tile_m // 16, self.ab_stage), + stride=(128, 1, 2 * 8 * 128, (tile_m // 16) * 2 * 8 * 128), + ) + self.b_smem_layout_staged = cute.make_layout( + (8, 128, tile_n // 8, self.ab_stage), + stride=(128, 1, 2 * 8 * 128, (tile_n // 8) * 2 * 8 * 128), + ) + else: + self.a_smem_layout_staged = cute.make_layout( + (tile_m, tile_k, self.ab_stage), stride=(tile_k, 1, tile_m * tile_k) + ) + self.b_smem_layout_staged = cute.make_layout( + (tile_n, tile_k, self.ab_stage), stride=(tile_k, 1, tile_n * tile_k) + ) + self.a_smem_store_layout_staged = self.a_smem_layout_staged + self.b_smem_store_layout_staged = self.b_smem_layout_staged + scale_tile_k = 16 + if const_expr(self.blockscaled_packed_ldsm): + a_compact_smem_layout_staged = cute.make_layout((1, 1, self.ab_stage), stride=(1, 1, 1)) + b_compact_smem_layout_staged = cute.make_layout((1, 1, self.ab_stage), stride=(1, 1, 1)) + else: + a_compact_smem_layout_staged = cute.make_layout( + (tile_m, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_m * (tile_k // 2)), + ) + b_compact_smem_layout_staged = cute.make_layout( + (tile_n, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_n * (tile_k // 2)), + ) + self.sfa_smem_layout_staged = cute.make_layout( + (tile_m, scale_tile_k, self.ab_stage), + stride=(scale_tile_k, 1, tile_m * scale_tile_k), + ) + self.sfb_smem_layout_staged = cute.make_layout( + (tile_n, scale_tile_k, self.ab_stage), + stride=(scale_tile_k, 1, tile_n * scale_tile_k), + ) + acc_smem_layout = cute.make_layout((tile_m, tile_n), stride=(tile_n, 1)) + + if const_expr(self.blockscaled_packed_ldsm): + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + else: + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0)) + a_compact_smem_layout = cute.slice_(a_compact_smem_layout_staged, (None, None, 0)) + b_compact_smem_layout = cute.slice_(b_compact_smem_layout_staged, (None, None, 0)) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, 0)) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, 0)) + + op = cpasync.CopyBulkTensorTileG2SOp() + m_extent = cute.size(mA, mode=[0]) + k_extent = cute.size(mA, mode=[1]) + n_extent = cute.size(mB, mode=[0]) + l_extent = cute.size(mA, mode=[2]) + if const_expr(self.blockscaled_packed_ldsm): + a_tma_payload_layout = cute.make_layout((16, tile_k), stride=(tile_k, 1)) + b_tma_payload_layout = cute.make_layout((8, tile_k), stride=(tile_k, 1)) + a_tma_smem_layout = cute.make_composed_layout( + cute.make_swizzle(2, 4, 3), + 0, + a_tma_payload_layout, + ) + b_tma_smem_layout = cute.make_composed_layout( + cute.make_swizzle(2, 4, 3), + 0, + b_tma_payload_layout, + ) + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + op, mA, a_tma_smem_layout, (16, tile_k) + ) + tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( + op, mB, b_tma_smem_layout, (8, tile_k) + ) + else: + packed_k_extent = k_extent // 2 + mA_u8 = cute.make_tensor( + cute.recast_ptr(mA.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (m_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, m_extent * packed_k_extent), + ), + ) + mB_u8 = cute.make_tensor( + cute.recast_ptr(mB.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (n_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, n_extent * packed_k_extent), + ), + ) + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + op, mA_u8, a_compact_smem_layout, (tile_m, tile_k // 2) + ) + tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( + op, mB_u8, b_compact_smem_layout, (tile_n, tile_k // 2) + ) + tma_atom_sfa, tma_tensor_sfa = self._make_tma_atoms_and_tensors( + mSFA, sfa_smem_layout, (tile_m, scale_tile_k), 1 + ) + tma_atom_sfb, tma_tensor_sfb = self._make_tma_atoms_and_tensors( + mSFB, sfb_smem_layout, (tile_n, scale_tile_k), 1 + ) + if const_expr(self.blockscaled_packed_ldsm): + self.num_tma_load_bytes = ( + (tile_m // 16) * cute.size_in_bytes(cutlass.Float4E2M1FN, a_tma_payload_layout) + + (tile_n // 8) * cute.size_in_bytes(cutlass.Float4E2M1FN, b_tma_payload_layout) + + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + ) + else: + self.num_tma_load_bytes = ( + cute.size_in_bytes(cutlass.Uint8, a_compact_smem_layout) + + cute.size_in_bytes(cutlass.Uint8, b_compact_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + ) + + a_compact_smem_size = cute.cosize(a_compact_smem_layout_staged) + b_compact_smem_size = cute.cosize(b_compact_smem_layout_staged) + sfa_smem_size = cute.cosize(self.sfa_smem_layout_staged) + sfb_smem_size = cute.cosize(self.sfb_smem_layout_staged) + uses_64x64_register_accum = ( + self.cta_tile_shape_mnk[0] == 64 + and self.cta_tile_shape_mnk[1] == 64 + and self.blockscaled_consumer_warps == 4 + ) + acc_smem_size = 1 if const_expr(uses_64x64_register_accum) else cute.cosize(acc_smem_layout) + + @cute.struct + class BlockscaledSharedStorage: + ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] + sA: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8 if self.blockscaled_packed_ldsm else cutlass.Int8, + cute.cosize(self.a_smem_store_layout_staged), + ], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8 if self.blockscaled_packed_ldsm else cutlass.Int8, + cute.cosize(self.b_smem_store_layout_staged), + ], + self.buffer_align_bytes, + ] + sACompact: cute.struct.Align[ + cute.struct.MemRange[cutlass.Uint8, a_compact_smem_size], + self.buffer_align_bytes, + ] + sBCompact: cute.struct.Align[ + cute.struct.MemRange[cutlass.Uint8, b_compact_smem_size], + self.buffer_align_bytes, + ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, sfa_smem_size], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, sfb_smem_size], + self.buffer_align_bytes, + ] + # Correctness scratch: for K > 64, K64 partials stay in FP32 here + # until the final K tile writes BF16 D exactly once. + sAcc: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, acc_smem_size], + self.buffer_align_bytes, + ] + + self.shared_storage = BlockscaledSharedStorage + + if const_expr(self.sf_vec_size == 16): + mma_op = warp.MmaMXF4NVF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + else: + mma_op = warp.MmaMXF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + one_warp_mma = cute.make_tiled_mma(mma_op) + varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + self.blockscaled_kernel( + one_warp_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + mD, + varlen_params, + cute.make_layout((1, 1, 1)), + self.a_smem_store_layout_staged, + self.b_smem_store_layout_staged, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + a_compact_smem_layout_staged, + b_compact_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + acc_smem_layout, + ).launch( + grid=[cute.ceil_div(m_extent, tile_m), cute.ceil_div(n_extent, tile_n), l_extent], + block=[(self.blockscaled_consumer_warps + 1) * cute.arch.WARP_SIZE, 1, 1], + cluster=(1, 1, 1), + ) + + @cute.kernel + def blockscaled_kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl16: cute.Tensor, + mD_mnl: cute.Tensor, + varlen_params: VarlenManager.Params, + cluster_layout_mnk: cute.Layout, + a_smem_store_layout: cute.Layout, + b_smem_store_layout: cute.Layout, + a_smem_layout: cute.Layout, + b_smem_layout: cute.Layout, + a_compact_smem_layout: cute.Layout, + b_compact_smem_layout: cute.Layout, + sfa_smem_layout: cute.Layout, + sfb_smem_layout: cute.Layout, + acc_smem_layout: cute.Layout, + ): + del varlen_params + + tidx, _, _ = cute.arch.thread_idx() + cta_m, cta_n, cta_l = cute.arch.block_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + if warp_idx == self.blockscaled_producer_warp: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + ab_pipeline = self.make_ab_pipeline( + tiled_mma=tiled_mma, + cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)), + ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(), + ) + + pipeline_init_arrive(cluster_shape_mn=(1, 1), is_relaxed=True) + sACompact = storage.sACompact.get_tensor(a_compact_smem_layout) + sBCompact = storage.sBCompact.get_tensor(b_compact_smem_layout) + if const_expr(self.blockscaled_packed_ldsm): + sAStore = storage.sA.get_tensor(a_smem_store_layout) + sBStore = storage.sB.get_tensor(b_smem_store_layout) + a_tma_smem_layout = cute.make_layout( + ((16, self.cta_tile_shape_mnk[0] // 16), self.cta_tile_shape_mnk[2], self.ab_stage), + stride=( + (self.cta_tile_shape_mnk[2], 2 * 8 * 128), + 1, + (self.cta_tile_shape_mnk[0] // 16) * 2 * 8 * 128, + ), + ) + b_tma_smem_layout = cute.make_layout( + ((8, self.cta_tile_shape_mnk[1] // 8), self.cta_tile_shape_mnk[2], self.ab_stage), + stride=( + (self.cta_tile_shape_mnk[2], 2 * 8 * 128), + 1, + (self.cta_tile_shape_mnk[1] // 8) * 2 * 8 * 128, + ), + ) + sATma = storage.sA.get_tensor( + a_tma_smem_layout, + swizzle=cute.make_swizzle(2, 4, 3), + dtype=cutlass.Float4E2M1FN, + ) + sBTma = storage.sB.get_tensor( + b_tma_smem_layout, + swizzle=cute.make_swizzle(2, 4, 3), + dtype=cutlass.Float4E2M1FN, + ) + sA = storage.sA.get_tensor( + a_smem_layout, swizzle=cute.make_swizzle(2, 4, 3), dtype=cutlass.Float4E2M1FN + ) + sB = storage.sB.get_tensor( + b_smem_layout, swizzle=cute.make_swizzle(2, 4, 3), dtype=cutlass.Float4E2M1FN + ) + else: + sA = storage.sA.get_tensor(a_smem_layout) + sB = storage.sB.get_tensor(b_smem_layout) + sAStore = sA + sBStore = sB + sATma = sACompact + sBTma = sBCompact + sSFA = storage.sSFA.get_tensor(sfa_smem_layout) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout) + sAcc = storage.sAcc.get_tensor(acc_smem_layout) + pipeline_init_wait(cluster_shape_mn=(1, 1)) + + k_tile_count = cute.size(mA_mkl, mode=[1]) // ( + self.cta_tile_shape_mnk[2] + if const_expr(self.blockscaled_packed_ldsm) + else self.cta_tile_shape_mnk[2] // 2 + ) + scales_per_k_tile = self.cta_tile_shape_mnk[2] // self.sf_vec_size + + if warp_idx == self.blockscaled_producer_warp: + producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) + gSFA_mk16 = cute.local_tile( + mSFA_mkl16[None, None, cta_l], + (self.cta_tile_shape_mnk[0], 16), + (cta_m, None), + ) + gSFB_nk16 = cute.local_tile( + mSFB_nkl16[None, None, cta_l], + (self.cta_tile_shape_mnk[1], 16), + (cta_n, None), + ) + if const_expr(self.blockscaled_packed_ldsm): + if const_expr(k_tile_count == 1): + producer_state = self.load_blockscaled_tma_tile_packed_direct( + ab_pipeline, + producer_state, + cutlass.Int32(0), + cutlass.Int32(0), + tma_atom_a, + mA_mkl, + tma_atom_b, + mB_nkl, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sATma, + sBTma, + sSFA, + sSFB, + cta_m, + cta_n, + cta_l, + ) + else: + for k_tile in cutlass.range(k_tile_count, unroll=1): + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + producer_state = self.load_blockscaled_tma_tile_packed_direct( + ab_pipeline, + producer_state, + k_tile, + scale_page, + tma_atom_a, + mA_mkl, + tma_atom_b, + mB_nkl, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sATma, + sBTma, + sSFA, + sSFB, + cta_m, + cta_n, + cta_l, + ) + else: + gA_mk = cute.local_tile( + mA_mkl[None, None, cta_l], + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2] // 2), + (cta_m, None), + ) + gB_nk = cute.local_tile( + mB_nkl[None, None, cta_l], + (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2] // 2), + (cta_n, None), + ) + if const_expr(k_tile_count == 1): + producer_state = self.load_blockscaled_tma_tile( + ab_pipeline, + producer_state, + tma_atom_a, + gA_mk, + tma_atom_b, + gB_nk, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sATma, + sBTma, + sSFA, + sSFB, + ) + else: + for k_tile in cutlass.range(k_tile_count, unroll=1): + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + producer_state = self.load_blockscaled_tma_tile_indexed( + ab_pipeline, + producer_state, + k_tile, + scale_page, + tma_atom_a, + gA_mk, + tma_atom_b, + gB_nk, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sACompact, + sBCompact, + sSFA, + sSFB, + ) + ab_pipeline.producer_tail(producer_state) + + if warp_idx < self.blockscaled_consumer_warps: + lane_idx = cute.arch.lane_idx() + gD_mn = cute.local_tile( + mD_mnl[None, None, cta_l], + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]), + (cta_m, cta_n), + ) + read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + self.mma_blockscaled_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + lane_idx, + warp_idx, + ) + + @cute.jit + def load_blockscaled_tma_tile_packed_direct( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + k_tile: cutlass.Int32, + scale_page: cutlass.Int32, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sATma: cute.Tensor, + sBTma: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + cta_m: cutlass.Int32, + cta_n: cutlass.Int32, + cta_l: cutlass.Int32, + ) -> pipeline.PipelineState: + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + + peek_empty_status = ab_pipeline.producer_try_acquire(producer_state) + ab_pipeline.producer_acquire(producer_state, peek_empty_status) + tma_bar_ptr = ab_pipeline.producer_get_barrier(producer_state) + smem_idx = producer_state.index + + for m_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[0] // 16): + gA_atom = cute.local_tile( + mA_mkl[None, None, cta_l], + (16, self.cta_tile_shape_mnk[2]), + (cta_m * (self.cta_tile_shape_mnk[0] // 16) + m_atom, None), + ) + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_atom, + dst_tensor=sATma[(None, m_atom), None, None], + mcast_mask=0, + ) + copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + + for n_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[1] // 8): + gB_atom = cute.local_tile( + mB_nkl[None, None, cta_l], + (8, self.cta_tile_shape_mnk[2]), + (cta_n * (self.cta_tile_shape_mnk[1] // 8) + n_atom, None), + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_atom, + dst_tensor=sBTma[(None, n_atom), None, None], + mcast_mask=0, + ) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + + copy_SFA(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFB(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + ab_pipeline.producer_commit(producer_state) + producer_state.advance() + return producer_state + + @cute.jit + def load_blockscaled_tma_tile( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + tma_atom_a: cute.CopyAtom, + gA_mk: cute.Tensor, + tma_atom_b: cute.CopyAtom, + gB_nk: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + ) -> pipeline.PipelineState: + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_mk, + dst_tensor=sACompact, + mcast_mask=0, + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_nk, + dst_tensor=sBCompact, + mcast_mask=0, + ) + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + return self.load_tma( + ab_pipeline, + producer_state, + [copy_A, copy_B, copy_SFA, copy_SFB], + Int32(1), + ) + + @cute.jit + def load_blockscaled_tma_tile_indexed( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + k_tile: cutlass.Int32, + scale_page: cutlass.Int32, + tma_atom_a: cute.CopyAtom, + gA_mk: cute.Tensor, + tma_atom_b: cute.CopyAtom, + gB_nk: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + ) -> pipeline.PipelineState: + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_mk, + dst_tensor=sACompact, + mcast_mask=0, + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_nk, + dst_tensor=sBCompact, + mcast_mask=0, + ) + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + peek_empty_status = ab_pipeline.producer_try_acquire(producer_state) + ab_pipeline.producer_acquire(producer_state, peek_empty_status) + tma_bar_ptr = ab_pipeline.producer_get_barrier(producer_state) + smem_idx = producer_state.index + copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFA(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFB(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + ab_pipeline.producer_commit(producer_state) + producer_state.advance() + return producer_state + + @cute.jit + def mma_blockscaled_kloop_store_bf16( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, + tidx: cutlass.Int32, + warp_idx: cutlass.Int32, + ) -> None: + thr_mma = tiled_mma.get_slice(tidx) + a_shape = tiled_mma.partition_shape_A((16, 64)) + b_shape = tiled_mma.partition_shape_B((8, 64)) + acc_shape = tiled_mma.partition_shape_C((16, 8)) + accum_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32) + store_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gD_mn.element_type) + scales_per_k_tile = self.cta_tile_shape_mnk[2] // self.sf_vec_size + if const_expr( + self.cta_tile_shape_mnk in ((64, 64, 64), (64, 64, 128)) + and self.blockscaled_consumer_warps == 4 + ): + self.mma_blockscaled_64x64_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + warp_idx, + ) + elif const_expr(k_tile_count == 1): + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + cutlass.Int32(0), + False, + True, + warp_idx, + ) + else: + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + cutlass.Int32(0), + False, + False, + warp_idx, + ) + for k_iter in cutlass.range(k_tile_count - 2, unroll=1): + k_tile = k_iter + 1 + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + k_tile, + True, + False, + warp_idx, + ) + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + k_tile_count - 1, + True, + True, + warp_idx, + ) + + @cute.jit + def mma_blockscaled_64x64_kloop_store_bf16( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + warp_idx: cutlass.Int32, + ) -> None: + # 64x64 has four 16-row bands and eight 8-column atoms. With four + # consumer warps, each warp owns one row band and keeps all eight atom + # accumulators live across K. This preserves FP32 accumulation while + # avoiding the LDS/STS traffic from the generic sAcc scratch path. + if warp_idx == 0: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 0, + True, + ) + elif warp_idx == 1: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 1, + False, + ) + elif warp_idx == 2: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 2, + False, + ) + elif warp_idx == 3: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 3, + False, + ) + + @cute.jit + def mma_blockscaled_64x64_mrow_kloop_store_bf16( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + m_atom: cutlass.Constexpr[int], + releases_pipeline: cutlass.Constexpr[bool], + ) -> None: + acc0 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc1 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc2 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc3 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc4 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc5 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc6 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc7 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc0.fill(0.0) + acc1.fill(0.0) + acc2.fill(0.0) + acc3.fill(0.0) + acc4.fill(0.0) + acc5.fill(0.0) + acc6.fill(0.0) + acc7.fill(0.0) + + for k_tile in cutlass.range(k_tile_count, unroll=1): + if const_expr(releases_pipeline): + peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) + ab_pipeline.consumer_wait(read_state, peek_ab_full_status) + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if const_expr(not self.blockscaled_packed_ldsm): + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sACompact[None, None, read_state.index], + sA[None, None, read_state.index], + self.cta_tile_shape_mnk[0], + self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sBCompact[None, None, read_state.index], + sB[None, None, read_state.index], + self.cta_tile_shape_mnk[1], + self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + + for k_block in cutlass.range_constexpr(self.cta_tile_shape_mnk[2] // 64): + scale_base = k_tile * scales_per_k_tile + k_block * (64 // self.sf_vec_size) + scale_page = scale_base // 16 + scale_page_offset = scale_base - scale_page * 16 + del scale_page + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc0, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(0), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc1, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(8), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc2, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(16), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc3, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(24), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc4, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(32), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc5, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(40), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc6, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(48), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc7, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(56), + tidx, + a_shape, + b_shape, + k_block, + ) + + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if const_expr(releases_pipeline): + ab_pipeline.consumer_release(read_state) + read_state.advance() + + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc0, gD_mn, m_atom, 0) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc1, gD_mn, m_atom, 1) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc2, gD_mn, m_atom, 2) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc3, gD_mn, m_atom, 3) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc4, gD_mn, m_atom, 4) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc5, gD_mn, m_atom, 5) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc6, gD_mn, m_atom, 6) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc7, gD_mn, m_atom, 7) + + @cute.jit + def mma_blockscaled_one_k_tile_accumulate_smem( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + k_tile: cutlass.Int32, + add_existing: cutlass.Constexpr[bool], + store_final: cutlass.Constexpr[bool], + warp_idx: cutlass.Int32, + ) -> pipeline.PipelineState: + if warp_idx == 0: + peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) + ab_pipeline.consumer_wait(read_state, peek_ab_full_status) + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if warp_idx == 0: + if const_expr(self.blockscaled_packed_ldsm): + pass + else: + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sACompact[None, None, read_state.index], + sA[None, None, read_state.index], + self.cta_tile_shape_mnk[0], + ) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sBCompact[None, None, read_state.index], + sB[None, None, read_state.index], + self.cta_tile_shape_mnk[1], + ) + cute.arch.fence_view_async_shared() + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + else: + cute.arch.sync_warp() + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + scale_page_offset = scale_base - scale_page * 16 + for m_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[0] // 16): + for n_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[1] // 8): + atom_owner = (m_atom * (self.cta_tile_shape_mnk[1] // 8) + n_atom) % ( + self.blockscaled_consumer_warps + ) + if warp_idx == atom_owner: + acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc.fill(0.0) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(n_atom * 8), + tidx, + a_shape, + b_shape, + 0, + ) + self.store_blockscaled_accum_smem_atom( + thr_mma, + accum_atom, + acc, + sAcc, + store_atom, + gD_mn, + m_atom, + n_atom, + add_existing, + store_final, + ) + cute.arch.fence_view_async_shared() + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + else: + cute.arch.sync_warp() + if warp_idx == 0: + ab_pipeline.consumer_release(read_state) + read_state.advance() + return read_state + + @cute.jit + def mma_blockscaled( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + k_scale_base: cutlass.Int32, + tidx: cutlass.Int32, + a_shape: cute.Shape, + b_shape: cute.Shape, + k_block: cutlass.Constexpr[int], + ) -> None: + a = cute.make_rmem_tensor(a_shape, cutlass.Float4E2M1FN) + b = cute.make_rmem_tensor(b_shape, cutlass.Float4E2M1FN) + if const_expr(self.blockscaled_packed_ldsm): + # Clear through the byte view: this packed-subbyte fragment layout + # is not fully covered by the Int32 view on all CuTe DSL builds. + cute.recast_tensor(a, cutlass.Uint8).fill(0) + cute.recast_tensor(b, cutlass.Uint8).fill(0) + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(num_matrices=4), + cutlass.Int4, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) + lane = cute.arch.lane_idx() + thr_copy_a = tiled_copy_a.get_slice(lane) + thr_copy_b = tiled_copy_b.get_slice(lane) + a_atom = m_atom_base // 16 + b_atom = n_atom_base // 8 + tCsA = thr_copy_a.partition_S(sA[None, None, a_atom, stage]) + tCsB = thr_copy_b.partition_S(sB[None, None, b_atom, stage]) + a_copy_view = thr_copy_a.retile(a) + b_copy_view = thr_copy_b.retile(b) + cute.copy(tiled_copy_a, tCsA[None, None, k_block], a_copy_view[None, None, 0]) + cute.copy(tiled_copy_b, tCsB[None, None, k_block], b_copy_view[None, None, 0]) + else: + sA_atom = cute.domain_offset((m_atom_base, None), sA[None, None, stage]) + sB_atom = cute.domain_offset((n_atom_base, None), sB[None, None, stage]) + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(num_matrices=1, unpack_bits=4), + cutlass.Int8, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) + a0, a1, a2, a3 = warp.sm120_mxf4nvf4_ldmatrix_A_regs( + tiled_copy_a, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sA_atom, 16), + ) + b0, b1 = warp.sm120_mxf4nvf4_ldmatrix_B_regs( + tiled_copy_b, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sB_atom, 8), + ) + a_i32 = cute.recast_tensor(a, cutlass.Int32) + b_i32 = cute.recast_tensor(b, cutlass.Int32) + # The asymmetric SM120 blockscaled tests catch this placement. + a_i32[0] = a0 + a_i32[1] = a2 + a_i32[2] = a1 + a_i32[3] = a3 + b_i32[0] = b0 + b_i32[1] = b1 + sfa, sfb = _load_sm120_blockscaled_selector0_scale_fragments( + sSFA, + sSFB, + stage, + m_atom_base, + n_atom_base, + k_scale_base, + self.sf_vec_size, + self.sf_dtype, + ) + if const_expr(self.blockscaled_packed_ldsm): + cute.gemm(tiled_mma, acc, (a[None, None, 0], sfa), (b[None, None, 0], sfb), acc) + else: + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + @cute.jit + def mma_blockscaled_tile_k64_accumulate( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + scale_page_offset: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + tidx: cutlass.Int32, + a_shape: cute.Shape, + b_shape: cute.Shape, + k_block: cutlass.Constexpr[int], + ) -> None: + del tiled_mma + if const_expr(self.sf_vec_size == 16): + mma_op = warp.MmaMXF4NVF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + else: + mma_op = warp.MmaMXF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + local_tiled_mma = cute.make_tiled_mma(mma_op) + self.mma_blockscaled( + local_tiled_mma, + acc, + sA, + sB, + sSFA, + sSFB, + stage, + m_atom_base, + n_atom_base, + scale_page_offset, + tidx, + a_shape, + b_shape, + k_block, + ) + + @cute.jit + def store_blockscaled_accum_smem_atom( + self, + thr_mma: cute.TiledMmaThrVal, + accum_atom: cute.CopyAtom, + acc: cute.Tensor, + sAcc: cute.Tensor, + store_atom: cute.CopyAtom, + mD_mn: cute.Tensor, + m_atom: cutlass.Constexpr[int], + n_atom: cutlass.Constexpr[int], + add_existing: cutlass.Constexpr[bool], + store_final: cutlass.Constexpr[bool], + ) -> None: + sAcc_atom = cute.local_tile(sAcc, (16, 8), (m_atom, n_atom)) + tCsAcc = thr_mma.partition_C(sAcc_atom) + if const_expr(add_existing): + tCrPrev = cute.make_rmem_tensor(acc.layout, cutlass.Float32) + cute.copy(accum_atom, tCsAcc, tCrPrev) + acc.store(acc.load() + tCrPrev.load()) + if const_expr(store_final): + gD_atom = cute.local_tile(mD_mn, (16, 8), (m_atom, n_atom)) + tCgD = thr_mma.partition_C(gD_atom) + tCrD = cute.make_rmem_tensor(acc.layout, mD_mn.element_type) + tCrD.store(acc.load().to(mD_mn.element_type)) + cute.copy(store_atom, tCrD, tCgD) + else: + cute.copy(accum_atom, acc, tCsAcc) + + @cute.jit + def store_blockscaled_accum_direct_atom( + self, + thr_mma: cute.TiledMmaThrVal, + store_atom: cute.CopyAtom, + acc: cute.Tensor, + mD_mn: cute.Tensor, + m_atom: cutlass.Constexpr[int], + n_atom: cutlass.Constexpr[int], + ) -> None: + gD_atom = cute.local_tile(mD_mn, (16, 8), (m_atom, n_atom)) + tCgD = thr_mma.partition_C(gD_atom) + tCrD = cute.make_rmem_tensor(acc.layout, mD_mn.element_type) + tCrD.store(acc.load().to(mD_mn.element_type)) + cute.copy(store_atom, tCrD, tCgD) @cute.kernel def kernel( diff --git a/tests/test_gemm_blockscaled.py b/tests/test_gemm_blockscaled.py new file mode 100644 index 00000000..4f086422 --- /dev/null +++ b/tests/test_gemm_blockscaled.py @@ -0,0 +1,1840 @@ +from pathlib import Path + +import pytest +import torch + +import cutlass + +from quack.blockscaled_gemm_utils import ( + FP4_E2M1FN_VALUES, + blockscaled_gemm_reference, + compile_blockscaled_gemm_tvm_ffi, + create_blockscaled_operand_quantized, + create_blockscaled_operand_tensor, + create_blockscaled_scale_tensor, + create_sm120_blockscaled_scale_tensor, + create_blockscaled_varlen_k_operands, + create_blockscaled_varlen_m_operands, + scale_blocked_for_cublas, + scale_view_for_kernel, +) +from quack.compile_utils import make_fake_tensor as fake_tensor +from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 +from quack.varlen_utils import VarlenArguments +from quack.mx_utils import to_blocked + + +def _skip_if_not_sm100(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + major = torch.cuda.get_device_properties(0).major + if major not in (10, 11): + pytest.skip("SM100/SM110 required") + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + major = torch.cuda.get_device_properties(0).major + if major != 12: + pytest.skip("SM120 required") + + +def _compile_blockscaled_gemm( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, +): + a_ref, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype) + b_ref, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype) + _, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty") + sfa_ref, mSFA = create_blockscaled_scale_tensor(l, m, k, sf_vec_size, sf_dtype) + sfb_ref, mSFB = create_blockscaled_scale_tensor(l, n, k, sf_vec_size, sf_dtype) + compiled = compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + mA, + mB, + mD, + mSFA, + mSFB, + ) + return ( + compiled, + (mA, mB, mD, mSFA, mSFB), + (a_ref, b_ref, sfa_ref, sfb_ref, mD), + ) + + +def _run_blockscaled_gemm(compiled, args): + mA, mB, mD, mSFA, mSFB = args + compiled(mA, mB, mD, mSFA, mSFB) + torch.cuda.synchronize() + + +def test_blockscaled_validation(): + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 64), + (1, 1), + 256, + 64, + 256, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 192), + (1, 1), + 256, + 192, + 256, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128), + (1, 1), + 256, + 256, + 256, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (128, 128), + (1, 1), + 256, + 256, + 256, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.Float32, + (128, 192), + (1, 1), + 256, + 192, + 256, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (256, 384), + (2, 1), + 256, + 512, + 256, + 1, + "k", + "k", + "n", + ) + + +def test_sm120_blockscaled_validation(monkeypatch): + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 128), + (1, 1), + 128, + 128, + 128, + 1, + "k", + "k", + "n", + ) + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 128), + (1, 1), + 128, + 128, + 128, + 1, + "k", + "k", + "n", + ) + for tile_shape in ((128, 128, 128), (64, 128, 128), (128, 64, 128)): + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + tile_shape, + (1, 1), + 256, + 256, + 128, + 1, + "k", + "k", + "n", + ) + for tile_shape in ((64, 128, 64), (128, 64, 64)): + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + tile_shape, + (1, 1), + 256, + 256, + 64, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 256, + 256, + 128, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (32, 64, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Int8, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 96, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (256, 224), + (2, 1), + 256, + 448, + 256, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (256, 384), + (2, 1), + 256, + 512, + 256, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (64, 128), + (1, 1), + 256, + 256, + 256, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 32, + cutlass.Float32, + (128, 128), + (1, 1), + 256, + 256, + 256, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (256, 128), + (1, 1), + 512, + 256, + 256, + 1, + "k", + "k", + "n", + ) + + +def test_sm120_blockscaled_class_call_validation(): + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 64), + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + assert ( + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + == VarlenArguments() + ) + with pytest.raises(ValueError, match="requires SFA and SFB"): + gemm._validate_blockscaled_call( + mA, mB, mD, None, None, mSFB, gemm.EpilogueArguments(), None, None, None + ) + with pytest.raises(NotImplementedError, match="C/beta"): + gemm._validate_blockscaled_call( + mA, mB, mD, mD, mSFA, mSFB, gemm.EpilogueArguments(), None, None, None + ) + packed_k_a = fake_tensor(cutlass.Float4E2M1FN, (m, k // 2, l), leading_dim=1, divisibility=4) + packed_k_b = fake_tensor(cutlass.Float4E2M1FN, (n, k // 2, l), leading_dim=1, divisibility=4) + with pytest.raises(ValueError, match="expects logical Float4E2M1FN K extent"): + gemm._validate_blockscaled_call( + packed_k_a, + packed_k_b, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + + +@pytest.mark.parametrize("tile_shape", [(64, 128, 64), (128, 64, 64)]) +def test_sm120_blockscaled_class_call_rejects_unadvertised_tile_shapes(tile_shape): + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + tile_shape, + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + with pytest.raises(NotImplementedError, match="supports tile shapes"): + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + + +def test_sm120_blockscaled_packed_env_keeps_128x128x64_class_call(monkeypatch): + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 64), + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + assert ( + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + == VarlenArguments() + ) + + +@pytest.mark.parametrize( + "ab_dtype,sf_dtype,sf_vec_size,d_dtype,mma_tiler_mn,cluster_shape_mn,m,n,k,l", + [ + ( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 64), + (1, 1), + 256, + 64, + 256, + 1, + ), + ( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 192), + (1, 1), + 256, + 192, + 256, + 1, + ), + ( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128), + (1, 1), + 256, + 256, + 256, + 1, + ), + ( + cutlass.Float8E5M2, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (256, 64), + (2, 1), + 512, + 64, + 256, + 1, + ), + ( + cutlass.Float8E5M2, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (256, 192), + (2, 1), + 512, + 192, + 256, + 1, + ), + ( + cutlass.Float8E5M2, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (256, 128), + (2, 1), + 512, + 256, + 256, + 1, + ), + ( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (256, 192), + (2, 1), + 256, + 192, + 256, + 1, + ), + ( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (256, 224), + (2, 1), + 256, + 224, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (128, 128), + (1, 1), + 256, + 256, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.Float32, + (256, 224), + (2, 1), + 256, + 224, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 16, + cutlass.Float32, + (128, 64), + (1, 1), + 256, + 64, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.Float32, + (256, 192), + (2, 1), + 256, + 192, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.Float32, + (128, 192), + (1, 1), + 256, + 192, + 256, + 1, + ), + ( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.Float32, + (256, 224), + (2, 1), + 256, + 224, + 256, + 1, + ), + ], +) +def test_blockscaled_correctness( + ab_dtype, sf_dtype, sf_vec_size, d_dtype, mma_tiler_mn, cluster_shape_mn, m, n, k, l +): + _skip_if_not_sm100() + + ( + compiled, + args, + (a_ref, b_ref, sfa_ref, sfb_ref, _), + ) = _compile_blockscaled_gemm( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, _, _ = args + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + err = (d_torch.float() - ref).abs().max().item() + tol = 5e-3 if d_dtype != cutlass.Float32 else 5e-4 + assert err < tol, f"max_err={err}" + + +# --------------------------------------------------------------------------- +# Scale layout invariants +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("mn,sf_k,l", [(128, 4, 1), (256, 16, 1), (384, 12, 2), (512, 8, 1)]) +def test_scale_layout_matches_cublas(mn, sf_k, l): + """The quack kernel scale-view and cuBLAS's to_blocked must share the + same underlying byte layout (they both represent the PTX + tcgen05 scale-factor atom, tiled in the same outer order).""" + torch.manual_seed(0) + # a 2D uint8 scale slice per batch + scale_2d = torch.randint(0, 255, (l, mn, sf_k), device="cuda", dtype=torch.uint8) + + # Build our contiguous scale storage via create_blockscaled_operand_quantized's + # rearrangement logic: pad + (l, rm, 128, rk, 4) -> (l, rm, rk, 512) + rm = (mn + 127) // 128 + rk = (sf_k + 3) // 4 + mn_pad = rm * 128 + sf_k_pad = rk * 4 + padded = torch.zeros(l, mn_pad, sf_k_pad, device="cuda", dtype=torch.uint8) + padded[:, :mn, :sf_k] = scale_2d + blocks = padded.view(l, rm, 128, rk, 4).permute(0, 1, 3, 2, 4) + blocks = blocks.reshape(l, rm, rk, 4, 32, 4).transpose(3, 4).contiguous() + scale_contig = blocks.view(l, rm, rk, 512) # (l, rm, rk, 512) + + # kernel view indexing: byte offset within tile = (m%32)*16 + ((m//32)%4)*4 + (k%4) + kv = scale_view_for_kernel(scale_contig.view(torch.float8_e8m0fnu), mn, sf_k, l).view( + torch.uint8 + ) + m_positions = sorted({0, 1, 17, 31, 33, 127, min(128, mn - 1), mn - 1} & set(range(mn))) + k_positions = sorted({0, 1, 3, 4, 7, sf_k - 1} & set(range(sf_k))) + for li in range(l): + for mi in m_positions: + for ki in k_positions: + byte_off = (mi % 32) * 16 + ((mi // 32) % 4) * 4 + (ki % 4) + assert kv[li, mi // 128, ki // 4, byte_off].item() == scale_2d[li, mi, ki].item(), ( + f"mismatch at l={li} m={mi} k={ki}" + ) + + # cuBLAS slice must equal to_blocked(scale_2d[l]) + for li in range(l): + ours = scale_blocked_for_cublas(scale_contig.view(torch.float8_e8m0fnu), mn, sf_k, li).view( + torch.uint8 + ) + ref = to_blocked(scale_2d[li]) + assert torch.equal(ours, ref), f"to_blocked mismatch at l={li}" + + +@pytest.mark.parametrize( + "k,sf_vec_size,expected_cols", + [ + (64, 16, 16), + (128, 16, 16), + (256, 16, 16), + (384, 16, 32), + (64, 32, 16), + (256, 32, 16), + (576, 32, 32), + ], +) +def test_sm120_blockscaled_padded_scale_layout(k, sf_vec_size, expected_cols): + _skip_if_not_sm120() + mn, l = 128, 1 + sf_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + ref, physical = create_sm120_blockscaled_scale_tensor(l, mn, k, sf_vec_size, sf_dtype) + assert tuple(physical.shape) == (mn, expected_cols, l) + assert tuple(ref.shape) == (mn, k, l) + + logical_cols = (k + sf_vec_size - 1) // sf_vec_size + if expected_cols > logical_cols: + padding = physical[:, logical_cols:, :].view(torch.uint8) + assert torch.any(padding != 0) + + +def test_sm120_blockscaled_scale_helper_validation(): + _skip_if_not_sm120() + with pytest.raises(ValueError, match="K divisible by 64"): + create_sm120_blockscaled_scale_tensor(1, 128, 96, 16, cutlass.Float8E4M3FN) + with pytest.raises(ValueError, match="sf_vec_size 16 or 32"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 8, cutlass.Float8E4M3FN) + with pytest.raises(ValueError, match="sf_vec_size=16 requires"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 16, cutlass.Float8E8M0FNU) + with pytest.raises(ValueError, match="sf_vec_size=32 requires"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 32, cutlass.Float8E4M3FN) + + +def _pack_sm120_fp4_codes(codes: torch.Tensor) -> torch.Tensor: + packed = torch.empty( + (codes.shape[0], codes.shape[1] // 2, 1), + device=codes.device, + dtype=torch.float4_e2m1fn_x2, + ) + packed.view(torch.uint8).copy_(codes[:, 0::2, None] | (codes[:, 1::2, None] << 4)) + return packed + + +def _sm120_fp4_blockscaled_reference( + a_codes: torch.Tensor, + b_codes: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + sf_vec_size: int, +) -> torch.Tensor: + table = torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32, device=a_codes.device) + scale_k = torch.arange(a_codes.shape[1], device=a_codes.device) // sf_vec_size + a = table[a_codes.long()] * sfa.float()[:, scale_k, 0] + b = table[b_codes.long()] * sfb.float()[:, scale_k, 0] + return torch.einsum("mk,nk->mn", a, b).unsqueeze(-1) + + +def _make_sm120_scales(mn, k, sf_vec_size, sf_dtype, row_or_col_sensitive=True): + _, scales = create_sm120_blockscaled_scale_tensor(1, mn, k, sf_vec_size, sf_dtype) + logical_cols = (k + sf_vec_size - 1) // sf_vec_size + if sf_dtype == cutlass.Float8E8M0FNU: + base = torch.tensor([1.0, 2.0], device="cuda", dtype=torch.float32) + else: + base = torch.tensor([1.0, 2.0, 0.5, 1.5], device="cuda", dtype=torch.float32) + for idx in range(mn): + values = base[torch.arange(logical_cols, device="cuda") % base.numel()] + if row_or_col_sensitive: + values = values * (1.0 + 0.125 * (idx % 4)) + scales[idx, :logical_cols, 0] = values.to(scales.dtype) + return scales + + +def _compile_sm120_blockscaled_gemm( + ab_dtype, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(128, 128), + compile_options="--enable-tvm-ffi", +): + l = 1 + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + mSFA = _make_sm120_scales(m, k, sf_vec_size, sf_dtype) + mSFB = _make_sm120_scales(n, k, sf_vec_size, sf_dtype) + compiled = compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + cutlass.BFloat16, + tile_shape_mn, + (1, 1), + mA, + mB, + mD, + mSFA, + mSFB, + compile_options=compile_options, + ) + return compiled, (mA, mB, mD, mSFA, mSFB) + + +def _compiled_ptx_text(compiled) -> str: + ptx = getattr(compiled, "__ptx__", None) + if isinstance(ptx, bytes): + return ptx.decode("utf-8", errors="replace") + if isinstance(ptx, str): + if "\n" not in ptx and len(ptx) < 4096: + path = Path(ptx) + if path.exists(): + return path.read_text(errors="replace") + return ptx + raise AssertionError("compiled kernel did not expose PTX") + + +@pytest.mark.parametrize( + "sf_dtype,sf_vec_size,m,n,k", + [ + (cutlass.Float8E4M3FN, 16, 128, 128, 64), + (cutlass.Float8E4M3FN, 16, 256, 128, 128), + (cutlass.Float8E4M3FN, 16, 128, 128, 320), + (cutlass.Float8E8M0FNU, 32, 128, 128, 64), + ], +) +def test_sm120_blockscaled_scale_correctness(sf_dtype, sf_vec_size, m, n, k): + _skip_if_not_sm120() + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + err = (d_torch.float() - ref).abs().max().item() + assert err < 1e-1, f"max_err={err}" + + +def test_sm120_blockscaled_scale_correctness_64x64_tile(): + _skip_if_not_sm120() + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_k_loop_accumulates_before_bf16_store(): + _skip_if_not_sm120() + m = n = 128 + k = 384 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + # Make the first K64 tile very large and the remaining K64 tiles small. + # Storing BF16 after each K tile loses the later small partials; true FP32 + # accumulation keeps them until the final BF16 conversion. + a_codes[:, :64] = 0x7 + b_codes[:, :64] = 0x7 + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _, _, _, mSFA, mSFB = args + logical_cols = k // sf_vec_size + mSFA[:, :logical_cols, 0] = torch.tensor(1.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :logical_cols, 0] = torch.tensor(1.0, device="cuda", dtype=mSFB.dtype) + mSFA[:, :4, 0] = torch.tensor(3.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :4, 0] = torch.tensor(3.0, device="cuda", dtype=mSFB.dtype) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_asymmetric_fp4_and_scale_page_crossing(): + _skip_if_not_sm120() + m = n = 128 + k = 320 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + (ks % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ).expand(m, k) + b_codes = torch.where( + (ks % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ).expand(n, k) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + err = (d_torch.float() - ref.float()).abs().max().item() + assert err < 1e-1, f"max_err={err}" + + +def test_sm120_blockscaled_packed_ldsm_scale_correctness_64x64_tile(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_asymmetric_fp4_and_scale_page_crossing(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 320 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + (ks % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ).expand(m, k) + b_codes = torch.where( + (ks % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ).expand(n, k) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_ptx_regression(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + compiled, _ = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64), + compile_options="--enable-tvm-ffi --keep-ptx", + ) + ptx = _compiled_ptx_text(compiled) + assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in ptx + assert "mma.sync.aligned" in ptx + assert "m16n8k64" in ptx + assert "kind::mxf4nvf4" in ptx + assert "cp.async.bulk.tensor.2d.shared::cta.global.tile" in ptx + assert "b4x16_p64" not in ptx + assert "ldmatrix.sync.aligned.m8n16" not in ptx + assert ".multicast" not in ptx + assert "shared::cluster" not in ptx + + +def test_sm120_blockscaled_packed_ldsm_scale_correctness_64x64x128_tile(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_k128_uses_second_scale_offset(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _, _, _, mSFA, mSFB = args + mSFA[:, :4, 0] = torch.tensor(1.0, device="cuda", dtype=mSFA.dtype) + mSFA[:, 4:8, 0] = torch.tensor(2.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :8, 0] = torch.tensor(1.0, device="cuda", dtype=mSFB.dtype) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + assert torch.all(d_torch.float() == torch.tensor(192.0, device="cuda")) + + +@pytest.mark.parametrize( + "tile_shape_mn,k", + [ + ((64, 64), 128), + ((64, 64, 128), 640), + ], +) +def test_sm120_blockscaled_packed_ldsm_mxfp4_scale_correctness(monkeypatch, tile_shape_mn, k): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + sf_vec_size = 32 + sf_dtype = cutlass.Float8E8M0FNU + rows = torch.arange(m, device="cuda")[:, None] + cols = torch.arange(n, device="cuda")[:, None] + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + ((rows + ks) % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ) + b_codes = torch.where( + ((cols + ks * 3) % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=tile_shape_mn, + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + if k == 640: + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + assert torch.any(mSFB[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_asymmetric_fp4_k128_page_crossing(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 384 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + rows = torch.arange(m, device="cuda")[:, None] + cols = torch.arange(n, device="cuda")[:, None] + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + ((rows + ks) % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ) + b_codes = torch.where( + ((cols * 3 + ks) % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_ptx_regression_64x64x128(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + compiled, _ = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64, 128), + compile_options="--enable-tvm-ffi --keep-ptx", + ) + ptx = _compiled_ptx_text(compiled) + assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in ptx + assert "mma.sync.aligned" in ptx + assert "m16n8k64" in ptx + assert "kind::mxf4nvf4" in ptx + assert "cp.async.bulk.tensor.2d.shared::cta.global.tile" in ptx + assert "b4x16_p64" not in ptx + assert "ldmatrix.sync.aligned.m8n16" not in ptx + assert ".multicast" not in ptx + assert "shared::cluster" not in ptx + + +def test_sm120_blockscaled_tile_k128_requires_packed_ldsm(monkeypatch): + _skip_if_not_sm120() + monkeypatch.delenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", raising=False) + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + with pytest.raises(NotImplementedError, match="tile_K=128"): + _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64, 128), + ) + + +def test_sm120_blockscaled_rejects_compact_scale_layout(): + _skip_if_not_sm120() + l, m, n, k, sf_vec_size = 1, 128, 128, 64, 16 + ab_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + _, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype) + _, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype) + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + mSFA = torch.empty((m, k // sf_vec_size, l), device="cuda", dtype=torch.float8_e4m3fn) + mSFB = torch.empty((n, k // sf_vec_size, l), device="cuda", dtype=torch.float8_e4m3fn) + + with pytest.raises(ValueError, match="SFA shape"): + runner = compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD, + mSFA, + mSFB, + ) + runner(mA, mB, mD, mSFA, mSFB) + + +# --------------------------------------------------------------------------- +# End-to-end: quantized MXFP8 inputs through quack kernel vs cuBLAS vs dequant ref +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + "mma_tiler_mn,cluster_shape_mn,m,n,k", + [ + # All 5 supported blockscaled tile_n values (64, 128, 192, 224, 256). + ((128, 64), (1, 1), 256, 64, 512), + ((128, 128), (1, 1), 256, 256, 256), + ((128, 128), (1, 1), 512, 512, 512), + ((128, 192), (1, 1), 256, 192, 256), + ((128, 256), (1, 1), 256, 256, 256), + ((256, 128), (2, 1), 512, 256, 512), + ((256, 192), (2, 1), 256, 192, 256), + ((256, 192), (2, 1), 256, 384, 256), + ((256, 192), (2, 1), 512, 192, 512), + ((256, 224), (2, 1), 256, 224, 256), + ((256, 224), (2, 1), 512, 224, 512), + ((256, 256), (2, 1), 512, 256, 512), + ], +) +def test_blockscaled_mxfp8_quantized(mma_tiler_mn, cluster_shape_mn, m, n, k): + _skip_if_not_sm100() + l, sf_vec = 1, 32 + + torch.manual_seed(0) + a_ref, mA, a_sc = create_blockscaled_operand_quantized(l, m, k, False, sf_vec) + b_ref, mB, b_sc = create_blockscaled_operand_quantized(l, n, k, False, sf_vec) + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + + mSFA = scale_view_for_kernel(a_sc, m, k // sf_vec, l) + mSFB = scale_view_for_kernel(b_sc, n, k // sf_vec, l) + + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + mma_tiler_mn, + cluster_shape_mn, + mA, + mB, + mD, + mSFA, + mSFB, + ) + runner(mA, mB, mD, mSFA, mSFB) + torch.cuda.synchronize() + + # Reference: dequant matmul (a_ref/b_ref are already dequantized) + d_ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) + err = (mD.float() - d_ref).abs().max().item() + assert err < 5e-3, f"quack vs dequant max_err={err}" + + # cuBLAS: bit-exact match expected (same operand bits, same scale bytes, same hw MMA) + from torch.nn.functional import scaled_mm as F_scaled_mm, ScalingType, SwizzleType + + a_cub = mA[:, :, 0].contiguous() + b_cub = mB[:, :, 0].contiguous() + a_sc_cub = scale_blocked_for_cublas(a_sc, m, k // sf_vec, 0) + b_sc_cub = scale_blocked_for_cublas(b_sc, n, k // sf_vec, 0) + out_cublas = F_scaled_mm( + a_cub, + b_cub.t(), + scale_a=a_sc_cub, + scale_recipe_a=ScalingType.BlockWise1x32, + scale_b=b_sc_cub, + scale_recipe_b=ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + output_dtype=torch.bfloat16, + ) + assert torch.equal(mD.squeeze(-1), out_cublas), ( + f"quack != cuBLAS: max_err={(mD.squeeze(-1).float() - out_cublas.float()).abs().max().item()}" + ) + + +# --------------------------------------------------------------------------- +# High-level PyTorch interface +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("shape_mnk", [(256, 256, 256), (512, 256, 256), (128, 128, 256)]) +@pytest.mark.parametrize("batched", [False, True]) +def test_mxfp8_interface(shape_mnk, batched): + _skip_if_not_sm100() + from quack.gemm_blockscaled_interface import ( + mxfp8_gemm, + mxfp8_gemm_cublas, + mxfp8_gemm_ref, + mxfp8_gemm_quantize, + mxfp8_quantize, + ) + + M, N, K = shape_mnk + L = 2 if batched else 1 + torch.manual_seed(0) + shape_A = (L, M, K) if batched else (M, K) + # Weight stored nn.Linear-style (N, K) row-major; pass .mT to get K-contig (K, N) + shape_W = (L, N, K) if batched else (N, K) + A_hp = torch.randn(*shape_A, device="cuda", dtype=torch.bfloat16) * K**-0.5 + W_hp = torch.randn(*shape_W, device="cuda", dtype=torch.bfloat16) * K**-0.5 + + A_q, A_sc = mxfp8_quantize(A_hp) + W_q, W_sc = mxfp8_quantize(W_hp) # (..., N, K), (..., N, K/32) + assert A_q.dtype == torch.float8_e4m3fn and A_sc.dtype == torch.float8_e8m0fnu + + B_q = W_q.mT # (..., K, N) K-contig view + B_sc = W_sc.mT # (..., K/32, N) K-contig view + + out = mxfp8_gemm(A_q, B_q, A_sc, B_sc) + assert out.shape == ((L, M, N) if batched else (M, N)) + assert out.dtype == torch.bfloat16 + + ref = mxfp8_gemm_ref(A_q, B_q, A_sc, B_sc) + err = (out.float() - ref.float()).abs().max().item() + assert err < 5e-3, f"quack vs ref max_err={err}" + + # cuBLAS comparison only for 2D / L=1 + if not batched: + out_cublas = mxfp8_gemm_cublas(A_q, B_q, A_sc, B_sc) + assert torch.equal(out, out_cublas), "quack interface != cuBLAS" + + # High-level quantize+gemm convenience fn + out2 = mxfp8_gemm_quantize(A_hp, W_hp) + assert torch.equal(out, out2) + + +@pytest.mark.parametrize("a_major", ["k", "m"]) +@pytest.mark.parametrize("b_major", ["k", "n"]) +def test_blockscaled_mxfp8_major_modes(a_major, b_major): + """MXFP8 with A in {k,m}-major × B in {k,n}-major. The SF tensor layout + stays K-major (hardware convention); only A/B operand strides differ.""" + _skip_if_not_sm100() + from quack.mx_utils import to_mx + + m, n, k, l = 256, 256, 256, 1 + sf_vec = 32 + + def _make_operand(mn, major): + hp = (torch.randn(l, mn, k, device="cuda", dtype=torch.bfloat16) * k**-0.5).contiguous() + q_flat, sc_flat = to_mx(hp.view(l * mn, k), sf_vec) + ref_mkl = ( + ( + q_flat.float().view(l, mn, k) + * sc_flat.float().view(l, mn, k // sf_vec).repeat_interleave(sf_vec, dim=-1) + ) + .permute(1, 2, 0) + .contiguous() + ) + if major == "k": + # (l, mn, k) contig → permute to (mn, k, l) → stride (k, 1, mn*k) + q_mkl = q_flat.view(l, mn, k).contiguous().permute(1, 2, 0) + else: + # (l, mn, k) contig → permute to (mn, k, l) with mn fastest → stride (1, mn, mn*k) + q_mkl = ( + q_flat.view(l, mn, k).contiguous().permute(0, 2, 1).contiguous().permute(2, 1, 0) + ) + return ref_mkl, q_mkl, sc_flat.view(l, mn, k // sf_vec) + + a_ref, mA, sa_2d = _make_operand(m, a_major) + b_ref, mB, sb_2d = _make_operand(n, b_major) + # Sanity: stride(0) == 1 iff mn-major. + assert (mA.stride(0) == 1) == (a_major == "m"), f"mA stride: {mA.stride()}" + assert (mB.stride(0) == 1) == (b_major == "n"), f"mB stride: {mB.stride()}" + from quack.blockscaled_gemm_utils import pack_scale_2d_to_blocked_contig + + a_sc = pack_scale_2d_to_blocked_contig(sa_2d) + b_sc = pack_scale_2d_to_blocked_contig(sb_2d) + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + + assert GemmDefaultSm100.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + (128, 128), + (1, 1), + m, + n, + k, + l, + a_major, + b_major, + "n", + ) + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD, + a_sc, + b_sc, + ) + runner(mA, mB, mD, a_sc, b_sc) + torch.cuda.synchronize() + + ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) + err = (mD.float() - ref).abs().max().item() + assert err < 5e-3, f"A={a_major} B={b_major} max_err={err}" + + +@pytest.mark.parametrize("b_major", ["k", "n"]) +@pytest.mark.parametrize( + "seqlens_m", + [ + [128, 128, 128], # baseline: all aligned + [100, 200, 150], # none aligned to 128 + [30, 300, 64, 200], # mix small + non-aligned + [1, 128, 127, 129], # boundary conditions + ], +) +def test_blockscaled_mxfp8_varlen_m_nonaligned(seqlens_m, b_major): + """varlen_m with per-expert seqlens not divisible by 128, plus k/n-major B. + SFA is stored in dQaccum-style padded format; kernel reads it via + offset_batch_SFA.""" + _skip_if_not_sm100() + num_experts = len(seqlens_m) + n, k = 256, 256 + sf_vec = 32 + mma_tiler_mn = (128, 128) + cluster_shape_mn = (1, 1) + + torch.manual_seed(0) + a_ref_dq, b_ref_dq, mA, mB, a_sc_contig, b_sc_contig, cu_seqlens_m = ( + create_blockscaled_varlen_m_operands( + num_experts, + 0, + n, + k, + sf_vec, + seqlens_m=seqlens_m, + b_major=b_major, + ) + ) + expected_b_stride0 = 1 if b_major == "n" else k + assert mB.stride(0) == expected_b_stride0, ( + f"b_major={b_major} → mB.stride(0) should be {expected_b_stride0}, got {mB.stride()}" + ) + total_m = int(sum(seqlens_m)) + mSFA = a_sc_contig # (1, total_padded_rm, rk, 512) + mSFB = b_sc_contig # (L, rn, rk, 512) + + mD = torch.empty(total_m, n, dtype=torch.bfloat16, device="cuda") + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + mma_tiler_mn, + cluster_shape_mn, + mA, + mB, + mD, + mSFA, + mSFB, + varlen_m=True, + ) + runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m) + torch.cuda.synchronize() + + # Per-expert reference matmul on dequantized operands. + cu = cu_seqlens_m.tolist() + ref = torch.cat([a_ref_dq[cu[i] : cu[i + 1]] @ b_ref_dq[i].T for i in range(num_experts)]) + err = (mD.float() - ref).abs().max().item() + assert err < 5e-3, f"varlen_m non-aligned seqlens_m={seqlens_m} max_err={err}" + + +@pytest.mark.parametrize( + "seqlens_k", + [ + [128, 128, 128], # all aligned to 128 + [128, 256, 128], # 128-aligned mixed sizes + [96, 160, 128], # not 128-aligned (but all sf_vec-aligned) + [32, 256, 64, 128], # small + varied + ], +) +def test_blockscaled_mxfp8_varlen_k(seqlens_k): + """varlen_k blockscaled: per-expert k_i (must be sf_vec-aligned; 128-alignment + is NOT required). SFA/SFB use dQaccum-style K-padded storage and the kernel + reads them via offset_batch_SFA/offset_batch_SFB padded-K formula.""" + _skip_if_not_sm100() + num_experts = len(seqlens_k) + m, n = 256, 256 + sf_vec = 32 + mma_tiler_mn = (128, 128) + cluster_shape_mn = (1, 1) + + torch.manual_seed(0) + a_ref_list, b_ref_list, mA, mB, a_sc_contig, b_sc_contig, cu_seqlens_k = ( + create_blockscaled_varlen_k_operands(num_experts, 0, m, n, sf_vec, seqlens_k=seqlens_k) + ) + # (m, n, L) with stride 1 on N dim (compile expects leading_dim=1 on mD). + mD = torch.empty(num_experts, m, n, dtype=torch.bfloat16, device="cuda").permute(1, 2, 0) + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + mma_tiler_mn, + cluster_shape_mn, + mA, + mB, + mD, + a_sc_contig, + b_sc_contig, + varlen_k=True, + ) + runner(mA, mB, mD, a_sc_contig, b_sc_contig, cu_seqlens_k) + torch.cuda.synchronize() + + # Per-expert reference: for expert i, result = a_ref[i] @ b_ref[i].T. + # mD has shape (m, n, L) N-major; each mD[:, :, i] is one expert's output. + for i in range(num_experts): + ref_i = a_ref_list[i] @ b_ref_list[i].T + out_i = mD[:, :, i].float() + err = (out_i - ref_i).abs().max().item() + assert err < 5e-3, f"varlen_k seqlens_k={seqlens_k} expert={i} max_err={err}" + + +@pytest.mark.parametrize("rk_pad", [1, 3, 5]) +def test_blockscaled_mxfp8_strided_sf(rk_pad): + """Verify the kernel honors mSFA/mSFB's actual outer strides (doesn't + require the full scale tensor to be contig — only the innermost 512-B + tile). Allocates a larger scale buffer with extra rk padding and slices + back to the valid rk, producing a non-packed rm stride.""" + _skip_if_not_sm100() + m, n, k = 256, 256, 512 # k=512 → sf_k=16 → rk=4 (meaningful stride change) + l, sf_vec = 1, 32 + + torch.manual_seed(0) + a_ref, mA, a_sc = create_blockscaled_operand_quantized(l, m, k, False, sf_vec) + b_ref, mB, b_sc = create_blockscaled_operand_quantized(l, n, k, False, sf_vec) + + rm = (m + 127) // 128 + rn = (n + 127) // 128 + rk = ((k // sf_vec) + 3) // 4 + + # Allocate padded scale buffers (rk + rk_pad along K-blocks), copy valid + # tiles into the prefix, slice back to rk. The slice is non-contig: + # stride(1) = (rk + rk_pad) * 512 instead of rk * 512. + a_sc_big = torch.zeros(l, rm, rk + rk_pad, 512, dtype=torch.float8_e8m0fnu, device="cuda") + b_sc_big = torch.zeros(l, rn, rk + rk_pad, 512, dtype=torch.float8_e8m0fnu, device="cuda") + a_sc_big[:, :, :rk, :] = a_sc + b_sc_big[:, :, :rk, :] = b_sc + mSFA_strided = a_sc_big[:, :, :rk, :] + mSFB_strided = b_sc_big[:, :, :rk, :] + assert not mSFA_strided.is_contiguous() + assert mSFA_strided.stride(-1) == 1 + assert mSFA_strided.stride(1) == (rk + rk_pad) * 512, ( + f"expected non-packed rm stride {(rk + rk_pad) * 512}, got {mSFA_strided.stride(1)}" + ) + + # Validate our helper accepts the non-contig layout + _ = scale_view_for_kernel(mSFA_strided, m, k // sf_vec, l) + _ = scale_view_for_kernel(mSFB_strided, n, k // sf_vec, l) + + _, mD_strided = create_blockscaled_operand_tensor( + l, m, n, False, cutlass.BFloat16, init="empty" + ) + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD_strided, + mSFA_strided, + mSFB_strided, + ) + runner(mA, mB, mD_strided, mSFA_strided, mSFB_strided) + + # Compare bit-exactly against the same matmul with contig scales. + _, mD_contig = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + runner_contig = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + sf_vec, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD_contig, + a_sc, + b_sc, + ) + runner_contig(mA, mB, mD_contig, a_sc, b_sc) + torch.cuda.synchronize() + + assert torch.equal(mD_strided, mD_contig), ( + f"strided-SF output differs from contig-SF: " + f"max_abs_err={(mD_strided.float() - mD_contig.float()).abs().max().item()}" + ) + + +def test_mxfp8_interface_preallocated_out(): + _skip_if_not_sm100() + from quack.gemm_blockscaled_interface import mxfp8_gemm, mxfp8_quantize + + M, N, K = 256, 256, 256 + torch.manual_seed(0) + A_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 + W_hp = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 + A_q, A_sc = mxfp8_quantize(A_hp) + W_q, W_sc = mxfp8_quantize(W_hp) + B_q, B_sc = W_q.mT, W_sc.mT + + out_alloc = mxfp8_gemm(A_q, B_q, A_sc, B_sc) + out_pre = torch.empty(M, N, device="cuda", dtype=torch.bfloat16) + mxfp8_gemm(A_q, B_q, A_sc, B_sc, out=out_pre) + assert torch.equal(out_alloc, out_pre) diff --git a/tests/test_gemm_sm100_blockscaled.py b/tests/test_gemm_sm100_blockscaled.py deleted file mode 100644 index 315697d8..00000000 --- a/tests/test_gemm_sm100_blockscaled.py +++ /dev/null @@ -1,905 +0,0 @@ -import pytest -import torch - -import cutlass - -from quack.blockscaled_gemm_utils import ( - blockscaled_gemm_reference, - compile_blockscaled_gemm_tvm_ffi, - create_blockscaled_operand_quantized, - create_blockscaled_operand_tensor, - create_blockscaled_scale_tensor, - create_blockscaled_varlen_k_operands, - create_blockscaled_varlen_m_operands, - scale_blocked_for_cublas, - scale_view_for_kernel, -) -from quack.gemm_default_epi import GemmDefaultSm100 -from quack.mx_utils import to_blocked - - -def _skip_if_not_sm100(): - major = torch.cuda.get_device_properties(0).major - if major < 10: - pytest.skip("SM100+ required") - - -def _compile_blockscaled_gemm( - ab_dtype, - sf_dtype, - sf_vec_size, - d_dtype, - mma_tiler_mn, - cluster_shape_mn, - m, - n, - k, - l, -): - a_ref, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype) - b_ref, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype) - _, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty") - sfa_ref, mSFA = create_blockscaled_scale_tensor(l, m, k, sf_vec_size, sf_dtype) - sfb_ref, mSFB = create_blockscaled_scale_tensor(l, n, k, sf_vec_size, sf_dtype) - compiled = compile_blockscaled_gemm_tvm_ffi( - ab_dtype, - sf_dtype, - sf_vec_size, - d_dtype, - mma_tiler_mn, - cluster_shape_mn, - mA, - mB, - mD, - mSFA, - mSFB, - ) - return ( - compiled, - (mA, mB, mD, mSFA, mSFB), - (a_ref, b_ref, sfa_ref, sfb_ref, mD), - ) - - -def _run_blockscaled_gemm(compiled, args): - mA, mB, mD, mSFA, mSFB = args - compiled(mA, mB, mD, mSFA, mSFB) - torch.cuda.synchronize() - - -def test_blockscaled_validation(): - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 64), - (1, 1), - 256, - 64, - 256, - 1, - "k", - "k", - "n", - ) - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 192), - (1, 1), - 256, - 192, - 256, - 1, - "k", - "k", - "n", - ) - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 128), - (1, 1), - 256, - 256, - 256, - 1, - "k", - "k", - "n", - ) - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float4E2M1FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (128, 128), - (1, 1), - 256, - 256, - 256, - 1, - "k", - "k", - "n", - ) - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float4E2M1FN, - cutlass.Float8E4M3FN, - 16, - cutlass.Float32, - (128, 192), - (1, 1), - 256, - 192, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (256, 384), - (2, 1), - 256, - 512, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (256, 224), - (2, 1), - 256, - 448, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float4E2M1FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (256, 384), - (2, 1), - 256, - 512, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (64, 128), - (1, 1), - 256, - 256, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float4E2M1FN, - cutlass.Float8E4M3FN, - 32, - cutlass.Float32, - (128, 128), - (1, 1), - 256, - 256, - 256, - 1, - "k", - "k", - "n", - ) - assert not GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (256, 128), - (1, 1), - 512, - 256, - 256, - 1, - "k", - "k", - "n", - ) - - -@pytest.mark.parametrize( - "ab_dtype,sf_dtype,sf_vec_size,d_dtype,mma_tiler_mn,cluster_shape_mn,m,n,k,l", - [ - ( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 64), - (1, 1), - 256, - 64, - 256, - 1, - ), - ( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 192), - (1, 1), - 256, - 192, - 256, - 1, - ), - ( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (128, 128), - (1, 1), - 256, - 256, - 256, - 1, - ), - ( - cutlass.Float8E5M2, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (256, 64), - (2, 1), - 512, - 64, - 256, - 1, - ), - ( - cutlass.Float8E5M2, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (256, 192), - (2, 1), - 512, - 192, - 256, - 1, - ), - ( - cutlass.Float8E5M2, - cutlass.Float8E8M0FNU, - 32, - cutlass.BFloat16, - (256, 128), - (2, 1), - 512, - 256, - 256, - 1, - ), - ( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (256, 192), - (2, 1), - 256, - 192, - 256, - 1, - ), - ( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (256, 224), - (2, 1), - 256, - 224, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (128, 128), - (1, 1), - 256, - 256, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E8M0FNU, - 32, - cutlass.Float32, - (256, 224), - (2, 1), - 256, - 224, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E8M0FNU, - 16, - cutlass.Float32, - (128, 64), - (1, 1), - 256, - 64, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E4M3FN, - 16, - cutlass.Float32, - (256, 192), - (2, 1), - 256, - 192, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E4M3FN, - 16, - cutlass.Float32, - (128, 192), - (1, 1), - 256, - 192, - 256, - 1, - ), - ( - cutlass.Float4E2M1FN, - cutlass.Float8E4M3FN, - 16, - cutlass.Float32, - (256, 224), - (2, 1), - 256, - 224, - 256, - 1, - ), - ], -) -def test_blockscaled_correctness( - ab_dtype, sf_dtype, sf_vec_size, d_dtype, mma_tiler_mn, cluster_shape_mn, m, n, k, l -): - _skip_if_not_sm100() - - ( - compiled, - args, - (a_ref, b_ref, sfa_ref, sfb_ref, _), - ) = _compile_blockscaled_gemm( - ab_dtype, - sf_dtype, - sf_vec_size, - d_dtype, - mma_tiler_mn, - cluster_shape_mn, - m, - n, - k, - l, - ) - _run_blockscaled_gemm(compiled, args) - - _, _, d_torch, _, _ = args - ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) - err = (d_torch.float() - ref).abs().max().item() - tol = 5e-3 if d_dtype != cutlass.Float32 else 5e-4 - assert err < tol, f"max_err={err}" - - -# --------------------------------------------------------------------------- -# Scale layout invariants -# --------------------------------------------------------------------------- -@pytest.mark.parametrize("mn,sf_k,l", [(128, 4, 1), (256, 16, 1), (384, 12, 2), (512, 8, 1)]) -def test_scale_layout_matches_cublas(mn, sf_k, l): - """The quack kernel scale-view and cuBLAS's to_blocked must share the - same underlying byte layout (they both represent the PTX - tcgen05 scale-factor atom, tiled in the same outer order).""" - torch.manual_seed(0) - # a 2D uint8 scale slice per batch - scale_2d = torch.randint(0, 255, (l, mn, sf_k), device="cuda", dtype=torch.uint8) - - # Build our contiguous scale storage via create_blockscaled_operand_quantized's - # rearrangement logic: pad + (l, rm, 128, rk, 4) -> (l, rm, rk, 512) - rm = (mn + 127) // 128 - rk = (sf_k + 3) // 4 - mn_pad = rm * 128 - sf_k_pad = rk * 4 - padded = torch.zeros(l, mn_pad, sf_k_pad, device="cuda", dtype=torch.uint8) - padded[:, :mn, :sf_k] = scale_2d - blocks = padded.view(l, rm, 128, rk, 4).permute(0, 1, 3, 2, 4) - blocks = blocks.reshape(l, rm, rk, 4, 32, 4).transpose(3, 4).contiguous() - scale_contig = blocks.view(l, rm, rk, 512) # (l, rm, rk, 512) - - # kernel view indexing: byte offset within tile = (m%32)*16 + ((m//32)%4)*4 + (k%4) - kv = scale_view_for_kernel(scale_contig.view(torch.float8_e8m0fnu), mn, sf_k, l).view( - torch.uint8 - ) - m_positions = sorted({0, 1, 17, 31, 33, 127, min(128, mn - 1), mn - 1} & set(range(mn))) - k_positions = sorted({0, 1, 3, 4, 7, sf_k - 1} & set(range(sf_k))) - for li in range(l): - for mi in m_positions: - for ki in k_positions: - byte_off = (mi % 32) * 16 + ((mi // 32) % 4) * 4 + (ki % 4) - assert kv[li, mi // 128, ki // 4, byte_off].item() == scale_2d[li, mi, ki].item(), ( - f"mismatch at l={li} m={mi} k={ki}" - ) - - # cuBLAS slice must equal to_blocked(scale_2d[l]) - for li in range(l): - ours = scale_blocked_for_cublas(scale_contig.view(torch.float8_e8m0fnu), mn, sf_k, li).view( - torch.uint8 - ) - ref = to_blocked(scale_2d[li]) - assert torch.equal(ours, ref), f"to_blocked mismatch at l={li}" - - -# --------------------------------------------------------------------------- -# End-to-end: quantized MXFP8 inputs through quack kernel vs cuBLAS vs dequant ref -# --------------------------------------------------------------------------- -@pytest.mark.parametrize( - "mma_tiler_mn,cluster_shape_mn,m,n,k", - [ - # All 5 supported blockscaled tile_n values (64, 128, 192, 224, 256). - ((128, 64), (1, 1), 256, 64, 512), - ((128, 128), (1, 1), 256, 256, 256), - ((128, 128), (1, 1), 512, 512, 512), - ((128, 192), (1, 1), 256, 192, 256), - ((128, 256), (1, 1), 256, 256, 256), - ((256, 128), (2, 1), 512, 256, 512), - ((256, 192), (2, 1), 256, 192, 256), - ((256, 192), (2, 1), 256, 384, 256), - ((256, 192), (2, 1), 512, 192, 512), - ((256, 224), (2, 1), 256, 224, 256), - ((256, 224), (2, 1), 512, 224, 512), - ((256, 256), (2, 1), 512, 256, 512), - ], -) -def test_blockscaled_mxfp8_quantized(mma_tiler_mn, cluster_shape_mn, m, n, k): - _skip_if_not_sm100() - l, sf_vec = 1, 32 - - torch.manual_seed(0) - a_ref, mA, a_sc = create_blockscaled_operand_quantized(l, m, k, False, sf_vec) - b_ref, mB, b_sc = create_blockscaled_operand_quantized(l, n, k, False, sf_vec) - _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") - - mSFA = scale_view_for_kernel(a_sc, m, k // sf_vec, l) - mSFB = scale_view_for_kernel(b_sc, n, k // sf_vec, l) - - runner = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - mma_tiler_mn, - cluster_shape_mn, - mA, - mB, - mD, - mSFA, - mSFB, - ) - runner(mA, mB, mD, mSFA, mSFB) - torch.cuda.synchronize() - - # Reference: dequant matmul (a_ref/b_ref are already dequantized) - d_ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) - err = (mD.float() - d_ref).abs().max().item() - assert err < 5e-3, f"quack vs dequant max_err={err}" - - # cuBLAS: bit-exact match expected (same operand bits, same scale bytes, same hw MMA) - from torch.nn.functional import scaled_mm as F_scaled_mm, ScalingType, SwizzleType - - a_cub = mA[:, :, 0].contiguous() - b_cub = mB[:, :, 0].contiguous() - a_sc_cub = scale_blocked_for_cublas(a_sc, m, k // sf_vec, 0) - b_sc_cub = scale_blocked_for_cublas(b_sc, n, k // sf_vec, 0) - out_cublas = F_scaled_mm( - a_cub, - b_cub.t(), - scale_a=a_sc_cub, - scale_recipe_a=ScalingType.BlockWise1x32, - scale_b=b_sc_cub, - scale_recipe_b=ScalingType.BlockWise1x32, - swizzle_a=SwizzleType.SWIZZLE_32_4_4, - swizzle_b=SwizzleType.SWIZZLE_32_4_4, - output_dtype=torch.bfloat16, - ) - assert torch.equal(mD.squeeze(-1), out_cublas), ( - f"quack != cuBLAS: max_err={(mD.squeeze(-1).float() - out_cublas.float()).abs().max().item()}" - ) - - -# --------------------------------------------------------------------------- -# High-level PyTorch interface -# --------------------------------------------------------------------------- -@pytest.mark.parametrize("shape_mnk", [(256, 256, 256), (512, 256, 256), (128, 128, 256)]) -@pytest.mark.parametrize("batched", [False, True]) -def test_mxfp8_interface(shape_mnk, batched): - _skip_if_not_sm100() - from quack.gemm_blockscaled_interface import ( - mxfp8_gemm, - mxfp8_gemm_cublas, - mxfp8_gemm_ref, - mxfp8_gemm_quantize, - mxfp8_quantize, - ) - - M, N, K = shape_mnk - L = 2 if batched else 1 - torch.manual_seed(0) - shape_A = (L, M, K) if batched else (M, K) - # Weight stored nn.Linear-style (N, K) row-major; pass .mT to get K-contig (K, N) - shape_W = (L, N, K) if batched else (N, K) - A_hp = torch.randn(*shape_A, device="cuda", dtype=torch.bfloat16) * K**-0.5 - W_hp = torch.randn(*shape_W, device="cuda", dtype=torch.bfloat16) * K**-0.5 - - A_q, A_sc = mxfp8_quantize(A_hp) - W_q, W_sc = mxfp8_quantize(W_hp) # (..., N, K), (..., N, K/32) - assert A_q.dtype == torch.float8_e4m3fn and A_sc.dtype == torch.float8_e8m0fnu - - B_q = W_q.mT # (..., K, N) K-contig view - B_sc = W_sc.mT # (..., K/32, N) K-contig view - - out = mxfp8_gemm(A_q, B_q, A_sc, B_sc) - assert out.shape == ((L, M, N) if batched else (M, N)) - assert out.dtype == torch.bfloat16 - - ref = mxfp8_gemm_ref(A_q, B_q, A_sc, B_sc) - err = (out.float() - ref.float()).abs().max().item() - assert err < 5e-3, f"quack vs ref max_err={err}" - - # cuBLAS comparison only for 2D / L=1 - if not batched: - out_cublas = mxfp8_gemm_cublas(A_q, B_q, A_sc, B_sc) - assert torch.equal(out, out_cublas), "quack interface != cuBLAS" - - # High-level quantize+gemm convenience fn - out2 = mxfp8_gemm_quantize(A_hp, W_hp) - assert torch.equal(out, out2) - - -@pytest.mark.parametrize("a_major", ["k", "m"]) -@pytest.mark.parametrize("b_major", ["k", "n"]) -def test_blockscaled_mxfp8_major_modes(a_major, b_major): - """MXFP8 with A in {k,m}-major × B in {k,n}-major. The SF tensor layout - stays K-major (hardware convention); only A/B operand strides differ.""" - _skip_if_not_sm100() - from quack.mx_utils import to_mx - - m, n, k, l = 256, 256, 256, 1 - sf_vec = 32 - - def _make_operand(mn, major): - hp = (torch.randn(l, mn, k, device="cuda", dtype=torch.bfloat16) * k**-0.5).contiguous() - q_flat, sc_flat = to_mx(hp.view(l * mn, k), sf_vec) - ref_mkl = ( - ( - q_flat.float().view(l, mn, k) - * sc_flat.float().view(l, mn, k // sf_vec).repeat_interleave(sf_vec, dim=-1) - ) - .permute(1, 2, 0) - .contiguous() - ) - if major == "k": - # (l, mn, k) contig → permute to (mn, k, l) → stride (k, 1, mn*k) - q_mkl = q_flat.view(l, mn, k).contiguous().permute(1, 2, 0) - else: - # (l, mn, k) contig → permute to (mn, k, l) with mn fastest → stride (1, mn, mn*k) - q_mkl = ( - q_flat.view(l, mn, k).contiguous().permute(0, 2, 1).contiguous().permute(2, 1, 0) - ) - return ref_mkl, q_mkl, sc_flat.view(l, mn, k // sf_vec) - - a_ref, mA, sa_2d = _make_operand(m, a_major) - b_ref, mB, sb_2d = _make_operand(n, b_major) - # Sanity: stride(0) == 1 iff mn-major. - assert (mA.stride(0) == 1) == (a_major == "m"), f"mA stride: {mA.stride()}" - assert (mB.stride(0) == 1) == (b_major == "n"), f"mB stride: {mB.stride()}" - from quack.blockscaled_gemm_utils import pack_scale_2d_to_blocked_contig - - a_sc = pack_scale_2d_to_blocked_contig(sa_2d) - b_sc = pack_scale_2d_to_blocked_contig(sb_2d) - _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") - - assert GemmDefaultSm100.can_implement_blockscaled( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - (128, 128), - (1, 1), - m, - n, - k, - l, - a_major, - b_major, - "n", - ) - runner = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - (128, 128), - (1, 1), - mA, - mB, - mD, - a_sc, - b_sc, - ) - runner(mA, mB, mD, a_sc, b_sc) - torch.cuda.synchronize() - - ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) - err = (mD.float() - ref).abs().max().item() - assert err < 5e-3, f"A={a_major} B={b_major} max_err={err}" - - -@pytest.mark.parametrize("b_major", ["k", "n"]) -@pytest.mark.parametrize( - "seqlens_m", - [ - [128, 128, 128], # baseline: all aligned - [100, 200, 150], # none aligned to 128 - [30, 300, 64, 200], # mix small + non-aligned - [1, 128, 127, 129], # boundary conditions - ], -) -def test_blockscaled_mxfp8_varlen_m_nonaligned(seqlens_m, b_major): - """varlen_m with per-expert seqlens not divisible by 128, plus k/n-major B. - SFA is stored in dQaccum-style padded format; kernel reads it via - offset_batch_SFA.""" - _skip_if_not_sm100() - num_experts = len(seqlens_m) - n, k = 256, 256 - sf_vec = 32 - mma_tiler_mn = (128, 128) - cluster_shape_mn = (1, 1) - - torch.manual_seed(0) - a_ref_dq, b_ref_dq, mA, mB, a_sc_contig, b_sc_contig, cu_seqlens_m = ( - create_blockscaled_varlen_m_operands( - num_experts, - 0, - n, - k, - sf_vec, - seqlens_m=seqlens_m, - b_major=b_major, - ) - ) - expected_b_stride0 = 1 if b_major == "n" else k - assert mB.stride(0) == expected_b_stride0, ( - f"b_major={b_major} → mB.stride(0) should be {expected_b_stride0}, got {mB.stride()}" - ) - total_m = int(sum(seqlens_m)) - mSFA = a_sc_contig # (1, total_padded_rm, rk, 512) - mSFB = b_sc_contig # (L, rn, rk, 512) - - mD = torch.empty(total_m, n, dtype=torch.bfloat16, device="cuda") - runner = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - mma_tiler_mn, - cluster_shape_mn, - mA, - mB, - mD, - mSFA, - mSFB, - varlen_m=True, - ) - runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m) - torch.cuda.synchronize() - - # Per-expert reference matmul on dequantized operands. - cu = cu_seqlens_m.tolist() - ref = torch.cat([a_ref_dq[cu[i] : cu[i + 1]] @ b_ref_dq[i].T for i in range(num_experts)]) - err = (mD.float() - ref).abs().max().item() - assert err < 5e-3, f"varlen_m non-aligned seqlens_m={seqlens_m} max_err={err}" - - -@pytest.mark.parametrize( - "seqlens_k", - [ - [128, 128, 128], # all aligned to 128 - [128, 256, 128], # 128-aligned mixed sizes - [96, 160, 128], # not 128-aligned (but all sf_vec-aligned) - [32, 256, 64, 128], # small + varied - ], -) -def test_blockscaled_mxfp8_varlen_k(seqlens_k): - """varlen_k blockscaled: per-expert k_i (must be sf_vec-aligned; 128-alignment - is NOT required). SFA/SFB use dQaccum-style K-padded storage and the kernel - reads them via offset_batch_SFA/offset_batch_SFB padded-K formula.""" - _skip_if_not_sm100() - num_experts = len(seqlens_k) - m, n = 256, 256 - sf_vec = 32 - mma_tiler_mn = (128, 128) - cluster_shape_mn = (1, 1) - - torch.manual_seed(0) - a_ref_list, b_ref_list, mA, mB, a_sc_contig, b_sc_contig, cu_seqlens_k = ( - create_blockscaled_varlen_k_operands(num_experts, 0, m, n, sf_vec, seqlens_k=seqlens_k) - ) - # (m, n, L) with stride 1 on N dim (compile expects leading_dim=1 on mD). - mD = torch.empty(num_experts, m, n, dtype=torch.bfloat16, device="cuda").permute(1, 2, 0) - runner = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - mma_tiler_mn, - cluster_shape_mn, - mA, - mB, - mD, - a_sc_contig, - b_sc_contig, - varlen_k=True, - ) - runner(mA, mB, mD, a_sc_contig, b_sc_contig, cu_seqlens_k) - torch.cuda.synchronize() - - # Per-expert reference: for expert i, result = a_ref[i] @ b_ref[i].T. - # mD has shape (m, n, L) N-major; each mD[:, :, i] is one expert's output. - for i in range(num_experts): - ref_i = a_ref_list[i] @ b_ref_list[i].T - out_i = mD[:, :, i].float() - err = (out_i - ref_i).abs().max().item() - assert err < 5e-3, f"varlen_k seqlens_k={seqlens_k} expert={i} max_err={err}" - - -@pytest.mark.parametrize("rk_pad", [1, 3, 5]) -def test_blockscaled_mxfp8_strided_sf(rk_pad): - """Verify the kernel honors mSFA/mSFB's actual outer strides (doesn't - require the full scale tensor to be contig — only the innermost 512-B - tile). Allocates a larger scale buffer with extra rk padding and slices - back to the valid rk, producing a non-packed rm stride.""" - _skip_if_not_sm100() - m, n, k = 256, 256, 512 # k=512 → sf_k=16 → rk=4 (meaningful stride change) - l, sf_vec = 1, 32 - - torch.manual_seed(0) - a_ref, mA, a_sc = create_blockscaled_operand_quantized(l, m, k, False, sf_vec) - b_ref, mB, b_sc = create_blockscaled_operand_quantized(l, n, k, False, sf_vec) - - rm = (m + 127) // 128 - rn = (n + 127) // 128 - rk = ((k // sf_vec) + 3) // 4 - - # Allocate padded scale buffers (rk + rk_pad along K-blocks), copy valid - # tiles into the prefix, slice back to rk. The slice is non-contig: - # stride(1) = (rk + rk_pad) * 512 instead of rk * 512. - a_sc_big = torch.zeros(l, rm, rk + rk_pad, 512, dtype=torch.float8_e8m0fnu, device="cuda") - b_sc_big = torch.zeros(l, rn, rk + rk_pad, 512, dtype=torch.float8_e8m0fnu, device="cuda") - a_sc_big[:, :, :rk, :] = a_sc - b_sc_big[:, :, :rk, :] = b_sc - mSFA_strided = a_sc_big[:, :, :rk, :] - mSFB_strided = b_sc_big[:, :, :rk, :] - assert not mSFA_strided.is_contiguous() - assert mSFA_strided.stride(-1) == 1 - assert mSFA_strided.stride(1) == (rk + rk_pad) * 512, ( - f"expected non-packed rm stride {(rk + rk_pad) * 512}, got {mSFA_strided.stride(1)}" - ) - - # Validate our helper accepts the non-contig layout - _ = scale_view_for_kernel(mSFA_strided, m, k // sf_vec, l) - _ = scale_view_for_kernel(mSFB_strided, n, k // sf_vec, l) - - _, mD_strided = create_blockscaled_operand_tensor( - l, m, n, False, cutlass.BFloat16, init="empty" - ) - runner = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - (128, 128), - (1, 1), - mA, - mB, - mD_strided, - mSFA_strided, - mSFB_strided, - ) - runner(mA, mB, mD_strided, mSFA_strided, mSFB_strided) - - # Compare bit-exactly against the same matmul with contig scales. - _, mD_contig = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") - runner_contig = compile_blockscaled_gemm_tvm_ffi( - cutlass.Float8E4M3FN, - cutlass.Float8E8M0FNU, - sf_vec, - cutlass.BFloat16, - (128, 128), - (1, 1), - mA, - mB, - mD_contig, - a_sc, - b_sc, - ) - runner_contig(mA, mB, mD_contig, a_sc, b_sc) - torch.cuda.synchronize() - - assert torch.equal(mD_strided, mD_contig), ( - f"strided-SF output differs from contig-SF: " - f"max_abs_err={(mD_strided.float() - mD_contig.float()).abs().max().item()}" - ) - - -def test_mxfp8_interface_preallocated_out(): - _skip_if_not_sm100() - from quack.gemm_blockscaled_interface import mxfp8_gemm, mxfp8_quantize - - M, N, K = 256, 256, 256 - torch.manual_seed(0) - A_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 - W_hp = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) * K**-0.5 - A_q, A_sc = mxfp8_quantize(A_hp) - W_q, W_sc = mxfp8_quantize(W_hp) - B_q, B_sc = W_q.mT, W_sc.mT - - out_alloc = mxfp8_gemm(A_q, B_q, A_sc, B_sc) - out_pre = torch.empty(M, N, device="cuda", dtype=torch.bfloat16) - mxfp8_gemm(A_q, B_q, A_sc, B_sc, out=out_pre) - assert torch.equal(out_alloc, out_pre)