From 159b93d5425c42a7806a9267cda9d7a3d3cbaf98 Mon Sep 17 00:00:00 2001 From: agent Date: Thu, 30 Apr 2026 12:59:18 +0200 Subject: [PATCH 1/4] [Sm120] add blockscaled FP4 GEMM path Add the first guarded SM120 blockscaled GEMM path to the existing blockscaled GEMM interface instead of adding a separate SM120-specific frontend. The new path is intentionally narrow and mirrors the SM100 blockscaled entry points where the existing abstractions fit. Supported SM120 scope in this commit: - A/B are Float4E2M1FN FP4 operands stored as packed torch.float4_e2m1fn_x2. - Scale tensors are byte-sized FP8 scale factors: - NVFP4: Float8E4M3FN with sf_vec_size=16. - MXFP4: Float8E8M0FNU with sf_vec_size=32. - Accumulation is Float32 and D is BFloat16. - M and N must be multiples of 128, K must be a multiple of 64, L must be 1. - tile_shape_mnk is fixed to 128x128x64 and cluster_shape_mnk is fixed to 1x1x1. - C/beta, varlen, gather_A, sparse/grouped kernels, multicast clusters, output FP4 quantization, and generic autodispatch remain unsupported. The SM120 scale layout is deliberately different from the compact SM100-style scale layout. CuTeDSL TMA copies for compact row-major scale pages such as (M, 4) trap on SM120 because the row stride is smaller than the TMA-friendly 16-byte granularity. The SM120 path therefore requires padded physical scale pages: logical_scale_cols = ceil(K / sf_vec_size) physical_scale_cols = round_up(logical_scale_cols, 16) SFA shape = (M, physical_scale_cols, 1) SFB shape = (N, physical_scale_cols, 1) Only logical columns are consumed; padding columns are ignored. The helper used by tests and benchmarks now validates K divisibility, sf_vec_size, and the matching scale dtype so invalid tensors are rejected before they can reach TMA. The compile-time entry also rejects compact SM120 scale tensors before launch. The SM120 kernel uses non-multicast TMA to stage packed A/B bytes and padded SFA/SFB pages, expands packed FP4 bytes into the padded Int8 shared-memory shape required by SM120 FP4 ldmatrix helpers, and issues the CUTLASS DSL tuple-MMA blockscaled path. It keeps the proven selector-zero scale packet mapping: SFA provider rows are group + 8 * (tid & 1), and SFB provider columns are group. A/B FP4 ldmatrix still uses the local SM120 helpers instead of generic cute.copy because source and destination element widths differ. For K > 64, this commit keeps Float32 accumulation across K64 tiles for each 16x8 atom. Intermediate K64 partial sums are written to an FP32 shared-memory scratch tile, later K64 partials add that FP32 value back into registers, and D is converted to BF16 only once on the final K64 tile. This avoids the previous BF16-chained partial accumulation behavior and preserves the advertised Float32 accumulation semantics for the supported K-multiple cases. The public class-level blockscaled call now runs the SM120 validation path before launching the JIT kernel. The TVM/FFI compile helper still calls a JIT-only internal entry because it creates logical fake CuTe tensors over packed torch storage before compilation; the host validation path documents and rejects raw packed-K class calls with a clear error. Tests are consolidated under tests/test_gemm_blockscaled.py. The SM120 coverage checks: - supported and rejected can_implement_blockscaled cases; - direct class-call validation for missing scales, C/beta, and packed-vs-logical K misuse; - padded scale layout sizing for K=64, 128, 256, 384 and MXFP4 examples; - scale-helper validation for K, sf_vec_size, and scale dtype; - compact scale tensors rejected before launch; - SM120 scale-sensitive runtime correctness for K=64, multi-CTA K=128, and K=320 crossing a 16-column scale page; - a K=384 regression that would fail if K64 partial sums were rounded through BF16 between tiles; - non-constant FP4 K-lane patterns for both A and B with padded-scale poison present beyond the logical scale columns. The benchmark smoke path can create the padded SM120 scale tensors and launch the explicit SM120 blockscaled path, but numerical correctness is covered by the pytest references rather than benchmark timing output. --- AI/varlen_blockscaled_sf_layout.md | 2 +- benchmarks/benchmark_gemm.py | 87 +- quack/blockscaled_gemm_utils.py | 124 +- quack/gemm_sm120.py | 1080 ++++++++++++++++- ...lockscaled.py => test_gemm_blockscaled.py} | 394 +++++- 5 files changed, 1639 insertions(+), 48 deletions(-) rename tests/{test_gemm_sm100_blockscaled.py => test_gemm_blockscaled.py} (67%) diff --git a/AI/varlen_blockscaled_sf_layout.md b/AI/varlen_blockscaled_sf_layout.md index 7208c895..ec8ab597 100644 --- a/AI/varlen_blockscaled_sf_layout.md +++ b/AI/varlen_blockscaled_sf_layout.md @@ -152,7 +152,7 @@ a tile-unit offset integer. ## Tests -`tests/test_gemm_sm100_blockscaled.py`: +`tests/test_gemm_blockscaled.py`: - `test_blockscaled_mxfp8_varlen_m_nonaligned` — 4 seqlen patterns × 2 B-majors = 8 cases. Patterns include `[128, 128, 128]`, `[100, 200, 150]`, `[30, 300, 64, 200]`, `[1, 128, 127, 129]`. diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index e71a923d..e5abfe10 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -194,19 +194,21 @@ def _run_blockscaled(args): compile_blockscaled_gemm_tvm_ffi, create_blockscaled_operand_quantized, create_blockscaled_operand_tensor, + create_sm120_blockscaled_scale_tensor, create_blockscaled_varlen_m_operands, scale_blocked_for_cublas, torch_dtype_for_cutlass, ) from quack.cute_dsl_utils import get_device_capacity - from quack.gemm_default_epi import GemmDefaultSm100 + from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 sm_major = get_device_capacity(torch.device("cuda"))[0] - assert sm_major in (10, 11), ( - f"Blockscaled GEMM requires SM100 (B200/B300) or SM110; got SM{sm_major}x. " - "MXFP8/MXFP4/NVFP4 use tcgen05 UMMA which is SM100+." + assert sm_major in (10, 11, 12), ( + f"Blockscaled GEMM requires SM100/SM110 or SM120; got SM{sm_major}x." ) + if sm_major == 12 and (args.varlen_m or args.varlen_k): + raise NotImplementedError("SM120 blockscaled benchmark path does not support varlen") if args.varlen_k or args.gather_A or args.pingpong: raise NotImplementedError( "blockscaled + varlen_k/gather/pingpong is not wired up yet. " @@ -255,12 +257,16 @@ def _run_blockscaled(args): raise ValueError( f"MXFP4/NVFP4 require K-major for both A and B; got a_major={a_major}, b_major={b_major}" ) - if not GemmDefaultSm100.can_implement_blockscaled( + GemmBlockscaledCls = GemmDefaultSm120 if sm_major == 12 else GemmDefaultSm100 + mma_tiler_for_validation = ( + tuple(mma_tiler_mnk) if len(mma_tiler_mnk) == 3 or sm_major != 12 else (*mma_tiler_mnk, 64) + ) + if not GemmBlockscaledCls.can_implement_blockscaled( ab_dtype, sf_dtype, sf_vec_size, d_dtype, - mma_tiler_mnk, + mma_tiler_for_validation, cluster_shape_mn, m, n, @@ -320,28 +326,43 @@ def _run_blockscaled(args): def fn(): runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m) else: - a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( - l, - m, - k, - a_major == "m", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( - l, - n, - k, - b_major == "n", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - # (l, rm, rk, 512) contig scale — consumed directly by the kernel. - mSFA, mSFB = a_sc_contig, b_sc_contig - sfa_ref = torch.ones_like(a_ref) - sfb_ref = torch.ones_like(b_ref) + if sm_major == 12: + if ab_dtype is not cutlass.Float4E2M1FN or d_dtype is not cutlass.BFloat16: + raise TypeError( + "SM120 blockscaled benchmark currently supports FP4 inputs and BF16 D" + ) + _, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype, init="empty") + _, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype, init="empty") + mA.view(torch.uint8).fill_(0x22) + mB.view(torch.uint8).fill_(0x22) + a_ref = torch.ones((m, k, l), device="cuda", dtype=torch.float32) + b_ref = torch.ones((n, k, l), device="cuda", dtype=torch.float32) + sfa_ref, mSFA = create_sm120_blockscaled_scale_tensor(l, m, k, sf_vec_size, sf_dtype) + sfb_ref, mSFB = create_sm120_blockscaled_scale_tensor(l, n, k, sf_vec_size, sf_dtype) + a_sc_contig = b_sc_contig = None + else: + a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( + l, + m, + k, + a_major == "m", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( + l, + n, + k, + b_major == "n", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + # (l, rm, rk, 512) contig scale — consumed directly by the kernel. + mSFA, mSFB = a_sc_contig, b_sc_contig + sfa_ref = torch.ones_like(a_ref) + sfb_ref = torch.ones_like(b_ref) _, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty") runner = compile_blockscaled_gemm_tvm_ffi( ab_dtype, @@ -372,10 +393,12 @@ def fn(): torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) else: ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + if d_dtype != cutlass.Float32: + ref = ref.to(torch_dtype_for_cutlass(d_dtype)).float() torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) print("Ref check PASSED") - print("Running SM100 Blockscaled GEMM with:") + print(f"Running SM{sm_major}0 Blockscaled GEMM with:") print(f"mnkl: {args.mnkl}") print(f"tile_shape_mnk: {mma_tiler_mnk}, cluster_shape_mnk: {cluster_shape_mnk}") print( @@ -395,6 +418,12 @@ def fn(): # batch would be an unfair comparison (hides batching potential), so skip. print("(skipping cuBLAS: batched blockscaled mm not supported via a single call)") return + if sm_major == 12: + print( + "(skipping cuBLAS comparison: SM120 benchmark currently builds QuACK's " + "padded row-major scale tensors, not the cuBLAS/PyTorch scaled_mm scale layout)" + ) + return if a_major != "k" or b_major != "k": # F.scaled_mm requires A (M,K) row-major and B (K,N) col-major — # i.e. both operands K-contiguous. Skip for m/n-major to avoid an diff --git a/quack/blockscaled_gemm_utils.py b/quack/blockscaled_gemm_utils.py index 479c78ff..115f999a 100644 --- a/quack/blockscaled_gemm_utils.py +++ b/quack/blockscaled_gemm_utils.py @@ -11,7 +11,7 @@ from quack.compile_utils import make_fake_tensor as fake_tensor from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters -from quack.gemm_default_epi import GemmDefaultSm100 +from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 from quack.gemm_tvm_ffi_utils import div_for_dtype, make_scheduler_args from quack.mx_utils import ( to_mx_compiled, @@ -235,6 +235,53 @@ def create_blockscaled_scale_tensor( return ref, packed +def create_sm120_blockscaled_scale_tensor( + l: int, + mn: int, + k: int, + sf_vec_size: int, + dtype: Type[cutlass.Numeric], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create row-major padded SM120 blockscaled scale tensors. + + SM120 TMA copies scale factors as 2D row-major pages. Physical scale columns + are padded to a multiple of 16; columns beyond ``ceil(k / sf_vec_size)`` are + padding and ignored by the kernel. + """ + if k % 64 != 0: + raise ValueError("SM120 blockscaled scale helper requires K divisible by 64") + if sf_vec_size not in (16, 32): + raise ValueError("SM120 blockscaled scale helper supports sf_vec_size 16 or 32") + expected_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + if dtype is not expected_dtype: + raise ValueError(f"SM120 sf_vec_size={sf_vec_size} requires {expected_dtype}, got {dtype}") + sf_k = ceil_div(k, sf_vec_size) + physical_sf_k = ceil_div(sf_k, 16) * 16 + if dtype == cutlass.Float8E8M0FNU: + exponents = torch.randint(0, 2, (mn, sf_k, l), device="cuda", dtype=torch.int32) + ref_blocks = torch.pow(2.0, exponents.float()) + else: + ref_blocks = torch.randint(1, 4, (mn, sf_k, l), device="cuda", dtype=torch.int32).float() + packed = torch.empty( + (mn, physical_sf_k, l), device="cuda", dtype=torch_dtype_for_cutlass(dtype) + ) + packed[:, :sf_k, :].copy_(ref_blocks.to(packed.dtype)) + if physical_sf_k > sf_k: + poison = torch.tensor([0.5, 1.5, 2.0, 3.0], device="cuda", dtype=torch.float32).to( + packed.dtype + ) + cols = torch.arange(physical_sf_k - sf_k, device="cuda") + packed[:, sf_k:, :] = poison[cols % poison.numel()][None, :, None] + ref = ( + ref_blocks.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(1, 2, 0) + )[:, :k, :] + return ref, packed + + def pack_scale_2d_to_blocked_contig(scale_2d: torch.Tensor) -> torch.Tensor: """Rearrange a (l, mn, sf_k) or (mn, sf_k) e8m0 scale tensor into the contiguous (l, rm, rk, 512) blocked layout shared by the quack kernel and @@ -609,7 +656,7 @@ def compile_blockscaled_gemm_tvm_ffi( varlen_m: bool = False, varlen_k: bool = False, ) -> Callable: - """Compile the SM100 blockscaled GEMM. + """Compile the blockscaled GEMM. When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major, mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor. @@ -617,15 +664,63 @@ def compile_blockscaled_gemm_tvm_ffi( run(...) takes an extra cu_seqlens_k tensor. """ device_capacity = get_device_capacity(mA.device) - if device_capacity[0] not in (10, 11): - raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110") + if device_capacity[0] not in (10, 11, 12): + raise RuntimeError("Blockscaled GEMM requires SM100/SM110 or SM120") assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k" - gemm = partial( - GemmDefaultSm100, - sf_vec_size=sf_vec_size, - use_clc_persistence=use_clc_persistence, - )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1)) + mma_tiler_mn_only = mma_tiler_mn[:2] + mma_tiler_k = mma_tiler_mn[2] if len(mma_tiler_mn) == 3 else 64 + + if device_capacity[0] == 12: + if varlen_m or varlen_k: + raise NotImplementedError("SM120 blockscaled GEMM does not support varlen") + if mma_tiler_k != 64: + raise NotImplementedError("SM120 blockscaled GEMM requires tile_K=64") + if len(mA.shape) != 3 or len(mB.shape) != 3 or len(mD.shape) != 3: + raise ValueError("SM120 blockscaled GEMM requires rank-3 A/B/D tensors") + logical_k = mA.shape[1] * (2 if ab_dtype is cutlass.Float4E2M1FN else 1) + physical_scale_cols = ceil_div(ceil_div(logical_k, sf_vec_size), 16) * 16 + expected_sfa_shape = (mA.shape[0], physical_scale_cols, mA.shape[2]) + expected_sfb_shape = (mB.shape[0], physical_scale_cols, mB.shape[2]) + if tuple(mSFA.shape) != expected_sfa_shape: + raise ValueError(f"SFA shape must be {expected_sfa_shape}, got {tuple(mSFA.shape)}") + if tuple(mSFB.shape) != expected_sfb_shape: + raise ValueError(f"SFB shape must be {expected_sfb_shape}, got {tuple(mSFB.shape)}") + if not GemmDefaultSm120.can_implement_blockscaled( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + (*mma_tiler_mn_only, 64), + cluster_shape_mn, + mA.shape[0], + mB.shape[0], + logical_k, + mA.shape[2], + "k", + "k", + "n", + ): + raise RuntimeError( + "Unsupported SM120 blockscaled config; supported formats are MXFP4 " + "(Float4E2M1FN + Float8E8M0FNU + vec32) and NVFP4 " + "(Float4E2M1FN + Float8E4M3FN + vec16)" + ) + gemm = GemmDefaultSm120( + cutlass.Float32, + ab_dtype, + (*mma_tiler_mn_only, 64), + (*cluster_shape_mn, 1), + is_persistent=False, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + ) + else: + gemm = partial( + GemmDefaultSm100, + sf_vec_size=sf_vec_size, + use_clc_persistence=use_clc_persistence, + )(cutlass.Float32, ab_dtype, mma_tiler_mn, (*cluster_shape_mn, 1)) compile_epi_args = gemm.EpilogueArguments() scheduler_args = make_scheduler_args( get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]), @@ -709,7 +804,16 @@ def runner( varlen_args, stream, ): - gemm(a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None) + if cutlass.const_expr(device_capacity[0] == 12): + # SM120 FFI has already validated packed torch storage and compiles + # logical fake CuTe tensor views. Do not route through + # blockscaled_call, whose host validation expects logical class-call + # tensors rather than packed torch.float4_e2m1fn_x2 storage. + gemm._blockscaled_call_jit(a, b, d, varlen_args, stream, sfa, sfb, None) + else: + gemm( + a, b, d, None, compile_epi_args, scheduler_args, varlen_args, stream, sfa, sfb, None + ) compiled = cute.compile( runner, diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py index 1c15f3f4..02f5dff9 100644 --- a/quack/gemm_sm120.py +++ b/quack/gemm_sm120.py @@ -17,14 +17,99 @@ from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.cute.nvgpu import cpasync, warp from cutlass import Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum -from quack.varlen_utils import VarlenManager +from quack.varlen_utils import VarlenArguments, VarlenManager from quack.pipeline import make_pipeline_state from quack import copy_utils from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm from quack import sm80_utils +def _round_up(x: int, multiple: int) -> int: + return ((x + multiple - 1) // multiple) * multiple + + +@cute.jit +def _sm120_blockscaled_scale_fragment( + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + sf_vec_size: cutlass.Constexpr[int], + is_sfa: cutlass.Constexpr[bool], +): + if const_expr(sf_vec_size == 16): + if const_expr(is_sfa): + return warp.make_mxf4nvf4_sfa_fragment(dtype) + return warp.make_mxf4nvf4_sfb_fragment(dtype) + return cute.make_rmem_tensor(cute.make_layout(((32, 2),), stride=((0, 1),)), dtype) + + +@cute.jit +def _load_sm120_blockscaled_selector0_scale_fragments( + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + k_scale_base: cutlass.Int32, + sf_vec_size: cutlass.Constexpr[int], + sf_dtype: cutlass.Constexpr[Type[cutlass.Numeric]], +): + """Load selector-zero SM120 FP4 blockscaled scale packets from SMEM. + + The tuple-lowered SM120 blockscaled MMA uses byte-id-a=byte-id-b=0. For + selector zero, SFA provider lanes map tid 0 to row group and tid 1 to row + group+8; SFB provider lanes map group 0..7 to logical N columns 0..7. + """ + lane = cute.arch.lane_idx() + group = lane >> 2 + tid = lane & 3 + sfa_row = m_atom_base + group + 8 * (tid & 1) + sfb_col = n_atom_base + group + sfa = _sm120_blockscaled_scale_fragment(sf_dtype, sf_vec_size, True) + sfb = _sm120_blockscaled_scale_fragment(sf_dtype, sf_vec_size, False) + compact_sfa = cute.filter_zeros(sfa) + compact_sfb = cute.filter_zeros(sfb) + for kb in cutlass.range_constexpr(64 // sf_vec_size): + compact_sfa[kb] = sSFA[sfa_row, k_scale_base + kb, stage] + compact_sfb[kb] = sSFB[sfb_col, k_scale_base + kb, stage] + return sfa, sfb + + +@cute.jit +def _make_sm120_fp4_ldmatrix_smem_view( + smem: cute.Tensor, + mn: cutlass.Constexpr[int], +): + return cute.make_tensor( + smem.iterator, + cute.make_layout((mn, (8, 4)), stride=(64, (1, 16))), + ) + + +@cute.jit +def _expand_compact_fp4_to_sm120_ldmatrix_smem( + compact: cute.Tensor, + padded: cute.Tensor, + mn: cutlass.Constexpr[int], +): + """Expand packed FP4 bytes into the padded SM120 ldmatrix SMEM layout.""" + tidx, _, _ = cute.arch.thread_idx() + + for i in cutlass.range((mn * 64 + 31) // 32, unroll_full=True): + flat = tidx + i * 32 + if flat < mn * 64: + padded[flat // 64, flat % 64] = cutlass.Int8(0) + + for i in cutlass.range((mn * 32 + 31) // 32, unroll_full=True): + flat = tidx + i * 32 + if flat < mn * 32: + row = flat // 32 + packed_k = flat - row * 32 + group = packed_k // 8 + in_group = packed_k - group * 8 + padded[row, group * 16 + in_group] = compact[row, packed_k] + + class GemmSm120(GemmSm90): """SM120-style GEMM using warp-level MMA instead of WGMMA. @@ -49,6 +134,8 @@ def __init__( gather_A: bool = False, concat_layout: tuple | None = None, use_pdl: bool = True, + sf_vec_size: int | None = None, + sf_dtype: Type[cutlass.Numeric] | None = None, ): # Don't call super().__init__ — we set up our own config self.acc_dtype = acc_dtype @@ -59,6 +146,27 @@ def __init__( self.fp8_slow_accum = False self.gather_A = gather_A self.concat_layout = concat_layout or () + self.blockscaled = sf_vec_size is not None + self.sf_vec_size = sf_vec_size + self.sf_dtype = sf_dtype + if self.blockscaled: + if a_dtype is not cutlass.Float4E2M1FN: + raise ValueError("SM120 blockscaled path currently supports Float4E2M1FN A/B only") + if acc_dtype is not cutlass.Float32: + raise ValueError("SM120 blockscaled path requires Float32 accumulation") + if sf_vec_size not in (16, 32): + raise ValueError("SM120 blockscaled path supports sf_vec_size 16 or 32") + expected_sf_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + if sf_dtype is not expected_sf_dtype: + raise ValueError( + f"SM120 blockscaled sf_vec_size={sf_vec_size} requires {expected_sf_dtype}" + ) + if pingpong: + raise NotImplementedError("SM120 blockscaled pingpong is not implemented") + if gather_A: + raise NotImplementedError("SM120 blockscaled gather_A is not implemented") + if cluster_shape_mnk != (1, 1, 1): + raise ValueError("SM120 blockscaled path requires cluster_shape_mnk=(1,1,1)") if self.pingpong: assert self.is_persistent, "Pingpong gemm requires persistent scheduler" if gather_A: @@ -74,7 +182,7 @@ def __init__( # Pingpong: 2 warp groups each with (2,2,1) atom layout # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout - self.mma_inst_mnk = (16, 8, 16) + self.mma_inst_mnk = (16, 8, 64) if self.blockscaled else (16, 8, 16) self.atom_layout_mnk = (4, 2, 1) if not self.pingpong else (2, 2, 1) # num_mma_warps = total warps doing MMA (both warp groups in pingpong) self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) @@ -129,7 +237,14 @@ def epi_smem_warp_shape_mnk(self): def _setup_tiled_mma(self): """Set up warp-level MMA (MmaF16BF16Op) and tile K dimension.""" - op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk) + if const_expr(self.blockscaled): + if self.sf_vec_size == 16: + op = warp.MmaMXF4NVF4Op(self.a_dtype, self.acc_dtype, self.sf_dtype) + else: + op = warp.MmaMXF4Op(self.a_dtype, self.acc_dtype, self.sf_dtype) + self.mma_inst_mnk = (16, 8, 64) + else: + op = warp.MmaF16BF16Op(self.a_dtype, self.acc_dtype, self.mma_inst_mnk) tC = cute.make_layout(self.atom_layout_mnk) atom_m, atom_n, atom_k = self.atom_layout_mnk # We want each warp to have 16 consecutive elements in the N direction, for STSM @@ -152,8 +267,963 @@ def _setup_tiled_mma(self): ) self.cta_tile_shape_mnk = (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], tile_k) - # __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, - # make_sched_pipeline, epilogue are all inherited from GemmSm90. + # Dense __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, + # make_sched_pipeline, epilogue are inherited from GemmSm90. + + @staticmethod + def padded_blockscale_cols(k: int, sf_vec_size: int) -> int: + if k % 64 != 0: + raise ValueError("SM120 blockscaled GEMM requires K divisible by 64") + return _round_up(k // sf_vec_size, 16) + + @staticmethod + def can_implement_blockscaled( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mnk: Tuple[int, int] | Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + del d_major + if ab_dtype is not cutlass.Float4E2M1FN: + return False + if sf_vec_size == 16: + if sf_dtype is not cutlass.Float8E4M3FN: + return False + elif sf_vec_size == 32: + if sf_dtype is not cutlass.Float8E8M0FNU: + return False + else: + return False + if d_dtype is not cutlass.BFloat16: + return False + if a_major != "k" or b_major != "k": + return False + if cluster_shape_mn != (1, 1): + return False + if len(mma_tiler_mnk) == 3 and mma_tiler_mnk[2] != 64: + return False + if mma_tiler_mnk[0] != 128 or mma_tiler_mnk[1] != 128: + return False + return m % 128 == 0 and n % 128 == 0 and k % 64 == 0 and l == 1 + + @staticmethod + def _shape_tuple(tensor: cute.Tensor) -> tuple[int, ...]: + return tuple(int(dim) for dim in tensor.shape) + + @staticmethod + def _is_empty_varlen(varlen_args: VarlenArguments) -> bool: + return ( + varlen_args.mCuSeqlensM is None + and varlen_args.mCuSeqlensK is None + and varlen_args.mAIdx is None + ) + + def blockscaled_call( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: Optional[cute.Tensor], + mC: Optional[cute.Tensor], + epilogue_args: tuple, + scheduler_args, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: Optional[cute.Tensor] = None, + mSFB: Optional[cute.Tensor] = None, + trace_ptr: Optional[cutlass.Int64] = None, + ): + varlen_args = self._validate_blockscaled_call( + mA, mB, mD, mC, mSFA, mSFB, epilogue_args, scheduler_args, varlen_args, trace_ptr + ) + return self._blockscaled_call_jit( + mA, + mB, + mD, + varlen_args, + stream, + mSFA, + mSFB, + trace_ptr, + ) + + @cute.jit + def _blockscaled_call_jit( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: cute.Tensor, + mSFB: cute.Tensor, + trace_ptr: Optional[cutlass.Int64] = None, + ): + if const_expr(varlen_args is None): + varlen_args = VarlenArguments() + return self._call_blockscaled( + mA, + mB, + mD, + varlen_args, + stream, + mSFA, + mSFB, + trace_ptr, + ) + + def _validate_blockscaled_call( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + mC: Optional[cute.Tensor], + mSFA: Optional[cute.Tensor], + mSFB: Optional[cute.Tensor], + epilogue_args: tuple, + scheduler_args, + varlen_args: Optional[VarlenArguments], + trace_ptr: Optional[cutlass.Int64], + ) -> VarlenArguments: + if mSFA is None or mSFB is None: + raise ValueError("SM120 blockscaled GEMM requires SFA and SFB scale tensors") + if mD is None: + raise ValueError("SM120 blockscaled GEMM requires an output tensor") + if mC is not None: + raise NotImplementedError("SM120 blockscaled C/beta path is not implemented") + if epilogue_args not in (None, ()): + beta = getattr(epilogue_args, "beta", None) + add_to_output = getattr(epilogue_args, "add_to_output", False) + if beta not in (None, 0) or add_to_output: + raise NotImplementedError("SM120 blockscaled C/beta path is not implemented") + del scheduler_args + if trace_ptr is not None: + raise NotImplementedError("SM120 blockscaled trace path is not implemented") + if self.cta_tile_shape_mnk != (128, 128, 64): + raise NotImplementedError("SM120 blockscaled path currently supports 128x128x64 tiles") + if self.cluster_shape_mnk != (1, 1, 1): + raise NotImplementedError("SM120 blockscaled path requires cluster_shape_mnk=(1,1,1)") + if ( + mA.element_type is not cutlass.Float4E2M1FN + or mB.element_type is not cutlass.Float4E2M1FN + ): + raise TypeError("SM120 blockscaled path requires Float4E2M1FN A/B") + if mSFA.element_type is not self.sf_dtype or mSFB.element_type is not self.sf_dtype: + raise TypeError(f"SM120 blockscaled path requires {self.sf_dtype} SFA/SFB") + if mD.element_type is not cutlass.BFloat16: + raise NotImplementedError("SM120 blockscaled path currently supports only BF16 D") + + if varlen_args is None: + varlen_args = VarlenArguments() + if not self._is_empty_varlen(varlen_args): + raise NotImplementedError("SM120 blockscaled varlen GEMM is not implemented") + + a_shape = self._shape_tuple(mA) + b_shape = self._shape_tuple(mB) + d_shape = self._shape_tuple(mD) + if len(a_shape) != 3 or len(b_shape) != 3 or len(d_shape) != 3: + raise ValueError("SM120 blockscaled tensors must use logical rank-3 shapes") + m, k, l = a_shape + n, kb, lb = b_shape + if k != kb or l != lb: + raise ValueError("SM120 blockscaled A/B K and L extents must match") + if k % 64 != 0: + if k * 2 % 64 == 0: + raise ValueError( + "SM120 blockscaled class call expects logical Float4E2M1FN K extent; " + "use compile_blockscaled_gemm_tvm_ffi for packed torch.float4_e2m1fn_x2 " + "storage" + ) + raise ValueError("SM120 blockscaled path requires logical K to be divisible by 64") + if d_shape != (m, n, l): + raise ValueError(f"SM120 blockscaled D shape must be {(m, n, l)}, got {d_shape}") + if m % 128 != 0 or n % 128 != 0: + raise NotImplementedError("SM120 blockscaled path requires M/N multiples of 128") + if l != 1: + raise NotImplementedError("SM120 blockscaled path currently supports L=1") + scale_cols = self.padded_blockscale_cols(k, self.sf_vec_size) + if self._shape_tuple(mSFA) != (m, scale_cols, l): + raise ValueError( + f"SFA shape must be {(m, scale_cols, l)}, got {self._shape_tuple(mSFA)}" + ) + if self._shape_tuple(mSFB) != (n, scale_cols, l): + raise ValueError( + f"SFB shape must be {(n, scale_cols, l)}, got {self._shape_tuple(mSFB)}" + ) + return varlen_args + + @cute.jit + def _call_blockscaled( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: cute.Tensor, + varlen_args: Optional[VarlenArguments], + stream, + mSFA: cute.Tensor, + mSFB: cute.Tensor, + trace_ptr: Optional[cutlass.Int64] = None, + ): + del stream, trace_ptr + + self.a_dtype = mA.element_type + self.b_dtype = mB.element_type + self.d_dtype = mD.element_type + self.c_dtype = None + self.sf_dtype = mSFA.element_type + self.a_layout = LayoutEnum.from_tensor(mA) + self.b_layout = LayoutEnum.from_tensor(mB) + self.d_layout = LayoutEnum.from_tensor(mD) + self.c_layout = None + self._setup_attributes(()) + self.ab_stage = 1 + tile_m, tile_n, tile_k = self.cta_tile_shape_mnk + + self.a_smem_layout_staged = cute.make_layout( + (tile_m, tile_k, self.ab_stage), stride=(tile_k, 1, tile_m * tile_k) + ) + self.b_smem_layout_staged = cute.make_layout( + (tile_n, tile_k, self.ab_stage), stride=(tile_k, 1, tile_n * tile_k) + ) + scale_tile_k = 16 + a_compact_smem_layout_staged = cute.make_layout( + (tile_m, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_m * (tile_k // 2)), + ) + b_compact_smem_layout_staged = cute.make_layout( + (tile_n, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_n * (tile_k // 2)), + ) + self.sfa_smem_layout_staged = cute.make_layout( + (tile_m, scale_tile_k, self.ab_stage), + stride=(scale_tile_k, 1, tile_m * scale_tile_k), + ) + self.sfb_smem_layout_staged = cute.make_layout( + (tile_n, scale_tile_k, self.ab_stage), + stride=(scale_tile_k, 1, tile_n * scale_tile_k), + ) + acc_smem_layout = cute.make_layout((tile_m, tile_n), stride=(tile_n, 1)) + + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0)) + a_compact_smem_layout = cute.slice_(a_compact_smem_layout_staged, (None, None, 0)) + b_compact_smem_layout = cute.slice_(b_compact_smem_layout_staged, (None, None, 0)) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, 0)) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, 0)) + + m_extent = cute.size(mA, mode=[0]) + k_extent = cute.size(mA, mode=[1]) + packed_k_extent = k_extent // 2 + n_extent = cute.size(mB, mode=[0]) + l_extent = cute.size(mA, mode=[2]) + mA_u8 = cute.make_tensor( + cute.recast_ptr(mA.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (m_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, m_extent * packed_k_extent), + ), + ) + mB_u8 = cute.make_tensor( + cute.recast_ptr(mB.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (n_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, n_extent * packed_k_extent), + ), + ) + op = cpasync.CopyBulkTensorTileG2SOp() + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + op, mA_u8, a_compact_smem_layout, (tile_m, tile_k // 2) + ) + tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( + op, mB_u8, b_compact_smem_layout, (tile_n, tile_k // 2) + ) + tma_atom_sfa, tma_tensor_sfa = self._make_tma_atoms_and_tensors( + mSFA, sfa_smem_layout, (tile_m, scale_tile_k), 1 + ) + tma_atom_sfb, tma_tensor_sfb = self._make_tma_atoms_and_tensors( + mSFB, sfb_smem_layout, (tile_n, scale_tile_k), 1 + ) + self.num_tma_load_bytes = ( + cute.size_in_bytes(cutlass.Uint8, a_compact_smem_layout) + + cute.size_in_bytes(cutlass.Uint8, b_compact_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + ) + + a_compact_smem_size = cute.cosize(a_compact_smem_layout_staged) + b_compact_smem_size = cute.cosize(b_compact_smem_layout_staged) + sfa_smem_size = cute.cosize(self.sfa_smem_layout_staged) + sfb_smem_size = cute.cosize(self.sfb_smem_layout_staged) + acc_smem_size = cute.cosize(acc_smem_layout) + + @cute.struct + class BlockscaledSharedStorage: + ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] + sA: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, cute.cosize(self.a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, cute.cosize(self.b_smem_layout_staged)], + self.buffer_align_bytes, + ] + sACompact: cute.struct.Align[ + cute.struct.MemRange[cutlass.Uint8, a_compact_smem_size], + self.buffer_align_bytes, + ] + sBCompact: cute.struct.Align[ + cute.struct.MemRange[cutlass.Uint8, b_compact_smem_size], + self.buffer_align_bytes, + ] + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, sfa_smem_size], + self.buffer_align_bytes, + ] + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, sfb_smem_size], + self.buffer_align_bytes, + ] + # Correctness scratch: for K > 64, K64 partials stay in FP32 here + # until the final K tile writes BF16 D exactly once. + sAcc: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, acc_smem_size], + self.buffer_align_bytes, + ] + + self.shared_storage = BlockscaledSharedStorage + + if const_expr(self.sf_vec_size == 16): + mma_op = warp.MmaMXF4NVF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + else: + mma_op = warp.MmaMXF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + one_warp_mma = cute.make_tiled_mma(mma_op) + varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + self.blockscaled_kernel( + one_warp_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + mD, + varlen_params, + cute.make_layout((1, 1, 1)), + self.a_smem_layout_staged, + self.b_smem_layout_staged, + a_compact_smem_layout_staged, + b_compact_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + acc_smem_layout, + ).launch( + grid=[cute.ceil_div(m_extent, tile_m), cute.ceil_div(n_extent, tile_n), l_extent], + block=[64, 1, 1], + cluster=(1, 1, 1), + ) + + @cute.kernel + def blockscaled_kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl16: cute.Tensor, + mD_mnl: cute.Tensor, + varlen_params: VarlenManager.Params, + cluster_layout_mnk: cute.Layout, + a_smem_layout: cute.Layout, + b_smem_layout: cute.Layout, + a_compact_smem_layout: cute.Layout, + b_compact_smem_layout: cute.Layout, + sfa_smem_layout: cute.Layout, + sfb_smem_layout: cute.Layout, + acc_smem_layout: cute.Layout, + ): + del varlen_params + + tidx, _, _ = cute.arch.thread_idx() + cta_m, cta_n, cta_l = cute.arch.block_idx() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + if warp_idx == 1: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + ab_pipeline = self.make_ab_pipeline( + tiled_mma=tiled_mma, + cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)), + ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(), + ) + + pipeline_init_arrive(cluster_shape_mn=(1, 1), is_relaxed=True) + sA = storage.sA.get_tensor(a_smem_layout) + sB = storage.sB.get_tensor(b_smem_layout) + sACompact = storage.sACompact.get_tensor(a_compact_smem_layout) + sBCompact = storage.sBCompact.get_tensor(b_compact_smem_layout) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout) + sAcc = storage.sAcc.get_tensor(acc_smem_layout) + pipeline_init_wait(cluster_shape_mn=(1, 1)) + + k_tile_count = cute.size(mA_mkl, mode=[1]) // (self.cta_tile_shape_mnk[2] // 2) + scales_per_k_tile = 64 // self.sf_vec_size + + if warp_idx == 1: + producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) + gA_mk = cute.local_tile( + mA_mkl[None, None, cta_l], + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2] // 2), + (cta_m, None), + ) + gB_nk = cute.local_tile( + mB_nkl[None, None, cta_l], + (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2] // 2), + (cta_n, None), + ) + gSFA_mk16 = cute.local_tile( + mSFA_mkl16[None, None, cta_l], + (self.cta_tile_shape_mnk[0], 16), + (cta_m, None), + ) + gSFB_nk16 = cute.local_tile( + mSFB_nkl16[None, None, cta_l], + (self.cta_tile_shape_mnk[1], 16), + (cta_n, None), + ) + if const_expr(k_tile_count == 1): + producer_state = self.load_blockscaled_tma_tile( + ab_pipeline, + producer_state, + tma_atom_a, + gA_mk, + tma_atom_b, + gB_nk, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sACompact, + sBCompact, + sSFA, + sSFB, + ) + else: + for k_tile in cutlass.range(k_tile_count, unroll=1): + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + producer_state = self.load_blockscaled_tma_tile_indexed( + ab_pipeline, + producer_state, + k_tile, + scale_page, + tma_atom_a, + gA_mk, + tma_atom_b, + gB_nk, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sACompact, + sBCompact, + sSFA, + sSFB, + ) + ab_pipeline.producer_tail(producer_state) + + if warp_idx == 0: + gD_mn = cute.local_tile( + mD_mnl[None, None, cta_l], + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]), + (cta_m, cta_n), + ) + read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + self.mma_blockscaled_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + ) + + @cute.jit + def load_blockscaled_tma_tile( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + tma_atom_a: cute.CopyAtom, + gA_mk: cute.Tensor, + tma_atom_b: cute.CopyAtom, + gB_nk: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + ) -> pipeline.PipelineState: + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_mk, + dst_tensor=sACompact, + mcast_mask=0, + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_nk, + dst_tensor=sBCompact, + mcast_mask=0, + ) + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + return self.load_tma( + ab_pipeline, + producer_state, + [copy_A, copy_B, copy_SFA, copy_SFB], + Int32(1), + ) + + @cute.jit + def load_blockscaled_tma_tile_indexed( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + k_tile: cutlass.Int32, + scale_page: cutlass.Int32, + tma_atom_a: cute.CopyAtom, + gA_mk: cute.Tensor, + tma_atom_b: cute.CopyAtom, + gB_nk: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + ) -> pipeline.PipelineState: + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_mk, + dst_tensor=sACompact, + mcast_mask=0, + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_nk, + dst_tensor=sBCompact, + mcast_mask=0, + ) + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + peek_empty_status = ab_pipeline.producer_try_acquire(producer_state) + ab_pipeline.producer_acquire(producer_state, peek_empty_status) + tma_bar_ptr = ab_pipeline.producer_get_barrier(producer_state) + smem_idx = producer_state.index + copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFA(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFB(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + ab_pipeline.producer_commit(producer_state) + producer_state.advance() + return producer_state + + @cute.jit + def mma_blockscaled_kloop_store_bf16( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, + tidx: cutlass.Int32, + ) -> None: + thr_mma = tiled_mma.get_slice(tidx) + a_shape = tiled_mma.partition_shape_A((16, 64)) + b_shape = tiled_mma.partition_shape_B((8, 64)) + acc_shape = tiled_mma.partition_shape_C((16, 8)) + accum_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32) + store_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gD_mn.element_type) + scales_per_k_tile = 64 // self.sf_vec_size + if const_expr(k_tile_count == 1): + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + cutlass.Int32(0), + False, + True, + ) + else: + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + cutlass.Int32(0), + False, + False, + ) + for k_iter in cutlass.range(k_tile_count - 2, unroll=1): + k_tile = k_iter + 1 + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + k_tile, + True, + False, + ) + read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + k_tile_count - 1, + True, + True, + ) + + @cute.jit + def mma_blockscaled_one_k_tile_accumulate_smem( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + k_tile: cutlass.Int32, + add_existing: cutlass.Constexpr[bool], + store_final: cutlass.Constexpr[bool], + ) -> pipeline.PipelineState: + peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) + ab_pipeline.consumer_wait(read_state, peek_ab_full_status) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sACompact[None, None, read_state.index], + sA[None, None, read_state.index], + 128, + ) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sBCompact[None, None, read_state.index], + sB[None, None, read_state.index], + 128, + ) + cute.arch.sync_warp() + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + scale_page_offset = scale_base - scale_page * 16 + for m_atom in cutlass.range_constexpr(8): + for n_atom in cutlass.range_constexpr(16): + acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc.fill(0.0) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc, + sA[None, None, read_state.index], + sB[None, None, read_state.index], + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(n_atom * 8), + tidx, + a_shape, + b_shape, + ) + self.store_blockscaled_accum_smem_atom( + thr_mma, + accum_atom, + acc, + sAcc, + store_atom, + gD_mn, + m_atom, + n_atom, + add_existing, + store_final, + ) + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + ab_pipeline.consumer_release(read_state) + read_state.advance() + return read_state + + @cute.jit + def mma_blockscaled( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sA_stage: cute.Tensor, + sB_stage: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + k_scale_base: cutlass.Int32, + tidx: cutlass.Int32, + a_shape: cute.Shape, + b_shape: cute.Shape, + ) -> None: + sA_atom = cute.domain_offset((m_atom_base, None), sA_stage) + sB_atom = cute.domain_offset((n_atom_base, None), sB_stage) + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(num_matrices=1, unpack_bits=4), + cutlass.Int8, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) + a0, a1, a2, a3 = warp.sm120_mxf4nvf4_ldmatrix_A_regs( + tiled_copy_a, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sA_atom, 16), + ) + b0, b1 = warp.sm120_mxf4nvf4_ldmatrix_B_regs( + tiled_copy_b, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sB_atom, 8), + ) + a = cute.make_rmem_tensor(a_shape, cutlass.Float4E2M1FN) + b = cute.make_rmem_tensor(b_shape, cutlass.Float4E2M1FN) + a_i32 = cute.recast_tensor(a, cutlass.Int32) + b_i32 = cute.recast_tensor(b, cutlass.Int32) + # The asymmetric SM120 blockscaled tests catch this placement. + a_i32[0] = a0 + a_i32[1] = a2 + a_i32[2] = a1 + a_i32[3] = a3 + b_i32[0] = b0 + b_i32[1] = b1 + sfa, sfb = _load_sm120_blockscaled_selector0_scale_fragments( + sSFA, + sSFB, + stage, + m_atom_base, + n_atom_base, + k_scale_base, + self.sf_vec_size, + self.sf_dtype, + ) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + @cute.jit + def mma_blockscaled_tile_k64_accumulate( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + stage: cutlass.Int32, + scale_page_offset: cutlass.Int32, + m_atom_base: cutlass.Int32, + n_atom_base: cutlass.Int32, + tidx: cutlass.Int32, + a_shape: cute.Shape, + b_shape: cute.Shape, + ) -> None: + del tiled_mma + if const_expr(self.sf_vec_size == 16): + mma_op = warp.MmaMXF4NVF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + else: + mma_op = warp.MmaMXF4Op(cutlass.Float4E2M1FN, cutlass.Float32, self.sf_dtype) + local_tiled_mma = cute.make_tiled_mma(mma_op) + self.mma_blockscaled( + local_tiled_mma, + acc, + sA, + sB, + sSFA, + sSFB, + stage, + m_atom_base, + n_atom_base, + scale_page_offset, + tidx, + a_shape, + b_shape, + ) + + @cute.jit + def store_blockscaled_accum_smem_atom( + self, + thr_mma: cute.TiledMmaThrVal, + accum_atom: cute.CopyAtom, + acc: cute.Tensor, + sAcc: cute.Tensor, + store_atom: cute.CopyAtom, + mD_mn: cute.Tensor, + m_atom: cutlass.Constexpr[int], + n_atom: cutlass.Constexpr[int], + add_existing: cutlass.Constexpr[bool], + store_final: cutlass.Constexpr[bool], + ) -> None: + sAcc_atom = cute.local_tile(sAcc, (16, 8), (m_atom, n_atom)) + tCsAcc = thr_mma.partition_C(sAcc_atom) + if const_expr(add_existing): + tCrPrev = cute.make_rmem_tensor(acc.layout, cutlass.Float32) + cute.copy(accum_atom, tCsAcc, tCrPrev) + acc.store(acc.load() + tCrPrev.load()) + if const_expr(store_final): + gD_atom = cute.local_tile(mD_mn, (16, 8), (m_atom, n_atom)) + tCgD = thr_mma.partition_C(gD_atom) + tCrD = cute.make_rmem_tensor(acc.layout, mD_mn.element_type) + tCrD.store(acc.load().to(mD_mn.element_type)) + cute.copy(store_atom, tCrD, tCgD) + else: + cute.copy(accum_atom, acc, tCsAcc) @cute.kernel def kernel( diff --git a/tests/test_gemm_sm100_blockscaled.py b/tests/test_gemm_blockscaled.py similarity index 67% rename from tests/test_gemm_sm100_blockscaled.py rename to tests/test_gemm_blockscaled.py index 315697d8..6822002e 100644 --- a/tests/test_gemm_sm100_blockscaled.py +++ b/tests/test_gemm_blockscaled.py @@ -4,24 +4,38 @@ import cutlass from quack.blockscaled_gemm_utils import ( + FP4_E2M1FN_VALUES, blockscaled_gemm_reference, compile_blockscaled_gemm_tvm_ffi, create_blockscaled_operand_quantized, create_blockscaled_operand_tensor, create_blockscaled_scale_tensor, + create_sm120_blockscaled_scale_tensor, create_blockscaled_varlen_k_operands, create_blockscaled_varlen_m_operands, scale_blocked_for_cublas, scale_view_for_kernel, ) -from quack.gemm_default_epi import GemmDefaultSm100 +from quack.compile_utils import make_fake_tensor as fake_tensor +from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 +from quack.varlen_utils import VarlenArguments from quack.mx_utils import to_blocked def _skip_if_not_sm100(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") major = torch.cuda.get_device_properties(0).major - if major < 10: - pytest.skip("SM100+ required") + if major not in (10, 11): + pytest.skip("SM100/SM110 required") + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required") + major = torch.cuda.get_device_properties(0).major + if major != 12: + pytest.skip("SM120 required") def _compile_blockscaled_gemm( @@ -158,6 +172,99 @@ def test_blockscaled_validation(): "k", "n", ) + + +def test_sm120_blockscaled_validation(): + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 256, + 256, + 128, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float8E4M3FN, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Int8, + cutlass.Float8E8M0FNU, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128, 64), + (1, 1), + 128, + 128, + 96, + 1, + "k", + "k", + "n", + ) assert not GemmDefaultSm100.can_implement_blockscaled( cutlass.Float8E4M3FN, cutlass.Float8E8M0FNU, @@ -235,6 +342,65 @@ def test_blockscaled_validation(): ) +def test_sm120_blockscaled_class_call_validation(): + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 64), + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + assert ( + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + == VarlenArguments() + ) + with pytest.raises(ValueError, match="requires SFA and SFB"): + gemm._validate_blockscaled_call( + mA, mB, mD, None, None, mSFB, gemm.EpilogueArguments(), None, None, None + ) + with pytest.raises(NotImplementedError, match="C/beta"): + gemm._validate_blockscaled_call( + mA, mB, mD, mD, mSFA, mSFB, gemm.EpilogueArguments(), None, None, None + ) + packed_k_a = fake_tensor(cutlass.Float4E2M1FN, (m, k // 2, l), leading_dim=1, divisibility=4) + packed_k_b = fake_tensor(cutlass.Float4E2M1FN, (n, k // 2, l), leading_dim=1, divisibility=4) + with pytest.raises(ValueError, match="expects logical Float4E2M1FN K extent"): + gemm._validate_blockscaled_call( + packed_k_a, + packed_k_b, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + + @pytest.mark.parametrize( "ab_dtype,sf_dtype,sf_vec_size,d_dtype,mma_tiler_mn,cluster_shape_mn,m,n,k,l", [ @@ -485,6 +651,228 @@ def test_scale_layout_matches_cublas(mn, sf_k, l): assert torch.equal(ours, ref), f"to_blocked mismatch at l={li}" +@pytest.mark.parametrize( + "k,sf_vec_size,expected_cols", + [ + (64, 16, 16), + (128, 16, 16), + (256, 16, 16), + (384, 16, 32), + (64, 32, 16), + (256, 32, 16), + (576, 32, 32), + ], +) +def test_sm120_blockscaled_padded_scale_layout(k, sf_vec_size, expected_cols): + _skip_if_not_sm120() + mn, l = 128, 1 + sf_dtype = cutlass.Float8E4M3FN if sf_vec_size == 16 else cutlass.Float8E8M0FNU + ref, physical = create_sm120_blockscaled_scale_tensor(l, mn, k, sf_vec_size, sf_dtype) + assert tuple(physical.shape) == (mn, expected_cols, l) + assert tuple(ref.shape) == (mn, k, l) + + logical_cols = (k + sf_vec_size - 1) // sf_vec_size + if expected_cols > logical_cols: + padding = physical[:, logical_cols:, :].view(torch.uint8) + assert torch.any(padding != 0) + + +def test_sm120_blockscaled_scale_helper_validation(): + _skip_if_not_sm120() + with pytest.raises(ValueError, match="K divisible by 64"): + create_sm120_blockscaled_scale_tensor(1, 128, 96, 16, cutlass.Float8E4M3FN) + with pytest.raises(ValueError, match="sf_vec_size 16 or 32"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 8, cutlass.Float8E4M3FN) + with pytest.raises(ValueError, match="sf_vec_size=16 requires"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 16, cutlass.Float8E8M0FNU) + with pytest.raises(ValueError, match="sf_vec_size=32 requires"): + create_sm120_blockscaled_scale_tensor(1, 128, 64, 32, cutlass.Float8E4M3FN) + + +def _pack_sm120_fp4_codes(codes: torch.Tensor) -> torch.Tensor: + packed = torch.empty( + (codes.shape[0], codes.shape[1] // 2, 1), + device=codes.device, + dtype=torch.float4_e2m1fn_x2, + ) + packed.view(torch.uint8).copy_(codes[:, 0::2, None] | (codes[:, 1::2, None] << 4)) + return packed + + +def _sm120_fp4_blockscaled_reference( + a_codes: torch.Tensor, + b_codes: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + sf_vec_size: int, +) -> torch.Tensor: + table = torch.tensor(FP4_E2M1FN_VALUES, dtype=torch.float32, device=a_codes.device) + scale_k = torch.arange(a_codes.shape[1], device=a_codes.device) // sf_vec_size + a = table[a_codes.long()] * sfa.float()[:, scale_k, 0] + b = table[b_codes.long()] * sfb.float()[:, scale_k, 0] + return torch.einsum("mk,nk->mn", a, b).unsqueeze(-1) + + +def _make_sm120_scales(mn, k, sf_vec_size, sf_dtype, row_or_col_sensitive=True): + _, scales = create_sm120_blockscaled_scale_tensor(1, mn, k, sf_vec_size, sf_dtype) + logical_cols = (k + sf_vec_size - 1) // sf_vec_size + if sf_dtype == cutlass.Float8E8M0FNU: + base = torch.tensor([1.0, 2.0], device="cuda", dtype=torch.float32) + else: + base = torch.tensor([1.0, 2.0, 0.5, 1.5], device="cuda", dtype=torch.float32) + for idx in range(mn): + values = base[torch.arange(logical_cols, device="cuda") % base.numel()] + if row_or_col_sensitive: + values = values * (1.0 + 0.125 * (idx % 4)) + scales[idx, :logical_cols, 0] = values.to(scales.dtype) + return scales + + +def _compile_sm120_blockscaled_gemm(ab_dtype, sf_dtype, sf_vec_size, m, n, k, mA, mB): + l = 1 + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + mSFA = _make_sm120_scales(m, k, sf_vec_size, sf_dtype) + mSFB = _make_sm120_scales(n, k, sf_vec_size, sf_dtype) + compiled = compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD, + mSFA, + mSFB, + ) + return compiled, (mA, mB, mD, mSFA, mSFB) + + +@pytest.mark.parametrize( + "sf_dtype,sf_vec_size,m,n,k", + [ + (cutlass.Float8E4M3FN, 16, 128, 128, 64), + (cutlass.Float8E4M3FN, 16, 256, 128, 128), + (cutlass.Float8E4M3FN, 16, 128, 128, 320), + (cutlass.Float8E8M0FNU, 32, 128, 128, 64), + ], +) +def test_sm120_blockscaled_scale_correctness(sf_dtype, sf_vec_size, m, n, k): + _skip_if_not_sm120() + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + err = (d_torch.float() - ref).abs().max().item() + assert err < 1e-1, f"max_err={err}" + + +def test_sm120_blockscaled_k_loop_accumulates_before_bf16_store(): + _skip_if_not_sm120() + m = n = 128 + k = 384 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + # Make the first K64 tile very large and the remaining K64 tiles small. + # Storing BF16 after each K tile loses the later small partials; true FP32 + # accumulation keeps them until the final BF16 conversion. + a_codes[:, :64] = 0x7 + b_codes[:, :64] = 0x7 + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _, _, _, mSFA, mSFB = args + logical_cols = k // sf_vec_size + mSFA[:, :logical_cols, 0] = torch.tensor(1.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :logical_cols, 0] = torch.tensor(1.0, device="cuda", dtype=mSFB.dtype) + mSFA[:, :4, 0] = torch.tensor(3.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :4, 0] = torch.tensor(3.0, device="cuda", dtype=mSFB.dtype) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_asymmetric_fp4_and_scale_page_crossing(): + _skip_if_not_sm120() + m = n = 128 + k = 320 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + (ks % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ).expand(m, k) + b_codes = torch.where( + (ks % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ).expand(n, k) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, sf_dtype, sf_vec_size, m, n, k, mA, mB + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + err = (d_torch.float() - ref.float()).abs().max().item() + assert err < 1e-1, f"max_err={err}" + + +def test_sm120_blockscaled_rejects_compact_scale_layout(): + _skip_if_not_sm120() + l, m, n, k, sf_vec_size = 1, 128, 128, 64, 16 + ab_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + _, mA = create_blockscaled_operand_tensor(l, m, k, False, ab_dtype) + _, mB = create_blockscaled_operand_tensor(l, n, k, False, ab_dtype) + _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") + mSFA = torch.empty((m, k // sf_vec_size, l), device="cuda", dtype=torch.float8_e4m3fn) + mSFB = torch.empty((n, k // sf_vec_size, l), device="cuda", dtype=torch.float8_e4m3fn) + + with pytest.raises(ValueError, match="SFA shape"): + runner = compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + cutlass.BFloat16, + (128, 128), + (1, 1), + mA, + mB, + mD, + mSFA, + mSFB, + ) + runner(mA, mB, mD, mSFA, mSFB) + + # --------------------------------------------------------------------------- # End-to-end: quantized MXFP8 inputs through quack kernel vs cuBLAS vs dequant ref # --------------------------------------------------------------------------- From 09dbc28ab881101705a4bbfc0ddbdb068de40abd Mon Sep 17 00:00:00 2001 From: agent Date: Thu, 30 Apr 2026 15:24:21 +0200 Subject: [PATCH 2/4] [Benchmark] refresh blockscaled GEMM examples Update the benchmark module examples to match the current CLI. The blockscaled path is selected by --sf_dtype and/or --sf_vec_size, so remove the stale --blockscaled flag and other old example-only flags that are not parsed by the script.\n\nUse BF16 output in the FP4 SM120 examples because that is the supported SM120 benchmark path. The benchmark remains a launch/timing utility; correctness is checked against the local reference, and cuBLAS comparison remains skipped for SM120 padded row-major scale tensors. --- benchmarks/benchmark_gemm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index e5abfe10..7a00e37d 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -7,8 +7,9 @@ from quack.gemm import gemm as quack_gemm """ -GEMM benchmark using quack.gemm.gemm() (dense path) or the SM100 blockscaled -path (MXFP8 / MXFP4 / NVFP4) via --blockscaled. +GEMM benchmark using quack.gemm.gemm() (dense path) or the blockscaled +path (MXFP8 / MXFP4 / NVFP4). The blockscaled path is selected by passing +--sf_dtype and/or --sf_vec_size. Usage (dense): python benchmarks/benchmark_gemm.py --mnkl 512,7168,2048,256 \ @@ -17,18 +18,17 @@ Usage (blockscaled MXFP8, with cuBLAS comparison): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --init quant --compare_cublas + --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --sf_vec_size 32 Usage (blockscaled MXFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ + --sf_vec_size 32 --d_dtype BFloat16 Usage (blockscaled NVFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ - --sf_vec_size 16 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 """ From 83c706d66e2a700e1e9995f48b425b9831a9c344 Mon Sep 17 00:00:00 2001 From: agent Date: Fri, 1 May 2026 16:04:43 +0200 Subject: [PATCH 3/4] [Sm120] add packed blockscaled LDSM path Add an opt-in SM120 blockscaled path that uses packed subbyte shared-memory fragments with ordinary m8n8.x4 ldmatrix instead of the original byte-expanded b4x16_p64 unpack path. The new path is gated by QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 so the existing correctness-first path remains available while the packed path is reviewed and tuned. The direction follows the local CUTLASS GeForce reference in examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu: CUTLASS uses a packed subbyte shared-memory consumer with m8n8 ldmatrix for the SM120 NVFP4 path rather than QuACK's earlier padded b4x16_p64 unpack route. This commit ports the relevant consumer-side idea while deliberately keeping QuACK's narrower, already-proven per-atom TMA producer instead of importing the full CUTLASS mainloop. The implementation keeps the scope narrow: Float4E2M1FN A/B, Float8E4M3FN NVFP4 or Float8E8M0FNU MXFP4 scales, BF16 D, C None, beta 0, cluster 1x1x1, and tile-aligned SM120 shapes. can_implement_blockscaled and direct class-call validation now share the same advertised tile contract: the existing 128x128x64 correctness path, 64x64x64, and 64x64x128 only when the packed path is explicitly enabled. Other mixed 64/128 tile shapes are rejected before launch. A globally set QUACK_SM120_BLOCKSCALED_PACKED_LDSM no longer hijacks the existing 128x128x64 fallback path. Packed mode is only activated for the two packed-supported tiles, so the env var can be left set while callers still compile or validate the correctness-first 128x128x64 path. For 64x64 tiles, four consumer warps each own one 16-row band and keep eight 16x8 FP32 accumulators live across all K tiles before storing BF16 once. This preserves Float32 accumulation semantics without the generic sAcc scratch traffic used by the wider fallback path. The 64x64x128 path also avoids allocating the full sAcc scratch because it uses the same register-resident accumulator path. This PR intentionally ships only one runtime knob: QUACK_SM120_BLOCKSCALED_PACKED_LDSM. Earlier profiling-only controls for stage count, consumer-warp count, and sync mode are omitted from the production PR surface so reviewers do not have to reason about untested scheduling variants. Full-tile and grouped TMA are deliberately not included here. Local experiments showed the current CuTe DSL subbyte/swizzled full-tile TMA layout either fails legalization, degenerates into many tiny TMA sites, or times out for nested raw FP4 layouts. This commit keeps the production path on the proven per-atom TMA mechanism and uses tile_K=128 to amortize producer/barrier overhead while leaving the full-tile layout issue for a separate repro/upstream track. Tests cover capability-contract negatives for unsupported K64 and K128 tile shapes, direct class-call rejection of unadvertised shapes, global packed-env compatibility with the 128x128x64 fallback path, packed K64/K128 NVFP4 correctness, packed K64/K128 MXFP4 correctness, asymmetric FP4 data, scale-page crossing with poisoned scale padding, a K128 scale-offset regression where the first and second K64 halves use different scales, pre-launch rejection for tile_K=128 without the packed path, and PTX regression checks requiring m8n8.x4.shared.b16 plus m16n8k64 mxf4nvf4 while rejecting b4x16_p64, m8n16, multicast, and shared::cluster. --- quack/blockscaled_gemm_utils.py | 20 +- quack/gemm_sm120.py | 1144 ++++++++++++++++++++++++++----- tests/test_gemm_blockscaled.py | 553 ++++++++++++++- 3 files changed, 1529 insertions(+), 188 deletions(-) diff --git a/quack/blockscaled_gemm_utils.py b/quack/blockscaled_gemm_utils.py index 115f999a..ed9580ea 100644 --- a/quack/blockscaled_gemm_utils.py +++ b/quack/blockscaled_gemm_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, Tri Dao. import itertools +import os from functools import partial from typing import Callable, Optional, Type, Tuple @@ -655,6 +656,7 @@ def compile_blockscaled_gemm_tvm_ffi( use_clc_persistence: bool = True, varlen_m: bool = False, varlen_k: bool = False, + compile_options: str = "--enable-tvm-ffi", ) -> Callable: """Compile the blockscaled GEMM. @@ -674,8 +676,12 @@ def compile_blockscaled_gemm_tvm_ffi( if device_capacity[0] == 12: if varlen_m or varlen_k: raise NotImplementedError("SM120 blockscaled GEMM does not support varlen") - if mma_tiler_k != 64: - raise NotImplementedError("SM120 blockscaled GEMM requires tile_K=64") + if mma_tiler_k not in (64, 128): + raise NotImplementedError("SM120 blockscaled GEMM requires tile_K in {64,128}") + if mma_tiler_k == 128 and os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") != "1": + raise NotImplementedError( + "SM120 blockscaled tile_K=128 requires QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1" + ) if len(mA.shape) != 3 or len(mB.shape) != 3 or len(mD.shape) != 3: raise ValueError("SM120 blockscaled GEMM requires rank-3 A/B/D tensors") logical_k = mA.shape[1] * (2 if ab_dtype is cutlass.Float4E2M1FN else 1) @@ -691,7 +697,7 @@ def compile_blockscaled_gemm_tvm_ffi( sf_dtype, sf_vec_size, d_dtype, - (*mma_tiler_mn_only, 64), + (*mma_tiler_mn_only, mma_tiler_k), cluster_shape_mn, mA.shape[0], mB.shape[0], @@ -709,7 +715,7 @@ def compile_blockscaled_gemm_tvm_ffi( gemm = GemmDefaultSm120( cutlass.Float32, ab_dtype, - (*mma_tiler_mn_only, 64), + (*mma_tiler_mn_only, mma_tiler_k), (*cluster_shape_mn, 1), is_persistent=False, sf_vec_size=sf_vec_size, @@ -824,7 +830,7 @@ def runner( _make_compile_tensor_like(mSFB, sf_dtype, dynamic_layout=True), varlen_args_fake, stream, - options="--enable-tvm-ffi", + options=compile_options, ) if varlen_m or varlen_k: @@ -840,6 +846,10 @@ def run(a, b, d, sfa, sfb, cu_seqlens): def run(a, b, d, sfa, sfb): compiled(a, b, d, sfa, sfb, VarlenArguments()) + for attr in ("__ptx__", "__cubin__"): + if hasattr(compiled, attr): + setattr(run, attr, getattr(compiled, attr)) + return run diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py index 02f5dff9..a279fe00 100644 --- a/quack/gemm_sm120.py +++ b/quack/gemm_sm120.py @@ -8,6 +8,7 @@ # This is a work in progress and not very optimized. import math +import os from typing import Tuple, Type, Callable, Optional from functools import partial @@ -91,17 +92,18 @@ def _expand_compact_fp4_to_sm120_ldmatrix_smem( compact: cute.Tensor, padded: cute.Tensor, mn: cutlass.Constexpr[int], + thread_count: cutlass.Constexpr[int] = 32, ): """Expand packed FP4 bytes into the padded SM120 ldmatrix SMEM layout.""" tidx, _, _ = cute.arch.thread_idx() - for i in cutlass.range((mn * 64 + 31) // 32, unroll_full=True): - flat = tidx + i * 32 + for i in cutlass.range((mn * 64 + thread_count - 1) // thread_count, unroll_full=True): + flat = tidx + i * thread_count if flat < mn * 64: padded[flat // 64, flat % 64] = cutlass.Int8(0) - for i in cutlass.range((mn * 32 + 31) // 32, unroll_full=True): - flat = tidx + i * 32 + for i in cutlass.range((mn * 32 + thread_count - 1) // thread_count, unroll_full=True): + flat = tidx + i * thread_count if flat < mn * 32: row = flat // 32 packed_k = flat - row * 32 @@ -227,6 +229,8 @@ def __init__( self.epi_m_major = True self.a_smem_layout_staged = None self.b_smem_layout_staged = None + self.a_smem_store_layout_staged = None + self.b_smem_store_layout_staged = None self.epi_smem_layout_staged = None self.epi_tile = None self.shared_storage = None @@ -309,11 +313,16 @@ def can_implement_blockscaled( return False if cluster_shape_mn != (1, 1): return False - if len(mma_tiler_mnk) == 3 and mma_tiler_mnk[2] != 64: + if len(mma_tiler_mnk) == 3 and mma_tiler_mnk[2] not in (64, 128): return False - if mma_tiler_mnk[0] != 128 or mma_tiler_mnk[1] != 128: + tile_m, tile_n = mma_tiler_mnk[:2] + tile_k = mma_tiler_mnk[2] if len(mma_tiler_mnk) == 3 else 64 + supported_tiles = {(128, 128, 64), (64, 64, 64)} + if os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") == "1": + supported_tiles.add((64, 64, 128)) + if (tile_m, tile_n, tile_k) not in supported_tiles: return False - return m % 128 == 0 and n % 128 == 0 and k % 64 == 0 and l == 1 + return m % tile_m == 0 and n % tile_n == 0 and k % tile_k == 0 and l == 1 @staticmethod def _shape_tuple(tensor: cute.Tensor) -> tuple[int, ...]: @@ -407,8 +416,15 @@ def _validate_blockscaled_call( del scheduler_args if trace_ptr is not None: raise NotImplementedError("SM120 blockscaled trace path is not implemented") - if self.cta_tile_shape_mnk != (128, 128, 64): - raise NotImplementedError("SM120 blockscaled path currently supports 128x128x64 tiles") + tile_m, tile_n, tile_k = self.cta_tile_shape_mnk + supported_tiles = {(128, 128, 64), (64, 64, 64)} + if os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") == "1": + supported_tiles.add((64, 64, 128)) + if (tile_m, tile_n, tile_k) not in supported_tiles: + raise NotImplementedError( + "SM120 blockscaled path currently supports tile shapes " + "(128,128,64), (64,64,64), and opt-in (64,64,128)" + ) if self.cluster_shape_mnk != (1, 1, 1): raise NotImplementedError("SM120 blockscaled path requires cluster_shape_mnk=(1,1,1)") if ( @@ -435,18 +451,18 @@ def _validate_blockscaled_call( n, kb, lb = b_shape if k != kb or l != lb: raise ValueError("SM120 blockscaled A/B K and L extents must match") - if k % 64 != 0: - if k * 2 % 64 == 0: + if k % tile_k != 0: + if k * 2 % tile_k == 0: raise ValueError( "SM120 blockscaled class call expects logical Float4E2M1FN K extent; " "use compile_blockscaled_gemm_tvm_ffi for packed torch.float4_e2m1fn_x2 " "storage" ) - raise ValueError("SM120 blockscaled path requires logical K to be divisible by 64") + raise ValueError("SM120 blockscaled path requires logical K to be divisible by tile_K") if d_shape != (m, n, l): raise ValueError(f"SM120 blockscaled D shape must be {(m, n, l)}, got {d_shape}") - if m % 128 != 0 or n % 128 != 0: - raise NotImplementedError("SM120 blockscaled path requires M/N multiples of 128") + if m % tile_m != 0 or n % tile_n != 0: + raise NotImplementedError("SM120 blockscaled path requires M/N multiples of CTA M/N") if l != 1: raise NotImplementedError("SM120 blockscaled path currently supports L=1") scale_cols = self.padded_blockscale_cols(k, self.sf_vec_size) @@ -484,24 +500,80 @@ def _call_blockscaled( self.d_layout = LayoutEnum.from_tensor(mD) self.c_layout = None self._setup_attributes(()) - self.ab_stage = 1 + self.ab_stage = 2 if self.cta_tile_shape_mnk in ((64, 64, 64), (64, 64, 128)) else 1 + # Split independent 16x8 output atoms across four consumer warps. + # Warp 4 owns TMA production; warp 0 owns pipeline wait/release. + self.blockscaled_consumer_warps = 4 + self.blockscaled_producer_warp = 4 + packed_ldsm_override = os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") + if const_expr(packed_ldsm_override is not None and packed_ldsm_override not in ("0", "1")): + raise ValueError("QUACK_SM120_BLOCKSCALED_PACKED_LDSM must be 0 or 1") tile_m, tile_n, tile_k = self.cta_tile_shape_mnk - - self.a_smem_layout_staged = cute.make_layout( - (tile_m, tile_k, self.ab_stage), stride=(tile_k, 1, tile_m * tile_k) - ) - self.b_smem_layout_staged = cute.make_layout( - (tile_n, tile_k, self.ab_stage), stride=(tile_k, 1, tile_n * tile_k) + self.blockscaled_packed_ldsm = packed_ldsm_override == "1" and (tile_m, tile_n, tile_k) in ( + (64, 64, 64), + (64, 64, 128), ) + if const_expr(not self.blockscaled_packed_ldsm and tile_k != 64): + raise NotImplementedError( + "SM120 blockscaled tile_K=128 currently requires " + "QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1" + ) + if const_expr( + self.blockscaled_packed_ldsm + and (tile_m, tile_n, tile_k) == (64, 64, 128) + and self.blockscaled_consumer_warps != 4 + ): + raise NotImplementedError( + "SM120 packed-subbyte ldmatrix tile 64x64x128 currently requires " + "four consumer warps" + ) + + if const_expr(self.blockscaled_packed_ldsm): + # Direct TMA uses a visible 16x64 / 8x64 FP4 tile composed with + # the same swizzle that the packed ldmatrix consumer partitions. + # The separate 8x128 atom view below is only for the consumer. + self.a_smem_store_layout_staged = cute.make_layout( + (8, 128, tile_m // 16, self.ab_stage), + stride=(128, 1, 8 * 128, (tile_m // 16) * 8 * 128), + ) + self.b_smem_store_layout_staged = cute.make_layout( + (8, 128, tile_n // 8, self.ab_stage), + stride=(128, 1, 8 * 128, (tile_n // 8) * 8 * 128), + ) + # The consumer view is typed as 4-bit elements, so the outer atom + # and stage strides are expressed in FP4 element units. They are + # twice the raw Uint8 store strides above to advance by the same + # physical byte distance. + self.a_smem_layout_staged = cute.make_layout( + (8, 128, tile_m // 16, self.ab_stage), + stride=(128, 1, 2 * 8 * 128, (tile_m // 16) * 2 * 8 * 128), + ) + self.b_smem_layout_staged = cute.make_layout( + (8, 128, tile_n // 8, self.ab_stage), + stride=(128, 1, 2 * 8 * 128, (tile_n // 8) * 2 * 8 * 128), + ) + else: + self.a_smem_layout_staged = cute.make_layout( + (tile_m, tile_k, self.ab_stage), stride=(tile_k, 1, tile_m * tile_k) + ) + self.b_smem_layout_staged = cute.make_layout( + (tile_n, tile_k, self.ab_stage), stride=(tile_k, 1, tile_n * tile_k) + ) + self.a_smem_store_layout_staged = self.a_smem_layout_staged + self.b_smem_store_layout_staged = self.b_smem_layout_staged scale_tile_k = 16 - a_compact_smem_layout_staged = cute.make_layout( - (tile_m, tile_k // 2, self.ab_stage), - stride=(tile_k // 2, 1, tile_m * (tile_k // 2)), - ) - b_compact_smem_layout_staged = cute.make_layout( - (tile_n, tile_k // 2, self.ab_stage), - stride=(tile_k // 2, 1, tile_n * (tile_k // 2)), - ) + if const_expr(self.blockscaled_packed_ldsm): + a_compact_smem_layout_staged = cute.make_layout((1, 1, self.ab_stage), stride=(1, 1, 1)) + b_compact_smem_layout_staged = cute.make_layout((1, 1, self.ab_stage), stride=(1, 1, 1)) + else: + a_compact_smem_layout_staged = cute.make_layout( + (tile_m, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_m * (tile_k // 2)), + ) + b_compact_smem_layout_staged = cute.make_layout( + (tile_n, tile_k // 2, self.ab_stage), + stride=(tile_k // 2, 1, tile_n * (tile_k // 2)), + ) self.sfa_smem_layout_staged = cute.make_layout( (tile_m, scale_tile_k, self.ab_stage), stride=(scale_tile_k, 1, tile_m * scale_tile_k), @@ -512,67 +584,110 @@ def _call_blockscaled( ) acc_smem_layout = cute.make_layout((tile_m, tile_n), stride=(tile_n, 1)) - a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0)) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0)) + if const_expr(self.blockscaled_packed_ldsm): + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + else: + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0)) a_compact_smem_layout = cute.slice_(a_compact_smem_layout_staged, (None, None, 0)) b_compact_smem_layout = cute.slice_(b_compact_smem_layout_staged, (None, None, 0)) sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, 0)) sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, 0)) + op = cpasync.CopyBulkTensorTileG2SOp() m_extent = cute.size(mA, mode=[0]) k_extent = cute.size(mA, mode=[1]) - packed_k_extent = k_extent // 2 n_extent = cute.size(mB, mode=[0]) l_extent = cute.size(mA, mode=[2]) - mA_u8 = cute.make_tensor( - cute.recast_ptr(mA.iterator, dtype=cutlass.Uint8), - cute.make_layout( - (m_extent, packed_k_extent, l_extent), - stride=(packed_k_extent, 1, m_extent * packed_k_extent), - ), - ) - mB_u8 = cute.make_tensor( - cute.recast_ptr(mB.iterator, dtype=cutlass.Uint8), - cute.make_layout( - (n_extent, packed_k_extent, l_extent), - stride=(packed_k_extent, 1, n_extent * packed_k_extent), - ), - ) - op = cpasync.CopyBulkTensorTileG2SOp() - tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( - op, mA_u8, a_compact_smem_layout, (tile_m, tile_k // 2) - ) - tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( - op, mB_u8, b_compact_smem_layout, (tile_n, tile_k // 2) - ) + if const_expr(self.blockscaled_packed_ldsm): + a_tma_payload_layout = cute.make_layout((16, tile_k), stride=(tile_k, 1)) + b_tma_payload_layout = cute.make_layout((8, tile_k), stride=(tile_k, 1)) + a_tma_smem_layout = cute.make_composed_layout( + cute.make_swizzle(2, 4, 3), + 0, + a_tma_payload_layout, + ) + b_tma_smem_layout = cute.make_composed_layout( + cute.make_swizzle(2, 4, 3), + 0, + b_tma_payload_layout, + ) + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + op, mA, a_tma_smem_layout, (16, tile_k) + ) + tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( + op, mB, b_tma_smem_layout, (8, tile_k) + ) + else: + packed_k_extent = k_extent // 2 + mA_u8 = cute.make_tensor( + cute.recast_ptr(mA.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (m_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, m_extent * packed_k_extent), + ), + ) + mB_u8 = cute.make_tensor( + cute.recast_ptr(mB.iterator, dtype=cutlass.Uint8), + cute.make_layout( + (n_extent, packed_k_extent, l_extent), + stride=(packed_k_extent, 1, n_extent * packed_k_extent), + ), + ) + tma_atom_a, tma_tensor_a = cpasync.make_tiled_tma_atom( + op, mA_u8, a_compact_smem_layout, (tile_m, tile_k // 2) + ) + tma_atom_b, tma_tensor_b = cpasync.make_tiled_tma_atom( + op, mB_u8, b_compact_smem_layout, (tile_n, tile_k // 2) + ) tma_atom_sfa, tma_tensor_sfa = self._make_tma_atoms_and_tensors( mSFA, sfa_smem_layout, (tile_m, scale_tile_k), 1 ) tma_atom_sfb, tma_tensor_sfb = self._make_tma_atoms_and_tensors( mSFB, sfb_smem_layout, (tile_n, scale_tile_k), 1 ) - self.num_tma_load_bytes = ( - cute.size_in_bytes(cutlass.Uint8, a_compact_smem_layout) - + cute.size_in_bytes(cutlass.Uint8, b_compact_smem_layout) - + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) - + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) - ) + if const_expr(self.blockscaled_packed_ldsm): + self.num_tma_load_bytes = ( + (tile_m // 16) * cute.size_in_bytes(cutlass.Float4E2M1FN, a_tma_payload_layout) + + (tile_n // 8) * cute.size_in_bytes(cutlass.Float4E2M1FN, b_tma_payload_layout) + + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + ) + else: + self.num_tma_load_bytes = ( + cute.size_in_bytes(cutlass.Uint8, a_compact_smem_layout) + + cute.size_in_bytes(cutlass.Uint8, b_compact_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + + cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + ) a_compact_smem_size = cute.cosize(a_compact_smem_layout_staged) b_compact_smem_size = cute.cosize(b_compact_smem_layout_staged) sfa_smem_size = cute.cosize(self.sfa_smem_layout_staged) sfb_smem_size = cute.cosize(self.sfb_smem_layout_staged) - acc_smem_size = cute.cosize(acc_smem_layout) + uses_64x64_register_accum = ( + self.cta_tile_shape_mnk[0] == 64 + and self.cta_tile_shape_mnk[1] == 64 + and self.blockscaled_consumer_warps == 4 + ) + acc_smem_size = 1 if const_expr(uses_64x64_register_accum) else cute.cosize(acc_smem_layout) @cute.struct class BlockscaledSharedStorage: ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] sA: cute.struct.Align[ - cute.struct.MemRange[cutlass.Int8, cute.cosize(self.a_smem_layout_staged)], + cute.struct.MemRange[ + cutlass.Uint8 if self.blockscaled_packed_ldsm else cutlass.Int8, + cute.cosize(self.a_smem_store_layout_staged), + ], self.buffer_align_bytes, ] sB: cute.struct.Align[ - cute.struct.MemRange[cutlass.Int8, cute.cosize(self.b_smem_layout_staged)], + cute.struct.MemRange[ + cutlass.Uint8 if self.blockscaled_packed_ldsm else cutlass.Int8, + cute.cosize(self.b_smem_store_layout_staged), + ], self.buffer_align_bytes, ] sACompact: cute.struct.Align[ @@ -619,6 +734,8 @@ class BlockscaledSharedStorage: mD, varlen_params, cute.make_layout((1, 1, 1)), + self.a_smem_store_layout_staged, + self.b_smem_store_layout_staged, self.a_smem_layout_staged, self.b_smem_layout_staged, a_compact_smem_layout_staged, @@ -628,7 +745,7 @@ class BlockscaledSharedStorage: acc_smem_layout, ).launch( grid=[cute.ceil_div(m_extent, tile_m), cute.ceil_div(n_extent, tile_n), l_extent], - block=[64, 1, 1], + block=[(self.blockscaled_consumer_warps + 1) * cute.arch.WARP_SIZE, 1, 1], cluster=(1, 1, 1), ) @@ -647,6 +764,8 @@ def blockscaled_kernel( mD_mnl: cute.Tensor, varlen_params: VarlenManager.Params, cluster_layout_mnk: cute.Layout, + a_smem_store_layout: cute.Layout, + b_smem_store_layout: cute.Layout, a_smem_layout: cute.Layout, b_smem_layout: cute.Layout, a_compact_smem_layout: cute.Layout, @@ -661,7 +780,7 @@ def blockscaled_kernel( cta_m, cta_n, cta_l = cute.arch.block_idx() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 1: + if warp_idx == self.blockscaled_producer_warp: cpasync.prefetch_descriptor(tma_atom_a) cpasync.prefetch_descriptor(tma_atom_b) cpasync.prefetch_descriptor(tma_atom_sfa) @@ -676,30 +795,64 @@ def blockscaled_kernel( ) pipeline_init_arrive(cluster_shape_mn=(1, 1), is_relaxed=True) - sA = storage.sA.get_tensor(a_smem_layout) - sB = storage.sB.get_tensor(b_smem_layout) sACompact = storage.sACompact.get_tensor(a_compact_smem_layout) sBCompact = storage.sBCompact.get_tensor(b_compact_smem_layout) + if const_expr(self.blockscaled_packed_ldsm): + sAStore = storage.sA.get_tensor(a_smem_store_layout) + sBStore = storage.sB.get_tensor(b_smem_store_layout) + a_tma_smem_layout = cute.make_layout( + ((16, self.cta_tile_shape_mnk[0] // 16), self.cta_tile_shape_mnk[2], self.ab_stage), + stride=( + (self.cta_tile_shape_mnk[2], 2 * 8 * 128), + 1, + (self.cta_tile_shape_mnk[0] // 16) * 2 * 8 * 128, + ), + ) + b_tma_smem_layout = cute.make_layout( + ((8, self.cta_tile_shape_mnk[1] // 8), self.cta_tile_shape_mnk[2], self.ab_stage), + stride=( + (self.cta_tile_shape_mnk[2], 2 * 8 * 128), + 1, + (self.cta_tile_shape_mnk[1] // 8) * 2 * 8 * 128, + ), + ) + sATma = storage.sA.get_tensor( + a_tma_smem_layout, + swizzle=cute.make_swizzle(2, 4, 3), + dtype=cutlass.Float4E2M1FN, + ) + sBTma = storage.sB.get_tensor( + b_tma_smem_layout, + swizzle=cute.make_swizzle(2, 4, 3), + dtype=cutlass.Float4E2M1FN, + ) + sA = storage.sA.get_tensor( + a_smem_layout, swizzle=cute.make_swizzle(2, 4, 3), dtype=cutlass.Float4E2M1FN + ) + sB = storage.sB.get_tensor( + b_smem_layout, swizzle=cute.make_swizzle(2, 4, 3), dtype=cutlass.Float4E2M1FN + ) + else: + sA = storage.sA.get_tensor(a_smem_layout) + sB = storage.sB.get_tensor(b_smem_layout) + sAStore = sA + sBStore = sB + sATma = sACompact + sBTma = sBCompact sSFA = storage.sSFA.get_tensor(sfa_smem_layout) sSFB = storage.sSFB.get_tensor(sfb_smem_layout) sAcc = storage.sAcc.get_tensor(acc_smem_layout) pipeline_init_wait(cluster_shape_mn=(1, 1)) - k_tile_count = cute.size(mA_mkl, mode=[1]) // (self.cta_tile_shape_mnk[2] // 2) - scales_per_k_tile = 64 // self.sf_vec_size + k_tile_count = cute.size(mA_mkl, mode=[1]) // ( + self.cta_tile_shape_mnk[2] + if const_expr(self.blockscaled_packed_ldsm) + else self.cta_tile_shape_mnk[2] // 2 + ) + scales_per_k_tile = self.cta_tile_shape_mnk[2] // self.sf_vec_size - if warp_idx == 1: + if warp_idx == self.blockscaled_producer_warp: producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) - gA_mk = cute.local_tile( - mA_mkl[None, None, cta_l], - (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2] // 2), - (cta_m, None), - ) - gB_nk = cute.local_tile( - mB_nkl[None, None, cta_l], - (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2] // 2), - (cta_n, None), - ) gSFA_mk16 = cute.local_tile( mSFA_mkl16[None, None, cta_l], (self.cta_tile_shape_mnk[0], 16), @@ -710,32 +863,69 @@ def blockscaled_kernel( (self.cta_tile_shape_mnk[1], 16), (cta_n, None), ) - if const_expr(k_tile_count == 1): - producer_state = self.load_blockscaled_tma_tile( - ab_pipeline, - producer_state, - tma_atom_a, - gA_mk, - tma_atom_b, - gB_nk, - tma_atom_sfa, - gSFA_mk16, - tma_atom_sfb, - gSFB_nk16, - sACompact, - sBCompact, - sSFA, - sSFB, - ) + if const_expr(self.blockscaled_packed_ldsm): + if const_expr(k_tile_count == 1): + producer_state = self.load_blockscaled_tma_tile_packed_direct( + ab_pipeline, + producer_state, + cutlass.Int32(0), + cutlass.Int32(0), + tma_atom_a, + mA_mkl, + tma_atom_b, + mB_nkl, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sATma, + sBTma, + sSFA, + sSFB, + cta_m, + cta_n, + cta_l, + ) + else: + for k_tile in cutlass.range(k_tile_count, unroll=1): + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + producer_state = self.load_blockscaled_tma_tile_packed_direct( + ab_pipeline, + producer_state, + k_tile, + scale_page, + tma_atom_a, + mA_mkl, + tma_atom_b, + mB_nkl, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sATma, + sBTma, + sSFA, + sSFB, + cta_m, + cta_n, + cta_l, + ) else: - for k_tile in cutlass.range(k_tile_count, unroll=1): - scale_base = k_tile * scales_per_k_tile - scale_page = scale_base // 16 - producer_state = self.load_blockscaled_tma_tile_indexed( + gA_mk = cute.local_tile( + mA_mkl[None, None, cta_l], + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2] // 2), + (cta_m, None), + ) + gB_nk = cute.local_tile( + mB_nkl[None, None, cta_l], + (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2] // 2), + (cta_n, None), + ) + if const_expr(k_tile_count == 1): + producer_state = self.load_blockscaled_tma_tile( ab_pipeline, producer_state, - k_tile, - scale_page, tma_atom_a, gA_mk, tma_atom_b, @@ -744,14 +934,37 @@ def blockscaled_kernel( gSFA_mk16, tma_atom_sfb, gSFB_nk16, - sACompact, - sBCompact, + sATma, + sBTma, sSFA, sSFB, ) + else: + for k_tile in cutlass.range(k_tile_count, unroll=1): + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + producer_state = self.load_blockscaled_tma_tile_indexed( + ab_pipeline, + producer_state, + k_tile, + scale_page, + tma_atom_a, + gA_mk, + tma_atom_b, + gB_nk, + tma_atom_sfa, + gSFA_mk16, + tma_atom_sfb, + gSFB_nk16, + sACompact, + sBCompact, + sSFA, + sSFB, + ) ab_pipeline.producer_tail(producer_state) - if warp_idx == 0: + if warp_idx < self.blockscaled_consumer_warps: + lane_idx = cute.arch.lane_idx() gD_mn = cute.local_tile( mD_mnl[None, None, cta_l], (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]), @@ -764,6 +977,8 @@ def blockscaled_kernel( tiled_mma, sA, sB, + sAStore, + sBStore, sACompact, sBCompact, sSFA, @@ -771,9 +986,93 @@ def blockscaled_kernel( sAcc, gD_mn, k_tile_count, - tidx, + lane_idx, + warp_idx, ) + @cute.jit + def load_blockscaled_tma_tile_packed_direct( + self, + ab_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + k_tile: cutlass.Int32, + scale_page: cutlass.Int32, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + gSFA_mk16: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + gSFB_nk16: cute.Tensor, + sATma: cute.Tensor, + sBTma: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + cta_m: cutlass.Int32, + cta_n: cutlass.Int32, + cta_l: cutlass.Int32, + ) -> pipeline.PipelineState: + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFA_mk16, + dst_tensor=sSFA, + mcast_mask=0, + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gSFB_nk16, + dst_tensor=sSFB, + mcast_mask=0, + ) + + peek_empty_status = ab_pipeline.producer_try_acquire(producer_state) + ab_pipeline.producer_acquire(producer_state, peek_empty_status) + tma_bar_ptr = ab_pipeline.producer_get_barrier(producer_state) + smem_idx = producer_state.index + + for m_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[0] // 16): + gA_atom = cute.local_tile( + mA_mkl[None, None, cta_l], + (16, self.cta_tile_shape_mnk[2]), + (cta_m * (self.cta_tile_shape_mnk[0] // 16) + m_atom, None), + ) + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gA_atom, + dst_tensor=sATma[(None, m_atom), None, None], + mcast_mask=0, + ) + copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + + for n_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[1] // 8): + gB_atom = cute.local_tile( + mB_nkl[None, None, cta_l], + (8, self.cta_tile_shape_mnk[2]), + (cta_n * (self.cta_tile_shape_mnk[1] // 8) + n_atom, None), + ) + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=gB_atom, + dst_tensor=sBTma[(None, n_atom), None, None], + mcast_mask=0, + ) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + + copy_SFA(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFB(scale_page, smem_idx, tma_bar_ptr=tma_bar_ptr) + ab_pipeline.producer_commit(producer_state) + producer_state.advance() + return producer_state + @cute.jit def load_blockscaled_tma_tile( self, @@ -903,6 +1202,8 @@ def mma_blockscaled_kloop_store_bf16( tiled_mma: cute.TiledMma, sA: cute.Tensor, sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, sACompact: cute.Tensor, sBCompact: cute.Tensor, sSFA: cute.Tensor, @@ -911,6 +1212,7 @@ def mma_blockscaled_kloop_store_bf16( gD_mn: cute.Tensor, k_tile_count: cutlass.Int32, tidx: cutlass.Int32, + warp_idx: cutlass.Int32, ) -> None: thr_mma = tiled_mma.get_slice(tidx) a_shape = tiled_mma.partition_shape_A((16, 64)) @@ -918,14 +1220,45 @@ def mma_blockscaled_kloop_store_bf16( acc_shape = tiled_mma.partition_shape_C((16, 8)) accum_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32) store_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gD_mn.element_type) - scales_per_k_tile = 64 // self.sf_vec_size - if const_expr(k_tile_count == 1): + scales_per_k_tile = self.cta_tile_shape_mnk[2] // self.sf_vec_size + if const_expr( + self.cta_tile_shape_mnk in ((64, 64, 64), (64, 64, 128)) + and self.blockscaled_consumer_warps == 4 + ): + self.mma_blockscaled_64x64_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + warp_idx, + ) + elif const_expr(k_tile_count == 1): read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( ab_pipeline, read_state, tiled_mma, sA, sB, + sAStore, + sBStore, sACompact, sBCompact, sSFA, @@ -943,6 +1276,7 @@ def mma_blockscaled_kloop_store_bf16( cutlass.Int32(0), False, True, + warp_idx, ) else: read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( @@ -951,6 +1285,8 @@ def mma_blockscaled_kloop_store_bf16( tiled_mma, sA, sB, + sAStore, + sBStore, sACompact, sBCompact, sSFA, @@ -968,6 +1304,7 @@ def mma_blockscaled_kloop_store_bf16( cutlass.Int32(0), False, False, + warp_idx, ) for k_iter in cutlass.range(k_tile_count - 2, unroll=1): k_tile = k_iter + 1 @@ -977,6 +1314,8 @@ def mma_blockscaled_kloop_store_bf16( tiled_mma, sA, sB, + sAStore, + sBStore, sACompact, sBCompact, sSFA, @@ -994,6 +1333,7 @@ def mma_blockscaled_kloop_store_bf16( k_tile, True, False, + warp_idx, ) read_state = self.mma_blockscaled_one_k_tile_accumulate_smem( ab_pipeline, @@ -1001,6 +1341,8 @@ def mma_blockscaled_kloop_store_bf16( tiled_mma, sA, sB, + sAStore, + sBStore, sACompact, sBCompact, sSFA, @@ -1018,22 +1360,26 @@ def mma_blockscaled_kloop_store_bf16( k_tile_count - 1, True, True, + warp_idx, ) @cute.jit - def mma_blockscaled_one_k_tile_accumulate_smem( + def mma_blockscaled_64x64_kloop_store_bf16( self, ab_pipeline: pipeline.PipelineAsync, read_state: pipeline.PipelineState, tiled_mma: cute.TiledMma, sA: cute.Tensor, sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, sACompact: cute.Tensor, sBCompact: cute.Tensor, sSFA: cute.Tensor, sSFB: cute.Tensor, sAcc: cute.Tensor, gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, tidx: cutlass.Int32, thr_mma: cute.TiledMmaThrVal, a_shape: cute.Shape, @@ -1042,60 +1388,453 @@ def mma_blockscaled_one_k_tile_accumulate_smem( accum_atom: cute.CopyAtom, store_atom: cute.CopyAtom, scales_per_k_tile: cutlass.Constexpr[int], - k_tile: cutlass.Int32, - add_existing: cutlass.Constexpr[bool], - store_final: cutlass.Constexpr[bool], - ) -> pipeline.PipelineState: - peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) - ab_pipeline.consumer_wait(read_state, peek_ab_full_status) - _expand_compact_fp4_to_sm120_ldmatrix_smem( - sACompact[None, None, read_state.index], - sA[None, None, read_state.index], - 128, - ) - _expand_compact_fp4_to_sm120_ldmatrix_smem( - sBCompact[None, None, read_state.index], - sB[None, None, read_state.index], - 128, - ) - cute.arch.sync_warp() - scale_base = k_tile * scales_per_k_tile - scale_page = scale_base // 16 - scale_page_offset = scale_base - scale_page * 16 - for m_atom in cutlass.range_constexpr(8): - for n_atom in cutlass.range_constexpr(16): - acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32) - acc.fill(0.0) - self.mma_blockscaled_tile_k64_accumulate( - tiled_mma, - acc, + warp_idx: cutlass.Int32, + ) -> None: + # 64x64 has four 16-row bands and eight 8-column atoms. With four + # consumer warps, each warp owns one row band and keeps all eight atom + # accumulators live across K. This preserves FP32 accumulation while + # avoiding the LDS/STS traffic from the generic sAcc scratch path. + if warp_idx == 0: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 0, + True, + ) + elif warp_idx == 1: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 1, + False, + ) + elif warp_idx == 2: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 2, + False, + ) + elif warp_idx == 3: + self.mma_blockscaled_64x64_mrow_kloop_store_bf16( + ab_pipeline, + read_state, + tiled_mma, + sA, + sB, + sAStore, + sBStore, + sACompact, + sBCompact, + sSFA, + sSFB, + sAcc, + gD_mn, + k_tile_count, + tidx, + thr_mma, + a_shape, + b_shape, + acc_shape, + accum_atom, + store_atom, + scales_per_k_tile, + 3, + False, + ) + + @cute.jit + def mma_blockscaled_64x64_mrow_kloop_store_bf16( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + k_tile_count: cutlass.Int32, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + m_atom: cutlass.Constexpr[int], + releases_pipeline: cutlass.Constexpr[bool], + ) -> None: + acc0 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc1 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc2 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc3 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc4 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc5 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc6 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc7 = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc0.fill(0.0) + acc1.fill(0.0) + acc2.fill(0.0) + acc3.fill(0.0) + acc4.fill(0.0) + acc5.fill(0.0) + acc6.fill(0.0) + acc7.fill(0.0) + + for k_tile in cutlass.range(k_tile_count, unroll=1): + if const_expr(releases_pipeline): + peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) + ab_pipeline.consumer_wait(read_state, peek_ab_full_status) + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if const_expr(not self.blockscaled_packed_ldsm): + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sACompact[None, None, read_state.index], sA[None, None, read_state.index], + self.cta_tile_shape_mnk[0], + self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sBCompact[None, None, read_state.index], sB[None, None, read_state.index], + self.cta_tile_shape_mnk[1], + self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + + for k_block in cutlass.range_constexpr(self.cta_tile_shape_mnk[2] // 64): + scale_base = k_tile * scales_per_k_tile + k_block * (64 // self.sf_vec_size) + scale_page = scale_base // 16 + scale_page_offset = scale_base - scale_page * 16 + del scale_page + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc0, + sA, + sB, sSFA, sSFB, read_state.index, scale_page_offset, cutlass.Int32(m_atom * 16), - cutlass.Int32(n_atom * 8), + cutlass.Int32(0), tidx, a_shape, b_shape, + k_block, ) - self.store_blockscaled_accum_smem_atom( - thr_mma, - accum_atom, - acc, - sAcc, - store_atom, - gD_mn, - m_atom, - n_atom, - add_existing, - store_final, + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc1, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(8), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc2, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(16), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc3, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(24), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc4, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(32), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc5, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(40), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc6, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(48), + tidx, + a_shape, + b_shape, + k_block, + ) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc7, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(56), + tidx, + a_shape, + b_shape, + k_block, + ) + + cute.arch.fence_view_async_shared() + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if const_expr(releases_pipeline): + ab_pipeline.consumer_release(read_state) + read_state.advance() + + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc0, gD_mn, m_atom, 0) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc1, gD_mn, m_atom, 1) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc2, gD_mn, m_atom, 2) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc3, gD_mn, m_atom, 3) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc4, gD_mn, m_atom, 4) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc5, gD_mn, m_atom, 5) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc6, gD_mn, m_atom, 6) + self.store_blockscaled_accum_direct_atom(thr_mma, store_atom, acc7, gD_mn, m_atom, 7) + + @cute.jit + def mma_blockscaled_one_k_tile_accumulate_smem( + self, + ab_pipeline: pipeline.PipelineAsync, + read_state: pipeline.PipelineState, + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + sAStore: cute.Tensor, + sBStore: cute.Tensor, + sACompact: cute.Tensor, + sBCompact: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sAcc: cute.Tensor, + gD_mn: cute.Tensor, + tidx: cutlass.Int32, + thr_mma: cute.TiledMmaThrVal, + a_shape: cute.Shape, + b_shape: cute.Shape, + acc_shape: cute.Shape, + accum_atom: cute.CopyAtom, + store_atom: cute.CopyAtom, + scales_per_k_tile: cutlass.Constexpr[int], + k_tile: cutlass.Int32, + add_existing: cutlass.Constexpr[bool], + store_final: cutlass.Constexpr[bool], + warp_idx: cutlass.Int32, + ) -> pipeline.PipelineState: + if warp_idx == 0: + peek_ab_full_status = ab_pipeline.consumer_try_wait(read_state) + ab_pipeline.consumer_wait(read_state, peek_ab_full_status) + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + if warp_idx == 0: + if const_expr(self.blockscaled_packed_ldsm): + pass + else: + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sACompact[None, None, read_state.index], + sA[None, None, read_state.index], + self.cta_tile_shape_mnk[0], + ) + _expand_compact_fp4_to_sm120_ldmatrix_smem( + sBCompact[None, None, read_state.index], + sB[None, None, read_state.index], + self.cta_tile_shape_mnk[1], ) cute.arch.fence_view_async_shared() - cute.arch.sync_warp() - ab_pipeline.consumer_release(read_state) + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + else: + cute.arch.sync_warp() + scale_base = k_tile * scales_per_k_tile + scale_page = scale_base // 16 + scale_page_offset = scale_base - scale_page * 16 + for m_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[0] // 16): + for n_atom in cutlass.range_constexpr(self.cta_tile_shape_mnk[1] // 8): + atom_owner = (m_atom * (self.cta_tile_shape_mnk[1] // 8) + n_atom) % ( + self.blockscaled_consumer_warps + ) + if warp_idx == atom_owner: + acc = cute.make_rmem_tensor(acc_shape, cutlass.Float32) + acc.fill(0.0) + self.mma_blockscaled_tile_k64_accumulate( + tiled_mma, + acc, + sA, + sB, + sSFA, + sSFB, + read_state.index, + scale_page_offset, + cutlass.Int32(m_atom * 16), + cutlass.Int32(n_atom * 8), + tidx, + a_shape, + b_shape, + 0, + ) + self.store_blockscaled_accum_smem_atom( + thr_mma, + accum_atom, + acc, + sAcc, + store_atom, + gD_mn, + m_atom, + n_atom, + add_existing, + store_final, + ) + cute.arch.fence_view_async_shared() + if const_expr(self.blockscaled_consumer_warps > 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierGemm.MmaWG0), + number_of_threads=self.blockscaled_consumer_warps * cute.arch.WARP_SIZE, + ) + else: + cute.arch.sync_warp() + if warp_idx == 0: + ab_pipeline.consumer_release(read_state) read_state.advance() return read_state @@ -1104,8 +1843,8 @@ def mma_blockscaled( self, tiled_mma: cute.TiledMma, acc: cute.Tensor, - sA_stage: cute.Tensor, - sB_stage: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, sSFA: cute.Tensor, sSFB: cute.Tensor, stage: cutlass.Int32, @@ -1115,36 +1854,60 @@ def mma_blockscaled( tidx: cutlass.Int32, a_shape: cute.Shape, b_shape: cute.Shape, + k_block: cutlass.Constexpr[int], ) -> None: - sA_atom = cute.domain_offset((m_atom_base, None), sA_stage) - sB_atom = cute.domain_offset((n_atom_base, None), sB_stage) - copy_atom = cute.make_copy_atom( - warp.LdMatrix8x16x8bOp(num_matrices=1, unpack_bits=4), - cutlass.Int8, - ) - tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) - tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) - a0, a1, a2, a3 = warp.sm120_mxf4nvf4_ldmatrix_A_regs( - tiled_copy_a, - tidx, - _make_sm120_fp4_ldmatrix_smem_view(sA_atom, 16), - ) - b0, b1 = warp.sm120_mxf4nvf4_ldmatrix_B_regs( - tiled_copy_b, - tidx, - _make_sm120_fp4_ldmatrix_smem_view(sB_atom, 8), - ) a = cute.make_rmem_tensor(a_shape, cutlass.Float4E2M1FN) b = cute.make_rmem_tensor(b_shape, cutlass.Float4E2M1FN) - a_i32 = cute.recast_tensor(a, cutlass.Int32) - b_i32 = cute.recast_tensor(b, cutlass.Int32) - # The asymmetric SM120 blockscaled tests catch this placement. - a_i32[0] = a0 - a_i32[1] = a2 - a_i32[2] = a1 - a_i32[3] = a3 - b_i32[0] = b0 - b_i32[1] = b1 + if const_expr(self.blockscaled_packed_ldsm): + # Clear through the byte view: this packed-subbyte fragment layout + # is not fully covered by the Int32 view on all CuTe DSL builds. + cute.recast_tensor(a, cutlass.Uint8).fill(0) + cute.recast_tensor(b, cutlass.Uint8).fill(0) + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(num_matrices=4), + cutlass.Int4, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) + lane = cute.arch.lane_idx() + thr_copy_a = tiled_copy_a.get_slice(lane) + thr_copy_b = tiled_copy_b.get_slice(lane) + a_atom = m_atom_base // 16 + b_atom = n_atom_base // 8 + tCsA = thr_copy_a.partition_S(sA[None, None, a_atom, stage]) + tCsB = thr_copy_b.partition_S(sB[None, None, b_atom, stage]) + a_copy_view = thr_copy_a.retile(a) + b_copy_view = thr_copy_b.retile(b) + cute.copy(tiled_copy_a, tCsA[None, None, k_block], a_copy_view[None, None, 0]) + cute.copy(tiled_copy_b, tCsB[None, None, k_block], b_copy_view[None, None, 0]) + else: + sA_atom = cute.domain_offset((m_atom_base, None), sA[None, None, stage]) + sB_atom = cute.domain_offset((n_atom_base, None), sB[None, None, stage]) + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(num_matrices=1, unpack_bits=4), + cutlass.Int8, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma) + a0, a1, a2, a3 = warp.sm120_mxf4nvf4_ldmatrix_A_regs( + tiled_copy_a, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sA_atom, 16), + ) + b0, b1 = warp.sm120_mxf4nvf4_ldmatrix_B_regs( + tiled_copy_b, + tidx, + _make_sm120_fp4_ldmatrix_smem_view(sB_atom, 8), + ) + a_i32 = cute.recast_tensor(a, cutlass.Int32) + b_i32 = cute.recast_tensor(b, cutlass.Int32) + # The asymmetric SM120 blockscaled tests catch this placement. + a_i32[0] = a0 + a_i32[1] = a2 + a_i32[2] = a1 + a_i32[3] = a3 + b_i32[0] = b0 + b_i32[1] = b1 sfa, sfb = _load_sm120_blockscaled_selector0_scale_fragments( sSFA, sSFB, @@ -1155,7 +1918,10 @@ def mma_blockscaled( self.sf_vec_size, self.sf_dtype, ) - cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + if const_expr(self.blockscaled_packed_ldsm): + cute.gemm(tiled_mma, acc, (a[None, None, 0], sfa), (b[None, None, 0], sfb), acc) + else: + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) @cute.jit def mma_blockscaled_tile_k64_accumulate( @@ -1173,6 +1939,7 @@ def mma_blockscaled_tile_k64_accumulate( tidx: cutlass.Int32, a_shape: cute.Shape, b_shape: cute.Shape, + k_block: cutlass.Constexpr[int], ) -> None: del tiled_mma if const_expr(self.sf_vec_size == 16): @@ -1194,6 +1961,7 @@ def mma_blockscaled_tile_k64_accumulate( tidx, a_shape, b_shape, + k_block, ) @cute.jit @@ -1225,6 +1993,22 @@ def store_blockscaled_accum_smem_atom( else: cute.copy(accum_atom, acc, tCsAcc) + @cute.jit + def store_blockscaled_accum_direct_atom( + self, + thr_mma: cute.TiledMmaThrVal, + store_atom: cute.CopyAtom, + acc: cute.Tensor, + mD_mn: cute.Tensor, + m_atom: cutlass.Constexpr[int], + n_atom: cutlass.Constexpr[int], + ) -> None: + gD_atom = cute.local_tile(mD_mn, (16, 8), (m_atom, n_atom)) + tCgD = thr_mma.partition_C(gD_atom) + tCrD = cute.make_rmem_tensor(acc.layout, mD_mn.element_type) + tCrD.store(acc.load().to(mD_mn.element_type)) + cute.copy(store_atom, tCrD, tCgD) + @cute.kernel def kernel( self, diff --git a/tests/test_gemm_blockscaled.py b/tests/test_gemm_blockscaled.py index 6822002e..4f086422 100644 --- a/tests/test_gemm_blockscaled.py +++ b/tests/test_gemm_blockscaled.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest import torch @@ -174,7 +176,7 @@ def test_blockscaled_validation(): ) -def test_sm120_blockscaled_validation(): +def test_sm120_blockscaled_validation(monkeypatch): assert GemmDefaultSm120.can_implement_blockscaled( cutlass.Float4E2M1FN, cutlass.Float8E4M3FN, @@ -190,6 +192,84 @@ def test_sm120_blockscaled_validation(): "k", "n", ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 128), + (1, 1), + 128, + 128, + 128, + 1, + "k", + "k", + "n", + ) + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) + assert GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (64, 64, 128), + (1, 1), + 128, + 128, + 128, + 1, + "k", + "k", + "n", + ) + for tile_shape in ((128, 128, 128), (64, 128, 128), (128, 64, 128)): + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + tile_shape, + (1, 1), + 256, + 256, + 128, + 1, + "k", + "k", + "n", + ) + for tile_shape in ((64, 128, 64), (128, 64, 64)): + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + tile_shape, + (1, 1), + 256, + 256, + 64, + 1, + "k", + "k", + "n", + ) assert GemmDefaultSm120.can_implement_blockscaled( cutlass.Float4E2M1FN, cutlass.Float8E8M0FNU, @@ -205,6 +285,21 @@ def test_sm120_blockscaled_validation(): "k", "n", ) + assert not GemmDefaultSm120.can_implement_blockscaled( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (32, 64, 64), + (1, 1), + 128, + 128, + 64, + 1, + "k", + "k", + "n", + ) assert not GemmDefaultSm120.can_implement_blockscaled( cutlass.Float8E4M3FN, cutlass.Float8E8M0FNU, @@ -401,6 +496,78 @@ def test_sm120_blockscaled_class_call_validation(): ) +@pytest.mark.parametrize("tile_shape", [(64, 128, 64), (128, 64, 64)]) +def test_sm120_blockscaled_class_call_rejects_unadvertised_tile_shapes(tile_shape): + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + tile_shape, + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + with pytest.raises(NotImplementedError, match="supports tile shapes"): + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + + +def test_sm120_blockscaled_packed_env_keeps_128x128x64_class_call(monkeypatch): + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 64 + l = 1 + gemm = GemmDefaultSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 64), + (1, 1, 1), + is_persistent=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + mA = fake_tensor(cutlass.Float4E2M1FN, (m, k, l), leading_dim=1, divisibility=4) + mB = fake_tensor(cutlass.Float4E2M1FN, (n, k, l), leading_dim=1, divisibility=4) + mD = fake_tensor(cutlass.BFloat16, (m, n, l), leading_dim=1, divisibility=8) + mSFA = fake_tensor(cutlass.Float8E4M3FN, (m, 16, l), leading_dim=1, divisibility=4) + mSFB = fake_tensor(cutlass.Float8E4M3FN, (n, 16, l), leading_dim=1, divisibility=4) + + assert ( + gemm._validate_blockscaled_call( + mA, + mB, + mD, + None, + mSFA, + mSFB, + gemm.EpilogueArguments(), + None, + None, + None, + ) + == VarlenArguments() + ) + + @pytest.mark.parametrize( "ab_dtype,sf_dtype,sf_vec_size,d_dtype,mma_tiler_mn,cluster_shape_mn,m,n,k,l", [ @@ -728,7 +895,18 @@ def _make_sm120_scales(mn, k, sf_vec_size, sf_dtype, row_or_col_sensitive=True): return scales -def _compile_sm120_blockscaled_gemm(ab_dtype, sf_dtype, sf_vec_size, m, n, k, mA, mB): +def _compile_sm120_blockscaled_gemm( + ab_dtype, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(128, 128), + compile_options="--enable-tvm-ffi", +): l = 1 _, mD = create_blockscaled_operand_tensor(l, m, n, False, cutlass.BFloat16, init="empty") mSFA = _make_sm120_scales(m, k, sf_vec_size, sf_dtype) @@ -738,17 +916,31 @@ def _compile_sm120_blockscaled_gemm(ab_dtype, sf_dtype, sf_vec_size, m, n, k, mA sf_dtype, sf_vec_size, cutlass.BFloat16, - (128, 128), + tile_shape_mn, (1, 1), mA, mB, mD, mSFA, mSFB, + compile_options=compile_options, ) return compiled, (mA, mB, mD, mSFA, mSFB) +def _compiled_ptx_text(compiled) -> str: + ptx = getattr(compiled, "__ptx__", None) + if isinstance(ptx, bytes): + return ptx.decode("utf-8", errors="replace") + if isinstance(ptx, str): + if "\n" not in ptx and len(ptx) < 4096: + path = Path(ptx) + if path.exists(): + return path.read_text(errors="replace") + return ptx + raise AssertionError("compiled kernel did not expose PTX") + + @pytest.mark.parametrize( "sf_dtype,sf_vec_size,m,n,k", [ @@ -777,6 +969,36 @@ def test_sm120_blockscaled_scale_correctness(sf_dtype, sf_vec_size, m, n, k): assert err < 1e-1, f"max_err={err}" +def test_sm120_blockscaled_scale_correctness_64x64_tile(): + _skip_if_not_sm120() + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + def test_sm120_blockscaled_k_loop_accumulates_before_bf16_store(): _skip_if_not_sm120() m = n = 128 @@ -845,6 +1067,331 @@ def test_sm120_blockscaled_asymmetric_fp4_and_scale_page_crossing(): assert err < 1e-1, f"max_err={err}" +def test_sm120_blockscaled_packed_ldsm_scale_correctness_64x64_tile(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_asymmetric_fp4_and_scale_page_crossing(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 320 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + (ks % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ).expand(m, k) + b_codes = torch.where( + (ks % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ).expand(n, k) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_ptx_regression(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + compiled, _ = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64), + compile_options="--enable-tvm-ffi --keep-ptx", + ) + ptx = _compiled_ptx_text(compiled) + assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in ptx + assert "mma.sync.aligned" in ptx + assert "m16n8k64" in ptx + assert "kind::mxf4nvf4" in ptx + assert "cp.async.bulk.tensor.2d.shared::cta.global.tile" in ptx + assert "b4x16_p64" not in ptx + assert "ldmatrix.sync.aligned.m8n16" not in ptx + assert ".multicast" not in ptx + assert "shared::cluster" not in ptx + + +def test_sm120_blockscaled_packed_ldsm_scale_correctness_64x64x128_tile(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_k128_uses_second_scale_offset(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _, _, _, mSFA, mSFB = args + mSFA[:, :4, 0] = torch.tensor(1.0, device="cuda", dtype=mSFA.dtype) + mSFA[:, 4:8, 0] = torch.tensor(2.0, device="cuda", dtype=mSFA.dtype) + mSFB[:, :8, 0] = torch.tensor(1.0, device="cuda", dtype=mSFB.dtype) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + assert torch.all(d_torch.float() == torch.tensor(192.0, device="cuda")) + + +@pytest.mark.parametrize( + "tile_shape_mn,k", + [ + ((64, 64), 128), + ((64, 64, 128), 640), + ], +) +def test_sm120_blockscaled_packed_ldsm_mxfp4_scale_correctness(monkeypatch, tile_shape_mn, k): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + sf_vec_size = 32 + sf_dtype = cutlass.Float8E8M0FNU + rows = torch.arange(m, device="cuda")[:, None] + cols = torch.arange(n, device="cuda")[:, None] + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + ((rows + ks) % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ) + b_codes = torch.where( + ((cols + ks * 3) % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=tile_shape_mn, + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + if k == 640: + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + assert torch.any(mSFB[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_asymmetric_fp4_k128_page_crossing(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 384 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + rows = torch.arange(m, device="cuda")[:, None] + cols = torch.arange(n, device="cuda")[:, None] + ks = torch.arange(k, device="cuda")[None, :] + a_codes = torch.where( + ((rows + ks) % 4) < 2, + torch.tensor(0x2, device="cuda", dtype=torch.uint8), + torch.tensor(0x4, device="cuda", dtype=torch.uint8), + ) + b_codes = torch.where( + ((cols * 3 + ks) % 8) < 4, + torch.tensor(0x3, device="cuda", dtype=torch.uint8), + torch.tensor(0x5, device="cuda", dtype=torch.uint8), + ) + mA = _pack_sm120_fp4_codes(a_codes) + mB = _pack_sm120_fp4_codes(b_codes) + compiled, args = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + mA, + mB, + tile_shape_mn=(64, 64, 128), + ) + _run_blockscaled_gemm(compiled, args) + + _, _, d_torch, mSFA, mSFB = args + logical_cols = k // sf_vec_size + assert mSFA.shape[1] == 32 + assert torch.any(mSFA[:, logical_cols:, :].float() != 1.0) + ref = _sm120_fp4_blockscaled_reference(a_codes, b_codes, mSFA, mSFB, sf_vec_size).to( + torch.bfloat16 + ) + torch.testing.assert_close(d_torch.float(), ref.float(), rtol=0, atol=0) + + +def test_sm120_blockscaled_packed_ldsm_ptx_regression_64x64x128(monkeypatch): + _skip_if_not_sm120() + monkeypatch.setenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", "1") + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + compiled, _ = _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64, 128), + compile_options="--enable-tvm-ffi --keep-ptx", + ) + ptx = _compiled_ptx_text(compiled) + assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in ptx + assert "mma.sync.aligned" in ptx + assert "m16n8k64" in ptx + assert "kind::mxf4nvf4" in ptx + assert "cp.async.bulk.tensor.2d.shared::cta.global.tile" in ptx + assert "b4x16_p64" not in ptx + assert "ldmatrix.sync.aligned.m8n16" not in ptx + assert ".multicast" not in ptx + assert "shared::cluster" not in ptx + + +def test_sm120_blockscaled_tile_k128_requires_packed_ldsm(monkeypatch): + _skip_if_not_sm120() + monkeypatch.delenv("QUACK_SM120_BLOCKSCALED_PACKED_LDSM", raising=False) + m = n = 128 + k = 128 + sf_vec_size = 16 + sf_dtype = cutlass.Float8E4M3FN + a_codes = torch.full((m, k), 0x2, device="cuda", dtype=torch.uint8) + b_codes = torch.full((n, k), 0x2, device="cuda", dtype=torch.uint8) + with pytest.raises(NotImplementedError, match="tile_K=128"): + _compile_sm120_blockscaled_gemm( + cutlass.Float4E2M1FN, + sf_dtype, + sf_vec_size, + m, + n, + k, + _pack_sm120_fp4_codes(a_codes), + _pack_sm120_fp4_codes(b_codes), + tile_shape_mn=(64, 64, 128), + ) + + def test_sm120_blockscaled_rejects_compact_scale_layout(): _skip_if_not_sm120() l, m, n, k, sf_vec_size = 1, 128, 128, 64, 16 From a1696b5c55d6bde07d32f8b26910aed720121eeb Mon Sep 17 00:00:00 2001 From: agent Date: Fri, 1 May 2026 16:04:53 +0200 Subject: [PATCH 4/4] [Docs,Sm120] explain blockscaled packed LDSM path Document the narrow performance path added for SM120 blockscaled GEMM: opt-in packed LDSM, supported FP4 blockscaled formats, the recommended correctness gate, the benchmark command, and the current 64x64x128 tile target. The note records the reviewer-relevant experiment outcome without carrying experimental code into the PR. The previous correctness-first path was useful for proving tuple MMA and scale handling, but it used the b4x16_p64 unpack route and was shared-load bound. This PR follows the CUTLASS 79a GeForce NVFP4 example's packed LDSM direction while keeping QuACK's smaller per-atom TMA producer. Full-tile/grouped TMA remains deliberately out of scope because local CuTe DSL layout-lowering experiments either failed legalization, generated many tiny TMA sites, hung, or timed out. That work belongs in a separate minimal repro/upstream track. Update the benchmark module docstring with the SM120 packed-LDSM NVFP4 command and make the benchmark CLI report expected configuration errors without a Python traceback. The benchmark wording stays conservative: numbers are local RTX 5060 workstation timing signals with reference checking skipped, while numerical validation is covered by pytest. --- benchmarks/benchmark_gemm.py | 25 ++++++- docs/sm120_blockscaled_perf.md | 116 +++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 docs/sm120_blockscaled_perf.md diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index 7a00e37d..ffb9acd7 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -1,4 +1,5 @@ import argparse +import os import time import torch @@ -29,6 +30,13 @@ python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ --sf_vec_size 16 --d_dtype BFloat16 + +Usage (SM120 packed-LDSM NVFP4 performance path): + QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ + python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ + --tile_shape_mnk 64,64,128 --cluster_shape_mnk 1,1,1 \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 --skip_ref_check """ @@ -261,6 +269,18 @@ def _run_blockscaled(args): mma_tiler_for_validation = ( tuple(mma_tiler_mnk) if len(mma_tiler_mnk) == 3 or sm_major != 12 else (*mma_tiler_mnk, 64) ) + if ( + sm_major == 12 + and len(mma_tiler_for_validation) == 3 + and mma_tiler_for_validation[2] == 128 + and os.environ.get("QUACK_SM120_BLOCKSCALED_PACKED_LDSM") != "1" + ): + raise NotImplementedError( + "SM120 blockscaled tile_K=128 requires the packed ldmatrix path. " + "Set QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1, for example:\n" + " QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a " + "python benchmarks/benchmark_gemm.py --tile_shape_mnk 64,64,128 ..." + ) if not GemmBlockscaledCls.can_implement_blockscaled( ab_dtype, sf_dtype, @@ -659,5 +679,8 @@ def fn(): if __name__ == "__main__": args = parse_arguments() - run(args) + try: + run(args) + except (NotImplementedError, TypeError, ValueError) as exc: + raise SystemExit(f"benchmark_gemm.py: error: {exc}") from None print("PASS") diff --git a/docs/sm120_blockscaled_perf.md b/docs/sm120_blockscaled_perf.md new file mode 100644 index 00000000..0891d0f4 --- /dev/null +++ b/docs/sm120_blockscaled_perf.md @@ -0,0 +1,116 @@ +# SM120 Blockscaled Performance Notes + +This note explains the first SM120 blockscaled performance path in +`GemmSm120`. It is intentionally narrower than the experimental branch: the PR +keeps the proven per-atom TMA mechanism and adds only the packed shared-memory +consumer path plus a `64x64x128` tile. + +## Supported Scope + +The performance path is opt-in: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 +``` + +Current intended benchmark shape: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ +python benchmarks/benchmark_gemm.py \ + --mnkl 4096,4096,4096,1 \ + --tile_shape_mnk 64,64,128 \ + --cluster_shape_mnk 1,1,1 \ + --ab_dtype Float4E2M1FN \ + --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 \ + --d_dtype BFloat16 \ + --warmup_iterations 5 \ + --iterations 30 \ + --skip_ref_check +``` + +The benchmark is a launch and timing harness. Numerical coverage lives in +`tests/test_gemm_blockscaled.py`, including asymmetric FP4 values, poisoned +scale padding, K-page crossing, and PTX checks. + +Targeted correctness gate for this path: + +```bash +QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a \ +pytest -q tests/test_gemm_blockscaled.py -k "sm120 and packed_ldsm" -n 16 -s -rs +``` + +Supported formats for this SM120 path are: + +- NVFP4: `Float4E2M1FN` A/B, `Float8E4M3FN` scales, `sf_vec_size=16` +- MXFP4: `Float4E2M1FN` A/B, `Float8E8M0FNU` scales, `sf_vec_size=32` +- BF16 output, `C is None`, `beta=0` +- cluster shape `(1, 1, 1)` + +The packed performance path supports `64x64x64` and `64x64x128` CTA tiles. For +`tile_K=128`, logical K must be divisible by 128. + +## Why Packed LDSM + +The correctness-first SM120 path expanded compact FP4 bytes into the padded +`.b4x16_p64` ldmatrix shared-memory format. That path was useful for proving +the tuple MMA, scale mapping, and padded scale TMA, but profiling showed a large +shared-memory load bottleneck around the generated `b4x16_p64` ldmatrix +instruction. + +The packed path instead stages FP4 into a swizzled packed shared-memory layout +and consumes it with: + +```text +ldmatrix.sync.aligned.m8n8.x4.shared.b16 +mma.sync.aligned.m16n8k64.kind::mxf4nvf4 +``` + +This direction is based on the local CUTLASS GeForce reference in +`examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu`, +while this PR keeps QuACK's narrower per-atom TMA producer path. + +Tests assert that the packed path does not regress back to `b4x16_p64`, +`m8n16`, multicast TMA, or `shared::cluster`. + +## Why 64x64x128 First + +The packed `64x64x64` path removed most of the original shared-load wavefront +excess, but the kernel was still producer/barrier heavy. Moving to +`64x64x128` keeps the same accumulator ownership and correctness surface while +doubling the K work per producer/barrier cycle. Local Nsight Compute runs on a +noisy workstation showed the expected direction: + +- shared-load excessive wavefronts stayed near the packed-path level +- tensor pipe active increased materially +- barrier and MIO stall samples per issued instruction dropped +- runtime improved over `64x64x64` + +Treat these numbers as direction, not a stable benchmark claim. The following +benchmark runs were taken on an RTX 5060 workstation with reference checking +skipped because the pytest suite owns numerical validation: + +```text +base sm120-blockscaled, correctness-first path: + 4096x4096x4096 NVFP4 -> BF16, 128x128x64: 60.571 ms, 2.3 TFLOP/s + +this PR, packed-LDSM path: + 4096x4096x4096 NVFP4 -> BF16, 64x64x64: 2.988 ms, 46.0 TFLOP/s + 4096x4096x4096 NVFP4 -> BF16, 64x64x128: 1.329 ms, 103.4 TFLOP/s +``` + +## Why Not Full-Tile TMA In This PR + +The natural next architecture is full-tile or grouped A/B TMA into the final +packed/swizzled shared-memory layout. Local experiments were not clean enough +for this PR: + +- raw subbyte swizzled full-tile TMA failed CuTe DSL legalization +- byte-addressable recast layouts compiled but produced many tiny static TMA + sites and could hang at runtime +- nested grouped raw FP4 layouts hit compile/codegen timeouts + +Those findings point to a separate minimal layout-lowering repro/upstream issue. +This PR keeps production on the proven per-atom TMA path and uses `tile_K=128` +as the low-risk amortization step.