diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index bc69f709..6fa097e4 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -7,8 +7,9 @@ from quack.gemm import gemm as quack_gemm """ -GEMM benchmark using quack.gemm.gemm() (dense path) or the SM100 blockscaled -path (MXFP8 / MXFP4 / NVFP4) via --blockscaled. +GEMM benchmark using quack.gemm.gemm() (dense path) or the blockscaled +path (MXFP8 / MXFP4 / NVFP4). The blockscaled path is selected by passing +--sf_dtype and/or --sf_vec_size. Usage (dense): python benchmarks/benchmark_gemm.py --mnkl 512,7168,2048,256 \ @@ -17,18 +18,22 @@ Usage (blockscaled MXFP8, with cuBLAS comparison): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --init quant --compare_cublas + --ab_dtype Float8E4M3FN --sf_dtype Float8E8M0FNU --sf_vec_size 32 Usage (blockscaled MXFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ - --sf_vec_size 32 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU \ + --sf_vec_size 32 --d_dtype BFloat16 Usage (blockscaled NVFP4): python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ - --blockscaled --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ - --sf_vec_size 16 --d_dtype Float32 + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 + +Usage (blockscaled NVFP4 fast SM120 path): + python benchmarks/benchmark_gemm.py --mnkl 4096,4096,4096,1 \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 --d_dtype BFloat16 --sm120_nvfp4_path fast """ @@ -124,6 +129,22 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("--use_tma_gather", action="store_true", help="Use TMA gather4 for A") parser.add_argument("--max_swizzle_size", type=int, default=8, help="Max swizzle size") parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking") + parser.add_argument( + "--sm120_nvfp4_init", + choices=("tensorfill", "ones"), + default="tensorfill", + help="SM120 NVFP4 input initialization. tensorfill uses bounded random non-zero " + "FP4/scales close to CUTLASS 79a TensorFillRandomUniform; ones preserves the " + "old all-ones microbenchmark.", + ) + parser.add_argument( + "--sm120_nvfp4_path", + choices=("validated", "fast"), + default="validated", + help="SM120 NVFP4 kernel policy. validated uses the conservative direct-store " + "static-scheduler path; fast uses the CLC/full-grid scheduler and delayed TMA " + "epilogue path.", + ) # Dtype flags. Blockscaled path is selected automatically when --sf_dtype is passed. parser.add_argument( "--ab_dtype", @@ -202,12 +223,20 @@ def _run_blockscaled(args): ) from quack.cute_dsl_utils import get_device_capacity from quack.gemm_default_epi import GemmDefaultSm100 + from quack.gemm_sm120 import GemmSm120 + from quack.sm120_blockscaled_utils import ( + create_sm120_nvfp4_ab_tensor, + create_sm120_nvfp4_scale_tensor, + create_sm120_nvfp4_tensorfill_like_ab_tensor, + create_sm120_nvfp4_tensorfill_like_scale_tensor, + ) sm_major = get_device_capacity(torch.device("cuda"))[0] - assert sm_major in (10, 11), ( - f"Blockscaled GEMM requires SM100 (B200/B300) or SM110; got SM{sm_major}x. " - "MXFP8/MXFP4/NVFP4 use tcgen05 UMMA which is SM100+." - ) + if sm_major not in (10, 11, 12): + raise RuntimeError( + f"Blockscaled GEMM requires SM100/SM110 or SM120; got SM{sm_major}x. " + "SM120 currently supports only the narrow NVFP4 path." + ) if args.varlen_k or args.gather_A or args.pingpong: raise NotImplementedError( @@ -217,6 +246,7 @@ def _run_blockscaled(args): m, n, k, l = args.mnkl mma_tiler_mnk = args.tile_shape_mnk + mma_tiler_mn = mma_tiler_mnk[:2] cluster_shape_mnk = args.cluster_shape_mnk cluster_shape_mn = cluster_shape_mnk[:2] if cluster_shape_mnk[2] != 1: @@ -257,7 +287,23 @@ def _run_blockscaled(args): raise ValueError( f"MXFP4/NVFP4 require K-major for both A and B; got a_major={a_major}, b_major={b_major}" ) - if not GemmDefaultSm100.can_implement_blockscaled( + is_sm120_nvfp4 = ( + sm_major == 12 + and ab_dtype == cutlass.Float4E2M1FN + and sf_dtype == cutlass.Float8E4M3FN + and sf_vec_size == 16 + ) + can_implement = ( + GemmSm120.can_implement_blockscaled + if is_sm120_nvfp4 + else GemmDefaultSm100.can_implement_blockscaled + ) + if sm_major == 12 and not is_sm120_nvfp4: + raise TypeError( + "SM120 blockscaled benchmark currently supports only NVFP4 " + "(Float4E2M1FN A/B, Float8E4M3FN scales, sf_vec_size=16)" + ) + if not can_implement( ab_dtype, sf_dtype, sf_vec_size, @@ -309,7 +355,7 @@ def _run_blockscaled(args): sf_dtype, sf_vec_size, d_dtype, - mma_tiler_mnk, + mma_tiler_mn, cluster_shape_mn, mA, mB, @@ -317,46 +363,66 @@ def _run_blockscaled(args): mSFA, mSFB, varlen_m=True, + sm120_nvfp4_path=args.sm120_nvfp4_path, ) def fn(): runner(mA, mB, mD, mSFA, mSFB, cu_seqlens_m) else: - a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( - l, - m, - k, - a_major == "m", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( - l, - n, - k, - b_major == "n", - sf_vec_size, - ab_dtype, - sf_dtype, - ) - # (l, rm, rk, 512) contig scale — consumed directly by the kernel. - mSFA, mSFB = a_sc_contig, b_sc_contig - sfa_ref = torch.ones_like(a_ref) - sfb_ref = torch.ones_like(b_ref) + if is_sm120_nvfp4: + if args.sm120_nvfp4_init == "ones": + mA = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + mB = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + a_ref = torch.ones((m, k, l), device="cuda", dtype=torch.float32) + b_ref = torch.ones((n, k, l), device="cuda", dtype=torch.float32) + _, mSFA = create_sm120_nvfp4_scale_tensor(l, m, k) + _, mSFB = create_sm120_nvfp4_scale_tensor(l, n, k) + mSFA.fill_(1.0) + mSFB.fill_(1.0) + sfa_ref = torch.ones_like(a_ref) + sfb_ref = torch.ones_like(b_ref) + else: + a_ref, mA = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, m, k) + b_ref, mB = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, n, k) + sfa_ref, mSFA = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, m, k) + sfb_ref, mSFB = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, n, k) + else: + a_ref, mA, a_sc_contig = create_blockscaled_operand_quantized( + l, + m, + k, + a_major == "m", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + b_ref, mB, b_sc_contig = create_blockscaled_operand_quantized( + l, + n, + k, + b_major == "n", + sf_vec_size, + ab_dtype, + sf_dtype, + ) + # (l, rm, rk, 512) contig scale — consumed directly by the SM100 kernel. + mSFA, mSFB = a_sc_contig, b_sc_contig + sfa_ref = torch.ones_like(a_ref) + sfb_ref = torch.ones_like(b_ref) _, mD = create_blockscaled_operand_tensor(l, m, n, False, d_dtype, init="empty") runner = compile_blockscaled_gemm_tvm_ffi( ab_dtype, sf_dtype, sf_vec_size, d_dtype, - mma_tiler_mnk, + mma_tiler_mn, cluster_shape_mn, mA, mB, mD, mSFA, mSFB, + sm120_nvfp4_path=args.sm120_nvfp4_path, ) def fn(): @@ -365,29 +431,42 @@ def fn(): if not args.skip_ref_check: fn() torch.cuda.synchronize() - tol = 5e-3 if d_dtype != cutlass.Float32 else 5e-4 + tol = ( + 0.25 + if is_sm120_nvfp4 and args.sm120_nvfp4_init == "tensorfill" + else 5e-3 + if d_dtype != cutlass.Float32 + else 5e-4 + ) + rtol = 2e-2 if is_sm120_nvfp4 and args.sm120_nvfp4_init == "tensorfill" else 1e-3 if args.varlen_m: # Per-expert matmul reference using dequantized operands ref = torch.cat( [a_ref_dq[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] @ b_ref_dq[i].T for i in range(l)] ) - torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) + torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=rtol) else: ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) - torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=1e-3) + torch.testing.assert_close(mD.float(), ref, atol=tol, rtol=rtol) print("Ref check PASSED") - print("Running SM100 Blockscaled GEMM with:") + print(f"Running SM{sm_major} Blockscaled GEMM with:") print(f"mnkl: {args.mnkl}") print(f"tile_shape_mnk: {mma_tiler_mnk}, cluster_shape_mnk: {cluster_shape_mnk}") print( f"ab_dtype: {ab_dtype}, sf_dtype: {sf_dtype}, sf_vec_size: {sf_vec_size}, d_dtype: {args.d_dtype}" ) print(f"a_major: {a_major}, b_major: {b_major}") + if is_sm120_nvfp4: + print(f"sm120_nvfp4_init: {args.sm120_nvfp4_init}") + print(f"sm120_nvfp4_path: {args.sm120_nvfp4_path}") flops = 2 * m * n * k * l timing = _bench_and_report("quack ", fn, flops, args.warmup_iterations, args.iterations) + if is_sm120_nvfp4: + print("(skipping cuBLAS: benchmark uses SM120 native NVFP4 scale storage)") + return if args.varlen_m: print("(skipping cuBLAS: varlen_m not supported)") return diff --git a/quack/_sm120_nvfp4_utils.py b/quack/_sm120_nvfp4_utils.py new file mode 100644 index 00000000..a5c1bf41 --- /dev/null +++ b/quack/_sm120_nvfp4_utils.py @@ -0,0 +1,5099 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""SM120 MXF4NVF4 warp-GEMM helper API.""" + +from typing import Optional, Tuple, Type + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass import const_expr +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import nvvm as _nvvm_ir +from cutlass.base_dsl.typing import Numeric +from cutlass.cute.nvgpu import cpasync, warp +from cutlass.cutlass_dsl import dsl_user_op, for_generate, yield_out +from cutlass.utils import blockscaled_layout +from cutlass.utils.blackwell_helpers import ( + get_layoutSFA_TV, + partition_fragment_SFA, + partition_fragment_SFB, + thrfrg_SFA, + thrfrg_SFB, +) +from cutlass.utils.smem_allocator import SmemAllocator +from cutlass.utils.static_persistent_tile_scheduler import ( + PersistentTileSchedulerParams, + StaticPersistentTileScheduler, +) +from quack.sm120_pipeline import PipelineTmaWarpMma + + +MXF4NVF4_CTA_SHAPE_MNK = (128, 128, 128) +MXF4NVF4_MMA_SHAPE_MNK = (16, 8, 64) +MXF4NVF4_SCALE_VEC_SIZE = 16 +MXF4NVF4_SCALE_K = MXF4NVF4_CTA_SHAPE_MNK[2] // MXF4NVF4_SCALE_VEC_SIZE +MXF4NVF4_SCALE_TMA_MIN_L = 2 +MXF4NVF4_AB_PACKED_TMA_BYTES = MXF4NVF4_CTA_SHAPE_MNK[0] * MXF4NVF4_CTA_SHAPE_MNK[2] // 2 +MXF4NVF4_AB_UNPACK_TMA_BYTES = MXF4NVF4_AB_PACKED_TMA_BYTES +MXF4NVF4_AB_TMA_BYTES = MXF4NVF4_AB_PACKED_TMA_BYTES +MXF4NVF4_AB_SMEM_BYTES = MXF4NVF4_CTA_SHAPE_MNK[0] * MXF4NVF4_CTA_SHAPE_MNK[2] +MXF4NVF4_SCALE_TMA_BYTES = MXF4NVF4_CTA_SHAPE_MNK[0] * MXF4NVF4_SCALE_K +MXF4NVF4_FULL_TMA_BYTES = 2 * MXF4NVF4_AB_TMA_BYTES + 2 * MXF4NVF4_SCALE_TMA_BYTES +MXF4NVF4_FULL_UNPACK_TMA_BYTES = MXF4NVF4_FULL_TMA_BYTES +MXF4NVF4_COOPERATIVE_PRODUCER_REGS = 40 +MXF4NVF4_COOPERATIVE_CONSUMER_REGS = 232 +MXF4NVF4_COOPERATIVE_THREADS_PER_WARPGROUP = 128 +MXF4NVF4_PINGPONG_MMA_BARRIER_ID = 3 +MXF4NVF4_PINGPONG_EPI_BARRIER_ID = 5 +MXF4NVF4_PINGPONG_BARRIER_THREADS = 2 * MXF4NVF4_COOPERATIVE_THREADS_PER_WARPGROUP + + +def _check_positive(name: str, value: int) -> None: + if value <= 0: + raise ValueError(f"`{name}` must be positive, but got {value}") + + +def _check_default_tile(tile_mn: int, tile_k: int, sf_vec_size: int) -> None: + _check_positive("tile_mn", tile_mn) + _check_positive("tile_k", tile_k) + _check_positive("sf_vec_size", sf_vec_size) + if tile_k != 128: + raise ValueError("SM120 MXF4NVF4 helpers currently support tile_k=128") + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 helpers currently support sf_vec_size=16") + + +def _normalize_mxf4nvf4_ab_smem_format(smem_format: str) -> str: + normalized = smem_format.replace("-", "_").lower() + if normalized in ("packed", "align8", "16u4_align8b"): + return "packed" + if normalized in ( + "unpack", + "unpacked", + "unpack_smem", + "unpacksmem", + "align16", + "16u4_align16b", + ): + return "unpack" + raise ValueError( + "`smem_format` must be 'packed'/'16u4_align8b' or " + f"'unpack'/'16u4_align16b', but got {smem_format!r}" + ) + + +def _mxf4nvf4_ab_tma_internal_type(smem_format: str) -> Optional[Type[Numeric]]: + if _normalize_mxf4nvf4_ab_smem_format(smem_format) == "unpack": + return cutlass.Uint8 + return None + + +def _require_zero_major_offset(name: str, value: cutlass.Int32 | int) -> None: + raw_value = getattr(value, "value", value) + if raw_value != 0: + raise ValueError( + f"`{name}` is not supported by this helper; encode the global major " + "tile in the TMA descriptor coordinates and stage the local 128-major tile" + ) + + +def _require_zero_scale_major_offset(name: str, value: cutlass.Int32 | int) -> None: + _require_zero_major_offset(name, value) + + +def _check_tuple(name: str, value: tuple[int, ...], rank: int) -> None: + if len(value) != rank: + raise ValueError(f"`{name}` must have rank {rank}, but got {value}") + + +def _mxf4nvf4_contiguous_alignment(dtype: Type[Numeric]) -> int: + return 16 * 8 // dtype.width + + +def _mxf4nvf4_gemm_config_errors( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + ab_smem_format: str = "packed", +) -> list[str]: + errors: list[str] = [] + for name, value in (("m", m), ("n", n), ("k", k), ("l_extent", l_extent)): + if value <= 0: + errors.append(f"`{name}` must be positive") + + try: + _check_tuple("tile_shape_mnk", tile_shape_mnk, 3) + except ValueError as exc: + errors.append(str(exc)) + try: + _check_tuple("cluster_shape_mnk", cluster_shape_mnk, 3) + except ValueError as exc: + errors.append(str(exc)) + + if a_dtype != cutlass.Float4E2M1FN: + errors.append("A dtype must be Float4E2M1FN") + if b_dtype != cutlass.Float4E2M1FN: + errors.append("B dtype must be Float4E2M1FN") + if sf_dtype != cutlass.Float8E4M3FN: + errors.append("scale dtype must be Float8E4M3FN") + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + errors.append(f"sf_vec_size must be {MXF4NVF4_SCALE_VEC_SIZE}") + if acc_dtype != cutlass.Float32: + errors.append("accumulator dtype must be Float32") + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + errors.append("output dtype must be Float32, Float16, or BFloat16") + + if tile_shape_mnk != MXF4NVF4_CTA_SHAPE_MNK: + errors.append(f"tile_shape_mnk must be {MXF4NVF4_CTA_SHAPE_MNK}") + if cluster_shape_mnk != (1, 1, 1): + errors.append("cluster_shape_mnk must be (1, 1, 1)") + if a_major != "k": + errors.append("A layout must be K-major") + if b_major != "k": + errors.append("B layout must be K-major") + if c_major not in {"n", "m"}: + errors.append("output layout must be N-major or M-major") + + try: + normalized_ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + except ValueError as exc: + errors.append(str(exc)) + else: + if normalized_ab_smem_format != "packed": + errors.append("native SM120 MXF4NVF4 GEMM currently supports only packed A/B TMA") + + if len(tile_shape_mnk) == 3: + tile_m, tile_n, tile_k = tile_shape_mnk + if m % tile_m != 0: + errors.append("m must be divisible by tile_shape_mnk[0]") + if n % tile_n != 0: + errors.append("n must be divisible by tile_shape_mnk[1]") + if k % tile_k != 0: + errors.append("k must be divisible by tile_shape_mnk[2]") + + if a_dtype == cutlass.Float4E2M1FN and k % _mxf4nvf4_contiguous_alignment(a_dtype): + errors.append("K-major A requires k to be 16-byte aligned") + if b_dtype == cutlass.Float4E2M1FN and k % _mxf4nvf4_contiguous_alignment(b_dtype): + errors.append("K-major B requires k to be 16-byte aligned") + if c_dtype in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + c_contiguous_extent = m if c_major == "m" else n + if c_contiguous_extent % _mxf4nvf4_contiguous_alignment(c_dtype): + errors.append("output contiguous dimension must be 16-byte aligned") + + return errors + + +def mxf4nvf4_can_implement( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + ab_smem_format: str = "packed", +) -> bool: + """Return whether the public packed native-TMA SM120 NVFP4 path supports a GEMM. + + This is a conservative public contract for the currently validated SM120 + Blackwell GeForce path. It describes the supported building block for + downstream kernels instead of implying that experimental descriptor or + unpack-SMEM probes are production-supported. + """ + return not _mxf4nvf4_gemm_config_errors( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_smem_format=ab_smem_format, + ) + + +def validate_mxf4nvf4_gemm_config( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + ab_smem_format: str = "packed", +) -> None: + """Raise if a GEMM is outside the public SM120 NVFP4 native-TMA contract.""" + errors = _mxf4nvf4_gemm_config_errors( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_smem_format=ab_smem_format, + ) + if errors: + raise ValueError("Unsupported SM120 MXF4NVF4 GEMM configuration: " + "; ".join(errors)) + + +def mxf4nvf4_native_tma_tile_coords( + m_tile: cutlass.Int32 | int = 0, + n_tile: cutlass.Int32 | int = 0, + k_tile: cutlass.Int32 | int = 0, + l_tile: cutlass.Int32 | int = 0, +) -> dict[str, tuple[cutlass.Int32 | int, ...]]: + """Map one GEMM tile coordinate to native SM120 A/B/SFA/SFB TMA coords.""" + scale_k_tile = k_tile % 2 + scale_page = k_tile // 2 + return { + "ab_tile_coord_a": (m_tile, k_tile, l_tile), + "ab_tile_coord_b": (n_tile, k_tile, l_tile), + "scale_tile_coord_sfa": (m_tile, scale_k_tile, scale_page, l_tile), + "scale_tile_coord_sfb": (n_tile, scale_k_tile, scale_page, l_tile), + } + + +def mxf4nvf4_scheduler_tile_tma_coords( + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int = 0, +) -> dict[str, tuple[cutlass.Int32 | int, ...]]: + """Map a persistent scheduler tile coordinate to native SM120 TMA coords.""" + _check_tuple("tile_mnl", tile_mnl, 3) + return mxf4nvf4_native_tma_tile_coords( + tile_mnl[0], + tile_mnl[1], + k_tile, + tile_mnl[2], + ) + + +def mxf4nvf4_tiled_problem_shape( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + ab_smem_format: str = "packed", +) -> dict[str, tuple[int, ...] | int]: + """Return host-side tiling metadata for the public SM120 NVFP4 path.""" + validate_mxf4nvf4_gemm_config( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_smem_format=ab_smem_format, + ) + + tile_m, tile_n, tile_k = tile_shape_mnk + num_ctas_mnl = (m // tile_m, n // tile_n, l_extent) + cluster_shape_mnl = (cluster_shape_mnk[0], cluster_shape_mnk[1], 1) + return { + "problem_shape_mnkl": (m, n, k, l_extent), + "tile_shape_mnk": tile_shape_mnk, + "cluster_shape_mnk": cluster_shape_mnk, + "cluster_shape_mnl": cluster_shape_mnl, + "num_ctas_mnl": num_ctas_mnl, + "k_tile_count": k // tile_k, + "logical_grid_shape": num_ctas_mnl, + } + + +@dsl_user_op +def make_mxf4nvf4_static_tile_scheduler_params( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + max_active_clusters: int = 1, + swizzle_size: int = 1, + raster_along_m: bool = True, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", + ab_smem_format: str = "packed", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[PersistentTileSchedulerParams, Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]]: + """Return static persistent scheduler params and launch grid for SM120 NVFP4.""" + problem_shape = mxf4nvf4_tiled_problem_shape( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_smem_format=ab_smem_format, + ) + tile_sched_params = PersistentTileSchedulerParams( + problem_shape["num_ctas_mnl"], + problem_shape["cluster_shape_mnl"], + swizzle_size=swizzle_size, + raster_along_m=raster_along_m, + loc=loc, + ip=ip, + ) + grid = StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, + cutlass.Int32(max_active_clusters), + loc=loc, + ip=ip, + ) + return tile_sched_params, grid + + +@dsl_user_op +def make_mxf4nvf4_static_tile_scheduler( + tile_sched_params: PersistentTileSchedulerParams, + block_idx: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + grid_dim: Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> StaticPersistentTileScheduler: + """Create the static persistent tile scheduler for an SM120 NVFP4 kernel.""" + return StaticPersistentTileScheduler.create( + tile_sched_params, + block_idx, + grid_dim, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_a_gmem_layout( + m: int = 128, + k: int = 128, + l_extent: int = 1, + major: str = "k", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the public logical A GMEM layout for the SM120 NVFP4 path.""" + _check_positive("m", m) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + if major != "k": + raise ValueError("SM120 MXF4NVF4 A GMEM layout currently requires major='k'") + return cute.make_layout( + (m, k, l_extent), + stride=(k, 1, m * k), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_b_gmem_layout( + n: int = 128, + k: int = 128, + l_extent: int = 1, + major: str = "k", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the public logical B GMEM layout for the SM120 NVFP4 path.""" + _check_positive("n", n) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + if major != "k": + raise ValueError("SM120 MXF4NVF4 B GMEM layout currently requires major='k'") + return cute.make_layout( + (n, k, l_extent), + stride=(k, 1, n * k), + loc=loc, + ip=ip, + ) + + +def _preserve_mxf4nvf4_ab_tma_l_mode(gmem_tensor: cute.Tensor) -> cute.Tensor: + """Keep A/B tensor maps rank-3 even for logical L=1. + + This mirrors the 79a C++ path, which builds A/B tensor maps over + ``(M,K,L)`` / ``(N,K,L)`` and keeps the L coordinate in the TMA + instruction stream. + """ + if const_expr(cute.size(gmem_tensor, mode=[2]) != 1): + return gmem_tensor + return cute.make_tensor( + gmem_tensor.iterator, + cute.make_layout( + (gmem_tensor.shape[0], gmem_tensor.shape[1], MXF4NVF4_SCALE_TMA_MIN_L), + stride=gmem_tensor.layout.stride, + ), + ) + + +@dsl_user_op +def make_mxf4nvf4_d_gmem_layout( + m: int = 128, + n: int = 128, + l_extent: int = 1, + major: str = "n", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the public logical D GMEM layout for the SM120 NVFP4 path.""" + _check_positive("m", m) + _check_positive("n", n) + _check_positive("l_extent", l_extent) + if major == "n": + stride = (n, 1, m * n) + elif major == "m": + stride = (1, m, m * n) + else: + raise ValueError("SM120 MXF4NVF4 D GMEM layout requires major='n' or 'm'") + return cute.make_layout( + (m, n, l_extent), + stride=stride, + loc=loc, + ip=ip, + ) + + +def mxf4nvf4_ab_tma_tx_bytes( + tile_mn: int = 128, + tile_k: int = 128, + *, + smem_format: str = "packed", +) -> int: + """Return bytes completed by one A or B full-tile TMA transaction. + + The unpack-SMEM FP4 tensor-map format expands the destination SMEM footprint + to 16 KiB for a 128x128 tile, but the transaction barrier count follows the + logical packed FP4 payload bytes, matching the SM120 C++ collectives. + """ + _check_default_tile(tile_mn, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _normalize_mxf4nvf4_ab_smem_format(smem_format) + return tile_mn * tile_k // 2 + + +def mxf4nvf4_ab_packed_tma_tx_bytes(tile_mn: int = 128, tile_k: int = 128) -> int: + """Return A/B TMA bytes for the normal packed FP4 ALIGN8B format.""" + return mxf4nvf4_ab_tma_tx_bytes(tile_mn, tile_k, smem_format="packed") + + +def mxf4nvf4_ab_unpack_tma_tx_bytes(tile_mn: int = 128, tile_k: int = 128) -> int: + """Return A/B barrier bytes for the FP4 unpack-SMEM ALIGN16B format.""" + return mxf4nvf4_ab_tma_tx_bytes(tile_mn, tile_k, smem_format="unpack") + + +def mxf4nvf4_ab_physical_smem_bytes(tile_mn: int = 128, tile_k: int = 128) -> int: + """Return bytes reserved for one A or B physical-SMEM tile.""" + _check_default_tile(tile_mn, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + return tile_mn * tile_k + + +def mxf4nvf4_scale_tma_tx_bytes( + tile_mn: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return bytes completed by one SFA or SFB full-tile TMA transaction.""" + _check_default_tile(tile_mn, tile_k, sf_vec_size) + return tile_mn * (tile_k // sf_vec_size) + + +def mxf4nvf4_scale_physical_smem_bytes( + tile_mn: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return bytes reserved for one SFA or SFB physical-SMEM tile.""" + _check_default_tile(tile_mn, tile_k, sf_vec_size) + return max(tile_mn, 128) * (tile_k // sf_vec_size) + + +def mxf4nvf4_full_tma_tx_bytes( + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + ab_smem_format: str = "packed", +) -> int: + """Return the barrier transaction byte count for A/B/SFA/SFB TMA.""" + return ( + mxf4nvf4_ab_tma_tx_bytes(tile_m, tile_k, smem_format=ab_smem_format) + + mxf4nvf4_ab_tma_tx_bytes(tile_n, tile_k, smem_format=ab_smem_format) + + mxf4nvf4_scale_tma_tx_bytes(tile_m, tile_k, sf_vec_size) + + mxf4nvf4_scale_tma_tx_bytes(tile_n, tile_k, sf_vec_size) + ) + + +def mxf4nvf4_full_packed_tma_tx_bytes( + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return full-tile TMA bytes for packed FP4 A/B plus SFA/SFB.""" + return mxf4nvf4_full_tma_tx_bytes( + tile_m, + tile_n, + tile_k, + sf_vec_size, + ab_smem_format="packed", + ) + + +def mxf4nvf4_full_unpack_tma_tx_bytes( + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return full-tile TMA bytes for unpack-SMEM FP4 A/B plus SFA/SFB.""" + return mxf4nvf4_full_tma_tx_bytes( + tile_m, + tile_n, + tile_k, + sf_vec_size, + ab_smem_format="unpack", + ) + + +@dsl_user_op +def make_mxf4nvf4_sfa_gmem_layout( + m: int = 128, + k: int = 128, + l_extent: int = 1, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the C++ 79a-style GMEM layout for SFA scale tensors.""" + _check_positive("m", m) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 scale GMEM layout requires sf_vec_size=16") + return blockscaled_layout.tile_atom_to_shape_SF( + (m, k, l_extent), + sf_vec_size, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_sfb_gmem_layout( + n: int = 128, + k: int = 128, + l_extent: int = 1, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the C++ 79a-style GMEM layout for SFB scale tensors.""" + _check_positive("n", n) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 scale GMEM layout requires sf_vec_size=16") + return blockscaled_layout.tile_atom_to_shape_SF( + (n, k, l_extent), + sf_vec_size, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_tma_physical_gmem_layout( + major_extent: int = 128, + scale_k_extent: int = MXF4NVF4_SCALE_K * 2, + tile_extent: int = 1, + l_extent: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the compact GMEM layout consumed by native SM120 scale TMA. + + The native scale TMA atom keeps the public scale tensor typed as FP8 and + exposes a layout with major as the contiguous mode. This is distinct from a + row-major Torch view of the same storage and from the logical blockscaled + SFA/SFB layout used by fragments. + """ + _check_positive("major_extent", major_extent) + _check_positive("scale_k_extent", scale_k_extent) + _check_positive("tile_extent", tile_extent) + _check_positive("l_extent", l_extent) + return cute.make_layout( + (major_extent, scale_k_extent, tile_extent, l_extent), + stride=( + 1, + major_extent, + major_extent * scale_k_extent, + major_extent * scale_k_extent * tile_extent, + ), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_interleaved_gmem_layout( + major_extent: int = 128, + logical_k_extent: int = 128, + l_extent: int = 1, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the compact 4D FP8 scale layout consumed by SM120 TMA.""" + _check_positive("major_extent", major_extent) + _check_positive("logical_k_extent", logical_k_extent) + _check_positive("l_extent", l_extent) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 scale layout requires sf_vec_size=16") + if major_extent % 128 != 0: + raise ValueError("SM120 scale interleaved layout requires major_extent % 128 == 0") + if logical_k_extent % sf_vec_size != 0: + raise ValueError( + "SM120 scale interleaved layout requires logical_k_extent % sf_vec_size == 0" + ) + logical_scale_k = cute.ceil_div(logical_k_extent, sf_vec_size) + if logical_scale_k % 4 != 0: + raise ValueError("SM120 scale interleaved layout requires scale_k % 4 == 0") + major_tiles = major_extent // 128 + scale_tiles = logical_scale_k // 4 + l_stride = major_tiles * scale_tiles * 512 + return cute.make_layout( + (((32, 4), major_tiles), 4, scale_tiles, l_extent), + stride=(((16, 4), 512), 1, major_tiles * 512, l_stride), + loc=loc, + ip=ip, + ) + + +def mxf4nvf4_padded_scale_k_extent(logical_scale_k_extent: int) -> int: + """Return the padded physical scale-K extent for SM120 NVFP4 scale TMA.""" + _check_positive("logical_scale_k_extent", logical_scale_k_extent) + granularity = MXF4NVF4_SCALE_K * 2 + return ((logical_scale_k_extent + granularity - 1) // granularity) * granularity + + +def mxf4nvf4_scale_tma_physical_k_extent( + k: int, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return the physical scale-K extent needed to back a logical K extent.""" + _check_positive("k", k) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError(f"SM120 MXF4NVF4 scale TMA requires sf_vec_size={MXF4NVF4_SCALE_VEC_SIZE}") + if k % sf_vec_size != 0: + raise ValueError("SM120 MXF4NVF4 K extent must be divisible by sf_vec_size") + return mxf4nvf4_padded_scale_k_extent(k // sf_vec_size) + + +def mxf4nvf4_scale_tma_physical_l_extent(logical_l_extent: int) -> int: + """Return the physical scale-L extent used by native SM120 scale TMA. + + Keeping the scale tensor-map rank-4 even for logical L=1 preserves the + compact scale TMA path used by the native SM120 MXF4/NVFP4 mainloop. + """ + _check_positive("logical_l_extent", logical_l_extent) + return max(logical_l_extent, MXF4NVF4_SCALE_TMA_MIN_L) + + +def mxf4nvf4_cooperative_launch_kwargs( + *, + producer_warpgroups: int = 1, + consumer_warpgroups: int = 2, + min_ctas_per_sm: int = 1, +) -> dict[str, tuple[int, int, int] | int]: + """Return launch metadata required for SM120 dynamic register allocation. + + PTX `setmaxnreg` only lowers to SASS `USETMAXREG` when ptxas sees an entry + metadata context such as `.maxntid` plus `.minnctapersm`. This helper keeps + that contract close to the cooperative SM120 schedule shape instead of + requiring each caller to spell the metadata by hand. + """ + _check_positive("producer_warpgroups", producer_warpgroups) + _check_positive("consumer_warpgroups", consumer_warpgroups) + _check_positive("min_ctas_per_sm", min_ctas_per_sm) + warpgroups = producer_warpgroups + consumer_warpgroups + threads = warpgroups * MXF4NVF4_COOPERATIVE_THREADS_PER_WARPGROUP + return { + "block": (threads, 1, 1), + "max_number_threads": (threads, 1, 1), + "min_blocks_per_mp": min_ctas_per_sm, + } + + +def mxf4nvf4_cooperative_sass_count_targets( + *, + tma_issue_groups: int = 3, + consumer_issue_groups: int = 2, +) -> dict[str, int]: + """Return expected static SASS count targets for SM120 schedule probes.""" + _check_positive("tma_issue_groups", tma_issue_groups) + _check_positive("consumer_issue_groups", consumer_issue_groups) + return { + "USETMAXREG": 2, + "UTMALDG": 4 * tma_issue_groups, + "OMMA.SF": 32 * consumer_issue_groups, + "LDSM": 12 * consumer_issue_groups, + } + + +def _mxf4nvf4_pingpong_barrier_base_id(stage: str) -> int: + if stage == "mma": + return MXF4NVF4_PINGPONG_MMA_BARRIER_ID + if stage == "epi": + return MXF4NVF4_PINGPONG_EPI_BARRIER_ID + raise ValueError("SM120 ping-pong barrier stage must be 'mma' or 'epi'") + + +@dsl_user_op +def mxf4nvf4_pingpong_barrier_arrive( + warpgroup_idx: cutlass.Int32 | int, + stage: str, + *, + number_of_threads: cutlass.Int32 | int = MXF4NVF4_PINGPONG_BARRIER_THREADS, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Arrive on the SM120 two-warpgroup ping-pong named barrier.""" + cute.arch.barrier_arrive( + barrier_id=_mxf4nvf4_pingpong_barrier_base_id(stage) + warpgroup_idx, + number_of_threads=number_of_threads, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mxf4nvf4_pingpong_barrier_sync( + warpgroup_idx: cutlass.Int32 | int, + stage: str, + *, + number_of_threads: cutlass.Int32 | int = MXF4NVF4_PINGPONG_BARRIER_THREADS, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Arrive and wait on the SM120 two-warpgroup ping-pong named barrier.""" + cute.arch.barrier( + barrier_id=_mxf4nvf4_pingpong_barrier_base_id(stage) + warpgroup_idx, + number_of_threads=number_of_threads, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mxf4nvf4_mma_warpgroup_barrier_sync( + *, + barrier_id: cutlass.Int32 | int = MXF4NVF4_PINGPONG_MMA_BARRIER_ID, + number_of_threads: cutlass.Int32 | int = MXF4NVF4_PINGPONG_BARRIER_THREADS, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Arrive and wait on the SM120 MXF4/NVFP4 MMA warpgroup barrier.""" + cute.arch.barrier( + barrier_id=barrier_id, + number_of_threads=number_of_threads, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mxf4nvf4_split_tma_consumer_wait( + pipe_mk: PipelineTmaWarpMma, + pipe_nk: PipelineTmaWarpMma, + consumer_state_mk: pipeline.PipelineState, + consumer_state_nk: pipeline.PipelineState, + *, + join_split_tma_barrier: bool = True, + try_wait_token_mk: Optional[cutlass.Boolean] = None, + try_wait_token_nk: Optional[cutlass.Boolean] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Wait for the SM120 split-TMA MK/NK consumer stages. + + When ``join_split_tma_barrier`` is true, the MK and NK TMA streams share one + full barrier and only the MK pipe is waited on. Callers may pass tokens + from an earlier ``consumer_try_wait`` to separate mbarrier probing from the + actual wait, matching the 79a-style ping-pong handoff. + """ + if const_expr(try_wait_token_mk is None): + try_wait_token_mk = pipe_mk.consumer_try_wait(consumer_state_mk, loc=loc, ip=ip) + pipe_mk.consumer_wait( + consumer_state_mk, + try_wait_token_mk, + loc=loc, + ip=ip, + ) + if const_expr(not join_split_tma_barrier): + if const_expr(try_wait_token_nk is None): + try_wait_token_nk = pipe_nk.consumer_try_wait(consumer_state_nk, loc=loc, ip=ip) + pipe_nk.consumer_wait( + consumer_state_nk, + try_wait_token_nk, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mxf4nvf4_split_tma_consumer_release( + pipe_mk: PipelineTmaWarpMma, + pipe_nk: PipelineTmaWarpMma, + consumer_state_mk: pipeline.PipelineState, + consumer_state_nk: pipeline.PipelineState, + *, + join_split_tma_barrier: bool = True, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Release the SM120 split-TMA MK/NK consumer stages.""" + pipe_mk.consumer_release(consumer_state_mk, loc=loc, ip=ip) + if const_expr(not join_split_tma_barrier): + pipe_nk.consumer_release(consumer_state_nk, loc=loc, ip=ip) + + +def make_mxf4nvf4_native_tma_pipeline( + barrier_storage: cute.Tensor, + *, + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + ab_smem_format: str = "packed", + producer_group=None, + consumer_group=None, +): + """Create the SM120 native A/B/SFA/SFB TMA load pipeline.""" + _check_positive("num_stages", num_stages) + if producer_group is None: + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + if consumer_group is None: + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 8) + return PipelineTmaWarpMma.create( + num_stages=num_stages, + producer_group=producer_group, + consumer_group=consumer_group, + tx_count=mxf4nvf4_full_tma_tx_bytes( + tile_m, + tile_n, + tile_k, + sf_vec_size, + ab_smem_format=ab_smem_format, + ), + barrier_storage=barrier_storage, + ) + + +@dsl_user_op +def producer_acquire_native_tma_already_elected( + pipe: PipelineTmaWarpMma, + state: pipeline.PipelineState, + try_acquire_token: Optional[cutlass.Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Acquire a native-TMA pipeline stage from inside an elected producer lane.""" + pipe.producer_acquire_already_elected( + state, + try_acquire_token, + loc=loc, + ip=ip, + ) + + +def _as_i32_ir_value(value, *, loc=None, ip=None): + if hasattr(value, "ir_value"): + return cutlass.Int32(value).ir_value(loc=loc, ip=ip) + return cutlass.Int32(value).ir_value(loc=loc, ip=ip) + + +def _flatten_coord_values(coord) -> list: + if isinstance(coord, tuple): + values = [] + for item in coord: + values.extend(_flatten_coord_values(item)) + return values + return [coord] + + +@dsl_user_op +def _issue_native_tma_load_already_elected( + tma_atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + tma_bar_ptr: cute.Pointer, + *, + cache_policy: Optional[cutlass.Int64] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one non-multicast native TMA load from an already elected lane.""" + exec_atom = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip) + tma_desc_ptr_type = ir.Type.parse( + "!cute.ptr>" + ) + tma_desc_ptr = _cute_nvgpu_ir.get_tma_desc_addr(tma_desc_ptr_type, exec_atom, loc=loc, ip=ip) + coords = [ + _as_i32_ir_value(coord, loc=loc, ip=ip) for coord in _flatten_coord_values(src.iterator) + ] + _nvvm_ir.CpAsyncBulkTensorGlobalToSharedCTAOp( + dstMem=dst.iterator.llvm_ptr, + tmaDescriptor=tma_desc_ptr.llvm_ptr, + mbar=tma_bar_ptr.llvm_ptr, + coordinates=coords, + l2CacheHint=cache_policy.ir_value(loc=loc, ip=ip) if cache_policy is not None else None, + loc=loc, + ip=ip, + ) + + +class Mxf4Nvf4CooperativeSchedule: + """Composable SM120 MXF4NVF4 cooperative schedule contract. + + This is intentionally small: it owns the launch metadata and dynamic + register-allocation role split that are easy to get wrong, while leaving + descriptor routing, pipelines, fragment movement, MMA issue order, and + epilogue composition to the caller. + """ + + def __init__( + self, + *, + producer_warpgroups: int = 1, + consumer_warpgroups: int = 2, + producer_warpgroup_start: int | None = None, + regs_producer: int = MXF4NVF4_COOPERATIVE_PRODUCER_REGS, + regs_consumer: int = MXF4NVF4_COOPERATIVE_CONSUMER_REGS, + min_ctas_per_sm: int = 1, + tma_issue_groups: int = 3, + consumer_issue_groups: int = 2, + ) -> None: + _check_positive("producer_warpgroups", producer_warpgroups) + _check_positive("consumer_warpgroups", consumer_warpgroups) + _check_positive("regs_producer", regs_producer) + _check_positive("regs_consumer", regs_consumer) + _check_positive("min_ctas_per_sm", min_ctas_per_sm) + _check_positive("tma_issue_groups", tma_issue_groups) + _check_positive("consumer_issue_groups", consumer_issue_groups) + if producer_warpgroup_start is None: + producer_warpgroup_start = consumer_warpgroups + if producer_warpgroup_start < 0: + raise ValueError("`producer_warpgroup_start` must be non-negative") + self.producer_warpgroups = producer_warpgroups + self.consumer_warpgroups = consumer_warpgroups + self.producer_warpgroup_start = producer_warpgroup_start + self.regs_producer = regs_producer + self.regs_consumer = regs_consumer + self.min_ctas_per_sm = min_ctas_per_sm + self.tma_issue_groups = tma_issue_groups + self.consumer_issue_groups = consumer_issue_groups + + @property + def threads_per_cta(self) -> int: + warpgroups = self.producer_warpgroups + self.consumer_warpgroups + return warpgroups * MXF4NVF4_COOPERATIVE_THREADS_PER_WARPGROUP + + @property + def warps_per_warpgroup(self) -> int: + return MXF4NVF4_COOPERATIVE_THREADS_PER_WARPGROUP // 32 + + @property + def producer_warps(self) -> int: + return self.producer_warpgroups * self.warps_per_warpgroup + + @property + def consumer_warps(self) -> int: + return self.consumer_warpgroups * self.warps_per_warpgroup + + @property + def consumer_warp_start(self) -> int: + return 0 + + @property + def consumer_warp_end(self) -> int: + return self.consumer_warp_start + self.consumer_warps + + @property + def producer_warp_start(self) -> int: + return self.producer_warpgroup_start * self.warps_per_warpgroup + + @property + def producer_warp_end(self) -> int: + return self.producer_warp_start + self.producer_warps + + @property + def producer_issue_warp(self) -> int: + return self.producer_warp_start + + def is_consumer_warp(self, warp_idx: cutlass.Int32 | int) -> cutlass.Boolean: + return cutlass.Boolean(warp_idx >= self.consumer_warp_start) & cutlass.Boolean( + warp_idx < self.consumer_warp_end + ) + + def is_producer_warp(self, warp_idx: cutlass.Int32 | int) -> cutlass.Boolean: + return cutlass.Boolean(warp_idx >= self.producer_warp_start) & cutlass.Boolean( + warp_idx < self.producer_warp_end + ) + + def is_producer_issue_warp(self, warp_idx: cutlass.Int32 | int) -> cutlass.Boolean: + return cutlass.Boolean(warp_idx == self.producer_issue_warp) + + def launch_kwargs(self) -> dict[str, tuple[int, int, int] | int]: + return mxf4nvf4_cooperative_launch_kwargs( + producer_warpgroups=self.producer_warpgroups, + consumer_warpgroups=self.consumer_warpgroups, + min_ctas_per_sm=self.min_ctas_per_sm, + ) + + def sass_count_targets(self) -> dict[str, int]: + return mxf4nvf4_cooperative_sass_count_targets( + tma_issue_groups=self.tma_issue_groups, + consumer_issue_groups=self.consumer_issue_groups, + ) + + def setmaxregister_role(self, warp_idx: cutlass.Int32) -> None: + setmaxregister_mxf4nvf4_cooperative_role( + warp_idx, + producer_warpgroup_start=self.producer_warpgroup_start, + producer_warpgroups=self.producer_warpgroups, + regs_producer=self.regs_producer, + regs_consumer=self.regs_consumer, + ) + + def setmaxregister_producer(self) -> None: + setmaxregister_mxf4nvf4_producer(self.regs_producer) + + def setmaxregister_consumer(self) -> None: + setmaxregister_mxf4nvf4_consumer(self.regs_consumer) + + +def make_mxf4nvf4_cooperative_schedule( + **kwargs, +) -> Mxf4Nvf4CooperativeSchedule: + """Create an SM120 MXF4NVF4 cooperative schedule facade.""" + return Mxf4Nvf4CooperativeSchedule(**kwargs) + + +@dsl_user_op +def setmaxregister_mxf4nvf4_producer( + regs_producer: int = MXF4NVF4_COOPERATIVE_PRODUCER_REGS, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Apply the SM120 MXF4NVF4 producer-side register deallocation.""" + _check_positive("regs_producer", regs_producer) + cute.arch.setmaxregister_decrease(regs_producer, loc=loc, ip=ip) + + +@dsl_user_op +def setmaxregister_mxf4nvf4_consumer( + regs_consumer: int = MXF4NVF4_COOPERATIVE_CONSUMER_REGS, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Apply the SM120 MXF4NVF4 consumer-side register allocation.""" + _check_positive("regs_consumer", regs_consumer) + cute.arch.setmaxregister_increase(regs_consumer, loc=loc, ip=ip) + + +@cute.jit(preprocess=True) +def setmaxregister_mxf4nvf4_cooperative_role( + warp_idx: cutlass.Int32, + producer_warpgroup_start: int = 0, + producer_warpgroups: int = 1, + regs_producer: int = 40, + regs_consumer: int = 232, +) -> None: + """Apply SM120 cooperative producer/consumer dynamic register allocation. + + The role is selected at warpgroup granularity. Callers should launch with + `max_number_threads` and `min_blocks_per_mp` metadata, for example via + `mxf4nvf4_cooperative_launch_kwargs()`, otherwise ptxas may keep PTX + `setmaxnreg` text but omit SASS `USETMAXREG`. + """ + if const_expr(producer_warpgroups <= 0): + raise ValueError("`producer_warpgroups` must be positive") + if const_expr(regs_producer <= 0): + raise ValueError("`regs_producer` must be positive") + if const_expr(regs_consumer <= 0): + raise ValueError("`regs_consumer` must be positive") + if const_expr(producer_warpgroup_start < 0): + raise ValueError("`producer_warpgroup_start` must be non-negative") + warpgroup_idx = cute.arch.make_warp_uniform(warp_idx // 4) + producer_warpgroup_end = producer_warpgroup_start + producer_warpgroups + if warpgroup_idx < producer_warpgroup_start: + cute.arch.setmaxregister_increase(regs_consumer) + else: + if warpgroup_idx < producer_warpgroup_end: + cute.arch.setmaxregister_decrease(regs_producer) + else: + cute.arch.setmaxregister_increase(regs_consumer) + + +@dsl_user_op +def issue_mxf4nvf4_native_tma_stage( + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + tma_bar_ptr: cute.Pointer, + stage_idx: cutlass.Int32 | int = 0, + *, + batch_idx: cutlass.Int32 | int = 0, + already_elected: cutlass.Constexpr[bool] = False, + cache_policy=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one native TMA stage for A/B and SFA/SFB atoms. + + The TMA tensors are expected to be already tiled/sliced to the CTA tile the + caller wants to load. This helper owns the repetitive SM120 partition and + copy plumbing for the descriptor-free atom path. + """ + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sA, 0, 2, loc=loc, ip=ip), + cute.group_modes(tma_tensor_a, 0, 2, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sB, 0, 2, loc=loc, ip=ip), + cute.group_modes(tma_tensor_b, 0, 2, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tSFAs, tSFAg = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1, loc=loc, ip=ip), + sSFA, + tma_tensor_sfa, + loc=loc, + ip=ip, + ) + tSFBs, tSFBg = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1, loc=loc, ip=ip), + sSFB, + tma_tensor_sfb, + loc=loc, + ip=ip, + ) + + if cutlass.const_expr(already_elected): + _issue_native_tma_load_already_elected( + tma_atom_a, + tAgA[(None, batch_idx)], + tAsA[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_b, + tBgB[(None, batch_idx)], + tBsB[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_sfa, + tSFAg[(None, 0, 0, batch_idx)], + tSFAs[(None, 0, 0, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_sfb, + tSFBg[(None, 0, 0, batch_idx)], + tSFBs[(None, 0, 0, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + return + + cute.copy( + tma_atom_a, + tAgA[(None, batch_idx)], + tAsA[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_b, + tBgB[(None, batch_idx)], + tBsB[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfa, + tSFAg[(None, 0, 0, batch_idx)], + tSFAs[(None, 0, 0, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfb, + tSFBg[(None, 0, 0, batch_idx)], + tSFBs[(None, 0, 0, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def local_tile_mxf4nvf4_native_tma_tensors( + tma_tensor_a: cute.Tensor, + tma_tensor_b: cute.Tensor, + tma_tensor_sfa: cute.Tensor, + tma_tensor_sfb: cute.Tensor, + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int, + *, + ab_cta_tiler: cute.Tile = (128, 128, 1), + scale_cta_tiler: cute.Tile = (128, 8, 1, 1), + scale_smem_format: str = "physical", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]: + """Local-tile native SM120 TMA tensors for one scheduler work tile.""" + tile_coords = mxf4nvf4_scheduler_tile_tma_coords(tile_mnl, k_tile) + if scale_smem_format == "interleaved": + scale_cta_tiler = ( + scale_cta_tiler[0], + 4, + MXF4NVF4_CTA_SHAPE_MNK[2] // (MXF4NVF4_SCALE_VEC_SIZE * 4), + 1, + ) + scale_tile_coord_sfa = (tile_mnl[0], 0, k_tile, tile_mnl[2]) + scale_tile_coord_sfb = (tile_mnl[1], 0, k_tile, tile_mnl[2]) + elif scale_smem_format == "physical": + scale_tile_coord_sfa = tile_coords["scale_tile_coord_sfa"] + scale_tile_coord_sfb = tile_coords["scale_tile_coord_sfb"] + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + return ( + cute.local_tile( + tma_tensor_a, + ab_cta_tiler, + tile_coords["ab_tile_coord_a"], + loc=loc, + ip=ip, + ), + cute.local_tile( + tma_tensor_b, + ab_cta_tiler, + tile_coords["ab_tile_coord_b"], + loc=loc, + ip=ip, + ), + cute.local_tile( + tma_tensor_sfa, + scale_cta_tiler, + scale_tile_coord_sfa, + loc=loc, + ip=ip, + ), + cute.local_tile( + tma_tensor_sfb, + scale_cta_tiler, + scale_tile_coord_sfb, + loc=loc, + ip=ip, + ), + ) + + +@dsl_user_op +def issue_mxf4nvf4_native_tma_stage_for_tile( + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + tma_bar_ptr: cute.Pointer, + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int = 0, + stage_idx: cutlass.Int32 | int = 0, + *, + batch_idx: cutlass.Int32 | int = 0, + ab_cta_tiler: cute.Tile = (128, 128, 1), + scale_cta_tiler: cute.Tile = (128, 8, 1, 1), + scale_smem_format: str = "physical", + already_elected: cutlass.Constexpr[bool] = False, + cache_policy=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Local-tile native TMA tensors for one scheduler tile and issue the stage.""" + ( + tiled_tma_tensor_a, + tiled_tma_tensor_b, + tiled_tma_tensor_sfa, + tiled_tma_tensor_sfb, + ) = local_tile_mxf4nvf4_native_tma_tensors( + tma_tensor_a, + tma_tensor_b, + tma_tensor_sfa, + tma_tensor_sfb, + tile_mnl, + k_tile, + ab_cta_tiler=ab_cta_tiler, + scale_cta_tiler=scale_cta_tiler, + scale_smem_format=scale_smem_format, + loc=loc, + ip=ip, + ) + issue_mxf4nvf4_native_tma_stage( + tma_atom_a, + tiled_tma_tensor_a, + tma_atom_b, + tiled_tma_tensor_b, + tma_atom_sfa, + tiled_tma_tensor_sfa, + tma_atom_sfb, + tiled_tma_tensor_sfb, + sA, + sB, + sSFA, + sSFB, + tma_bar_ptr, + stage_idx, + batch_idx=batch_idx, + already_elected=already_elected, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_native_tma_full_tile_consumer_group( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + acc: cute.Tensor, + tidx: cutlass.Int32, + stage_idx: cutlass.Int32 | int = 0, + *, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + ab_smem_format: str = "packed", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one K128 native-TMA consumer group for a full 128x128 tile.""" + _check_default_tile(tile_m, tile_k, sf_vec_size) + _check_default_tile(tile_n, tile_k, sf_vec_size) + ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + if ab_smem_format != "packed": + raise ValueError( + "SM120 native full-tile consumer group currently supports only ab_smem_format='packed'" + ) + + a_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((tile_m, tile_k), loc=loc, ip=ip), + cutlass.Float4E2M1FN, + loc=loc, + ip=ip, + ) + b_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((tile_n, tile_k), loc=loc, ip=ip), + cutlass.Float4E2M1FN, + loc=loc, + ip=ip, + ) + copy_atom_a, copy_atom_b = make_mxf4nvf4_ab_smem_copy_atoms(loc=loc, ip=ip) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma, loc=loc, ip=ip) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma, loc=loc, ip=ip) + thr_copy_a = tiled_copy_a.get_slice(tidx) + thr_copy_b = tiled_copy_b.get_slice(tidx) + sA_src = cute.as_position_independent_swizzle_tensor(sA_consumer, loc=loc, ip=ip) + sB_src = cute.as_position_independent_swizzle_tensor(sB_consumer, loc=loc, ip=ip) + tCsA = thr_copy_a.partition_S(sA_src, loc=loc, ip=ip) + tCsB = thr_copy_b.partition_S(sB_src, loc=loc, ip=ip) + tCrA = thr_copy_a.retile_D(a_frag, loc=loc, ip=ip) + tCrB = thr_copy_b.retile_D(b_frag, loc=loc, ip=ip) + + sfa_frag, sfb_frag = make_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + tidx, + tile_shape_mnk=(tile_m, tile_n, tile_k), + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + + issue_mxf4nvf4_direct_tma_consumer_group( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + a_frag, + b_frag, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + acc, + tidx, + stage_idx, + major_extent_sfa=tile_m, + major_extent_sfb=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_tiled_mma( + atom_layout_mnk: Tuple[int, int, int] = (1, 1, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: + """Create the SM120 warp-level MXF4NVF4 tiled MMA.""" + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, + cutlass.Float32, + cutlass.Float8E4M3FN, + ) + return cute.make_tiled_mma(mma_op, atom_layout_mnk=atom_layout_mnk, loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_79a_tiled_mma( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: + """Create the 79a-style SM120 128x128 ping-pong tiled MMA. + + This is the compact 4-warpgroup-local layout used by the fast SM120 + NVFP4 path: a (2,2,1) MMA atom layout with the N-major permutation needed + by the STSM epilogue schedule. + """ + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, + cutlass.Float32, + cutlass.Float8E4M3FN, + ) + return cute.make_tiled_mma( + mma_op, + atom_layout_mnk=cute.make_layout((2, 2, 1), stride=(1, 2, 0), loc=loc, ip=ip), + permutation_mnk=( + 128, + cute.make_layout((8, 2, 2), stride=(1, 16, 8), loc=loc, ip=ip), + 64, + ), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def convert_mxf4nvf4_acc_layout_for_epilogue_stmatrix( + acc_layout: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """View SM120 accumulator registers in the fragment order used by STSM.""" + if const_expr(cute.rank(acc_layout.shape[0]) == 3): + div = 2 if const_expr(acc_layout.shape[0][2] % 2 == 0) else 1 + divided = cute.logical_divide(acc_layout, ((None, None, div), None, None), loc=loc, ip=ip) + return cute.make_layout( + ( + (divided.shape[0][0], divided.shape[0][1], divided.shape[0][2][0]), + divided.shape[1], + (divided.shape[0][2][1], divided.shape[2]), + ), + stride=( + ( + divided.stride[0][0], + divided.stride[0][1], + divided.stride[0][2][0], + ), + divided.stride[1], + (divided.stride[0][2][1], divided.stride[2]), + ), + loc=loc, + ip=ip, + ) + if acc_layout.shape[2] % 2 != 0: + raise ValueError("SM120 epilogue STSM accumulator view requires even N modes") + divided = cute.logical_divide(acc_layout, (None, None, 2), loc=loc, ip=ip) + return cute.make_layout( + ( + (divided.shape[0][0], divided.shape[0][1], divided.shape[2][0]), + divided.shape[1], + divided.shape[2][1], + ), + stride=( + (divided.stride[0][0], divided.stride[0][1], divided.stride[2][0]), + divided.stride[1], + divided.stride[2][1], + ), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def retile_mxf4nvf4_accumulator_for_epilogue_stmatrix( + acc: cute.Tensor, + tRS_rD: cute.Tensor, + tiled_copy_r2s: cute.TiledCopy, + *, + epi_tile_shape: cute.Tile = (2, 1), + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Tensor: + """Retile a full SM120 accumulator fragment for one epilogue SMEM tile.""" + return tiled_copy_r2s.retile(acc, loc=loc, ip=ip) + + +@dsl_user_op +def load_mxf4nvf4_accumulator_epilogue_subtile( + tRS_rAcc: cute.Tensor, + tRS_rD: cute.Tensor, + epi_coord: cute.Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one accumulator epilogue subtile into the STSM source registers.""" + tRS_rD_flat = cute.coalesce(tRS_rD, loc=loc, ip=ip) + for mma_n_in_epi in range(2): + for mma_m_in_epi in range(2): + idx = mma_n_in_epi * 2 + mma_m_in_epi + tRS_rAcc_flat = cute.coalesce( + tRS_rAcc[ + None, + epi_coord[0] * 2 + mma_m_in_epi, + epi_coord[1] * 2 + mma_n_in_epi, + ], + loc=loc, + ip=ip, + ) + for epi_v in range(4): + tRS_rD_flat[idx * 4 + epi_v] = tRS_rAcc_flat[epi_v].to(tRS_rD.element_type) + + +@dsl_user_op +def copy_mxf4nvf4_epilogue_registers_to_smem( + tiled_copy_r2s: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Convert epilogue registers to the SMEM type and issue the STSM copy.""" + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type), loc=loc, ip=ip) + src = src_cvt + cute.copy(tiled_copy_r2s, src, dst, loc=loc, ip=ip) + cute.arch.fence_view_async_shared() + + +@dsl_user_op +def make_mxf4nvf4_epilogue_stmatrix_views( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sD_tile: cute.Tensor, + tidx: cutlass.Int32, + *, + epi_tile_shape: cute.Tile = (2, 1), + num_matrices: int = 2, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """Create the SM120 STSM epilogue copy views for a BF16 SMEM tile.""" + if num_matrices != 2: + raise ValueError("SM120 MXF4NVF4 epilogue STSM helper currently requires x2") + copy_atom_c = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=False, num_matrices=num_matrices), + cutlass.Float16, + loc=loc, + ip=ip, + ) + tiled_copy_c_atom = cute.make_tiled_copy_C_atom(copy_atom_c, tiled_mma, loc=loc, ip=ip) + copy_atom_r2s = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(transpose=False, num_matrices=num_matrices), + sD_tile.element_type, + loc=loc, + ip=ip, + ) + tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_c_atom, loc=loc, ip=ip) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD_tile, loc=loc, ip=ip) + tRS_rD_shape = thr_copy_r2s.partition_S( + cute.make_identity_tensor(sD_tile.shape, loc=loc, ip=ip), loc=loc, ip=ip + ).shape + tRS_rD = cute.make_rmem_tensor(tRS_rD_shape, acc.element_type, loc=loc, ip=ip) + tRS_rAcc = retile_mxf4nvf4_accumulator_for_epilogue_stmatrix( + acc, + tRS_rD, + tiled_copy_r2s, + epi_tile_shape=epi_tile_shape, + loc=loc, + ip=ip, + ) + return tiled_copy_r2s, tRS_rD, tRS_sD, tRS_rAcc + + +def mxf4nvf4_79a_epilogue_tile( + tile_m: int = 128, + tile_n: int = 128, +) -> tuple[int, int]: + """Return the 79a-style SM120 NVFP4 epilogue TMA-store subtile.""" + if tile_m != 128 or tile_n != 128: + raise ValueError("SM120 MXF4NVF4 79a epilogue tile currently requires a 128x128 CTA tile") + return (64, 32) + + +@dsl_user_op +def make_mxf4nvf4_epilogue_smem_layout( + *, + epi_tile: cute.Tile = (64, 32), + num_stages: int = 1, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Create the BF16 epilogue SMEM layout used by the SM120 fast store path.""" + _check_tuple("epi_tile", epi_tile, 2) + epi_m, epi_n = epi_tile + _check_positive("epi_m", epi_m) + _check_positive("epi_n", epi_n) + _check_positive("num_stages", num_stages) + return cute.make_layout( + (epi_m, epi_n, num_stages), + stride=(epi_n, 1, epi_m * epi_n), + loc=loc, + ip=ip, + ) + + +def make_mxf4nvf4_epilogue_tma_store_atom( + gD: cute.Tensor, + smem_layout, + *, + epi_tile: cute.Tile = (64, 32), + op=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create the SM120 MXF4/NVFP4 BF16 epilogue S2G TMA atom and tensor.""" + _check_tuple("epi_tile", epi_tile, 2) + if op is None: + op = cpasync.CopyBulkTensorTileS2GOp() + smem_rank = cute.rank(smem_layout) + if smem_rank == cute.rank(epi_tile) + 1: + smem_layout = cute.slice_(smem_layout, (None, None, 0), loc=loc, ip=ip) + d_cta_v_layout = cute.composition( + cute.make_identity_layout(gD.shape, loc=loc, ip=ip), + epi_tile, + loc=loc, + ip=ip, + ) + return cpasync.make_tiled_tma_atom( + op, + gD, + smem_layout, + d_cta_v_layout, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def partition_mxf4nvf4_epilogue_tma_store( + tma_atom_d: cute.CopyAtom, + tma_tensor_d: cute.Tensor, + sD_epi: cute.Tensor, + tile_mnl: tuple[cutlass.Int32 | int, ...], + *, + cta_tiler: cute.Tile = (128, 128, 1), + epi_tile: cute.Tile = (64, 32), + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[cute.Tensor, cute.Tensor]: + """Partition one scheduler-selected BF16 epilogue S2G TMA store tile.""" + _check_tuple("tile_mnl", tile_mnl, 3) + _check_tuple("cta_tiler", cta_tiler, 3) + _check_tuple("epi_tile", epi_tile, 2) + tiled_d = cute.local_tile( + tma_tensor_d, + cta_tiler[:2], + (None, None, None), + loc=loc, + ip=ip, + ) + tile_d = tiled_d[(None, None, tile_mnl[0], tile_mnl[1], tile_mnl[2])] + epi_d = cute.zipped_divide(tile_d, epi_tile, loc=loc, ip=ip) + return cpasync.tma_partition( + tma_atom_d, + 0, + cute.make_layout(1), + cute.group_modes(sD_epi, 0, 2, loc=loc, ip=ip), + epi_d, + loc=loc, + ip=ip, + ) + + +def issue_mxf4nvf4_epilogue_tma_store( + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + *, + epi_m: int = 0, + epi_n: int = 0, + stage_idx: int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one selected SM120 MXF4/NVFP4 BF16 epilogue S2G TMA subtile.""" + cute.copy( + tma_atom_d, + tDsD[None, stage_idx], + tDgD[None, (epi_m, epi_n)], + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def stage_mxf4nvf4_accumulator_fragment_D_to_smem( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + sD_tile: cute.Tensor, + tidx: cutlass.Int32, + *, + epi_m: int = 0, + epi_n: int = 0, + epi_tile_shape: cute.Tile = (2, 1), + num_matrices: int = 2, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage one SM120 MXF4/NVFP4 accumulator epilogue tile with STSM. + + The default shape matches the 79a-style 128x128 CTA tile split into + subtiles. ``epi_m`` and ``epi_n`` select the epilogue subtile to stage. + """ + tiled_copy_r2s, tRS_rD, tRS_sD, tRS_rAcc = make_mxf4nvf4_epilogue_stmatrix_views( + tiled_mma, + acc, + sD_tile, + tidx, + epi_tile_shape=epi_tile_shape, + num_matrices=num_matrices, + loc=loc, + ip=ip, + ) + load_mxf4nvf4_accumulator_epilogue_subtile(tRS_rAcc, tRS_rD, (epi_m, epi_n), loc=loc, ip=ip) + copy_mxf4nvf4_epilogue_registers_to_smem(tiled_copy_r2s, tRS_rD, tRS_sD, loc=loc, ip=ip) + + +@dsl_user_op +def store_mxf4nvf4_accumulator_fragment_D( + thr_mma: cute.ThrMma, + acc: cute.Tensor, + gD: cute.Tensor, + pred: Optional[cute.Tensor] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Store one SM120 warp-MMA accumulator fragment directly to global D.""" + tDgD = thr_mma.partition_C(gD) + rD = cute.make_rmem_tensor(acc.layout, gD.element_type) + rD.store(acc.load().to(gD.element_type)) + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gD.element_type, + loc=loc, + ip=ip, + ) + if const_expr(pred is None): + cute.copy(copy_atom, rD, tDgD, loc=loc, ip=ip) + else: + cute.copy(copy_atom, rD, tDgD, pred=pred, loc=loc, ip=ip) + + +@dsl_user_op +def local_tile_mxf4nvf4_d_tensor( + gD: cute.Tensor, + tile_mnl: tuple[cutlass.Int32 | int, ...], + *, + cta_tiler: cute.Tile = (128, 128, 1), + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Tensor: + """Local-tile the SM120 output tensor for one scheduler work tile.""" + _check_tuple("tile_mnl", tile_mnl, 3) + if const_expr(cute.rank(gD) == 2): + return cute.local_tile( + gD, + (cta_tiler[0], cta_tiler[1]), + (tile_mnl[0], tile_mnl[1]), + loc=loc, + ip=ip, + ) + gD_tile = cute.local_tile( + gD, + cta_tiler, + tile_mnl, + loc=loc, + ip=ip, + ) + return gD_tile[(None, None, 0)] + + +@dsl_user_op +def store_mxf4nvf4_accumulator_fragment_D_for_tile( + thr_mma: cute.ThrMma, + acc: cute.Tensor, + gD: cute.Tensor, + tile_mnl: tuple[cutlass.Int32 | int, ...], + pred: Optional[cute.Tensor] = None, + *, + cta_tiler: cute.Tile = (128, 128, 1), + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Store one accumulator fragment to the scheduler-selected D tile.""" + gD_tile = local_tile_mxf4nvf4_d_tensor( + gD, + tile_mnl, + cta_tiler=cta_tiler, + loc=loc, + ip=ip, + ) + store_mxf4nvf4_accumulator_fragment_D( + thr_mma, + acc, + gD_tile, + pred=pred, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def store_mxf4nvf4_accumulator_fragment_D_for_tiled_mma_tile( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + gD: cute.Tensor, + tile_mnl: tuple[cutlass.Int32 | int, ...], + tidx: cutlass.Int32, + pred: Optional[cute.Tensor] = None, + *, + cta_tiler: cute.Tile = (128, 128, 1), + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Direct-store one SM120 accumulator tile from a tiled-MMA slice. + + Ping-pong mainloops commonly have two consumer warpgroups resident in one + CTA. Each warpgroup can store a complete 128x128 accumulator tile when the + caller gates ownership with a surrounding runtime branch. This helper keeps + the selected warpgroup's direct global-store path compact so callers do not + need to route through BF16 epilogue SMEM/TMA staging. + """ + thr_mma = tiled_mma.get_slice(tidx) + store_mxf4nvf4_accumulator_fragment_D_for_tile( + thr_mma, + acc, + gD, + tile_mnl, + pred=pred, + cta_tiler=cta_tiler, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_tma_physical_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the A/B physical SMEM byte layout populated by external TMA.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.make_layout( + (major_extent, tile_k, num_stages), + stride=(tile_k, 1, major_extent * tile_k), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_consumer_smem_layout_atom_ab( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the SM120 packed-FP4 consumer SMEM atom layout. + + This mirrors the layout atom selected by the 79a C++ collective: + `Sw<2,4,3> o smem_ptr[4b] o (_8,_128):(_128,_1)`. + """ + return cute.make_composed_layout( + cute.make_swizzle(2, 4, 3, loc=loc, ip=ip), + 0, + cute.make_layout((8, 128), stride=(128, 1), loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_a_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the staged 79a-style A consumer SMEM layout.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.tile_to_shape( + make_mxf4nvf4_consumer_smem_layout_atom_ab(loc=loc, ip=ip), + (major_extent, tile_k, num_stages), + (0, 1, 2), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_b_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the staged 79a-style B consumer SMEM layout.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.tile_to_shape( + make_mxf4nvf4_consumer_smem_layout_atom_ab(loc=loc, ip=ip), + (major_extent, tile_k, num_stages), + (0, 1, 2), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_direct_tma_consumer_smem_layout_atom_ab( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the packed-U4 direct-TMA consumer SMEM atom layout. + + This is the 79a-style `UMMA::Layout_K_SW128_Atom` layout for + loading A/B directly from TMA into the SMEM layout consumed by LDSM. + """ + return cute.make_composed_layout( + cute.make_swizzle(3, 4, 3, loc=loc, ip=ip), + 0, + cute.make_layout((8, 128), stride=(128, 1), loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_a_direct_tma_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the staged A direct-TMA consumer SMEM layout.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.tile_to_shape( + make_mxf4nvf4_direct_tma_consumer_smem_layout_atom_ab(loc=loc, ip=ip), + (major_extent, tile_k, num_stages), + (0, 1, 2), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_b_direct_tma_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the staged B direct-TMA consumer SMEM layout.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.tile_to_shape( + make_mxf4nvf4_direct_tma_consumer_smem_layout_atom_ab(loc=loc, ip=ip), + (major_extent, tile_k, num_stages), + (0, 1, 2), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_a_packed_direct_tma_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the packed-FP4 A direct-TMA consumer SMEM layout. + + This is the compact FP4 consumer layout used by the native CuTe TMA atom + path when A/B SMEM format is packed. + """ + return make_mxf4nvf4_a_consumer_smem_layout_staged( + major_extent, tile_k, num_stages, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_mxf4nvf4_b_packed_direct_tma_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the packed-FP4 B direct-TMA consumer SMEM layout.""" + return make_mxf4nvf4_b_consumer_smem_layout_staged( + major_extent, tile_k, num_stages, loc=loc, ip=ip + ) + + +def make_mxf4nvf4_ab_direct_tma_consumer_smem_views( + smem: SmemAllocator, + *, + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate Uint8 A/B SMEM views for direct consumer-layout TMA.""" + layout_a = make_mxf4nvf4_a_direct_tma_consumer_smem_layout_staged(tile_m, tile_k, num_stages) + layout_b = make_mxf4nvf4_b_direct_tma_consumer_smem_layout_staged(tile_n, tile_k, num_stages) + return ( + smem.allocate_tensor(cutlass.Uint8, layout_a, byte_alignment=128), + smem.allocate_tensor(cutlass.Uint8, layout_b, byte_alignment=128), + ) + + +def make_mxf4nvf4_ab_packed_direct_tma_consumer_smem_views( + smem: SmemAllocator, + *, + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate packed-FP4 A/B SMEM views for direct consumer-layout TMA.""" + return make_mxf4nvf4_ab_consumer_smem_views( + smem, num_stages=num_stages, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k + ) + + +def make_mxf4nvf4_native_tma_smem_views( + smem: SmemAllocator, + *, + tiled_mma: Optional[cute.TiledMma] = None, + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + ab_smem_format: str = "packed", + scale_smem_format: str = "physical", +) -> Tuple[cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]: + """Allocate A/B/SFA/SFB SMEM views for native SM120 TMA atoms.""" + ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + if tiled_mma is None: + tiled_mma = make_mxf4nvf4_tiled_mma() + if ab_smem_format == "unpack": + sA, sB = make_mxf4nvf4_ab_direct_tma_consumer_smem_views( + smem, + num_stages=num_stages, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + ) + else: + sA, sB = make_mxf4nvf4_ab_packed_direct_tma_consumer_smem_views( + smem, + num_stages=num_stages, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + ) + if scale_smem_format == "interleaved": + sSFA, sSFB = allocate_mxf4nvf4_scale_tma_interleaved( + smem, + tiled_mma=tiled_mma, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + num_stages=num_stages, + ) + elif scale_smem_format == "physical": + sSFA = allocate_mxf4nvf4_scale_tma_physical( + smem, + major_extent=tile_m, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + num_stages=num_stages, + ) + sSFB = allocate_mxf4nvf4_scale_tma_physical( + smem, + major_extent=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + num_stages=num_stages, + ) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + return (sA, sB, sSFA, sSFB) + + +@dsl_user_op +def make_mxf4nvf4_ab_packed_direct_tma_consumer_tma_views( + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Reinterpret packed-FP4 direct consumer SMEM as byte TMA destinations.""" + return ( + cute.recast_tensor(sA_consumer, cutlass.Uint8, loc=loc, ip=ip), + cute.recast_tensor(sB_consumer, cutlass.Uint8, loc=loc, ip=ip), + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_direct_tma_consumer_fp4_views( + sA_direct: cute.Tensor, + sB_direct: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Reinterpret direct-TMA Uint8 A/B SMEM as packed FP4 consumer views.""" + return ( + cute.recast_tensor(sA_direct, cutlass.Float4E2M1FN, loc=loc, ip=ip), + cute.recast_tensor(sB_direct, cutlass.Float4E2M1FN, loc=loc, ip=ip), + ) + + +def make_mxf4nvf4_ab_consumer_smem_views( + smem: SmemAllocator, + *, + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate A/B SMEM views for the 79a-style consumer LDSM path.""" + layout_a = make_mxf4nvf4_a_consumer_smem_layout_staged(tile_m, tile_k, num_stages) + layout_b = make_mxf4nvf4_b_consumer_smem_layout_staged(tile_n, tile_k, num_stages) + return ( + smem.allocate_tensor(cutlass.Float4E2M1FN, layout_a, byte_alignment=128), + smem.allocate_tensor(cutlass.Float4E2M1FN, layout_b, byte_alignment=128), + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_consumer_microtile_views( + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Return local 16x8 MMA microtile views from a staged CTA consumer tile. + + Global CTA M/N selection belongs in the tensor-map descriptor coordinates. + This helper only selects the local output atom within the already-staged + 128x128 CTA tile. + """ + return ( + cute.domain_offset( + (m_atom * MXF4NVF4_MMA_SHAPE_MNK[0], 0, 0), + sA_consumer, + loc=loc, + ip=ip, + ), + cute.domain_offset( + (n_atom * MXF4NVF4_MMA_SHAPE_MNK[1], 0, 0), + sB_consumer, + loc=loc, + ip=ip, + ), + ) + + +@dsl_user_op +def make_mxf4nvf4_sfa_smem_layout_staged( + tiled_mma: Optional[cute.TiledMma] = None, + tile_shape_mnk: cute.Tile = MXF4NVF4_CTA_SHAPE_MNK, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the staged SFA SMEM layout for SM120 MXF4NVF4.""" + tiled_mma = make_mxf4nvf4_tiled_mma(loc=loc, ip=ip) if tiled_mma is None else tiled_mma + return blockscaled_layout.sm120_make_smem_layout_sfa( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + num_stages, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_sfb_smem_layout_staged( + tiled_mma: Optional[cute.TiledMma] = None, + tile_shape_mnk: cute.Tile = MXF4NVF4_CTA_SHAPE_MNK, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the staged SFB SMEM layout for SM120 MXF4NVF4.""" + tiled_mma = make_mxf4nvf4_tiled_mma(loc=loc, ip=ip) if tiled_mma is None else tiled_mma + return blockscaled_layout.sm120_make_smem_layout_sfb( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + num_stages, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_tma_physical_as_tiled_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return a logical-K view over rank-4 native scale-TMA SMEM bytes. + + Native scale TMA writes compact FP8 scale columns through the tensor-map + 128B swizzle. The consumer scale fragments are indexed by logical FP4 K + coordinates, where all 16 FP4 values in one scale vector share one FP8 + scale. This layout keeps that logical K surface while preserving the TMA + physical byte mapping in shared memory. + """ + _check_default_tile(major_extent, tile_k, sf_vec_size) + _check_positive("num_stages", num_stages) + if major_extent % 128 != 0: + raise ValueError("SM120 scale TMA logical view requires major_extent % 128 == 0") + scale_k = tile_k // sf_vec_size + if scale_k % 4 != 0: + raise ValueError("SM120 scale TMA logical view requires scale_k % 4 == 0") + physical_major_extent = max(major_extent, 128) + major_tiles = physical_major_extent // 128 + physical_bytes = physical_major_extent * scale_k + layout = cute.make_layout( + (((32, 4), major_tiles), ((sf_vec_size, 4), scale_k // 4), num_stages), + stride=( + ((1, 32), 128), + ((0, physical_major_extent), physical_major_extent * 4), + physical_bytes, + ), + loc=loc, + ip=ip, + ) + return cute.make_composed_layout( + cute.make_swizzle(3, 4, 3, loc=loc, ip=ip), + 0, + layout, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_tma_physical_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Return the rank-4 scale TMA physical destination layout. + + The returned layout carries the 128B tensor-map swizzle used by the SM120 + native scale TMA path. Flattened byte consumers can still use the raw + iterator with mxf4nvf4_scale_tma_physical_offset. + """ + _check_default_tile(major_extent, tile_k, sf_vec_size) + _check_positive("num_stages", num_stages) + scale_k = tile_k // sf_vec_size + physical_major_extent = max(major_extent, 128) + physical_bytes = physical_major_extent * scale_k + if major_extent < physical_major_extent: + layout = cute.make_layout( + (major_extent, scale_k, 1, num_stages), + stride=(1, physical_major_extent, physical_bytes, physical_bytes), + loc=loc, + ip=ip, + ) + else: + layout = cute.make_layout( + (physical_major_extent, scale_k, 1, num_stages), + stride=(1, physical_major_extent, physical_bytes, physical_bytes), + loc=loc, + ip=ip, + ) + return cute.make_composed_layout( + cute.make_swizzle(3, 4, 3, loc=loc, ip=ip), + 0, + layout, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_interleaved_tma_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return compact interleaved FP8 scale TMA SMEM layout.""" + _check_default_tile(major_extent, tile_k, sf_vec_size) + _check_positive("num_stages", num_stages) + scale_k = tile_k // sf_vec_size + if scale_k % 4 != 0: + raise ValueError("SM120 scale interleaved SMEM layout requires scale_k % 4 == 0") + major_tiles = major_extent // 128 + scale_tiles = scale_k // 4 + stage_stride = major_tiles * scale_tiles * 512 + return cute.make_layout( + (((32, 4), major_tiles), 4, scale_tiles, num_stages), + stride=(((16, 4), 512), 1, major_tiles * 512, stage_stride), + loc=loc, + ip=ip, + ) + + +def make_mxf4nvf4_tma_scale_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + """Compatibility alias for the scale TMA physical SMEM layout.""" + return make_mxf4nvf4_scale_tma_physical_layout_staged( + major_extent, tile_k, sf_vec_size, num_stages, loc=loc, ip=ip + ) + + +def allocate_mxf4nvf4_scale_tma_physical( + smem: SmemAllocator, + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, +) -> cute.Tensor: + """Allocate padded SFA/SFB TMA physical storage and return its logical view.""" + view_layout = make_mxf4nvf4_scale_tma_physical_layout_staged( + major_extent, + tile_k, + sf_vec_size, + num_stages, + ) + physical_bytes = mxf4nvf4_scale_physical_smem_bytes(major_extent, tile_k, sf_vec_size) + backing = smem.allocate_tensor( + cutlass.Uint8, + cute.make_layout((physical_bytes, num_stages), stride=(1, physical_bytes)), + byte_alignment=128, + ) + return cute.make_tensor(backing.iterator, view_layout) + + +def allocate_mxf4nvf4_scale_tma_interleaved( + smem: SmemAllocator, + *, + tiled_mma: Optional[cute.TiledMma] = None, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate 79a/SM100-style interleaved scale TMA SMEM views.""" + if tiled_mma is None: + tiled_mma = make_mxf4nvf4_tiled_mma() + sfa_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + tile_m, + tile_k, + sf_vec_size, + num_stages, + ) + sfb_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + tile_n, + tile_k, + sf_vec_size, + num_stages, + ) + return ( + smem.allocate_tensor(cutlass.Uint8, sfa_layout, byte_alignment=128), + smem.allocate_tensor(cutlass.Uint8, sfb_layout, byte_alignment=128), + ) + + +def _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: cute.Layout, + cta_tiler: cute.Tile, + *, + internal_type: Optional[Type[Numeric]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + op = cpasync.CopyBulkTensorTileG2SOp() + return cpasync.make_tiled_tma_atom( + op, + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_tiled_tma_atom_a( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + smem_format: str = "packed", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create the layout-aware TMA atom/tensor for one A tile. + + The default uses the runtime-validated packed FP4 tensor-map format. Keep + the GMEM tensor logically FP4 and pass ``smem_format="unpack"`` or call + ``make_mxf4nvf4_unpack_tiled_tma_atom_a`` explicitly for the experimental + FP4 unpack-SMEM tensor-map path + (``CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B``). + """ + smem_format = _normalize_mxf4nvf4_ab_smem_format(smem_format) + if const_expr(smem_layout is None): + if const_expr(smem_format == "unpack"): + smem_layout = make_mxf4nvf4_a_direct_tma_consumer_smem_layout_staged(loc=loc, ip=ip) + else: + smem_layout = make_mxf4nvf4_a_packed_direct_tma_consumer_smem_layout_staged( + loc=loc, ip=ip + ) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=_mxf4nvf4_ab_tma_internal_type(smem_format), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_tiled_tma_atom_b( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + smem_format: str = "packed", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create the layout-aware TMA atom/tensor for one B tile.""" + smem_format = _normalize_mxf4nvf4_ab_smem_format(smem_format) + if const_expr(smem_layout is None): + if const_expr(smem_format == "unpack"): + smem_layout = make_mxf4nvf4_b_direct_tma_consumer_smem_layout_staged(loc=loc, ip=ip) + else: + smem_layout = make_mxf4nvf4_b_packed_direct_tma_consumer_smem_layout_staged( + loc=loc, ip=ip + ) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=_mxf4nvf4_ab_tma_internal_type(smem_format), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_packed_tiled_tma_atom_a( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create an A TMA atom for packed FP4 SMEM format.""" + return make_mxf4nvf4_tiled_tma_atom_a( + gmem_tensor, + smem_layout, + cta_tiler, + smem_format="packed", + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_packed_tiled_tma_atom_b( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create a B TMA atom for packed FP4 SMEM format.""" + return make_mxf4nvf4_tiled_tma_atom_b( + gmem_tensor, + smem_layout, + cta_tiler, + smem_format="packed", + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_unpack_tiled_tma_atom_a( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create an A TMA atom for FP4 unpack-SMEM format.""" + return make_mxf4nvf4_tiled_tma_atom_a( + gmem_tensor, + smem_layout, + cta_tiler, + smem_format="unpack", + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_unpack_tiled_tma_atom_b( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create a B TMA atom for FP4 unpack-SMEM format.""" + return make_mxf4nvf4_tiled_tma_atom_b( + gmem_tensor, + smem_layout, + cta_tiler, + smem_format="unpack", + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_sfa_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 8, 1, 1), + tiled_mma: Optional[cute.TiledMma] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create the layout-aware TMA atom/tensor for one SFA tile.""" + if const_expr(smem_layout is None): + smem_layout = make_mxf4nvf4_tma_scale_layout_staged(loc=loc, ip=ip) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_sfb_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 8, 1, 1), + tiled_mma: Optional[cute.TiledMma] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create the layout-aware TMA atom/tensor for one SFB tile.""" + if const_expr(smem_layout is None): + smem_layout = make_mxf4nvf4_tma_scale_layout_staged(loc=loc, ip=ip) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_native_tma_atoms( + gA: cute.Tensor, + gB: cute.Tensor, + gSFA: cute.Tensor, + gSFB: cute.Tensor, + *, + tiled_mma: Optional[cute.TiledMma] = None, + ab_smem_format: str = "packed", + ab_cta_tiler: cute.Tile = (128, 128, 1), + ab_tile_coord: Optional[Tuple[int, int, int]] = None, + ab_tile_coord_a: Optional[Tuple[int, int, int]] = None, + ab_tile_coord_b: Optional[Tuple[int, int, int]] = None, + scale_cta_tiler: cute.Tile = (128, 8, 1, 1), + scale_tile_coord: Optional[Tuple[int, int, int, int]] = (0, 0, 0, 0), + scale_tile_coord_sfa: Optional[Tuple[int, int, int, int]] = None, + scale_tile_coord_sfb: Optional[Tuple[int, int, int, int]] = None, + scale_smem_format: str = "physical", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create A/B/SFA/SFB native TMA atoms for the SM120 NVFP4 path. + + `ab_tile_coord` is optional to preserve the legacy single-tile behavior. + Tiled GEMM callers can pass independent `ab_tile_coord_a` and + `ab_tile_coord_b` values because A is tiled by M while B is tiled by N. + + `scale_tile_coord` preserves the single-tile default. Tiled GEMM callers can + pass independent `scale_tile_coord_sfa` and `scale_tile_coord_sfb` values + because SFA is tiled by M while SFB is tiled by N. + """ + ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + gA = _preserve_mxf4nvf4_ab_tma_l_mode(gA) + gB = _preserve_mxf4nvf4_ab_tma_l_mode(gB) + if const_expr(ab_smem_format == "unpack"): + tma_atom_a, tma_tensor_a = make_mxf4nvf4_unpack_tiled_tma_atom_a( + gA, cta_tiler=ab_cta_tiler, loc=loc, ip=ip + ) + tma_atom_b, tma_tensor_b = make_mxf4nvf4_unpack_tiled_tma_atom_b( + gB, cta_tiler=ab_cta_tiler, loc=loc, ip=ip + ) + else: + tma_atom_a, tma_tensor_a = make_mxf4nvf4_packed_tiled_tma_atom_a( + gA, cta_tiler=ab_cta_tiler, loc=loc, ip=ip + ) + tma_atom_b, tma_tensor_b = make_mxf4nvf4_packed_tiled_tma_atom_b( + gB, cta_tiler=ab_cta_tiler, loc=loc, ip=ip + ) + if ab_tile_coord_a is None: + ab_tile_coord_a = ab_tile_coord + if ab_tile_coord_b is None: + ab_tile_coord_b = ab_tile_coord + if ab_tile_coord_a is not None: + tma_tensor_a = cute.local_tile( + tma_tensor_a, + ab_cta_tiler, + ab_tile_coord_a, + loc=loc, + ip=ip, + ) + if ab_tile_coord_b is not None: + tma_tensor_b = cute.local_tile( + tma_tensor_b, + ab_cta_tiler, + ab_tile_coord_b, + loc=loc, + ip=ip, + ) + if scale_smem_format == "interleaved": + if tiled_mma is None: + tiled_mma = make_mxf4nvf4_tiled_mma(loc=loc, ip=ip) + scale_cta_tiler = ( + scale_cta_tiler[0], + 4, + MXF4NVF4_CTA_SHAPE_MNK[2] // (MXF4NVF4_SCALE_VEC_SIZE * 4), + 1, + ) + sfa_smem_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + MXF4NVF4_CTA_SHAPE_MNK[0], + MXF4NVF4_CTA_SHAPE_MNK[2], + MXF4NVF4_SCALE_VEC_SIZE, + 1, + loc=loc, + ip=ip, + ) + sfb_smem_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + MXF4NVF4_CTA_SHAPE_MNK[1], + MXF4NVF4_CTA_SHAPE_MNK[2], + MXF4NVF4_SCALE_VEC_SIZE, + 1, + loc=loc, + ip=ip, + ) + tma_atom_sfa, tma_tensor_sfa = make_mxf4nvf4_sfa_tiled_tma_atom( + gSFA, + smem_layout=sfa_smem_layout, + cta_tiler=scale_cta_tiler, + tiled_mma=tiled_mma, + loc=loc, + ip=ip, + ) + tma_atom_sfb, tma_tensor_sfb = make_mxf4nvf4_sfb_tiled_tma_atom( + gSFB, + smem_layout=sfb_smem_layout, + cta_tiler=scale_cta_tiler, + tiled_mma=tiled_mma, + loc=loc, + ip=ip, + ) + elif scale_smem_format == "physical": + tma_atom_sfa, tma_tensor_sfa = make_mxf4nvf4_sfa_tiled_tma_atom( + gSFA, cta_tiler=scale_cta_tiler, tiled_mma=tiled_mma, loc=loc, ip=ip + ) + tma_atom_sfb, tma_tensor_sfb = make_mxf4nvf4_sfb_tiled_tma_atom( + gSFB, cta_tiler=scale_cta_tiler, tiled_mma=tiled_mma, loc=loc, ip=ip + ) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + if scale_tile_coord_sfa is None: + scale_tile_coord_sfa = scale_tile_coord + if scale_tile_coord_sfb is None: + scale_tile_coord_sfb = scale_tile_coord + if scale_tile_coord_sfa is not None: + tma_tensor_sfa = cute.local_tile( + tma_tensor_sfa, + scale_cta_tiler, + scale_tile_coord_sfa, + loc=loc, + ip=ip, + ) + if scale_tile_coord_sfb is not None: + tma_tensor_sfb = cute.local_tile( + tma_tensor_sfb, + scale_cta_tiler, + scale_tile_coord_sfb, + loc=loc, + ip=ip, + ) + return ( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + ) + + +@dsl_user_op +def make_mxf4nvf4_native_tma_atoms_for_scheduler( + gA: cute.Tensor, + gB: cute.Tensor, + gSFA: cute.Tensor, + gSFB: cute.Tensor, + *, + tiled_mma: Optional[cute.TiledMma] = None, + ab_smem_format: str = "packed", + ab_cta_tiler: cute.Tile = (128, 128, 1), + scale_cta_tiler: cute.Tile = (128, 8, 1, 1), + scale_smem_format: str = "physical", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create unlocalized native TMA atoms for scheduler-driven SM120 tiles.""" + return make_mxf4nvf4_native_tma_atoms( + gA, + gB, + gSFA, + gSFB, + tiled_mma=tiled_mma, + ab_smem_format=ab_smem_format, + ab_cta_tiler=ab_cta_tiler, + ab_tile_coord=None, + ab_tile_coord_a=None, + ab_tile_coord_b=None, + scale_cta_tiler=scale_cta_tiler, + scale_smem_format=scale_smem_format, + scale_tile_coord=None, + scale_tile_coord_sfa=None, + scale_tile_coord_sfb=None, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def partition_mxf4nvf4_native_tma_tensors_for_scheduler( + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + sA: cute.Tensor, + sB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + *, + tile_shape_mnk: cute.Tile = MXF4NVF4_CTA_SHAPE_MNK, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + scale_group_rank_smem: int = 3, + scale_smem_format: str = "physical", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Partition native TMA tensors by scheduler M/N/K/L tile coordinates. + + The returned GMEM partitions are indexed as: + A: ``(None, tile_m, k_tile, tile_l)`` + B: ``(None, tile_n, k_tile, tile_l)`` + SFA: ``(None, tile_m, k_tile % 2, k_tile // 2, tile_l)`` + SFB: ``(None, tile_n, k_tile % 2, k_tile // 2, tile_l)`` + """ + _check_tuple("tile_shape_mnk", tile_shape_mnk, 3) + tile_m, tile_n, tile_k = tile_shape_mnk + _check_default_tile(tile_m, tile_k, sf_vec_size) + _check_default_tile(tile_n, tile_k, sf_vec_size) + scale_k = tile_k // sf_vec_size + gA_mkl = cute.local_tile( + tma_tensor_a, + (tile_m, tile_k, 1), + (None, None, None), + loc=loc, + ip=ip, + ) + gB_nkl = cute.local_tile( + tma_tensor_b, + (tile_n, tile_k, 1), + (None, None, None), + loc=loc, + ip=ip, + ) + if scale_smem_format == "interleaved": + scale_tiles_per_tma = tile_k // (sf_vec_size * 4) + gSFA_mkl = cute.local_tile( + tma_tensor_sfa, + (tile_m, 4, scale_tiles_per_tma, 1), + (None, None, None, None), + loc=loc, + ip=ip, + ) + gSFB_nkl = cute.local_tile( + tma_tensor_sfb, + (tile_n, 4, scale_tiles_per_tma, 1), + (None, None, None, None), + loc=loc, + ip=ip, + ) + scale_group_rank_smem = 3 + scale_group_rank_gmem = 4 + elif scale_smem_format == "physical": + gSFA_mkl = cute.local_tile( + tma_tensor_sfa, + (tile_m, scale_k, 1, 1), + (None, None, None, None), + loc=loc, + ip=ip, + ) + gSFB_nkl = cute.local_tile( + tma_tensor_sfb, + (tile_n, scale_k, 1, 1), + (None, None, None, None), + loc=loc, + ip=ip, + ) + scale_group_rank_gmem = 4 + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sA, 0, 2, loc=loc, ip=ip), + cute.group_modes(gA_mkl, 0, 3, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sB, 0, 2, loc=loc, ip=ip), + cute.group_modes(gB_nkl, 0, 3, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tSFAs, tSFAg = cpasync.tma_partition( + tma_atom_sfa, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sSFA, 0, scale_group_rank_smem, loc=loc, ip=ip), + cute.group_modes(gSFA_mkl, 0, scale_group_rank_gmem, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tSFBs, tSFBg = cpasync.tma_partition( + tma_atom_sfb, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(sSFB, 0, scale_group_rank_smem, loc=loc, ip=ip), + cute.group_modes(gSFB_nkl, 0, scale_group_rank_gmem, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return tAsA, tAgA, tBsB, tBgB, tSFAs, tSFAg, tSFBs, tSFBg + + +@dsl_user_op +def issue_mxf4nvf4_partitioned_native_tma_stage_for_tile( + tma_atom_a: cute.CopyAtom, + tAsA: cute.Tensor, + tAgA: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tBsB: cute.Tensor, + tBgB: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tSFAs: cute.Tensor, + tSFAg: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tSFBs: cute.Tensor, + tSFBg: cute.Tensor, + tma_bar_ptr: cute.Pointer, + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int = 0, + stage_idx: cutlass.Int32 | int = 0, + *, + already_elected: cutlass.Constexpr[bool] = False, + scale_smem_format: str = "physical", + cache_policy=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one scheduler-selected stage from pre-partitioned native TMA tensors.""" + _check_tuple("tile_mnl", tile_mnl, 3) + tile_m, tile_n, tile_l = tile_mnl + scale_k_tile = k_tile % 2 + scale_page = k_tile // 2 + if scale_smem_format == "interleaved": + sfa_coord = (None, tile_m, 0, k_tile, tile_l) + sfb_coord = (None, tile_n, 0, k_tile, tile_l) + elif scale_smem_format == "physical": + sfa_coord = (None, tile_m, scale_k_tile, scale_page, tile_l) + sfb_coord = (None, tile_n, scale_k_tile, scale_page, tile_l) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + if cutlass.const_expr(already_elected): + _issue_native_tma_load_already_elected( + tma_atom_a, + tAgA[(None, tile_m, k_tile, tile_l)], + tAsA[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_b, + tBgB[(None, tile_n, k_tile, tile_l)], + tBsB[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_sfa, + tSFAg[sfa_coord], + tSFAs[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + _issue_native_tma_load_already_elected( + tma_atom_sfb, + tSFBg[sfb_coord], + tSFBs[(None, stage_idx)], + tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + return + cute.copy( + tma_atom_a, + tAgA[(None, tile_m, k_tile, tile_l)], + tAsA[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_b, + tBgB[(None, tile_n, k_tile, tile_l)], + tBsB[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfa, + tSFAg[sfa_coord], + tSFAs[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfb, + tSFBg[sfb_coord], + tSFBs[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_partitioned_native_tma_mk_stage_for_tile( + tma_atom_a: cute.CopyAtom, + tAsA: cute.Tensor, + tAgA: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tSFAs: cute.Tensor, + tSFAg: cute.Tensor, + tma_bar_ptr: cute.Pointer, + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int = 0, + stage_idx: cutlass.Int32 | int = 0, + *, + scale_smem_format: str = "physical", + cache_policy=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue the A/SFA half of one scheduler-selected native TMA stage.""" + _check_tuple("tile_mnl", tile_mnl, 3) + tile_m, _, tile_l = tile_mnl + scale_k_tile = k_tile % 2 + scale_page = k_tile // 2 + if scale_smem_format == "interleaved": + sfa_coord = (None, tile_m, 0, k_tile, tile_l) + elif scale_smem_format == "physical": + sfa_coord = (None, tile_m, scale_k_tile, scale_page, tile_l) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + cute.copy( + tma_atom_a, + tAgA[(None, tile_m, k_tile, tile_l)], + tAsA[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfa, + tSFAg[sfa_coord], + tSFAs[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_partitioned_native_tma_nk_stage_for_tile( + tma_atom_b: cute.CopyAtom, + tBsB: cute.Tensor, + tBgB: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tSFBs: cute.Tensor, + tSFBg: cute.Tensor, + tma_bar_ptr: cute.Pointer, + tile_mnl: tuple[cutlass.Int32 | int, ...], + k_tile: cutlass.Int32 | int = 0, + stage_idx: cutlass.Int32 | int = 0, + *, + scale_smem_format: str = "physical", + cache_policy=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue the B/SFB half of one scheduler-selected native TMA stage.""" + _check_tuple("tile_mnl", tile_mnl, 3) + _, tile_n, tile_l = tile_mnl + scale_k_tile = k_tile % 2 + scale_page = k_tile // 2 + if scale_smem_format == "interleaved": + sfb_coord = (None, tile_n, 0, k_tile, tile_l) + elif scale_smem_format == "physical": + sfb_coord = (None, tile_n, scale_k_tile, scale_page, tile_l) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + cute.copy( + tma_atom_b, + tBgB[(None, tile_n, k_tile, tile_l)], + tBsB[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom_sfb, + tSFBg[sfb_coord], + tSFBs[(None, stage_idx)], + tma_bar_ptr=tma_bar_ptr, + cache_policy=cache_policy, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_ldsm_copy_atom( + *, + transpose: bool = False, + dtype: Type[Numeric] = cutlass.Uint8, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.CopyAtom: + """Create the packed 16-bit LDSM atom used by SM120 MXF4NVF4 A/B loads.""" + return cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + dtype, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_smem_copy_atoms( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.CopyAtom, cute.CopyAtom]: + """Return the non-transposed packed-FP4 A/B LDSM copy atoms.""" + return ( + make_mxf4nvf4_ldsm_copy_atom(transpose=False, dtype=cutlass.Float4E2M1FN, loc=loc, ip=ip), + make_mxf4nvf4_ldsm_copy_atom(transpose=False, dtype=cutlass.Float4E2M1FN, loc=loc, ip=ip), + ) + + +@dsl_user_op +def make_mxf4nvf4_unpack_ldsm_copy_atom( + *, + dtype: Type[Numeric] = cutlass.Uint8, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.CopyAtom: + """Create the 79a-style FP4 unpack-SMEM LDSM atom. + + This mirrors C++ ``SM100_SU4_DU8x16_x4_LDSM_N``: + ``ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64``. + """ + return cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + dtype, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_unpack_smem_copy_atoms( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.CopyAtom, cute.CopyAtom]: + """Return A/B LDSM atoms for FP4 unpack-SMEM consumer tiles.""" + return ( + make_mxf4nvf4_unpack_ldsm_copy_atom(dtype=cutlass.Uint8, loc=loc, ip=ip), + make_mxf4nvf4_unpack_ldsm_copy_atom(dtype=cutlass.Uint8, loc=loc, ip=ip), + ) + + +@dsl_user_op +def make_mxf4nvf4_ab_ldsm_copy_views_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + lane_idx: cutlass.Int32, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Return 79a-style A/B tiled-copy views from consumer SMEM. + + Input tensors must use the consumer SMEM layouts produced by + `make_mxf4nvf4_ab_consumer_smem_views()` or the corresponding layout + helpers. Do not pass raw external-TMA physical SMEM to this helper. + `lane_idx` is the warp-local lane index, not the CTA thread index. + """ + sA_consumer, sB_consumer = make_mxf4nvf4_ab_consumer_microtile_views( + sA_consumer, + sB_consumer, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + copy_atom_a, copy_atom_b = make_mxf4nvf4_ab_smem_copy_atoms(loc=loc, ip=ip) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma, loc=loc, ip=ip) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma, loc=loc, ip=ip) + thr_copy_a = tiled_copy_a.get_slice(lane_idx) + thr_copy_b = tiled_copy_b.get_slice(lane_idx) + sA_src = cute.as_position_independent_swizzle_tensor(sA_consumer, loc=loc, ip=ip) + sB_src = cute.as_position_independent_swizzle_tensor(sB_consumer, loc=loc, ip=ip) + tCsA = thr_copy_a.partition_S(sA_src, loc=loc, ip=ip) + tCsB = thr_copy_b.partition_S(sB_src, loc=loc, ip=ip) + tCrA = thr_copy_a.retile_D(a_frag, loc=loc, ip=ip) + tCrB = thr_copy_b.retile_D(b_frag, loc=loc, ip=ip) + return tiled_copy_a, tCsA, tCrA, tiled_copy_b, tCsB, tCrB + + +@dsl_user_op +def make_mxf4nvf4_ab_unpack_ldsm_copy_views_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + lane_idx: cutlass.Int32, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Return 79a-style unpack-SMEM A/B tiled-copy views. + + Input tensors must be Uint8 direct-consumer SMEM views populated by a + logical-FP4 TMA atom built with ``internal_type=cutlass.Uint8``. The + LDSM source pointer is aligned to 16 bytes because the unpack form loads + 128-bit source rows. + """ + sA_consumer, sB_consumer = make_mxf4nvf4_ab_consumer_microtile_views( + sA_consumer, + sB_consumer, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + copy_atom_a, copy_atom_b = make_mxf4nvf4_ab_unpack_smem_copy_atoms(loc=loc, ip=ip) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma, loc=loc, ip=ip) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma, loc=loc, ip=ip) + thr_copy_a = tiled_copy_a.get_slice(lane_idx) + thr_copy_b = tiled_copy_b.get_slice(lane_idx) + sA_src = cute.as_position_independent_swizzle_tensor(sA_consumer, loc=loc, ip=ip) + sB_src = cute.as_position_independent_swizzle_tensor(sB_consumer, loc=loc, ip=ip) + tCsA = thr_copy_a.partition_S(sA_src, loc=loc, ip=ip) + tCsB = thr_copy_b.partition_S(sB_src, loc=loc, ip=ip) + tCsA = cute.make_tensor(tCsA.iterator.align(16), tCsA.layout, loc=loc, ip=ip) + tCsB = cute.make_tensor(tCsB.iterator.align(16), tCsB.layout, loc=loc, ip=ip) + tCrA = thr_copy_a.retile_D(a_frag, loc=loc, ip=ip) + tCrB = thr_copy_b.retile_D(b_frag, loc=loc, ip=ip) + return tiled_copy_a, tCsA, tCrA, tiled_copy_b, tCsB, tCrB + + +@dsl_user_op +def make_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + lane_idx: cutlass.Int32, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate A/B fragments for one local output atom. + + `lane_idx` is the warp-local lane index. Use `m_atom`/`n_atom` to select + the local 16x8 atom inside the staged 128x128 CTA tile. + """ + sA_consumer, sB_consumer = make_mxf4nvf4_ab_consumer_microtile_views( + sA_consumer, + sB_consumer, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + thread_mma = tiled_mma.get_slice(lane_idx) + tCsA_mma = thread_mma.partition_A(sA_consumer, loc=loc, ip=ip) + tCsB_mma = thread_mma.partition_B(sB_consumer, loc=loc, ip=ip) + return ( + tiled_mma.make_fragment_A(tCsA_mma[None, None, None, 0], loc=loc, ip=ip), + tiled_mma.make_fragment_B(tCsB_mma[None, None, None, 0], loc=loc, ip=ip), + ) + + +@dsl_user_op +def shift_mxf4nvf4_post_ldsm_fp4_fragment( + fragment: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Validate the FP4 post-LDSM transform point for MXF4NVF4 MMA. + + The C++ path applies an explicit nibble shift after its packed LDSM copy. + Python CuTe's typed `Float4E2M1FN` LDSM copy already materializes + MMA-ready fragments, so this hook is intentionally a no-op after dtype + validation. Keeping the hook explicit makes the consumer path match the C++ + mainloop structure without corrupting the Python fragment encoding. + """ + if fragment.element_type is not cutlass.Float4E2M1FN: + raise TypeError( + "SM120 MXF4NVF4 post-LDSM shift expects a Float4E2M1FN fragment, " + f"got {fragment.element_type}" + ) + + +@dsl_user_op +def fp4_shift_mxf4nvf4_a( + fragment: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Apply the A-fragment FP4 post-LDSM shift.""" + shift_mxf4nvf4_post_ldsm_fp4_fragment(fragment, loc=loc, ip=ip) + + +@dsl_user_op +def fp4_shift_mxf4nvf4_b( + fragment: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Apply the B-fragment FP4 post-LDSM shift.""" + shift_mxf4nvf4_post_ldsm_fp4_fragment(fragment, loc=loc, ip=ip) + + +@dsl_user_op +def load_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + lane_idx: cutlass.Int32, + k_block_idx: int, + consumer_stage_idx: int = 0, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one K64 A/B block through the 79a-style consumer copy path. + + `lane_idx` is the warp-local lane index. Use `m_atom`/`n_atom` to select + the local 16x8 atom inside the staged 128x128 CTA tile. + """ + tiled_copy_a, tCsA, tCrA, tiled_copy_b, tCsB, tCrB = ( + make_mxf4nvf4_ab_ldsm_copy_views_from_consumer_smem( + tiled_mma, + sA_consumer, + sB_consumer, + a_frag, + b_frag, + lane_idx, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + ) + tCsA_stage = tCsA[(None, None, None, consumer_stage_idx)] + tCsB_stage = tCsB[(None, None, None, consumer_stage_idx)] + cute.copy( + tiled_copy_a, + tCsA_stage[(None, None, k_block_idx)], + tCrA[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB_stage[(None, None, k_block_idx)], + tCrB[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + fp4_shift_mxf4nvf4_a(tCrA[(None, None, k_block_idx)], loc=loc, ip=ip) + fp4_shift_mxf4nvf4_b(tCrB[(None, None, k_block_idx)], loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_ab_unpack_fragments_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + lane_idx: cutlass.Int32, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Allocate logical FP4 A/B fragments for the unpack-SMEM LDSM path.""" + return make_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma, + sA_consumer, + sB_consumer, + lane_idx, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def load_mxf4nvf4_ab_unpack_fragments_from_consumer_smem( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + lane_idx: cutlass.Int32, + k_block_idx: int, + consumer_stage_idx: int = 0, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one K64 block through the 79a-style unpack-SMEM LDSM path.""" + a_copy_frag = cute.recast_tensor(a_frag, cutlass.Uint8, loc=loc, ip=ip) + b_copy_frag = cute.recast_tensor(b_frag, cutlass.Uint8, loc=loc, ip=ip) + tiled_copy_a, tCsA, tCrA, tiled_copy_b, tCsB, tCrB = ( + make_mxf4nvf4_ab_unpack_ldsm_copy_views_from_consumer_smem( + tiled_mma, + sA_consumer, + sB_consumer, + a_copy_frag, + b_copy_frag, + lane_idx, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + ) + tCsA_stage = tCsA[(None, None, None, consumer_stage_idx)] + tCsB_stage = tCsB[(None, None, None, consumer_stage_idx)] + cute.copy( + tiled_copy_a, + tCsA_stage[(None, None, k_block_idx)], + tCrA[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB_stage[(None, None, k_block_idx)], + tCrB[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + + +def stage_mxf4nvf4_a_tma_physical_to_consumer_smem( + sA_tma_physical: cute.Tensor, + sA_consumer: cute.Tensor, + *, + a_major_tile: cutlass.Int32 = cutlass.Int32(0), + consumer_stage_idx: int = 0, + tile_m: int = 128, + tile_k: int = 128, + lane_idx: Optional[cutlass.Int32] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage one physical TMA A tile into the SM120 consumer SMEM layout. + + `consumer_stage_idx` selects the destination consumer stage. Pass a + physical-stage view as `sA_tma_physical` when staging from a nonzero + physical stage. + """ + _require_zero_major_offset("a_major_tile", a_major_tile) + smem_bytes = mxf4nvf4_ab_physical_smem_bytes(tile_m, tile_k) + tma_bytes = mxf4nvf4_ab_tma_tx_bytes(tile_m, tile_k) + k_bytes = tile_k // 2 + loop_start = lane_idx if lane_idx is not None else 0 + loop_step = 32 if lane_idx is not None else 1 + src = cute.make_tensor(sA_tma_physical.iterator, cute.make_layout(smem_bytes)) + dst = cute.recast_tensor(sA_consumer, cutlass.Uint8, loc=loc, ip=ip) + for i in for_generate(loop_start, tma_bytes, loop_step, loc=loc, ip=ip): + major = i // k_bytes + k_byte = i % k_bytes + payload_byte = major * k_bytes + k_byte + payload_chunk = payload_byte // 8 + payload_byte_in_chunk = payload_byte % 8 + physical_chunk = payload_chunk ^ ((payload_chunk >> 3) & 0x7) + dst[(major, k_byte, consumer_stage_idx)] = src[physical_chunk * 16 + payload_byte_in_chunk] + yield_out() + + +def stage_mxf4nvf4_b_tma_physical_to_consumer_smem( + sB_tma_physical: cute.Tensor, + sB_consumer: cute.Tensor, + *, + b_major_tile: cutlass.Int32 = cutlass.Int32(0), + consumer_stage_idx: int = 0, + tile_n: int = 128, + tile_k: int = 128, + lane_idx: Optional[cutlass.Int32] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage one physical TMA B tile into the SM120 consumer SMEM layout. + + `consumer_stage_idx` selects the destination consumer stage. Pass a + physical-stage view as `sB_tma_physical` when staging from a nonzero + physical stage. + """ + _require_zero_major_offset("b_major_tile", b_major_tile) + smem_bytes = mxf4nvf4_ab_physical_smem_bytes(tile_n, tile_k) + tma_bytes = mxf4nvf4_ab_tma_tx_bytes(tile_n, tile_k) + k_bytes = tile_k // 2 + loop_start = lane_idx if lane_idx is not None else 0 + loop_step = 32 if lane_idx is not None else 1 + src = cute.make_tensor(sB_tma_physical.iterator, cute.make_layout(smem_bytes)) + dst = cute.recast_tensor(sB_consumer, cutlass.Uint8, loc=loc, ip=ip) + for i in for_generate(loop_start, tma_bytes, loop_step, loc=loc, ip=ip): + major = i // k_bytes + k_byte = i % k_bytes + payload_byte = major * k_bytes + k_byte + payload_chunk = payload_byte // 8 + payload_byte_in_chunk = payload_byte % 8 + physical_chunk = payload_chunk ^ ((payload_chunk >> 3) & 0x7) + dst[(major, k_byte, consumer_stage_idx)] = src[physical_chunk * 16 + payload_byte_in_chunk] + yield_out() + + +def stage_mxf4nvf4_ab_tma_physical_to_consumer_smem( + sA_tma_physical: cute.Tensor, + sB_tma_physical: cute.Tensor, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + *, + a_major_tile: cutlass.Int32 = cutlass.Int32(0), + b_major_tile: cutlass.Int32 = cutlass.Int32(0), + consumer_stage_idx: int = 0, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + lane_idx: Optional[cutlass.Int32] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage physical TMA A/B tiles into the SM120 consumer SMEM layouts.""" + stage_mxf4nvf4_a_tma_physical_to_consumer_smem( + sA_tma_physical, + sA_consumer, + a_major_tile=a_major_tile, + consumer_stage_idx=consumer_stage_idx, + tile_m=tile_m, + tile_k=tile_k, + lane_idx=lane_idx, + loc=loc, + ip=ip, + ) + stage_mxf4nvf4_b_tma_physical_to_consumer_smem( + sB_tma_physical, + sB_consumer, + b_major_tile=b_major_tile, + consumer_stage_idx=consumer_stage_idx, + tile_n=tile_n, + tile_k=tile_k, + lane_idx=lane_idx, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_smem_fragment_views( + sSFA: cute.Tensor, + sSFB: cute.Tensor, + k_block_idx: int, + stage_idx: int = 0, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Return SFA/SFB source views for one K64 block from staged scale SMEM.""" + scale_k_offset = k_block_idx * (MXF4NVF4_MMA_SHAPE_MNK[2] // MXF4NVF4_SCALE_VEC_SIZE) + stage_offset = stage_idx * MXF4NVF4_SCALE_TMA_BYTES + sfa_f8 = cute.recast_tensor(sSFA, cutlass.Float8E4M3FN, loc=loc, ip=ip) + sfb_f8 = cute.recast_tensor(sSFB, cutlass.Float8E4M3FN, loc=loc, ip=ip) + sfa_ptr = sfa_f8.iterator + scale_k_offset + stage_offset + sfb_ptr = sfb_f8.iterator + scale_k_offset + stage_offset + return ( + cute.make_tensor(sfa_ptr, warp.make_mxf4nvf4_sfa_layout(loc=loc, ip=ip), loc=loc, ip=ip), + cute.make_tensor(sfb_ptr, warp.make_mxf4nvf4_sfb_layout(loc=loc, ip=ip), loc=loc, ip=ip), + ) + + +def _mxf4nvf4_scale_tma_physical_offset_const( + major: int, + scale_col: int, + major_extent: int, +) -> int: + physical_major_extent = max(major_extent, 128) + payload_idx = scale_col * physical_major_extent + major + payload_chunk = payload_idx // 16 + payload_byte_in_chunk = payload_idx % 16 + physical_chunk = payload_chunk ^ ((payload_chunk >> 3) & 0x7) + return physical_chunk * 16 + payload_byte_in_chunk + + +@dsl_user_op +def make_mxf4nvf4_scale_fragment_views_from_direct_tma( + sSFA: cute.Tensor, + sSFB: cute.Tensor, + k_block_idx: int, + stage_idx: int = 0, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Return SFA/SFB fragment views from compact direct scale TMA storage.""" + _check_default_tile(major_extent_sfa, tile_k, sf_vec_size) + _check_default_tile(major_extent_sfb, tile_k, sf_vec_size) + scale_col = k_block_idx * (MXF4NVF4_MMA_SHAPE_MNK[2] // sf_vec_size) + sfa_stage_offset = stage_idx * mxf4nvf4_scale_physical_smem_bytes( + major_extent_sfa, tile_k, sf_vec_size + ) + sfb_stage_offset = stage_idx * mxf4nvf4_scale_physical_smem_bytes( + major_extent_sfb, tile_k, sf_vec_size + ) + sfa_offset = sfa_stage_offset + _mxf4nvf4_scale_tma_physical_offset_const( + 0, scale_col, major_extent_sfa + ) + sfb_offset = sfb_stage_offset + _mxf4nvf4_scale_tma_physical_offset_const( + 0, scale_col, major_extent_sfb + ) + sfa_f8 = cute.recast_tensor(sSFA, cutlass.Float8E4M3FN, loc=loc, ip=ip) + sfb_f8 = cute.recast_tensor(sSFB, cutlass.Float8E4M3FN, loc=loc, ip=ip) + return ( + cute.make_tensor( + sfa_f8.iterator + sfa_offset, + warp.make_mxf4nvf4_sfa_layout(loc=loc, ip=ip), + loc=loc, + ip=ip, + ), + cute.make_tensor( + sfb_f8.iterator + sfb_offset, + warp.make_mxf4nvf4_sfb_layout(loc=loc, ip=ip), + loc=loc, + ip=ip, + ), + ) + + +make_mxf4nvf4_scale_fragment_views_from_compact_smem = make_mxf4nvf4_scale_smem_fragment_views + + +@dsl_user_op +def mxf4nvf4_scale_tma_physical_offset( + major: cutlass.Int32, + scale_col: cutlass.Int32, + major_extent: int = 128, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cutlass.Int32: + """Return rank-4 scale TMA physical byte offset for one logical scale.""" + _check_positive("major_extent", major_extent) + physical_major_extent = max(major_extent, 128) + payload_idx = scale_col * physical_major_extent + major + chunk = payload_idx // 16 + byte = payload_idx % 16 + phys_chunk = chunk ^ (chunk >> 3) + return phys_chunk * 16 + byte + + +def stage_mxf4nvf4_sfa_tma_physical_to_tiled_smem( + tiled_mma: cute.TiledMma, + sSFA_tma_physical: cute.Tensor, + sSFA_tiled_smem: cute.Tensor, + *, + major_extent: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + lane_idx: Optional[cutlass.Int32] = None, + thread_idx: Optional[cutlass.Int32] = None, + thread_count: int = 32, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage SFA physical TMA SMEM into the SM120 tiled scale SMEM layout.""" + _check_default_tile(major_extent, tile_k, sf_vec_size) + scale_k = tile_k // sf_vec_size + tma_bytes = mxf4nvf4_scale_tma_tx_bytes(major_extent, tile_k, sf_vec_size) + physical_bytes = mxf4nvf4_scale_physical_smem_bytes(major_extent, tile_k, sf_vec_size) + src_u8 = cute.make_tensor( + sSFA_tma_physical.iterator, + cute.make_layout(physical_bytes, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + dst_u8 = cute.make_tensor( + sSFA_tiled_smem.iterator, + make_mxf4nvf4_sfa_smem_layout_staged( + tiled_mma, + (major_extent, tile_n, tile_k), + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + if thread_idx is not None: + loop_start = thread_idx + loop_step = thread_count + else: + loop_start = lane_idx if lane_idx is not None else 0 + loop_step = 32 if lane_idx is not None else 1 + for i in for_generate(loop_start, tma_bytes, loop_step, loc=loc, ip=ip): + local_major = i // scale_k + scale_col = i % scale_k + phys = mxf4nvf4_scale_tma_physical_offset( + local_major, scale_col, major_extent, loc=loc, ip=ip + ) + dst_u8[(local_major, scale_col * sf_vec_size, 0)] = src_u8[phys] + yield_out() + + +def stage_mxf4nvf4_sfb_tma_physical_to_tiled_smem( + tiled_mma: cute.TiledMma, + sSFB_tma_physical: cute.Tensor, + sSFB_tiled_smem: cute.Tensor, + *, + tile_m: int = 128, + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + lane_idx: Optional[cutlass.Int32] = None, + thread_idx: Optional[cutlass.Int32] = None, + thread_count: int = 32, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage SFB physical TMA SMEM into the SM120 tiled scale SMEM layout.""" + _check_default_tile(major_extent, tile_k, sf_vec_size) + scale_k = tile_k // sf_vec_size + tma_bytes = mxf4nvf4_scale_tma_tx_bytes(major_extent, tile_k, sf_vec_size) + physical_bytes = mxf4nvf4_scale_physical_smem_bytes(major_extent, tile_k, sf_vec_size) + src_u8 = cute.make_tensor( + sSFB_tma_physical.iterator, + cute.make_layout(physical_bytes, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + dst_u8 = cute.make_tensor( + sSFB_tiled_smem.iterator, + make_mxf4nvf4_sfb_smem_layout_staged( + tiled_mma, + (tile_m, major_extent, tile_k), + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + if thread_idx is not None: + loop_start = thread_idx + loop_step = thread_count + else: + loop_start = lane_idx if lane_idx is not None else 0 + loop_step = 32 if lane_idx is not None else 1 + for i in for_generate(loop_start, tma_bytes, loop_step, loc=loc, ip=ip): + local_major = i // scale_k + scale_col = i % scale_k + phys = mxf4nvf4_scale_tma_physical_offset( + local_major, scale_col, major_extent, loc=loc, ip=ip + ) + dst_u8[(local_major, scale_col * sf_vec_size, 0)] = src_u8[phys] + yield_out() + + +def stage_mxf4nvf4_scale_tma_physical_to_tiled_smem( + tiled_mma: cute.TiledMma, + sSFA_tma_physical: cute.Tensor, + sSFB_tma_physical: cute.Tensor, + sSFA_tiled_smem: cute.Tensor, + sSFB_tiled_smem: cute.Tensor, + *, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + lane_idx: Optional[cutlass.Int32] = None, + thread_idx: Optional[cutlass.Int32] = None, + thread_count: int = 32, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Stage physical SFA/SFB TMA storage into tiled SM120 scale SMEM.""" + stage_mxf4nvf4_sfa_tma_physical_to_tiled_smem( + tiled_mma, + sSFA_tma_physical, + sSFA_tiled_smem, + major_extent=tile_m, + tile_n=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + lane_idx=lane_idx, + thread_idx=thread_idx, + thread_count=thread_count, + loc=loc, + ip=ip, + ) + stage_mxf4nvf4_sfb_tma_physical_to_tiled_smem( + tiled_mma, + sSFB_tma_physical, + sSFB_tiled_smem, + tile_m=tile_m, + major_extent=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + lane_idx=lane_idx, + thread_idx=thread_idx, + thread_count=thread_count, + loc=loc, + ip=ip, + ) + + +def copy_mxf4nvf4_tiled_smem_scale_fragments( + tiled_mma: cute.TiledMma, + sSFA_tiled_smem: cute.Tensor, + sSFB_tiled_smem: cute.Tensor, + sfa_frag_dst: cute.Tensor, + sfb_frag_dst: cute.Tensor, + tidx: cutlass.Int32, + k_block_idx: int, + stage_idx: int = 0, + *, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Copy one K64 block from tiled scale SMEM into SFA/SFB fragments.""" + _check_default_tile(tile_m, tile_k, sf_vec_size) + _check_default_tile(tile_n, tile_k, sf_vec_size) + stage_stride_sfa = mxf4nvf4_scale_tma_tx_bytes(tile_m, tile_k, sf_vec_size) + stage_stride_sfb = mxf4nvf4_scale_tma_tx_bytes(tile_n, tile_k, sf_vec_size) + scale_copy_tile_k = MXF4NVF4_MMA_SHAPE_MNK[2] + scale_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + tiled_copy_sfa = cute.make_tiled_copy( + scale_copy_atom, + get_layoutSFA_TV(tiled_mma), + (tile_m, scale_copy_tile_k), + loc=loc, + ip=ip, + ) + sSFA_f8 = cute.recast_tensor( + cute.make_tensor( + sSFA_tiled_smem.iterator + stage_idx * stage_stride_sfa, + make_mxf4nvf4_sfa_smem_layout_staged( + tiled_mma, + (tile_m, tile_n, tile_k), + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + sSFB_f8 = cute.recast_tensor( + cute.make_tensor( + sSFB_tiled_smem.iterator + stage_idx * stage_stride_sfb, + make_mxf4nvf4_sfb_smem_layout_staged( + tiled_mma, + (tile_m, tile_n, tile_k), + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(tidx) + tCsSFA = thr_copy_sfa.partition_S( + cute.as_position_independent_swizzle_tensor(sSFA_f8, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + tCrSFA = thr_copy_sfa.retile(sfa_frag_dst, loc=loc, ip=ip) + thr_mma = tiled_mma.get_slice(tidx) + sfb_source_layout = thrfrg_SFB(sSFB_f8[(None, None, 0)].layout, thr_mma) + sfb_source = cute.make_tensor(sSFB_f8.iterator, sfb_source_layout, loc=loc, ip=ip) + thr_vmnk = thr_mma.thr_layout_vmnk.get_flat_coord(tidx) + thr_vnk = (thr_vmnk[0], (thr_vmnk[2], thr_vmnk[3])) + sfb_source = sfb_source[thr_vnk, (None, None)] + sfb_source = cute.group_modes(cute.flatten(sfb_source), 0, 2) + sfb_source = cute.group_modes(sfb_source, 1, 3) + cute.copy( + tiled_copy_sfa, + tCsSFA[(None, None, k_block_idx, 0)], + tCrSFA[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + sfb_source_k = sfb_source[(None, None, k_block_idx)] + sfb_dst_k = sfb_frag_dst[(None, None, k_block_idx)] + sfb_dst_k.store(sfb_source_k.load(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_scale_fragments( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Create SFA and SFB register fragments for bundled SM120 MXF4NVF4 MMA.""" + return ( + warp.make_mxf4nvf4_sfa_fragment(loc=loc, ip=ip), + warp.make_mxf4nvf4_sfb_fragment(loc=loc, ip=ip), + ) + + +@dsl_user_op +def make_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma: cute.TiledMma, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + tidx: cutlass.Int32 | int, + *, + tile_shape_mnk: cute.Tile = MXF4NVF4_CTA_SHAPE_MNK, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Create bundled K128 scale fragments for direct native scale TMA SMEM.""" + _check_tuple("tile_shape_mnk", tile_shape_mnk, 3) + tile_m, tile_n, tile_k = tile_shape_mnk + _check_default_tile(tile_m, tile_k, sf_vec_size) + _check_default_tile(tile_n, tile_k, sf_vec_size) + sSFA_f8 = cute.recast_tensor(sSFA, cutlass.Float8E4M3FN, loc=loc, ip=ip) + sSFB_f8 = cute.recast_tensor(sSFB, cutlass.Float8E4M3FN, loc=loc, ip=ip) + sSFA_logical = cute.make_tensor( + sSFA_f8.iterator, + make_mxf4nvf4_sfa_smem_layout_staged( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + sSFB_logical = cute.make_tensor( + sSFB_f8.iterator, + make_mxf4nvf4_sfb_smem_layout_staged( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + thr_mma = tiled_mma.get_slice(tidx) + return ( + partition_fragment_SFA(sSFA_logical[(None, None, 0)], thr_mma, tidx), + partition_fragment_SFB(sSFB_logical[(None, None, 0)], thr_mma, tidx), + ) + + +@dsl_user_op +def make_mxf4nvf4_direct_tma_scale_fragment_source_views( + tiled_mma: cute.TiledMma, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + tidx: cutlass.Int32 | int, + *, + stage_idx: int = 0, + tile_shape_mnk: cute.Tile = MXF4NVF4_CTA_SHAPE_MNK, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + scale_smem_format: str = "physical", + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[cute.Tensor, cute.Tensor]: + """Return per-thread source views over native physical scale-TMA SMEM.""" + _check_tuple("tile_shape_mnk", tile_shape_mnk, 3) + tile_m, tile_n, tile_k = tile_shape_mnk + _check_default_tile(tile_m, tile_k, sf_vec_size) + _check_default_tile(tile_n, tile_k, sf_vec_size) + if scale_smem_format == "interleaved": + stage_stride_sfa = mxf4nvf4_scale_tma_tx_bytes(tile_m, tile_k, sf_vec_size) + stage_stride_sfb = mxf4nvf4_scale_tma_tx_bytes(tile_n, tile_k, sf_vec_size) + sfa_layout = make_mxf4nvf4_sfa_smem_layout_staged( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ) + sfb_layout = make_mxf4nvf4_sfb_smem_layout_staged( + tiled_mma, + tile_shape_mnk, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ) + elif scale_smem_format == "physical": + stage_stride_sfa = mxf4nvf4_scale_physical_smem_bytes(tile_m, tile_k, sf_vec_size) + stage_stride_sfb = mxf4nvf4_scale_physical_smem_bytes(tile_n, tile_k, sf_vec_size) + sfa_layout = make_mxf4nvf4_scale_tma_physical_as_tiled_smem_layout_staged( + tile_m, + tile_k, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ) + sfb_layout = make_mxf4nvf4_scale_tma_physical_as_tiled_smem_layout_staged( + tile_n, + tile_k, + sf_vec_size, + 1, + loc=loc, + ip=ip, + ) + else: + raise ValueError("scale_smem_format must be either 'physical' or 'interleaved'") + sSFA_f8 = cute.recast_tensor( + cute.make_tensor( + sSFA.iterator + stage_idx * stage_stride_sfa, + sfa_layout, + loc=loc, + ip=ip, + ), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + sSFB_f8 = cute.recast_tensor( + cute.make_tensor( + sSFB.iterator + stage_idx * stage_stride_sfb, + sfb_layout, + loc=loc, + ip=ip, + ), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + sSFA_src = cute.as_position_independent_swizzle_tensor(sSFA_f8, loc=loc, ip=ip) + sSFB_src = cute.as_position_independent_swizzle_tensor(sSFB_f8, loc=loc, ip=ip) + thr_mma = tiled_mma.get_slice(tidx) + thr_vmnk = thr_mma.thr_layout_vmnk.get_flat_coord(tidx) + + sfa_source_layout = thrfrg_SFA(sSFA_src[(None, None, 0)].layout, thr_mma) + sfa_source = cute.make_tensor( + sSFA_src.iterator, + sfa_source_layout, + loc=loc, + ip=ip, + ) + thr_vmk = (thr_vmnk[0], (thr_vmnk[1], thr_vmnk[3])) + sfa_source = sfa_source[thr_vmk, (None, None)] + sfa_source = cute.group_modes(cute.flatten(sfa_source), 0, 2) + + sfb_source_layout = thrfrg_SFB(sSFB_src[(None, None, 0)].layout, thr_mma) + sfb_source = cute.make_tensor( + sSFB_src.iterator, + sfb_source_layout, + loc=loc, + ip=ip, + ) + thr_vnk = (thr_vmnk[0], (thr_vmnk[2], thr_vmnk[3])) + sfb_source = sfb_source[thr_vnk, (None, None)] + sfb_source = cute.group_modes(cute.flatten(sfb_source), 0, 2) + sfb_source = cute.group_modes(sfb_source, 1, 3) + return sfa_source, sfb_source + + +@dsl_user_op +def load_mxf4nvf4_sfa_fragment( + src: cute.Tensor, + dst: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load a prepartitioned SFA scale view into its register fragment.""" + dst.store(src.load(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def load_mxf4nvf4_sfb_fragment( + src: cute.Tensor, + dst: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load a prepartitioned SFB scale view into its register fragment.""" + dst.store(src.load(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def issue_mxf4nvf4_native_tma_consumer_group( + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + acc: cute.Tensor, + lane_idx: cutlass.Int32, + stage_idx: cutlass.Int32 | int = 0, + *, + m_atom: cutlass.Int32 | int = 0, + n_atom: cutlass.Int32 | int = 0, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + ab_smem_format: str = "packed", + sync_between_k_blocks: bool = True, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one K128 consumer group from native-TMA staged SMEM. + + This is the compact descriptor-free path for a single local SM120 + MXF4/NVFP4 warp-MMA output atom. It loads both K64 A/B halves from the + consumer SMEM layout, pairs each half with the matching direct scale TMA + fragment views, and accumulates into ``acc``. + """ + ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + if ab_smem_format != "packed": + raise ValueError( + "SM120 native TMA consumer group currently supports only " + "ab_smem_format='packed'; use the lower-level unpack LDSM helpers " + "for unpack-SMEM experiments" + ) + a_frag, b_frag = make_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma, + sA_consumer, + sB_consumer, + lane_idx, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + sfa, sfb = make_mxf4nvf4_scale_fragments(loc=loc, ip=ip) + for k_block_idx in range(2): + load_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma, + sA_consumer, + sB_consumer, + a_frag, + b_frag, + lane_idx, + k_block_idx, + consumer_stage_idx=stage_idx, + m_atom=m_atom, + n_atom=n_atom, + loc=loc, + ip=ip, + ) + sfa_src, sfb_src = make_mxf4nvf4_scale_fragment_views_from_direct_tma( + sSFA, + sSFB, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + load_mxf4nvf4_sfa_fragment(sfa_src, sfa, loc=loc, ip=ip) + load_mxf4nvf4_sfb_fragment(sfb_src, sfb, loc=loc, ip=ip) + cute.gemm( + tiled_mma, + acc, + (a_frag[(None, 0, k_block_idx)], sfa), + (b_frag[(None, 0, k_block_idx)], sfb), + acc, + loc=loc, + ip=ip, + ) + if const_expr(sync_between_k_blocks): + cute.arch.sync_threads() + + +@dsl_user_op +def copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma: cute.TiledMma, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag_dst: cute.Tensor, + sfb_frag_dst: cute.Tensor, + tidx: cutlass.Int32, + k_block_idx: int, + stage_idx: int = 0, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + scale_smem_format: str = "physical", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Copy one K64 block from native scale-TMA SMEM into MMA fragments.""" + _check_default_tile(major_extent_sfa, tile_k, sf_vec_size) + _check_default_tile(major_extent_sfb, tile_k, sf_vec_size) + if const_expr(major_extent_sfa != major_extent_sfb): + raise ValueError("direct scale fragment copy currently requires square CTA tiles") + + sfa_source, sfb_source = make_mxf4nvf4_direct_tma_scale_fragment_source_views( + tiled_mma, + sSFA, + sSFB, + tidx, + stage_idx=stage_idx, + tile_shape_mnk=(major_extent_sfa, major_extent_sfb, tile_k), + sf_vec_size=sf_vec_size, + scale_smem_format=scale_smem_format, + loc=loc, + ip=ip, + ) + sfa_source_k = sfa_source[(None, None, k_block_idx)] + sfb_source_k = sfb_source[(None, None, k_block_idx)] + sfa_dst_k = sfa_frag_dst[(None, None, k_block_idx)] + sfb_dst_k = sfb_frag_dst[(None, None, k_block_idx)] + sfa_src_compact = cute.filter_zeros(sfa_source_k, loc=loc, ip=ip) + sfb_src_compact = cute.filter_zeros(sfb_source_k, loc=loc, ip=ip) + sfa_dst_compact = cute.filter_zeros(sfa_dst_k, loc=loc, ip=ip) + sfb_dst_compact = cute.filter_zeros(sfb_dst_k, loc=loc, ip=ip) + sfa_dst_compact.store(sfa_src_compact.load(loc=loc, ip=ip), loc=loc, ip=ip) + sfb_dst_compact.store(sfb_src_compact.load(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def load_mxf4nvf4_direct_tma_k_block_fragments( + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tiled_copy_b: cute.TiledCopy, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + tidx: cutlass.Int32, + k_block_idx: int, + stage_idx: int = 0, + *, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + scale_first: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one SM120 direct-TMA K64 A/B/scale block into MMA fragments.""" + if const_expr(scale_first): + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, k_block_idx, stage_idx)], + tCrA[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, k_block_idx, stage_idx)], + tCrB[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + if const_expr(not scale_first): + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def load_mxf4nvf4_direct_tma_k_block_a_scale_fragments( + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tCsA: cute.Tensor, + tCrA: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + tidx: cutlass.Int32, + k_block_idx: int, + stage_idx: int = 0, + *, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + scale_first: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one K64 A fragment and matching direct-TMA scale fragments.""" + if const_expr(scale_first): + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, k_block_idx, stage_idx)], + tCrA[(None, None, k_block_idx)], + loc=loc, + ip=ip, + ) + if const_expr(not scale_first): + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def load_mxf4nvf4_direct_tma_k_block_b_group_fragment( + tiled_copy_b: cute.TiledCopy, + tCsB: cute.Tensor, + tCrB: cute.Tensor, + k_block_idx: int, + b_group_idx: int, + stage_idx: int = 0, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Load one B LDSM group for a K64 block.""" + cute.copy( + tiled_copy_b, + tCsB[(None, b_group_idx, k_block_idx, stage_idx)], + tCrB[(None, b_group_idx, k_block_idx)], + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_direct_tma_eager_consumer_group( + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tiled_copy_b: cute.TiledCopy, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + acc: cute.Tensor, + tidx: cutlass.Int32, + stage_idx: cutlass.Int32, + *, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one staged consumer group with eager K-block fragment loads. + + The helper loads both K64 halves from staged A/B consumer SMEM, copies the + matching direct-TMA scale fragments, and issues the two bundled warp MMAs + that cover one K128 CTA stage. This is retained as a comparison path; the + primary helper uses the 79a-style copy-next / compute-current schedule. + """ + cute.copy( + tiled_copy_a, + tCsA[(None, None, 0, stage_idx)], + tCrA[(None, None, 0)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 0, stage_idx)], + tCrB[(None, None, 0)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, 1, stage_idx)], + tCrA[(None, None, 1)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 1, stage_idx)], + tCrB[(None, None, 1)], + loc=loc, + ip=ip, + ) + for k_block_idx in range(2): + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + cute.gemm( + tiled_mma, + acc, + ( + a_frag[(None, None, k_block_idx)], + sfa_frag[(None, None, k_block_idx)], + ), + ( + b_frag[(None, None, k_block_idx)], + sfb_frag[(None, None, k_block_idx)], + ), + acc, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_direct_tma_consumer_group( + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tiled_copy_b: cute.TiledCopy, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + acc: cute.Tensor, + tidx: cutlass.Int32, + stage_idx: cutlass.Int32, + *, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one direct-TMA consumer group with a 79a-style K-block schedule. + + The first K64 block is loaded before compute starts. The second K64 block is + then loaded before issuing the first MMA group, matching the copy-next / + compute-current shape used by the C++ SM120 blockscaled mainloop. + """ + cute.copy( + tiled_copy_a, + tCsA[(None, None, 0, stage_idx)], + tCrA[(None, None, 0)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 0, stage_idx)], + tCrB[(None, None, 0)], + loc=loc, + ip=ip, + ) + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, 1, stage_idx)], + tCrA[(None, None, 1)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 1, stage_idx)], + tCrB[(None, None, 1)], + loc=loc, + ip=ip, + ) + copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 1, + stage_idx=stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + gemm_mxf4nvf4_direct_tma_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + 0, + loc=loc, + ip=ip, + ) + gemm_mxf4nvf4_direct_tma_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + 1, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def gemm_mxf4nvf4_direct_tma_k_block( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + k_block_idx: int, + *, + ab_smem_format: str = "packed", + n_major: bool = False, + sync_warp_before: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue one SM120 MXF4/NVFP4 K64 bundled-MMA block. + + This helper is the composable per-K-block compute primitive used by the + higher-level SM120 direct-TMA schedules. It keeps the logical K128 fragment + contract intact while letting callers choose the local MMA traversal order + independently from TMA staging. + """ + ab_smem_format = _normalize_mxf4nvf4_ab_smem_format(ab_smem_format) + if ab_smem_format == "unpack": + a_frag = cute.recast_tensor(a_frag, cutlass.Int8, loc=loc, ip=ip) + b_frag = cute.recast_tensor(b_frag, cutlass.Int8, loc=loc, ip=ip) + if const_expr(sync_warp_before): + cute.arch.sync_warp() + if const_expr(n_major): + a_block = a_frag[(None, None, k_block_idx)] + b_block = b_frag[(None, None, k_block_idx)] + sfa_block = sfa_frag[(None, None, k_block_idx)] + sfb_block = sfb_frag[(None, None, k_block_idx)] + a_tile_size = 16 if const_expr(ab_smem_format == "unpack") else 32 + b_tile_size = 8 if const_expr(ab_smem_format == "unpack") else 16 + a_tiles = cute.size(a_block) // a_tile_size + b_tiles = cute.size(b_block) // b_tile_size + for n_idx in range(b_tiles): + for m_idx in range(a_tiles): + warp.mma_mxf4nvf4( + tiled_mma, + acc[(None, m_idx, n_idx)], + (a_block[(None, m_idx)], sfa_block[(None, m_idx)]), + (b_block[(None, n_idx)], sfb_block[(None, n_idx)]), + acc[(None, m_idx, n_idx)], + loc=loc, + ip=ip, + ) + else: + cute.gemm( + tiled_mma, + acc, + ( + a_frag[(None, None, k_block_idx)], + sfa_frag[(None, None, k_block_idx)], + ), + ( + b_frag[(None, None, k_block_idx)], + sfb_frag[(None, None, k_block_idx)], + ), + acc, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def gemm_mxf4nvf4_direct_tma_k_block_b_group( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + k_block_idx: int, + b_group_idx: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue the two N-major MMA columns covered by one packed B load group.""" + a_block = a_frag[(None, None, k_block_idx)] + b_block = b_frag[(None, None, k_block_idx)] + sfa_block = sfa_frag[(None, None, k_block_idx)] + sfb_block = sfb_frag[(None, None, k_block_idx)] + a_tiles = cute.size(a_block) // 32 + n_start = b_group_idx * 2 + for n_offset in range(2): + n_idx = n_start + n_offset + for m_idx in range(a_tiles): + warp.mma_mxf4nvf4( + tiled_mma, + acc[(None, m_idx, n_idx)], + (a_block[(None, m_idx)], sfa_block[(None, m_idx)]), + (b_block[(None, n_idx)], sfb_block[(None, n_idx)]), + acc[(None, m_idx, n_idx)], + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def issue_mxf4nvf4_direct_tma_pingpong_consumer_group( + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tiled_copy_b: cute.TiledCopy, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + acc: cute.Tensor, + tidx: cutlass.Int32, + stage_idx: cutlass.Int32, + *, + major_extent_sfa: int = 128, + major_extent_sfb: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Compatibility alias for the primary pingpong consumer group helper.""" + issue_mxf4nvf4_direct_tma_consumer_group( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + a_frag, + b_frag, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + acc, + tidx, + stage_idx, + major_extent_sfa=major_extent_sfa, + major_extent_sfb=major_extent_sfb, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + loc=loc, + ip=ip, + ) + + +__all__: tuple[str, ...] = () diff --git a/quack/blockscaled_gemm_utils.py b/quack/blockscaled_gemm_utils.py index 479c78ff..58b15a44 100644 --- a/quack/blockscaled_gemm_utils.py +++ b/quack/blockscaled_gemm_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2026, Tri Dao. import itertools +import os from functools import partial from typing import Callable, Optional, Type, Tuple @@ -11,7 +12,13 @@ from quack.compile_utils import make_fake_tensor as fake_tensor from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters -from quack.gemm_default_epi import GemmDefaultSm100 +from quack.gemm_default_epi import GemmDefaultSm100, GemmDefaultSm120 +from quack.gemm_sm120 import GemmSm120 +from quack.sm120_blockscaled_utils import ( + validate_sm120_nvfp4_ab_storage, + validate_sm120_nvfp4_d_storage, + validate_sm120_nvfp4_scale_storage, +) from quack.gemm_tvm_ffi_utils import div_for_dtype, make_scheduler_args from quack.mx_utils import ( to_mx_compiled, @@ -89,7 +96,10 @@ def _leading_dim_from_stride(tensor: torch.Tensor) -> int: def _make_compile_tensor_like( tensor: torch.Tensor, dtype: Type[cutlass.Numeric], dynamic_layout: bool = False ) -> cute.Tensor: - compile_tensor = cute.runtime.from_dlpack(tensor) + compile_tensor = cute.runtime.from_dlpack( + tensor, + enable_tvm_ffi=dtype is not cutlass.Float4E2M1FN, + ) compile_tensor.element_type = dtype if dynamic_layout: marked = compile_tensor.mark_layout_dynamic(leading_dim=_leading_dim_from_stride(tensor)) @@ -592,6 +602,162 @@ def create_blockscaled_varlen_k_operands( return a_ref_list, b_ref_list, qa, qb, a_sc_contig, b_sc_contig, cu_seqlens_k +def _compile_sm120_nvfp4_blockscaled_gemm_tvm_ffi( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + mA: torch.Tensor, + mB: torch.Tensor, + mD: torch.Tensor, + mSFA: torch.Tensor, + mSFB: torch.Tensor, + *, + varlen_m: bool = False, + varlen_k: bool = False, + keep_ptx: bool = False, + sm120_nvfp4_path: str = "validated", +) -> Callable: + if sm120_nvfp4_path not in ("validated", "fast"): + raise ValueError("SM120 NVFP4 path must be 'validated' or 'fast'") + if varlen_m or varlen_k: + raise ValueError("SM120 NVFP4 blockscaled does not support varlen") + if ab_dtype is not cutlass.Float4E2M1FN: + raise ValueError("SM120 NVFP4 blockscaled requires Float4E2M1FN A/B") + if sf_dtype is not cutlass.Float8E4M3FN: + raise ValueError("SM120 NVFP4 blockscaled requires Float8E4M3FN scales") + if sf_vec_size != 16: + raise ValueError("SM120 NVFP4 blockscaled requires sf_vec_size=16") + if d_dtype is not cutlass.BFloat16: + raise ValueError("SM120 NVFP4 blockscaled requires BFloat16 output") + tile_m, tile_n = tuple(mma_tiler_mn) + tile_k = 128 + if (tile_m, tile_n) != (128, 128): + raise ValueError("SM120 NVFP4 blockscaled currently requires mma_tiler_mn=(128,128)") + if tuple(cluster_shape_mn) != (1, 1): + raise ValueError("SM120 NVFP4 blockscaled requires cluster_shape_mn=(1,1)") + + m, packed_k, l = mA.shape + n, packed_k_b, l_b = mB.shape + if packed_k != packed_k_b or l != l_b: + raise ValueError("A/B packed K and batch dimensions must match") + k = packed_k * 2 + validate_sm120_nvfp4_d_storage(mD, m=m, n=n, l=l) + if not GemmSm120.can_implement_blockscaled( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + (tile_m, tile_n, tile_k), + cluster_shape_mn, + m, + n, + k, + l, + "k", + "k", + "n", + ): + raise ValueError( + f"unsupported SM120 NVFP4 config: m={m}, n={n}, k={k}, l={l}, " + f"tiler={mma_tiler_mn}, cluster={cluster_shape_mn}" + ) + validate_sm120_nvfp4_ab_storage(mA, logical_k=k, major_extent=m, batch_extent=l) + validate_sm120_nvfp4_ab_storage(mB, logical_k=k, major_extent=n, batch_extent=l) + validate_sm120_nvfp4_scale_storage(mSFA, logical_k=k, major_extent=m, batch_extent=l) + validate_sm120_nvfp4_scale_storage(mSFB, logical_k=k, major_extent=n, batch_extent=l) + + gemm = GemmDefaultSm120( + cutlass.Float32, + ab_dtype, + (tile_m, tile_n, tile_k), + (1, 1, 1), + pingpong=True, + use_pdl=True, + sf_vec_size=sf_vec_size, + sf_dtype=sf_dtype, + sm120_nvfp4_path=sm120_nvfp4_path, + ) + gemm.max_active_clusters = get_max_active_clusters( + 1, device_capacity=get_device_capacity(mA.device) + ) + compile_epi_args = gemm.EpilogueArguments() + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + @cute.jit + def runner( + a: cute.Tensor, + b: cute.Tensor, + d: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + problem_m: cutlass.Constexpr[int], + problem_n: cutlass.Constexpr[int], + problem_k: cutlass.Constexpr[int], + problem_l: cutlass.Constexpr[int], + stream, + ): + gemm.blockscaled_call( + a, + b, + d, + sfa, + sfb, + problem_m, + problem_n, + problem_k, + problem_l, + compile_epi_args, + stream, + ) + + gpu_arch = os.environ.get("GPU_ARCH") or os.environ.get("CUTE_DSL_ARCH", "sm_120a") + options = f"--gpu-arch={gpu_arch} --enable-tvm-ffi" + if keep_ptx: + options += " --keep-ptx" + + compiled = cute.compile( + runner, + _make_compile_tensor_like(mA, ab_dtype), + _make_compile_tensor_like(mB, ab_dtype), + _make_compile_tensor_like(mD, d_dtype), + _make_compile_tensor_like(mSFA, sf_dtype), + _make_compile_tensor_like(mSFB, sf_dtype), + m, + n, + k, + l, + stream, + options=options, + ) + compile_device = mA.device + + def run(a, b, d, sfa, sfb): + for tensor_name, tensor in ( + ("A", a), + ("B", b), + ("D", d), + ("SFA", sfa), + ("SFB", sfb), + ): + if tensor.device != compile_device: + raise ValueError( + f"SM120 NVFP4 {tensor_name} tensor must be on {compile_device}, " + f"got {tensor.device}" + ) + validate_sm120_nvfp4_ab_storage(a, logical_k=k, major_extent=m, batch_extent=l) + validate_sm120_nvfp4_ab_storage(b, logical_k=k, major_extent=n, batch_extent=l) + validate_sm120_nvfp4_d_storage(d, m=m, n=n, l=l) + validate_sm120_nvfp4_scale_storage(sfa, logical_k=k, major_extent=m, batch_extent=l) + validate_sm120_nvfp4_scale_storage(sfb, logical_k=k, major_extent=n, batch_extent=l) + compiled(a, b, d, sfa, sfb) + + run.compiled = compiled + return run + + def compile_blockscaled_gemm_tvm_ffi( ab_dtype: Type[cutlass.Numeric], sf_dtype: Type[cutlass.Numeric], @@ -608,8 +774,10 @@ def compile_blockscaled_gemm_tvm_ffi( use_clc_persistence: bool = True, varlen_m: bool = False, varlen_k: bool = False, + keep_ptx: bool = False, + sm120_nvfp4_path: str = "validated", ) -> Callable: - """Compile the SM100 blockscaled GEMM. + """Compile the blockscaled GEMM. When varlen_m: mA is (total_m, k) K-major, mD is (total_m, n) N-major, mB is (n, k, l); run(...) takes an extra cu_seqlens_m tensor. @@ -617,8 +785,46 @@ def compile_blockscaled_gemm_tvm_ffi( run(...) takes an extra cu_seqlens_k tensor. """ device_capacity = get_device_capacity(mA.device) + is_sm120_nvfp4 = ( + device_capacity[0] == 12 + and ab_dtype is cutlass.Float4E2M1FN + and sf_dtype is cutlass.Float8E4M3FN + and sf_vec_size == 16 + and d_dtype is cutlass.BFloat16 + and tuple(mma_tiler_mn) == (128, 128) + and tuple(cluster_shape_mn) == (1, 1) + and not varlen_m + and not varlen_k + ) + if is_sm120_nvfp4: + return _compile_sm120_nvfp4_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + mA, + mB, + mD, + mSFA, + mSFB, + varlen_m=varlen_m, + varlen_k=varlen_k, + keep_ptx=keep_ptx, + sm120_nvfp4_path=sm120_nvfp4_path, + ) + if sm120_nvfp4_path != "validated": + raise ValueError("sm120_nvfp4_path applies only to the SM120 NVFP4 blockscaled path") + if device_capacity[0] == 12: + raise RuntimeError( + "SM120 blockscaled GEMM currently supports only NVFP4 " + "(Float4E2M1FN A/B, Float8E4M3FN scales, sf_vec_size=16, " + "BFloat16 output, mma_tiler_mn=(128,128), " + "cluster_shape_mn=(1,1), no varlen)" + ) if device_capacity[0] not in (10, 11): - raise RuntimeError("Blockscaled SM100 GEMM requires SM100/SM110") + raise RuntimeError("Blockscaled GEMM requires SM100/SM110 or SM120") assert not (varlen_m and varlen_k), "Only one of varlen_m / varlen_k" gemm = partial( diff --git a/quack/gemm_sm120.py b/quack/gemm_sm120.py index 738b17b3..6e9804ef 100644 --- a/quack/gemm_sm120.py +++ b/quack/gemm_sm120.py @@ -8,17 +8,27 @@ # This is a work in progress and not very optimized. import math -from typing import Tuple, Type, Callable, Optional +from typing import Tuple, Type, Callable, Optional, Union from functools import partial +import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.pipeline as pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.cute.nvgpu import cpasync, warp from cutlass import Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from quack import _sm120_nvfp4_utils as _sm120 from quack.varlen_utils import VarlenManager +from quack.sm120_pipeline import PipelineTmaWarpMma +from quack.tile_scheduler import ( + PersistenceMode, + RasterOrderOption, + TileScheduler, + TileSchedulerArguments, +) from quack.pipeline import make_pipeline_state from quack import copy_utils from quack.gemm_sm90 import GemmSm90, NamedBarrierGemm @@ -49,9 +59,32 @@ def __init__( gather_A: bool = False, concat_layout: tuple | None = None, use_pdl: bool = True, + sf_vec_size: Optional[int] = None, + sf_dtype: Optional[Type[cutlass.Numeric]] = None, + sm120_nvfp4_path: str = "validated", ): # Don't call super().__init__ — we set up our own config self.acc_dtype = acc_dtype + self.sf_vec_size = sf_vec_size + self.sf_dtype = sf_dtype + self.blockscaled = sf_vec_size is not None + if sm120_nvfp4_path not in ("validated", "fast"): + raise ValueError("SM120 NVFP4 path must be 'validated' or 'fast'") + if not self.blockscaled and sm120_nvfp4_path != "validated": + raise ValueError("SM120 NVFP4 fast path requires blockscaled NVFP4") + self.sm120_nvfp4_path = sm120_nvfp4_path + if self.blockscaled: + self._validate_blockscaled_nvfp4_config( + acc_dtype, + a_dtype, + tile_shape_mnk, + cluster_shape_mnk, + sf_vec_size, + sf_dtype, + pingpong, + is_persistent, + gather_A, + ) self.pingpong = pingpong self.is_persistent = is_persistent self.use_clc_persistence = False @@ -72,18 +105,36 @@ def __init__( ) tile_M, tile_N = self.cta_tile_shape_mnk[:2] - # Pingpong: 2 warp groups each with (2,2,1) atom layout - # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout self.mma_inst_mnk = (16, 8, 16) - self.atom_layout_mnk = (4, 2, 1) if not self.pingpong else (2, 2, 1) - # num_mma_warps = total warps doing MMA (both warp groups in pingpong) - self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) - # For compatibility with SM90 code that uses warp groups + if self.blockscaled: + if self.pingpong: + self.atom_layout_mnk = (2, 2, 1) + self.num_mma_warps = 8 + else: + self.atom_layout_mnk = ( + (4, 2, 1) + if tile_M == 128 and tile_N == 128 + else (tile_M // self.mma_inst_mnk[0], 1, 1) + ) + self.num_mma_warps = tile_M // self.mma_inst_mnk[0] + else: + # Pingpong: 2 warp groups each with (2,2,1) atom layout + # Non-pingpong: 1 group of 8 warps with (4,2,1) atom layout + self.atom_layout_mnk = (4, 2, 1) if not self.pingpong else (2, 2, 1) + # num_mma_warps = total warps doing MMA (both warp groups in pingpong) + self.num_mma_warps = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) + # Keep the warp-group-sized thread count used by SM120 scheduling helpers. self.num_threads_per_warp_group = 128 assert self.num_mma_warps % 4 == 0 self.mma_warp_groups = self.num_mma_warps // 4 if self.pingpong: assert self.mma_warp_groups == 2 + direct_128_pingpong = ( + self.blockscaled and self.pingpong and self.cta_tile_shape_mnk[:2] == (128, 128) + ) + self.blockscaled_pingpong_split_tiles = direct_128_pingpong + self.blockscaled_pingpong_full_tma_pipeline = False + self.blockscaled_pingpong_elected_tma = False # threads_per_cta must be a multiple of 128 (warp group size) so that # the DMA warp's setmaxnreg.dec.sync has a complete warp group to sync with. self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group @@ -99,13 +150,24 @@ def __init__( self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}") # In pingpong, only 1 warp group (4 warps) participates in epilogue at a time - self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4 + split_epi_by_warpgroup = self.pingpong and ( + not self.blockscaled + or self.blockscaled_pingpong_split_tiles + or self.blockscaled_pingpong_full_tma_pipeline + or self.blockscaled_pingpong_elected_tma + ) + self.num_epi_warps = (self.mma_warp_groups if not split_epi_by_warpgroup else 1) * 4 self.epilogue_barrier = pipeline.NamedBarrier( barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, ) self.num_ab_load_warps = 1 if not self.gather_A else 4 self.ab_load_warp_id = self.num_mma_warps + self.mma_warp_id_start = 0 + if self.blockscaled and self.pingpong: + self.ab_load_warp_id = 0 + self.mma_warp_id_start = 4 + self.mma_warp_id_end = self.mma_warp_id_start + self.num_mma_warps if not self.gather_A: self.num_regs_load = 40 @@ -123,6 +185,175 @@ def __init__( self.epi_tile = None self.shared_storage = None self.buffer_align_bytes = 1024 + self.max_active_clusters = 1 + self.direct_producer_unroll = 1 + self.direct_consumer_unroll = 1 + self.direct_consumer_barrier = True + self.direct_consumer_fence = True + self.direct_consumer_warp_sync = False + self.direct_release_before_sync = False + self.direct_kblock_pipeline = False + self.direct_kblock_barrier = False + self.direct_ab_tma_layout = "packed" + self.direct_unpack_shift = False + direct_128_default = self.blockscaled and self.cta_tile_shape_mnk[:2] == (128, 128) + self.direct_pre_mma_warp_sync = False + ab_stage_override = 0 + self.direct_elected_tma = direct_128_default and self.blockscaled_pingpong_elected_tma + self.direct_setmaxregister = True + self.direct_cute_dsl_helpers = False + self.direct_global_store = sm120_nvfp4_path == "validated" + self.direct_global_store_probe = False + self.direct_tile_scheduler = direct_128_default + self.direct_cute_static_scheduler = sm120_nvfp4_path == "validated" + self.direct_pipelined_consumer = direct_128_default + self.direct_split_tma_pipelines = direct_128_default and not self.direct_elected_tma + self.direct_full_tma_pipeline = False + self.direct_single_tma_pipeline = False + self.direct_join_split_tma_barrier = direct_128_default and self.direct_split_tma_pipelines + if self.direct_split_tma_pipelines: + if not self.direct_tile_scheduler or not self.direct_pipelined_consumer: + raise ValueError("SM120 NVFP4 split TMA requires scheduler and consumer pipelines") + self.num_ab_load_warps = 3 + elif self.direct_join_split_tma_barrier: + raise ValueError("SM120 NVFP4 joined split barrier requires split TMA pipelines") + if ( + self.blockscaled + and self.pingpong + and not ( + self.direct_split_tma_pipelines + or self.direct_full_tma_pipeline + or self.direct_single_tma_pipeline + ) + ): + raise ValueError("SM120 NVFP4 blockscaled pingpong requires split TMA pipelines") + self.direct_skip_split_tma_tail = False + if self.direct_skip_split_tma_tail and not self.direct_split_tma_pipelines: + raise ValueError("SM120 NVFP4 skip split TMA tail requires split TMA pipelines") + self.direct_scheduler_local_tma = False + self.direct_sched_exclude_producer = False + self.direct_skip_scheduler_tail = False + self.direct_pingpong_barriers = True + self.direct_try_wait_before_pingpong_barrier = ( + direct_128_default + and self.direct_split_tma_pipelines + and self.blockscaled_pingpong_split_tiles + ) + self.direct_pingpong_split_tiles = ( + self.blockscaled_pingpong_split_tiles and not self.direct_elected_tma + ) + if self.pingpong and self.direct_pingpong_split_tiles and not self.direct_pingpong_barriers: + raise ValueError("SM120 NVFP4 pingpong split path requires pingpong barriers") + if self.direct_try_wait_before_pingpong_barrier and not ( + self.direct_split_tma_pipelines + and self.direct_pingpong_split_tiles + and self.direct_pingpong_barriers + ): + raise ValueError("SM120 NVFP4 try-wait path requires split-TMA pingpong barriers") + self.direct_scale_prefetch_first = False + self.direct_scale_smem_format = "interleaved" + direct_bgroup_pipeline_requested = False + self.direct_epi_barrier_trim = direct_128_default + if self.direct_epi_barrier_trim and not ( + self.direct_split_tma_pipelines or self.direct_full_tma_pipeline + ): + raise ValueError("SM120 NVFP4 epilogue barrier trim requires split/full TMA") + self.direct_mma_n_major = direct_128_default and self.pingpong + self.direct_bgroup_pipeline = ( + direct_128_default + and direct_bgroup_pipeline_requested + and self.direct_full_tma_pipeline + and self.direct_ab_tma_layout == "packed" + and self.direct_mma_n_major + ) + self.direct_fragment_contract = "shape" + self.direct_tma_scale_first = False + self.direct_tma_prefetch = False + self.direct_skip_tma_acquire = False + self.direct_tma_policy = "zero" + self.direct_epi_stsm_matrices = 2 + if self.direct_epi_stsm_matrices != 2: + raise ValueError("SM120 NVFP4 epilogue STSM matrices must be 2") + self.direct_epi_tile_m = 0 + if self.direct_epi_tile_m not in (0, 64, 128): + raise ValueError("SM120 NVFP4 epilogue tile M must be 0, 64, or 128") + self.direct_epi_tile_n = 0 + if self.direct_epi_tile_n not in (0, 32, 64, 128): + raise ValueError("SM120 NVFP4 epilogue tile N must be 0, 32, 64, or 128") + both_epi_dims_split = self.direct_epi_tile_m not in ( + 0, + 128, + ) and self.direct_epi_tile_n not in (0, 128) + if both_epi_dims_split and (self.direct_epi_tile_m, self.direct_epi_tile_n) != ( + 64, + 32, + ): + raise ValueError("Only the 79a-style SM120 NVFP4 (64, 32) epilogue split is supported") + self.direct_epi_tma_rank3 = direct_128_default and self.pingpong + + @staticmethod + def _validate_blockscaled_nvfp4_config( + acc_dtype: Type[cutlass.Numeric], + a_dtype: Type[cutlass.Numeric], + tile_shape_mnk: Tuple[int, int] | Tuple[int, int, int], + cluster_shape_mnk: Tuple[int, int, int], + sf_vec_size: Optional[int], + sf_dtype: Optional[Type[cutlass.Numeric]], + pingpong: bool, + is_persistent: bool, + gather_A: bool, + ) -> None: + if acc_dtype is not cutlass.Float32: + raise ValueError("SM120 NVFP4 blockscaled requires Float32 accumulation") + if a_dtype is not cutlass.Float4E2M1FN: + raise ValueError("SM120 NVFP4 blockscaled requires Float4E2M1FN A/B operands") + if sf_dtype is not cutlass.Float8E4M3FN: + raise ValueError("SM120 NVFP4 blockscaled requires Float8E4M3FN scales") + if sf_vec_size != 16: + raise ValueError("SM120 NVFP4 blockscaled requires sf_vec_size=16") + if tuple(tile_shape_mnk) != (128, 128, 128): + raise ValueError("SM120 NVFP4 blockscaled supports CTA tile (128,128,128)") + if tuple(cluster_shape_mnk) != (1, 1, 1): + raise ValueError("SM120 NVFP4 blockscaled initially supports cluster (1,1,1)") + if pingpong and (tuple(tile_shape_mnk) != (128, 128, 128) or not is_persistent): + raise ValueError( + "SM120 NVFP4 blockscaled pingpong requires persistent (128,128,128) tiles" + ) + if gather_A: + raise ValueError("SM120 NVFP4 blockscaled does not support gather_A") + + @staticmethod + def can_implement_blockscaled( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + mma_tiler_mnk: Tuple[int, int] | Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + tile_shape = tuple(mma_tiler_mnk) if len(mma_tiler_mnk) == 3 else (*mma_tiler_mnk, 128) + return ( + ab_dtype is cutlass.Float4E2M1FN + and sf_dtype is cutlass.Float8E4M3FN + and sf_vec_size == 16 + and d_dtype is cutlass.BFloat16 + and tile_shape == (128, 128, 128) + and cluster_shape_mn == (1, 1) + and m % tile_shape[0] == 0 + and n % tile_shape[1] == 0 + and k % 128 == 0 + and l >= 1 + and a_major == "k" + and b_major == "k" + and d_major == "n" + ) def epi_smem_warp_shape_mnk(self): return self.atom_layout_mnk @@ -153,7 +384,2928 @@ def _setup_tiled_mma(self): self.cta_tile_shape_mnk = (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1], tile_k) # __call__, _setup_attributes, make_ab_pipeline, make_epi_store_pipeline, - # make_sched_pipeline, epilogue are all inherited from GemmSm90. + # epilogue are all inherited from GemmSm90. + + def make_sm120_single_warp_epi_store_pipeline(self): + return pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + cute.arch.WARP_SIZE, + ), + ) + + def make_sched_pipeline( + self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool + ): + if not (self.blockscaled and (self.pingpong or self.direct_sched_exclude_producer)): + return super().make_sched_pipeline( + cluster_layout_mnk, sched_pipeline_mbar_ptr, varlen_k + ) + + # The inherited SM90 pingpong scheduler counts one MMA warpgroup when + # varlen_k is false. This SM120 NVFP4 non-split path has both consumer + # warpgroups waiting on the scheduler tile. + sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(cluster_layout_mnk) + sched_mma_warp_groups = ( + 1 + if const_expr(self.direct_full_tma_pipeline) + else (self.mma_warp_groups if not self.direct_pingpong_split_tiles else 1) + ) + sched_producer_warps = ( + 0 if const_expr(self.direct_sched_exclude_producer) else self.num_ab_load_warps + ) + consumer_arrive_cnt = (sched_mma_warp_groups * 4 + sched_producer_warps) * cluster_size + sched_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + return pipeline.PipelineAsync.create( + barrier_storage=sched_pipeline_mbar_ptr, + num_stages=self.sched_stage, + producer_group=sched_pipeline_producer_group, + consumer_group=sched_pipeline_consumer_group, + consumer_mask=None if const_expr(cluster_size == 1) else 0, + defer_sync=True, + ) + + @cute.jit + def blockscaled_call( + self, + gA_storage: cute.Tensor, + gB_storage: cute.Tensor, + mD: cute.Tensor, + gSFA_storage: cute.Tensor, + gSFB_storage: cute.Tensor, + problem_m: cutlass.Constexpr[int], + problem_n: cutlass.Constexpr[int], + problem_k: cutlass.Constexpr[int], + problem_l: cutlass.Constexpr[int], + epilogue_args, + stream: cuda.CUstream, + ): + self.a_dtype = cutlass.Float4E2M1FN + self.b_dtype = cutlass.Float4E2M1FN + self.d_dtype = mD.element_type + self.d_layout = LayoutEnum.from_tensor(mD) + if const_expr(self.d_dtype is not cutlass.BFloat16): + raise TypeError("SM120 NVFP4 blockscaled output must be BFloat16") + if const_expr(not self.d_layout.is_n_major_c()): + raise ValueError("SM120 NVFP4 blockscaled output must be N-major") + epilogue_params = self.epi_to_underlying_arguments(epilogue_args) + + tile_extent_m, tile_extent_n, tile_extent_k = self.cta_tile_shape_mnk + gA = cute.make_tensor( + gA_storage.iterator, + _sm120.make_mxf4nvf4_a_gmem_layout(problem_m, problem_k, problem_l), + ) + gB = cute.make_tensor( + gB_storage.iterator, + _sm120.make_mxf4nvf4_b_gmem_layout(problem_n, problem_k, problem_l), + ) + gSFA = cute.make_tensor( + gSFA_storage.iterator, + _sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(problem_m, problem_k, problem_l), + ) + gSFB = cute.make_tensor( + gSFB_storage.iterator, + _sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(problem_n, problem_k, problem_l), + ) + k_tile_count = problem_k // tile_extent_k + m_tile_count = cute.ceil_div(problem_m, tile_extent_m) + n_tile_count = cute.ceil_div(problem_n, tile_extent_n) + if const_expr(tile_extent_m == 128 and tile_extent_n == 128): + if const_expr(self.direct_ab_tma_layout == "unpack"): + self.ab_stage = 2 + else: + self.ab_stage = 4 + self.sched_stage = 3 + else: + self.ab_stage = 2 + self.sched_stage = 1 + direct_128_default = tile_extent_m == 128 and tile_extent_n == 128 + delay_tma_store = direct_128_default and self.pingpong + use_default_79a_epi_tile = ( + self.pingpong + and direct_128_default + and self.direct_epi_tile_m == 0 + and self.direct_epi_tile_n == 0 + ) + effective_epi_tile_m = ( + 64 if const_expr(use_default_79a_epi_tile) else self.direct_epi_tile_m + ) + effective_epi_tile_n = ( + 32 if const_expr(use_default_79a_epi_tile) else self.direct_epi_tile_n + ) + if const_expr(effective_epi_tile_n == 32 and effective_epi_tile_m == 0): + effective_epi_tile_m = tile_extent_m + if const_expr(effective_epi_tile_n == 64 and effective_epi_tile_m == 0): + effective_epi_tile_m = tile_extent_m + if const_expr(effective_epi_tile_m == 64 and effective_epi_tile_n == 0): + effective_epi_tile_n = tile_extent_n + use_split_epi_tile = ( + tile_extent_m == 128 + and tile_extent_n == 128 + and effective_epi_tile_m != 0 + and effective_epi_tile_n != 0 + and (effective_epi_tile_m != tile_extent_m or effective_epi_tile_n != tile_extent_n) + ) + self.epi_tile = ( + (effective_epi_tile_m, effective_epi_tile_n) + if const_expr(use_split_epi_tile) + else (tile_extent_m, tile_extent_n) + ) + self.direct_epi_m_tiles = ( + tile_extent_m // effective_epi_tile_m if const_expr(use_split_epi_tile) else 1 + ) + self.direct_epi_n_tiles = ( + tile_extent_n // effective_epi_tile_n if const_expr(use_split_epi_tile) else 1 + ) + self.direct_epi_tiles = self.direct_epi_m_tiles * self.direct_epi_n_tiles + self.epi_tile_shape = cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile) + self.direct_delay_tma_store = delay_tma_store + if const_expr(self.direct_delay_tma_store and self.direct_epi_tiles < 2): + raise ValueError("SM120 NVFP4 delayed TMA store requires a split epilogue tile") + self.epi_stage = 2 if const_expr(self.direct_delay_tma_store) else 1 + self.direct_epi_tma_rank3 = direct_128_default and self.pingpong + self.epi_c_stage = 0 + epi_smem_m, epi_smem_n = self.epi_tile + self.epi_smem_layout_staged = _sm120.make_mxf4nvf4_epilogue_smem_layout( + epi_tile=(epi_smem_m, epi_smem_n), + num_stages=self.epi_stage, + ) + if const_expr(tile_extent_m == 128 and tile_extent_n == 128): + # Match CUTLASS 79a's 128x128 NVFP4 tiled-MMA layout. + if const_expr(self.pingpong): + self.tiled_mma = _sm120.make_mxf4nvf4_79a_tiled_mma() + else: + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, + cutlass.Float32, + cutlass.Float8E4M3FN, + ) + self.tiled_mma = cute.make_tiled_mma( + mma_op, + atom_layout_mnk=cute.make_layout((4, 2, 1), stride=(1, 4, 0)), + permutation_mnk=( + 128, + cute.make_layout((8, 2, 2), stride=(1, 16, 8)), + 64, + ), + ) + else: + self.tiled_mma = _sm120.make_mxf4nvf4_tiled_mma( + atom_layout_mnk=self.atom_layout_mnk, + ) + + if const_expr(self.direct_epi_tma_rank3): + mD_tma = mD + if const_expr(cute.size(mD, mode=[2]) == 1): + mD_tma = cute.make_tensor( + mD.iterator, + cute.make_layout( + (mD.shape[0], mD.shape[1], 2), + stride=mD.layout.stride, + ), + ) + epi_tma_smem_layout = self.epi_smem_layout_staged + if const_expr(self.epi_stage != 1): + epi_tma_smem_layout = cute.slice_( + self.epi_smem_layout_staged, + (None, None, 0), + ) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mD_tma, + epi_tma_smem_layout, + self.epi_tile, + ) + else: + tma_atom_d, tma_tensor_d = _sm120.make_mxf4nvf4_epilogue_tma_store_atom( + mD, + self.epi_smem_layout_staged, + epi_tile=self.epi_tile, + ) + + grid = ( + cute.ceil_div(cute.size(mD, mode=[0]), tile_extent_m), + cute.ceil_div(cute.size(mD, mode=[1]), tile_extent_n), + cute.size(mD, mode=[2]), + ) + if const_expr(tile_extent_m == 128 and tile_extent_n == 128): + ( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + ) = _sm120.make_mxf4nvf4_native_tma_atoms_for_scheduler( + gA, + gB, + gSFA, + gSFB, + tiled_mma=self.tiled_mma, + ab_smem_format=self.direct_ab_tma_layout, + scale_smem_format=self.direct_scale_smem_format, + ) + persistence_mode = PersistenceMode.CLC if self.pingpong else PersistenceMode.STATIC + if const_expr(self.direct_cute_static_scheduler): + static_scheduler_swizzle_size = 1 + if const_expr(static_scheduler_swizzle_size < 1): + raise ValueError("SM120 NVFP4 static scheduler swizzle must be >= 1") + tile_sched_params, static_scheduler_grid = ( + _sm120.make_mxf4nvf4_static_tile_scheduler_params( + m=problem_m, + n=problem_n, + k=problem_k, + l_extent=problem_l, + max_active_clusters=self.max_active_clusters, + swizzle_size=static_scheduler_swizzle_size, + ) + ) + direct_grid = static_scheduler_grid + else: + tile_sched_args = TileSchedulerArguments( + problem_shape_ntile_mnl=grid, + raster_order=RasterOrderOption.Heuristic, + group_size=Int32(8), + cluster_shape_mnk=self.cluster_shape_mnk, + tile_count_semaphore=None, + batch_idx_permute=None, + persistence_mode=persistence_mode, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + direct_grid = ( + TileScheduler.get_grid_shape(tile_sched_params, self.max_active_clusters) + if const_expr(self.direct_tile_scheduler) + else (1, 1, m_tile_count * n_tile_count * problem_l) + ) + cooperative_schedule = _sm120.make_mxf4nvf4_cooperative_schedule( + producer_warpgroup_start=self.mma_warp_groups, + consumer_warpgroups=self.mma_warp_groups, + ) + cooperative_launch_kwargs = cooperative_schedule.launch_kwargs() + scheduler_smem_bytes = 0 + if const_expr(not self.direct_cute_static_scheduler): + scheduler_data_rows = ( + 12 if const_expr(persistence_mode == PersistenceMode.CLC) else 4 + ) + scheduler_smem_bytes = ( + 8 * self.sched_stage * 2 + 4 * scheduler_data_rows * self.sched_stage + ) + split_smem_bytes = 82688 + max(0, self.ab_stage - 4) * 9728 + scheduler_smem_bytes + single_smem_bytes = max(82432, self.ab_stage * (18432 + 16)) + scheduler_smem_bytes + launch_smem_bytes = ( + split_smem_bytes + if const_expr(self.direct_split_tma_pipelines) + else single_smem_bytes + ) + if const_expr(launch_smem_bytes > 101376): + raise ValueError( + "SM120 NVFP4 kernel exceeds the 128x128 shared-memory budget; " + "use fewer A/B stages or disable the tiled scale SMEM path" + ) + self.blockscaled_kernel( + self.tiled_mma, + tma_atom_d, + tma_tensor_d, + mD, + self.epi_smem_layout_staged, + cute.size(mD, mode=[2]), + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + epilogue_params, + tile_sched_params, + k_tile_count, + m_tile_count, + n_tile_count, + ).launch( + grid=direct_grid, + block=cooperative_launch_kwargs["block"], + max_number_threads=cooperative_launch_kwargs["max_number_threads"], + min_blocks_per_mp=cooperative_launch_kwargs["min_blocks_per_mp"], + cluster=(1, 1, 1), + stream=stream, + smem=launch_smem_bytes, + ) + else: + raise NotImplementedError("SM120 NVFP4 blockscaled requires tile (128,128,128)") + + @cute.kernel + def blockscaled_kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_d: cute.CopyAtom, + tma_tensor_d: cute.Tensor, + gD: cute.Tensor, + epi_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + l_tiles: Int32, + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + epilogue_params, + tile_sched_params, + k_tile_count: cutlass.Constexpr[int], + m_tile_count: cutlass.Constexpr[int], + n_tile_count: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.make_warp_uniform(tidx // 32) + lane_idx = tidx % 32 + cta_m, cta_n, batch_idx = cute.arch.block_idx() + tile_extent_m, tile_extent_n, tile_extent_k = self.cta_tile_shape_mnk + m_atoms = tile_extent_m // 16 + n_atoms = tile_extent_n // 8 + + if const_expr(self.direct_tma_prefetch): + if warp_idx == self.ab_load_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + + smem = cutlass.utils.SmemAllocator() + sA_consumer, sB_consumer, sSFA, sSFB = _sm120.make_mxf4nvf4_native_tma_smem_views( + smem, + tiled_mma=tiled_mma, + num_stages=self.ab_stage, + tile_m=tile_extent_m, + tile_n=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + ab_smem_format=self.direct_ab_tma_layout, + scale_smem_format=self.direct_scale_smem_format, + ) + if const_expr(self.direct_ab_tma_layout == "unpack"): + sA_direct = sA_consumer + sB_direct = sB_consumer + else: + sA_direct, sB_direct = _sm120.make_mxf4nvf4_ab_packed_direct_tma_consumer_tma_views( + sA_consumer, + sB_consumer, + ) + if const_expr(self.direct_split_tma_pipelines): + barriers_mk = smem.allocate_array(cutlass.Int64, self.ab_stage * 2, byte_alignment=8) + if const_expr(self.direct_join_split_tma_barrier): + barriers_nk = barriers_mk + else: + barriers_nk = smem.allocate_array( + cutlass.Int64, self.ab_stage * 2, byte_alignment=8 + ) + else: + barriers = smem.allocate_array(cutlass.Int64, self.ab_stage * 2, byte_alignment=8) + sD_epi = cute.make_tensor( + cute.recast_ptr(sA_direct.iterator, dtype=cutlass.BFloat16), + epi_smem_layout_staged, + ) + epi_store_pipeline = self.make_sm120_single_warp_epi_store_pipeline() + + if const_expr(self.direct_split_tma_pipelines): + producer_arrive_count = 2 if const_expr(self.direct_join_split_tma_barrier) else 1 + tma_consumer_warps = ( + 4 if const_expr(self.direct_pingpong_split_tiles) else self.num_mma_warps + ) + pipe_mk = PipelineTmaWarpMma.create( + num_stages=self.ab_stage, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, producer_arrive_count + ), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, tma_consumer_warps), + tx_count=( + _sm120.mxf4nvf4_ab_tma_tx_bytes(tile_extent_m, tile_extent_k) + + _sm120.mxf4nvf4_scale_tma_tx_bytes(tile_extent_m, tile_extent_k, 16) + ), + barrier_storage=barriers_mk, + ) + if const_expr(self.direct_join_split_tma_barrier): + pipe_nk = pipe_mk + else: + pipe_nk = PipelineTmaWarpMma.create( + num_stages=self.ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, tma_consumer_warps + ), + tx_count=( + _sm120.mxf4nvf4_ab_tma_tx_bytes(tile_extent_n, tile_extent_k) + + _sm120.mxf4nvf4_scale_tma_tx_bytes(tile_extent_n, tile_extent_k, 16) + ), + barrier_storage=barriers_nk, + ) + else: + tma_consumer_warps = ( + 4 + if const_expr(self.direct_full_tma_pipeline and self.pingpong) + else self.num_mma_warps + ) + pipe = _sm120.make_mxf4nvf4_native_tma_pipeline( + barrier_storage=barriers, + num_stages=self.ab_stage, + tile_m=tile_extent_m, + tile_n=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + ab_smem_format=self.direct_ab_tma_layout, + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + tma_consumer_warps, + ), + ) + + if const_expr(self.direct_tile_scheduler or self.direct_split_tma_pipelines): + ( + tAsA, + tAgA, + tBsB, + tBgB, + tSFAs, + tSFAg, + tSFBs, + tSFBg, + ) = _sm120.partition_mxf4nvf4_native_tma_tensors_for_scheduler( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + tile_shape_mnk=(tile_extent_m, tile_extent_n, tile_extent_k), + sf_vec_size=16, + scale_smem_format=self.direct_scale_smem_format, + ) + mn_tile_count = m_tile_count * n_tile_count + total_work = mn_tile_count * l_tiles + work_stride = cute.arch.grid_dim()[2] + if const_expr(self.direct_tile_scheduler): + cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + if const_expr(not self.direct_cute_static_scheduler): + sched_barriers = smem.allocate_array( + cutlass.Int64, self.sched_stage * 2, byte_alignment=8 + ) + sched_data_rows = ( + 12 + if const_expr(tile_sched_params.persistence_mode == PersistenceMode.CLC) + else 4 + ) + sched_data = smem.allocate_tensor( + Int32, + cute.make_layout((sched_data_rows, self.sched_stage)), + byte_alignment=16, + ) + sched_pipeline = self.make_sched_pipeline( + cluster_layout_mnk, + sched_pipeline_mbar_ptr=sched_barriers, + varlen_k=False, + ) + TileSchedulerCls = partial( + TileScheduler.create, tile_sched_params, sched_data, sched_pipeline + ) + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk[:-1], is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk[:-1]) + + if ( + const_expr(self.direct_setmaxregister) + and self.ab_load_warp_id <= warp_idx + and warp_idx < self.ab_load_warp_id + 4 + ): + _sm120.setmaxregister_mxf4nvf4_producer(self.num_regs_load) + + tma_cache_policy = None + if const_expr(self.direct_tma_policy == "evict_last"): + tma_cache_policy = cpasync.create_l2_evict_last_policy() + + if const_expr(self.direct_split_tma_pipelines): + if warp_idx == self.ab_load_warp_id: + if const_expr( + (not self.direct_cute_static_scheduler) + and self.use_pdl + and tile_sched_params.persistence_mode == PersistenceMode.CLC + ): + cute.arch.griddepcontrol_wait() + producer_state_mk = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if const_expr(self.direct_cute_static_scheduler): + tile_scheduler = _sm120.make_mxf4nvf4_static_tile_scheduler( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + ) + else: + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + cta_m = tile_coord_mnkl[0] + batch_idx = ( + tile_coord_mnkl[2] + if const_expr(self.direct_cute_static_scheduler) + else tile_coord_mnkl[3] + ) + for k_tile in cutlass.range(k_tile_count, unroll=self.direct_producer_unroll): + pipe_mk.producer_acquire( + producer_state_mk, + pipe_mk.producer_try_acquire(producer_state_mk), + ) + _sm120.issue_mxf4nvf4_partitioned_native_tma_mk_stage_for_tile( + tma_atom_a, + tAsA, + tAgA, + tma_atom_sfa, + tSFAs, + tSFAg, + pipe_mk.producer_get_barrier(producer_state_mk), + (cta_m, tile_coord_mnkl[1], batch_idx), + k_tile, + producer_state_mk.index, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + pipe_mk.producer_commit(producer_state_mk) + producer_state_mk.advance() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + if const_expr(not self.direct_skip_split_tma_tail): + pipe_mk.producer_tail(producer_state_mk) + + if const_expr(not self.direct_cute_static_scheduler): + if warp_idx == self.ab_load_warp_id + 1: + if const_expr( + self.use_pdl and tile_sched_params.persistence_mode == PersistenceMode.CLC + ): + cute.arch.griddepcontrol_wait() + tile_scheduler = TileSchedulerCls(is_scheduler_warp=True) + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.advance_to_next_work(is_scheduler_warp=True) + work_tile = tile_scheduler.get_current_work() + if const_expr(self.direct_pingpong_split_tiles): + tile_scheduler.write_work_tile_to_smem(work_tile) + if const_expr(not self.direct_skip_scheduler_tail): + tile_scheduler.producer_tail() + + if warp_idx == self.ab_load_warp_id + 2: + if const_expr( + (not self.direct_cute_static_scheduler) + and self.use_pdl + and tile_sched_params.persistence_mode == PersistenceMode.CLC + ): + cute.arch.griddepcontrol_wait() + producer_state_nk = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if const_expr(self.direct_cute_static_scheduler): + tile_scheduler = _sm120.make_mxf4nvf4_static_tile_scheduler( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + ) + else: + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + cta_n = tile_coord_mnkl[1] + batch_idx = ( + tile_coord_mnkl[2] + if const_expr(self.direct_cute_static_scheduler) + else tile_coord_mnkl[3] + ) + for k_tile in cutlass.range(k_tile_count, unroll=self.direct_producer_unroll): + pipe_nk.producer_acquire( + producer_state_nk, + pipe_nk.producer_try_acquire(producer_state_nk), + ) + _sm120.issue_mxf4nvf4_partitioned_native_tma_nk_stage_for_tile( + tma_atom_b, + tBsB, + tBgB, + tma_atom_sfb, + tSFBs, + tSFBg, + pipe_nk.producer_get_barrier(producer_state_nk), + (tile_coord_mnkl[0], cta_n, batch_idx), + k_tile, + producer_state_nk.index, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + pipe_nk.producer_commit(producer_state_nk) + producer_state_nk.advance() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + if const_expr(not self.direct_skip_split_tma_tail): + pipe_nk.producer_tail(producer_state_nk) + + elif warp_idx == self.ab_load_warp_id: + if const_expr( + (not self.direct_cute_static_scheduler) + and self.use_pdl + and tile_sched_params.persistence_mode == PersistenceMode.CLC + ): + cute.arch.griddepcontrol_wait() + producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) + if const_expr(self.direct_tile_scheduler): + if const_expr(self.direct_cute_static_scheduler): + tile_scheduler = _sm120.make_mxf4nvf4_static_tile_scheduler( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + ) + else: + tile_scheduler = TileSchedulerCls(is_scheduler_warp=True) + work_tile = tile_scheduler.initial_work_tile_info() + else: + work_idx = cute.arch.block_idx()[2] + while ( + work_tile.is_valid_tile + if const_expr(self.direct_tile_scheduler) + else work_idx < total_work + ): + if const_expr(self.direct_tile_scheduler): + tile_coord_mnkl = work_tile.tile_idx + cta_m = tile_coord_mnkl[0] + cta_n = tile_coord_mnkl[1] + batch_idx = ( + tile_coord_mnkl[2] + if const_expr(self.direct_cute_static_scheduler) + else tile_coord_mnkl[3] + ) + else: + batch_idx = work_idx // mn_tile_count + mn_tile = work_idx - batch_idx * mn_tile_count + cta_n = mn_tile // m_tile_count + cta_m = mn_tile - cta_n * m_tile_count + if const_expr((not self.direct_tile_scheduler) and k_tile_count == 1): + pipe.producer_acquire( + producer_state, + pipe.producer_try_acquire(producer_state), + ) + _sm120.issue_mxf4nvf4_native_tma_stage_for_tile( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + pipe.producer_get_barrier(producer_state), + (cta_m, cta_n, batch_idx), + 0, + producer_state.index, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + pipe.producer_commit(producer_state) + producer_state.advance() + else: + for k_tile in cutlass.range(k_tile_count, unroll=self.direct_producer_unroll): + if const_expr(self.direct_elected_tma): + with cute.arch.elect_one(): + pipe.producer_acquire_already_elected( + producer_state, + pipe.producer_try_acquire(producer_state), + ) + if const_expr( + (not self.direct_tile_scheduler) + or self.direct_scheduler_local_tma + ): + _sm120.issue_mxf4nvf4_native_tma_stage_for_tile( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + pipe.producer_get_barrier(producer_state), + (cta_m, cta_n, batch_idx), + k_tile, + producer_state.index, + already_elected=True, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + else: + _sm120.issue_mxf4nvf4_partitioned_native_tma_stage_for_tile( + tma_atom_a, + tAsA, + tAgA, + tma_atom_b, + tBsB, + tBgB, + tma_atom_sfa, + tSFAs, + tSFAg, + tma_atom_sfb, + tSFBs, + tSFBg, + pipe.producer_get_barrier(producer_state), + (cta_m, cta_n, batch_idx), + k_tile, + producer_state.index, + already_elected=True, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + pipe.producer_commit(producer_state) + else: + pipe.producer_acquire( + producer_state, + pipe.producer_try_acquire(producer_state), + ) + if const_expr( + self.direct_tile_scheduler and not self.direct_scheduler_local_tma + ): + _sm120.issue_mxf4nvf4_partitioned_native_tma_stage_for_tile( + tma_atom_a, + tAsA, + tAgA, + tma_atom_b, + tBsB, + tBgB, + tma_atom_sfa, + tSFAs, + tSFAg, + tma_atom_sfb, + tSFBs, + tSFBg, + pipe.producer_get_barrier(producer_state), + (cta_m, cta_n, batch_idx), + k_tile, + producer_state.index, + cache_policy=tma_cache_policy, + scale_smem_format=self.direct_scale_smem_format, + ) + else: + _sm120.issue_mxf4nvf4_native_tma_stage_for_tile( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + pipe.producer_get_barrier(producer_state), + (cta_m, cta_n, batch_idx), + k_tile, + producer_state.index, + scale_smem_format=self.direct_scale_smem_format, + cache_policy=tma_cache_policy, + ) + pipe.producer_commit(producer_state) + producer_state.advance() + if const_expr(self.direct_tile_scheduler): + if const_expr(self.direct_cute_static_scheduler): + tile_scheduler.advance_to_next_work() + else: + tile_scheduler.advance_to_next_work(is_scheduler_warp=True) + work_tile = tile_scheduler.get_current_work() + else: + work_idx += work_stride + if const_expr(self.direct_tile_scheduler): + if const_expr( + (not self.direct_cute_static_scheduler) + and (not self.direct_skip_scheduler_tail) + ): + tile_scheduler.producer_tail() + + if const_expr(tile_extent_m == 128 and tile_extent_n == 128): + if self.mma_warp_id_start <= warp_idx and warp_idx < self.mma_warp_id_end: + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + pingpong_group_idx = warp_group_idx + tidx_mma = tidx + warp_idx_mma = warp_idx - self.mma_warp_id_start + if const_expr(self.pingpong): + pingpong_group_idx = warp_group_idx - self.mma_warp_id_start // 4 + if const_expr(self.direct_pingpong_split_tiles): + tidx_mma = tidx % self.num_threads_per_warp_group + warp_idx_mma = warp_idx % 4 + if const_expr(self.direct_pingpong_barriers) and pingpong_group_idx == 0: + self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") + self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") + elif const_expr(self.pingpong): + tidx_mma = tidx % self.num_threads_per_warp_group + warp_idx_mma = warp_idx % 4 + if ( + const_expr(self.direct_full_tma_pipeline and self.direct_pingpong_barriers) + and pingpong_group_idx == 0 + ): + self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") + self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") + if const_expr(self.direct_setmaxregister): + _sm120.setmaxregister_mxf4nvf4_consumer(self.num_regs_mma) + if const_expr(self.direct_split_tma_pipelines): + consumer_state_mk = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + consumer_state_nk = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + else: + consumer_state = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + if const_expr(self.direct_tile_scheduler): + if const_expr(self.direct_cute_static_scheduler): + tile_scheduler = _sm120.make_mxf4nvf4_static_tile_scheduler( + tile_sched_params, + cute.arch.block_idx(), + cute.arch.grid_dim(), + ) + else: + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + if const_expr(self.direct_pingpong_split_tiles): + if pingpong_group_idx == 1: + consumer_state_mk.advance_iters(k_tile_count) + consumer_state_nk.advance_iters(k_tile_count) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + elif const_expr(self.direct_full_tma_pipeline): + if pingpong_group_idx == 1: + consumer_state.advance_iters(k_tile_count) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + else: + work_idx = cute.arch.block_idx()[2] + if const_expr(self.direct_full_tma_pipeline): + if pingpong_group_idx == 1: + consumer_state.advance_iters(k_tile_count) + work_idx += work_stride + while ( + work_tile.is_valid_tile + if const_expr(self.direct_tile_scheduler) + else work_idx < total_work + ): + if const_expr(self.direct_tile_scheduler): + tile_coord_mnkl = work_tile.tile_idx + cta_m = tile_coord_mnkl[0] + cta_n = tile_coord_mnkl[1] + batch_idx = ( + tile_coord_mnkl[2] + if const_expr(self.direct_cute_static_scheduler) + else tile_coord_mnkl[3] + ) + else: + batch_idx = work_idx // mn_tile_count + mn_tile = work_idx - batch_idx * mn_tile_count + cta_n = mn_tile // m_tile_count + cta_m = mn_tile - cta_n * m_tile_count + if const_expr( + (self.direct_pingpong_split_tiles or self.direct_full_tma_pipeline) + and self.direct_pingpong_barriers + ): + if const_expr(self.direct_try_wait_before_pingpong_barrier): + initial_try_wait_mk = pipe_mk.consumer_try_wait(consumer_state_mk) + initial_try_wait_nk = initial_try_wait_mk + if const_expr(not self.direct_join_split_tma_barrier): + initial_try_wait_nk = pipe_nk.consumer_try_wait(consumer_state_nk) + self.pingpong_barrier_sync(pingpong_group_idx, stage="mma") + tDsD_work, tDgD_work = self._partition_blockscaled_epilogue_tma_store( + tma_atom_d, + tma_tensor_d, + sD_epi, + cta_m, + cta_n, + batch_idx, + ) + if const_expr(self.direct_split_tma_pipelines): + ( + consumer_state_mk, + consumer_state_nk, + ) = self._blockscaled_compute_store_full_tile_k_loop_split_tma( + tiled_mma, + pipe_mk, + pipe_nk, + consumer_state_mk, + consumer_state_nk, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + epilogue_params, + sD_epi, + tma_atom_d, + tDsD_work, + tDgD_work, + gD, + (cta_m, cta_n, batch_idx), + epi_store_pipeline, + tidx_mma, + warp_idx_mma, + pingpong_group_idx, + lane_idx, + k_tile_count, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ( + initial_try_wait_mk + if const_expr(self.direct_try_wait_before_pingpong_barrier) + else None + ), + ( + initial_try_wait_nk + if const_expr(self.direct_try_wait_before_pingpong_barrier) + else None + ), + ) + else: + consumer_state = self._blockscaled_compute_store_full_tile_k_loop( + tiled_mma, + pipe, + consumer_state, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + epilogue_params, + sD_epi, + tma_atom_d, + tDsD_work, + tDgD_work, + gD, + (cta_m, cta_n, batch_idx), + epi_store_pipeline, + tidx_mma, + warp_idx_mma, + pingpong_group_idx, + lane_idx, + k_tile_count, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + if const_expr(self.direct_tile_scheduler): + if const_expr(self.direct_pingpong_split_tiles): + consumer_state_mk.advance_iters(k_tile_count) + consumer_state_nk.advance_iters(k_tile_count) + tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups) + elif const_expr(self.direct_full_tma_pipeline): + consumer_state.advance_iters(k_tile_count) + tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups) + else: + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + else: + if const_expr(self.direct_full_tma_pipeline): + consumer_state.advance_iters(k_tile_count) + work_idx += work_stride * self.mma_warp_groups + else: + work_idx += work_stride + if const_expr( + (not self.direct_cute_static_scheduler) + and self.use_pdl + and tile_sched_params.persistence_mode == PersistenceMode.CLC + ): + cute.arch.griddepcontrol_launch_dependents() + else: + tDsD, tDgD = self._partition_blockscaled_epilogue_tma_store( + tma_atom_d, + tma_tensor_d, + sD_epi, + cta_m, + cta_n, + batch_idx, + ) + for m_atom in cutlass.range_constexpr(m_atoms): + if warp_idx == m_atom: + self._blockscaled_compute_store_m_atom_k_loop( + tiled_mma, + pipe, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + epilogue_params, + sD_epi, + tma_atom_d, + tDsD, + tDgD, + epi_store_pipeline, + m_atom, + lane_idx, + k_tile_count, + n_atoms, + tile_extent_m, + tile_extent_n, + tile_extent_k, + 0, + 1, + m_atom == 0, + ) + + @cute.jit + def _blockscaled_store_full_tile_direct_global( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + gD: cute.Tensor, + tile_mnl, + tidx: Int32, + warp_group_idx: Int32, + ) -> None: + if const_expr( + (self.direct_full_tma_pipeline or self.direct_pingpong_split_tiles) + and self.direct_pingpong_barriers + ): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + self.pingpong_barrier_sync(warp_group_idx, stage="epi") + _sm120.store_mxf4nvf4_accumulator_fragment_D_for_tiled_mma_tile( + tiled_mma, + acc, + gD, + tile_mnl, + tidx, + ) + if const_expr( + (self.direct_full_tma_pipeline or self.direct_pingpong_split_tiles) + and self.direct_pingpong_barriers + ): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + + @cute.jit + def _blockscaled_compute_store_full_tile_k_loop( + self, + tiled_mma: cute.TiledMma, + pipe, + consumer_state: cutlass.pipeline.PipelineState, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epilogue_params, + sD_epi: cute.Tensor, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + gD: Optional[cute.Tensor], + tile_mnl, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + tidx: Int32, + warp_idx: Int32, + warp_group_idx: Int32, + lane_idx: Int32, + k_tile_count: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> cutlass.pipeline.PipelineState: + if const_expr(self.direct_pipelined_consumer): + return self._blockscaled_compute_store_full_tile_k_loop_pipelined( + tiled_mma, + pipe, + consumer_state, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + epilogue_params, + sD_epi, + tma_atom_d, + tDsD, + tDgD, + gD, + tile_mnl, + epi_store_pipeline, + tidx, + warp_idx, + warp_group_idx, + lane_idx, + k_tile_count, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + + a_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((tile_extent_m, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + b_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((tile_extent_n, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + if const_expr(self.direct_ab_tma_layout == "unpack"): + a_copy_frag = cute.recast_tensor(a_frag, cutlass.Uint8) + b_copy_frag = cute.recast_tensor(b_frag, cutlass.Uint8) + else: + a_copy_frag = a_frag + b_copy_frag = b_frag + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((tile_extent_m, tile_extent_n)), + cutlass.Float32, + ) + acc.fill(0.0) + if const_expr(self.direct_ab_tma_layout == "unpack"): + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + else: + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma) + thr_copy_a = tiled_copy_a.get_slice(tidx) + thr_copy_b = tiled_copy_b.get_slice(tidx) + tCsA = thr_copy_a.partition_S(cute.as_position_independent_swizzle_tensor(sA_consumer)) + tCsB = thr_copy_b.partition_S(cute.as_position_independent_swizzle_tensor(sB_consumer)) + if const_expr(self.direct_ab_tma_layout == "unpack"): + tCsA = cute.make_tensor(tCsA.iterator.align(16), tCsA.layout) + tCsB = cute.make_tensor(tCsB.iterator.align(16), tCsB.layout) + tCrA = thr_copy_a.retile(a_copy_frag) + tCrB = thr_copy_b.retile(b_copy_frag) + + sfa_frag, sfb_frag = _sm120.make_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + tidx, + tile_shape_mnk=(tile_extent_m, tile_extent_n, tile_extent_k), + sf_vec_size=16, + ) + scale_stage_thread_idx = None + scale_stage_thread_count = cute.arch.WARP_SIZE + scale_barrier_id = Int32(8) + for _k_tile in cutlass.range(k_tile_count, unroll=self.direct_consumer_unroll): + stage = consumer_state.index + pipe.consumer_wait(consumer_state, pipe.consumer_try_wait(consumer_state)) + if const_expr(self.direct_cute_dsl_helpers): + _sm120.issue_mxf4nvf4_direct_tma_consumer_group( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + a_frag, + b_frag, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + acc, + tidx, + stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + ) + elif const_expr(self.direct_kblock_pipeline): + cute.copy( + tiled_copy_a, + tCsA[(None, None, 0, stage)], + tCrA[(None, None, 0)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 0, stage)], + tCrB[(None, None, 0)], + ) + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, 1, stage)], + tCrA[(None, None, 1)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 1, stage)], + tCrB[(None, None, 1)], + ) + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 1, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + for k_block_idx in cutlass.range_constexpr(2): + cute.gemm( + tiled_mma, + acc, + ( + a_frag[(None, None, k_block_idx)], + sfa_frag[(None, None, k_block_idx)], + ), + ( + b_frag[(None, None, k_block_idx)], + sfb_frag[(None, None, k_block_idx)], + ), + acc, + ) + else: + cute.copy( + tiled_copy_a, + tCsA[(None, None, 0, stage)], + tCrA[(None, None, 0)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 0, stage)], + tCrB[(None, None, 0)], + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, 1, stage)], + tCrA[(None, None, 1)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 1, stage)], + tCrB[(None, None, 1)], + ) + for k_block_idx in cutlass.range_constexpr(2): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + cute.gemm( + tiled_mma, + acc, + (a_frag, sfa_frag), + (b_frag, sfb_frag), + acc, + ) + if const_expr(self.direct_release_before_sync): + pipe.consumer_release(consumer_state) + consumer_state.advance() + if const_expr(self.direct_consumer_barrier): + self._direct_kblock_barrier() + if const_expr(self.direct_consumer_fence): + cute.arch.fence_view_async_shared() + if const_expr(self.direct_consumer_warp_sync): + cute.arch.sync_warp() + if const_expr(not self.direct_release_before_sync): + pipe.consumer_release(consumer_state) + consumer_state.advance() + + if const_expr(self.direct_global_store or self.direct_global_store_probe): + self._blockscaled_store_full_tile_direct_global( + tiled_mma, + acc, + gD, + tile_mnl, + tidx, + warp_group_idx, + ) + return consumer_state + + if const_expr(self.direct_full_tma_pipeline and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + self.pingpong_barrier_sync(warp_group_idx, stage="epi") + if const_expr(self.direct_delay_tma_store): + if warp_idx == 0: + epi_store_pipeline.producer_acquire() + for epi_tile_idx in cutlass.range_constexpr(self.direct_epi_tiles): + epi_m = epi_tile_idx // self.direct_epi_n_tiles + epi_n = epi_tile_idx - epi_m * self.direct_epi_n_tiles + epi_stage_idx = epi_tile_idx % self.epi_stage + prev_epi_tile_idx = epi_tile_idx - 1 + prev_epi_m = prev_epi_tile_idx // self.direct_epi_n_tiles + prev_epi_n = prev_epi_tile_idx - prev_epi_m * self.direct_epi_n_tiles + prev_epi_stage_idx = prev_epi_tile_idx % self.epi_stage + acc_epi_m = epi_m + if const_expr(self.direct_delay_tma_store and epi_tile_idx != 0): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + prev_epi_m, + prev_epi_n, + prev_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + epi_store_pipeline.producer_acquire() + self.epilogue_barrier.arrive_and_wait() + elif const_expr(not self.direct_delay_tma_store) and warp_idx == 0: + epi_store_pipeline.producer_acquire() + if const_expr(not self.direct_epi_barrier_trim): + self.epilogue_barrier.arrive_and_wait() + if const_expr( + (not self.pingpong) + or self.direct_pingpong_split_tiles + or self.direct_full_tma_pipeline + ): + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + elif const_expr(epi_tile_idx == 0): + if warp_group_idx == 0: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + else: + if warp_group_idx == 1: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + cute.arch.fence_view_async_shared() + self.epilogue_barrier.arrive_and_wait() + if const_expr(not self.direct_delay_tma_store): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + epi_m, + epi_n, + 0, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if const_expr(self.direct_delay_tma_store): + last_epi_tile_idx = self.direct_epi_tiles - 1 + last_epi_m = last_epi_tile_idx // self.direct_epi_n_tiles + last_epi_n = last_epi_tile_idx - last_epi_m * self.direct_epi_n_tiles + last_epi_stage_idx = last_epi_tile_idx % self.epi_stage + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + last_epi_m, + last_epi_n, + last_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if warp_idx == 0: + epi_store_pipeline.producer_tail() + if const_expr(self.direct_full_tma_pipeline and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + return consumer_state + + @cute.jit + def _blockscaled_compute_store_full_tile_k_loop_pipelined( + self, + tiled_mma: cute.TiledMma, + pipe, + consumer_state: cutlass.pipeline.PipelineState, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epilogue_params, + sD_epi: cute.Tensor, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + gD: Optional[cute.Tensor], + tile_mnl, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + tidx: Int32, + warp_idx: Int32, + warp_group_idx: Int32, + lane_idx: Int32, + k_tile_count: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> cutlass.pipeline.PipelineState: + a_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((tile_extent_m, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + b_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((tile_extent_n, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + if const_expr(self.direct_ab_tma_layout == "unpack"): + a_copy_frag = cute.recast_tensor(a_frag, cutlass.Uint8) + b_copy_frag = cute.recast_tensor(b_frag, cutlass.Uint8) + else: + a_copy_frag = a_frag + b_copy_frag = b_frag + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((tile_extent_m, tile_extent_n)), + cutlass.Float32, + ) + acc.fill(0.0) + + if const_expr(self.direct_ab_tma_layout == "unpack"): + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + else: + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma) + thr_copy_a = tiled_copy_a.get_slice(tidx) + thr_copy_b = tiled_copy_b.get_slice(tidx) + tCsA = thr_copy_a.partition_S(cute.as_position_independent_swizzle_tensor(sA_consumer)) + tCsB = thr_copy_b.partition_S(cute.as_position_independent_swizzle_tensor(sB_consumer)) + if const_expr(self.direct_ab_tma_layout == "unpack"): + tCsA = cute.make_tensor(tCsA.iterator.align(16), tCsA.layout) + tCsB = cute.make_tensor(tCsB.iterator.align(16), tCsB.layout) + tCrA = thr_copy_a.retile(a_copy_frag) + tCrB = thr_copy_b.retile(b_copy_frag) + + sfa_frag, sfb_frag = _sm120.make_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + tidx, + tile_shape_mnk=(tile_extent_m, tile_extent_n, tile_extent_k), + sf_vec_size=16, + ) + scale_stage_thread_idx = None + scale_stage_thread_count = cute.arch.WARP_SIZE + scale_barrier_id = Int32(8) + + stage = consumer_state.index + pipe.consumer_wait(consumer_state, pipe.consumer_try_wait(consumer_state)) + + if const_expr(self.direct_bgroup_pipeline): + self._blockscaled_load_direct_k_block_a_scale_fragments( + tiled_mma, + tiled_copy_a, + tCsA, + tCrA, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + scale_stage_thread_idx, + scale_stage_thread_count, + scale_barrier_id, + ) + for b_group_idx in cutlass.range_constexpr(2): + self._blockscaled_load_direct_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + 0, + b_group_idx, + stage, + ) + _sm120.load_mxf4nvf4_direct_tma_k_block_fragments( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage_idx=stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + scale_first=self.direct_scale_prefetch_first, + ) + + for _k_tile in cutlass.range(k_tile_count - 1, unroll=1): + for k_block_idx in cutlass.range_constexpr(2): + k_block_next = 0 if const_expr(k_block_idx == 1) else 1 + if const_expr(k_block_idx == 1): + if const_expr(self.direct_kblock_barrier): + self._direct_kblock_barrier() + pipe.consumer_release(consumer_state) + consumer_state.advance() + stage = consumer_state.index + pipe.consumer_wait(consumer_state, pipe.consumer_try_wait(consumer_state)) + if const_expr(self.direct_bgroup_pipeline): + self._blockscaled_load_direct_k_block_a_scale_fragments( + tiled_mma, + tiled_copy_a, + tCsA, + tCrA, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + scale_stage_thread_idx, + scale_stage_thread_count, + scale_barrier_id, + ) + for b_group_idx in cutlass.range_constexpr(2): + self._blockscaled_load_direct_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + k_block_next, + b_group_idx, + stage, + ) + self._blockscaled_direct_gemm_k_block_bgroup_pipeline( + tiled_mma, + tiled_copy_b, + acc, + tCsB, + tCrB, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + stage, + ) + _sm120.load_mxf4nvf4_direct_tma_k_block_fragments( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage_idx=stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + scale_first=self.direct_scale_prefetch_first, + ) + self._blockscaled_direct_gemm_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + ) + for k_block_idx in cutlass.range_constexpr(2): + k_block_next = 0 if const_expr(k_block_idx == 1) else 1 + if const_expr(k_block_idx == 1): + if const_expr(self.direct_kblock_barrier): + self._direct_kblock_barrier() + pipe.consumer_release(consumer_state) + consumer_state.advance() + if const_expr(k_block_idx == 0): + if const_expr(self.direct_bgroup_pipeline): + self._blockscaled_load_direct_k_block_a_scale_fragments( + tiled_mma, + tiled_copy_a, + tCsA, + tCrA, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + for b_group_idx in cutlass.range_constexpr(2): + self._blockscaled_load_direct_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + k_block_next, + b_group_idx, + stage, + ) + else: + _sm120.load_mxf4nvf4_direct_tma_k_block_fragments( + tiled_mma, + tiled_copy_a, + tiled_copy_b, + tCsA, + tCsB, + tCrA, + tCrB, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage_idx=stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + scale_first=self.direct_scale_prefetch_first, + ) + if const_expr(self.direct_bgroup_pipeline): + self._blockscaled_direct_gemm_k_block_bgroup_pipeline( + tiled_mma, + tiled_copy_b, + acc, + tCsB, + tCrB, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + stage, + ) + else: + self._blockscaled_direct_gemm_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + ) + + if const_expr(self.direct_global_store or self.direct_global_store_probe): + self._blockscaled_store_full_tile_direct_global( + tiled_mma, + acc, + gD, + tile_mnl, + tidx, + warp_group_idx, + ) + return consumer_state + + if const_expr(self.direct_full_tma_pipeline and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + self.pingpong_barrier_sync(warp_group_idx, stage="epi") + if const_expr(self.direct_delay_tma_store): + if warp_idx == 0: + epi_store_pipeline.producer_acquire() + for epi_tile_idx in cutlass.range_constexpr(self.direct_epi_tiles): + epi_m = epi_tile_idx // self.direct_epi_n_tiles + epi_n = epi_tile_idx - epi_m * self.direct_epi_n_tiles + epi_stage_idx = epi_tile_idx % self.epi_stage + prev_epi_tile_idx = epi_tile_idx - 1 + prev_epi_m = prev_epi_tile_idx // self.direct_epi_n_tiles + prev_epi_n = prev_epi_tile_idx - prev_epi_m * self.direct_epi_n_tiles + prev_epi_stage_idx = prev_epi_tile_idx % self.epi_stage + acc_epi_m = epi_m + if const_expr(self.direct_delay_tma_store and epi_tile_idx != 0): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + prev_epi_m, + prev_epi_n, + prev_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + epi_store_pipeline.producer_acquire() + self.epilogue_barrier.arrive_and_wait() + elif const_expr(not self.direct_delay_tma_store) and warp_idx == 0: + epi_store_pipeline.producer_acquire() + if const_expr(not self.direct_epi_barrier_trim): + self.epilogue_barrier.arrive_and_wait() + if const_expr( + (not self.pingpong) + or self.direct_pingpong_split_tiles + or self.direct_full_tma_pipeline + ): + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + elif const_expr(epi_tile_idx == 0): + if warp_group_idx == 0: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + else: + if warp_group_idx == 1: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + cute.arch.fence_view_async_shared() + self.epilogue_barrier.arrive_and_wait() + if const_expr(not self.direct_delay_tma_store): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + epi_m, + epi_n, + 0, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if const_expr(self.direct_delay_tma_store): + last_epi_tile_idx = self.direct_epi_tiles - 1 + last_epi_m = last_epi_tile_idx // self.direct_epi_n_tiles + last_epi_n = last_epi_tile_idx - last_epi_m * self.direct_epi_n_tiles + last_epi_stage_idx = last_epi_tile_idx % self.epi_stage + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + last_epi_m, + last_epi_n, + last_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if warp_idx == 0: + epi_store_pipeline.producer_tail() + if const_expr(self.direct_full_tma_pipeline and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + return consumer_state + + @cute.jit + def _split_tma_consumer_wait( + self, + pipe_mk, + pipe_nk, + consumer_state_mk: cutlass.pipeline.PipelineState, + consumer_state_nk: cutlass.pipeline.PipelineState, + ) -> None: + _sm120.mxf4nvf4_split_tma_consumer_wait( + pipe_mk, + pipe_nk, + consumer_state_mk, + consumer_state_nk, + join_split_tma_barrier=self.direct_join_split_tma_barrier, + ) + + @cute.jit + def _split_tma_consumer_wait_with_tokens( + self, + pipe_mk, + pipe_nk, + consumer_state_mk: cutlass.pipeline.PipelineState, + consumer_state_nk: cutlass.pipeline.PipelineState, + try_wait_token_mk: Boolean, + try_wait_token_nk: Boolean, + ) -> None: + _sm120.mxf4nvf4_split_tma_consumer_wait( + pipe_mk, + pipe_nk, + consumer_state_mk, + consumer_state_nk, + join_split_tma_barrier=self.direct_join_split_tma_barrier, + try_wait_token_mk=try_wait_token_mk, + try_wait_token_nk=try_wait_token_nk, + ) + + @cute.jit + def _split_tma_consumer_release( + self, + pipe_mk, + pipe_nk, + consumer_state_mk: cutlass.pipeline.PipelineState, + consumer_state_nk: cutlass.pipeline.PipelineState, + ) -> None: + _sm120.mxf4nvf4_split_tma_consumer_release( + pipe_mk, + pipe_nk, + consumer_state_mk, + consumer_state_nk, + join_split_tma_barrier=self.direct_join_split_tma_barrier, + ) + + @cute.jit + def _direct_kblock_barrier(self) -> None: + _sm120.mxf4nvf4_mma_warpgroup_barrier_sync( + number_of_threads=self.num_mma_warps * cute.arch.WARP_SIZE, + ) + + def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str): + _sm120.mxf4nvf4_pingpong_barrier_sync(warp_group_idx, stage) + + def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str): + _sm120.mxf4nvf4_pingpong_barrier_arrive(warp_group_idx, stage) + + @cute.jit + def _partition_blockscaled_epilogue_tma_store( + self, + tma_atom_d: cute.CopyAtom, + tma_tensor_d: cute.Tensor, + sD_epi: cute.Tensor, + cta_m: Int32, + cta_n: Int32, + batch_idx: Int32, + ): + return _sm120.partition_mxf4nvf4_epilogue_tma_store( + tma_atom_d, + tma_tensor_d, + sD_epi, + (cta_m, cta_n, batch_idx), + cta_tiler=self.cta_tile_shape_mnk, + epi_tile=self.epi_tile, + ) + + @cute.jit + def _copy_blockscaled_epilogue_tma_store( + self, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + epi_m: cutlass.Constexpr[int], + epi_n: cutlass.Constexpr[int], + epi_stage_idx: cutlass.Constexpr[int], + ) -> None: + if const_expr(self.direct_epi_tma_rank3): + # The rank-3 TMA partition keeps the epilogue stage in the source + # tensor basis. The copy source coordinate is the TMA vector mode. + cute.copy(tma_atom_d, tDsD[None, epi_stage_idx], tDgD[None, (epi_m, epi_n)]) + else: + _sm120.issue_mxf4nvf4_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + epi_m=epi_m, + epi_n=epi_n, + stage_idx=epi_stage_idx, + ) + + @cute.jit + def _blockscaled_compute_store_full_tile_k_loop_split_tma( + self, + tiled_mma: cute.TiledMma, + pipe_mk, + pipe_nk, + consumer_state_mk: cutlass.pipeline.PipelineState, + consumer_state_nk: cutlass.pipeline.PipelineState, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epilogue_params, + sD_epi: cute.Tensor, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + gD: Optional[cute.Tensor], + tile_mnl, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + tidx: Int32, + warp_idx: Int32, + warp_group_idx: Int32, + lane_idx: Int32, + k_tile_count: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + initial_try_wait_mk: Optional[Boolean] = None, + initial_try_wait_nk: Optional[Boolean] = None, + ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: + if const_expr(self.direct_fragment_contract == "thread_mma"): + thr_mma_for_frag = tiled_mma.get_slice(tidx) + tCsA_mma = thr_mma_for_frag.partition_A(sA_consumer) + tCsB_mma = thr_mma_for_frag.partition_B(sB_consumer) + a_frag = tiled_mma.make_fragment_A(tCsA_mma[(None, None, None, 0)]) + b_frag = tiled_mma.make_fragment_B(tCsB_mma[(None, None, None, 0)]) + else: + a_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((tile_extent_m, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + b_frag = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((tile_extent_n, tile_extent_k)), + cutlass.Float4E2M1FN, + ) + if const_expr(self.direct_ab_tma_layout == "unpack"): + a_copy_frag = cute.recast_tensor(a_frag, cutlass.Uint8) + b_copy_frag = cute.recast_tensor(b_frag, cutlass.Uint8) + else: + a_copy_frag = a_frag + b_copy_frag = b_frag + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((tile_extent_m, tile_extent_n)), + cutlass.Float32, + ) + acc.fill(0.0) + if const_expr(self.direct_ab_tma_layout == "unpack"): + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x16x8bOp(transpose=False, num_matrices=4, unpack_bits=4), + cutlass.Uint8, + ) + else: + copy_atom_a = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + copy_atom_b = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom_a, tiled_mma) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom_b, tiled_mma) + thr_copy_a = tiled_copy_a.get_slice(tidx) + thr_copy_b = tiled_copy_b.get_slice(tidx) + tCsA = thr_copy_a.partition_S(cute.as_position_independent_swizzle_tensor(sA_consumer)) + tCsB = thr_copy_b.partition_S(cute.as_position_independent_swizzle_tensor(sB_consumer)) + if const_expr(self.direct_ab_tma_layout == "unpack"): + tCsA = cute.make_tensor(tCsA.iterator.align(16), tCsA.layout) + tCsB = cute.make_tensor(tCsB.iterator.align(16), tCsB.layout) + tCrA = thr_copy_a.retile(a_copy_frag) + tCrB = thr_copy_b.retile(b_copy_frag) + + sSFA_f8 = cute.recast_tensor(sSFA, cutlass.Float8E4M3FN) + sSFB_f8 = cute.recast_tensor(sSFB, cutlass.Float8E4M3FN) + sfa_frag, sfb_frag = _sm120.make_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + tidx, + tile_shape_mnk=(tile_extent_m, tile_extent_n, tile_extent_k), + sf_vec_size=16, + ) + scale_stage_thread_idx = None + scale_stage_thread_count = cute.arch.WARP_SIZE + scale_barrier_id = Int32(8) + + stage = consumer_state_mk.index + if const_expr(self.direct_try_wait_before_pingpong_barrier): + self._split_tma_consumer_wait_with_tokens( + pipe_mk, + pipe_nk, + consumer_state_mk, + consumer_state_nk, + initial_try_wait_mk, + initial_try_wait_nk, + ) + else: + self._split_tma_consumer_wait(pipe_mk, pipe_nk, consumer_state_mk, consumer_state_nk) + + if const_expr(self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, 0, stage)], + tCrA[(None, None, 0)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, 0, stage)], + tCrB[(None, None, 0)], + ) + self._shift_unpack_ab_fragments(tCrA, tCrB, 0) + if const_expr(not self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + 0, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + + for _k_tile in cutlass.range(k_tile_count - 1, unroll=1): + for k_block_idx in cutlass.range_constexpr(2): + k_block_next = 0 if const_expr(k_block_idx == 1) else 1 + if const_expr(k_block_idx == 1): + if const_expr(self.direct_kblock_barrier): + self._direct_kblock_barrier() + self._split_tma_consumer_release( + pipe_mk, pipe_nk, consumer_state_mk, consumer_state_nk + ) + consumer_state_mk.advance() + consumer_state_nk.advance() + stage = consumer_state_mk.index + self._split_tma_consumer_wait( + pipe_mk, pipe_nk, consumer_state_mk, consumer_state_nk + ) + if const_expr(self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, k_block_next, stage)], + tCrA[(None, None, k_block_next)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, k_block_next, stage)], + tCrB[(None, None, k_block_next)], + ) + self._shift_unpack_ab_fragments(tCrA, tCrB, k_block_next) + if const_expr(not self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + self._blockscaled_direct_gemm_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + ) + for k_block_idx in cutlass.range_constexpr(2): + k_block_next = 0 if const_expr(k_block_idx == 1) else 1 + if const_expr(k_block_idx == 1): + if const_expr(self.direct_kblock_barrier): + self._direct_kblock_barrier() + self._split_tma_consumer_release( + pipe_mk, pipe_nk, consumer_state_mk, consumer_state_nk + ) + consumer_state_mk.advance() + consumer_state_nk.advance() + if const_expr(k_block_idx == 0): + if const_expr(self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + cute.copy( + tiled_copy_a, + tCsA[(None, None, k_block_next, stage)], + tCrA[(None, None, k_block_next)], + ) + cute.copy( + tiled_copy_b, + tCsB[(None, None, k_block_next, stage)], + tCrB[(None, None, k_block_next)], + ) + self._shift_unpack_ab_fragments(tCrA, tCrB, k_block_next) + if const_expr(not self.direct_scale_prefetch_first): + self._blockscaled_load_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_next, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + self._blockscaled_direct_gemm_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + ) + if const_expr(self.direct_global_store or self.direct_global_store_probe): + self._blockscaled_store_full_tile_direct_global( + tiled_mma, + acc, + gD, + tile_mnl, + tidx, + warp_group_idx, + ) + return consumer_state_mk, consumer_state_nk + if const_expr(self.direct_pingpong_split_tiles and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + self.pingpong_barrier_sync(warp_group_idx, stage="epi") + if const_expr(self.direct_delay_tma_store): + if warp_idx == 0: + epi_store_pipeline.producer_acquire() + for epi_tile_idx in cutlass.range_constexpr(self.direct_epi_tiles): + epi_m = epi_tile_idx // self.direct_epi_n_tiles + epi_n = epi_tile_idx - epi_m * self.direct_epi_n_tiles + epi_stage_idx = epi_tile_idx % self.epi_stage + prev_epi_tile_idx = epi_tile_idx - 1 + prev_epi_m = prev_epi_tile_idx // self.direct_epi_n_tiles + prev_epi_n = prev_epi_tile_idx - prev_epi_m * self.direct_epi_n_tiles + prev_epi_stage_idx = prev_epi_tile_idx % self.epi_stage + acc_epi_m = epi_m + pingpong_epi_tile = warp_group_idx + if const_expr(self.direct_delay_tma_store and epi_tile_idx != 0): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + prev_epi_m, + prev_epi_n, + prev_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + epi_store_pipeline.producer_acquire() + self.epilogue_barrier.arrive_and_wait() + elif const_expr(not self.direct_delay_tma_store) and warp_idx == 0: + epi_store_pipeline.producer_acquire() + if const_expr(not self.direct_epi_barrier_trim): + self.epilogue_barrier.arrive_and_wait() + if const_expr((not self.pingpong) or self.direct_pingpong_split_tiles): + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + elif const_expr(epi_tile_idx == 0): + if pingpong_epi_tile == 0: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + else: + if pingpong_epi_tile == 1: + self._blockscaled_stage_full_tile_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + acc, + tidx, + acc_epi_m, + epi_n, + epi_stage_idx, + ) + cute.arch.fence_view_async_shared() + self.epilogue_barrier.arrive_and_wait() + if const_expr(not self.direct_delay_tma_store): + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + epi_m, + epi_n, + 0, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if const_expr(self.direct_delay_tma_store): + last_epi_tile_idx = self.direct_epi_tiles - 1 + last_epi_m = last_epi_tile_idx // self.direct_epi_n_tiles + last_epi_n = last_epi_tile_idx - last_epi_m * self.direct_epi_n_tiles + last_epi_stage_idx = last_epi_tile_idx % self.epi_stage + if warp_idx == 0: + self._copy_blockscaled_epilogue_tma_store( + tma_atom_d, + tDsD, + tDgD, + last_epi_m, + last_epi_n, + last_epi_stage_idx, + ) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if warp_idx == 0: + epi_store_pipeline.producer_tail() + if const_expr(self.direct_pingpong_split_tiles and self.direct_pingpong_barriers): + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + return consumer_state_mk, consumer_state_nk + + @cute.jit + def _shift_unpack_ab_fragments( + self, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + k_block_idx: cutlass.Constexpr[int], + ) -> None: + if const_expr(self.direct_ab_tma_layout == "unpack" and self.direct_unpack_shift): + _sm120.shift_mxf4nvf4_post_ldsm_fp4_fragment(tCrA[(None, None, k_block_idx)]) + _sm120.shift_mxf4nvf4_post_ldsm_fp4_fragment(tCrB[(None, None, k_block_idx)]) + + @cute.jit + def _blockscaled_load_scale_fragments( + self, + tiled_mma: cute.TiledMma, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + tidx: Int32, + k_block_idx: cutlass.Constexpr[int], + stage: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> None: + _sm120.copy_mxf4nvf4_direct_tma_scale_fragments( + tiled_mma, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + scale_smem_format=self.direct_scale_smem_format, + ) + + @cute.jit + def _blockscaled_load_direct_k_block_a_scale_fragments( + self, + tiled_mma: cute.TiledMma, + tiled_copy_a: cute.TiledCopy, + tCsA: cute.Tensor, + tCrA: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + tidx: Int32, + k_block_idx: cutlass.Constexpr[int], + stage: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> None: + _sm120.load_mxf4nvf4_direct_tma_k_block_a_scale_fragments( + tiled_mma, + tiled_copy_a, + tCsA, + tCrA, + sSFA, + sSFB, + sfa_frag, + sfb_frag, + tidx, + k_block_idx, + stage_idx=stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + scale_first=self.direct_scale_prefetch_first, + ) + + @cute.jit + def _blockscaled_load_direct_k_block_b_group_fragment( + self, + tiled_copy_b: cute.TiledCopy, + tCsB: cute.Tensor, + tCrB: cute.Tensor, + k_block_idx: cutlass.Constexpr[int], + b_group_idx: cutlass.Constexpr[int], + stage: Int32, + ) -> None: + _sm120.load_mxf4nvf4_direct_tma_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + k_block_idx, + b_group_idx, + stage_idx=stage, + ) + + @cute.jit + def _blockscaled_direct_gemm_k_block_b_group( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + k_block_idx: cutlass.Constexpr[int], + b_group_idx: cutlass.Constexpr[int], + ) -> None: + _sm120.gemm_mxf4nvf4_direct_tma_k_block_b_group( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + b_group_idx, + ) + + @cute.jit + def _blockscaled_direct_gemm_k_block_bgroup_pipeline( + self, + tiled_mma: cute.TiledMma, + tiled_copy_b: cute.TiledCopy, + acc: cute.Tensor, + tCsB: cute.Tensor, + tCrB: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + k_block_idx: cutlass.Constexpr[int], + stage: Int32, + ) -> None: + self._blockscaled_direct_gemm_k_block_b_group( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + 0, + ) + self._blockscaled_load_direct_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + k_block_idx, + 2, + stage, + ) + self._blockscaled_direct_gemm_k_block_b_group( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + 1, + ) + self._blockscaled_load_direct_k_block_b_group_fragment( + tiled_copy_b, + tCsB, + tCrB, + k_block_idx, + 3, + stage, + ) + for b_group_idx in cutlass.range_constexpr(2, 4): + self._blockscaled_direct_gemm_k_block_b_group( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + b_group_idx, + ) + + @cute.jit + def _blockscaled_direct_gemm_k_block( + self, + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + sfa_frag: cute.Tensor, + sfb_frag: cute.Tensor, + k_block_idx: cutlass.Constexpr[int], + ) -> None: + _sm120.gemm_mxf4nvf4_direct_tma_k_block( + tiled_mma, + acc, + a_frag, + b_frag, + sfa_frag, + sfb_frag, + k_block_idx, + ab_smem_format=self.direct_ab_tma_layout, + n_major=self.direct_mma_n_major, + sync_warp_before=self.direct_pre_mma_warp_sync, + ) + + @cute.jit + def _blockscaled_stage_full_tile_to_epi_smem( + self, + tiled_mma: cute.TiledMma, + epilogue_params, + sD_epi: cute.Tensor, + acc: cute.Tensor, + tidx: Int32, + epi_m: cutlass.Constexpr[int] = 0, + epi_n: cutlass.Constexpr[int] = 0, + epi_stage_idx: cutlass.Constexpr[int] = 0, + ) -> None: + sD_tile = sD_epi[(None, None, epi_stage_idx)] + if const_expr(self.direct_epi_m_tiles == 1 and self.direct_epi_n_tiles == 1): + rD_acc = cute.make_rmem_tensor(acc.shape, cutlass.Float32) + rD_acc.store(acc.load()) + self.epi_visit_subtile( + epilogue_params, + { + "alpha": None, + "beta": None, + "mRowVecBroadcast": None, + "mColVecBroadcast": None, + }, + rD_acc, + None, + ) + thr_mma = tiled_mma.get_slice(tidx) + tCsD = thr_mma.partition_C(sD_tile) + rD = cute.make_rmem_tensor(acc.shape, cutlass.BFloat16) + rD.store(rD_acc.load().to(cutlass.BFloat16)) + atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16) + cute.copy(atom, rD, tCsD) + return + + tiled_copy_r2s, tRS_rD, tRS_sD, tRS_rAcc = _sm120.make_mxf4nvf4_epilogue_stmatrix_views( + tiled_mma, + acc, + sD_tile, + tidx, + epi_tile_shape=self.epi_tile_shape, + num_matrices=self.direct_epi_stsm_matrices, + ) + _sm120.load_mxf4nvf4_accumulator_epilogue_subtile( + tRS_rAcc, + tRS_rD, + (epi_m, epi_n), + ) + self.epi_visit_subtile( + epilogue_params, + { + "alpha": None, + "beta": None, + "mRowVecBroadcast": None, + "mColVecBroadcast": None, + }, + tRS_rD, + None, + ) + _sm120.copy_mxf4nvf4_epilogue_registers_to_smem( + tiled_copy_r2s, + tRS_rD, + tRS_sD, + ) + + @cute.jit + def _blockscaled_compute_store_mn_group_k_loop( + self, + tiled_mma: cute.TiledMma, + pipe, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epilogue_params, + sD_epi: cute.Tensor, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + m_group: cutlass.Constexpr[int], + n_group: cutlass.Constexpr[int], + lane_idx: Int32, + k_tile_count: Int32, + n_atoms: cutlass.Constexpr[int], + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> None: + n_group_count = 2 + n_atoms_per_group = n_atoms // n_group_count + accs = [ + [ + cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + for _ in range(n_atoms_per_group) + ] + for _ in range(2) + ] + for m_offset in cutlass.range_constexpr(2): + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + accs[m_offset][n_atom_local].fill(0.0) + + consumer_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + for _k_tile in cutlass.range(k_tile_count, unroll=1): + stage = consumer_state.index + pipe.consumer_wait(consumer_state, pipe.consumer_try_wait(consumer_state)) + for m_offset in cutlass.range_constexpr(2): + m_atom = m_group + 4 * m_offset + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + n_atom = n_atom_local * n_group_count + n_group + self._blockscaled_mma_n_atom( + tiled_mma, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + accs[m_offset][n_atom_local], + m_atom, + n_atom, + lane_idx, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + self._direct_kblock_barrier() + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + pipe.consumer_release(consumer_state) + consumer_state.advance() + + if const_expr(m_group == 0 and n_group == 0): + epi_store_pipeline.producer_acquire() + self.epilogue_barrier.arrive_and_wait() + for m_offset in cutlass.range_constexpr(2): + m_atom = m_group + 4 * m_offset + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + n_atom = n_atom_local * n_group_count + n_group + self._blockscaled_stage_n_atom_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + accs[m_offset][n_atom_local], + m_atom, + n_atom, + lane_idx, + ) + cute.arch.fence_view_async_shared() + self.epilogue_barrier.arrive_and_wait() + if const_expr(m_group == 0 and n_group == 0): + cute.copy(tma_atom_d, tDsD[None, 0], tDgD[None, 0, 0]) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if const_expr(m_group == 0 and n_group == 0): + epi_store_pipeline.producer_tail() + + @cute.jit + def _blockscaled_compute_store_m_atom_k_loop( + self, + tiled_mma: cute.TiledMma, + pipe, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epilogue_params, + sD_epi: cute.Tensor, + tma_atom_d: cute.CopyAtom, + tDsD: cute.Tensor, + tDgD: cute.Tensor, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + m_atom: cutlass.Constexpr[int], + lane_idx: Int32, + k_tile_count: Int32, + n_atoms: cutlass.Constexpr[int], + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + n_group: cutlass.Constexpr[int], + n_group_count: cutlass.Constexpr[int], + do_tma_store: cutlass.Constexpr[bool], + ) -> None: + n_atoms_per_group = n_atoms // n_group_count + accs = [ + cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + for _ in range(n_atoms_per_group) + ] + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + accs[n_atom_local].fill(0.0) + + consumer_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + for _k_tile in cutlass.range(k_tile_count, unroll=1): + stage = consumer_state.index + pipe.consumer_wait(consumer_state, pipe.consumer_try_wait(consumer_state)) + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + n_atom = n_atom_local * n_group_count + n_group + self._blockscaled_mma_n_atom( + tiled_mma, + sA_consumer, + sB_consumer, + sSFA, + sSFB, + accs[n_atom_local], + m_atom, + n_atom, + lane_idx, + stage, + tile_extent_m, + tile_extent_n, + tile_extent_k, + ) + self._direct_kblock_barrier() + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + pipe.consumer_release(consumer_state) + consumer_state.advance() + + if const_expr(do_tma_store): + epi_store_pipeline.producer_acquire() + self.epilogue_barrier.arrive_and_wait() + for n_atom_local in cutlass.range_constexpr(n_atoms_per_group): + n_atom = n_atom_local * n_group_count + n_group + self._blockscaled_stage_n_atom_to_epi_smem( + tiled_mma, + epilogue_params, + sD_epi, + accs[n_atom_local], + m_atom, + n_atom, + lane_idx, + ) + cute.arch.fence_view_async_shared() + self.epilogue_barrier.arrive_and_wait() + if const_expr(do_tma_store): + cute.copy(tma_atom_d, tDsD[None, 0], tDgD[None, 0, 0]) + epi_store_pipeline.producer_commit() + self.epilogue_barrier.arrive_and_wait() + if const_expr(do_tma_store): + epi_store_pipeline.producer_tail() + + @cute.jit + def _blockscaled_mma_n_atom( + self, + tiled_mma: cute.TiledMma, + sA_consumer: cute.Tensor, + sB_consumer: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + acc: cute.Tensor, + m_atom: cutlass.Constexpr[int], + n_atom: cutlass.Constexpr[int], + lane_idx: Int32, + consumer_stage: Int32, + tile_extent_m: cutlass.Constexpr[int], + tile_extent_n: cutlass.Constexpr[int], + tile_extent_k: cutlass.Constexpr[int], + ) -> None: + sA_atom, sB_atom = _sm120.make_mxf4nvf4_ab_consumer_microtile_views( + sA_consumer, + sB_consumer, + m_atom=m_atom, + n_atom=n_atom, + ) + a_frag, b_frag = _sm120.make_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma, + sA_atom, + sB_atom, + lane_idx=lane_idx, + ) + sfa, sfb = _sm120.make_mxf4nvf4_scale_fragments() + sSFA_atom = cute.make_tensor( + sSFA.iterator + cutlass.Int32(m_atom * 16), + sSFA.layout, + ) + sSFB_atom = cute.make_tensor( + sSFB.iterator + cutlass.Int32(n_atom * 8), + sSFB.layout, + ) + for k_block_idx in cutlass.range_constexpr(2): + _sm120.load_mxf4nvf4_ab_fragments_from_consumer_smem( + tiled_mma, + sA_atom, + sB_atom, + a_frag, + b_frag, + lane_idx, + k_block_idx, + consumer_stage_idx=consumer_stage, + ) + sfa_src, sfb_src = _sm120.make_mxf4nvf4_scale_fragment_views_from_direct_tma( + sSFA_atom, + sSFB_atom, + k_block_idx, + stage_idx=consumer_stage, + major_extent_sfa=tile_extent_m, + major_extent_sfb=tile_extent_n, + tile_k=tile_extent_k, + sf_vec_size=16, + ) + _sm120.load_mxf4nvf4_sfa_fragment(sfa_src, sfa) + _sm120.load_mxf4nvf4_sfb_fragment(sfb_src, sfb) + cute.gemm( + tiled_mma, + acc, + (a_frag[(None, 0, k_block_idx)], sfa), + (b_frag[(None, 0, k_block_idx)], sfb), + acc, + ) + + @cute.jit + def _blockscaled_stage_n_atom_to_epi_smem( + self, + tiled_mma: cute.TiledMma, + epilogue_params, + sD_epi: cute.Tensor, + acc: cute.Tensor, + epi_m_atom: cutlass.Constexpr[int], + epi_n_atom: cutlass.Constexpr[int], + lane_idx: Int32, + ) -> None: + sD_tile = sD_epi[(None, None, 0)] + sD_atom = cute.local_tile(sD_tile, (16, 8), (epi_m_atom, epi_n_atom)) + thr_mma = tiled_mma.get_slice(lane_idx) + tCsD = thr_mma.partition_C(sD_atom) + rD_acc = cute.make_rmem_tensor(acc.shape, cutlass.Float32) + rD_acc.store(acc.load()) + self.epi_visit_subtile( + epilogue_params, + { + "alpha": None, + "beta": None, + "mRowVecBroadcast": None, + "mColVecBroadcast": None, + }, + rD_acc, + None, + ) + rD = cute.make_rmem_tensor(acc.shape, cutlass.BFloat16) + rD.store(rD_acc.load().to(cutlass.BFloat16)) + atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16) + cute.copy(atom, rD, tCsD) @cute.kernel def kernel( diff --git a/quack/sm120_blockscaled_utils.py b/quack/sm120_blockscaled_utils.py new file mode 100644 index 00000000..46a71d8d --- /dev/null +++ b/quack/sm120_blockscaled_utils.py @@ -0,0 +1,334 @@ +# Copyright (c) 2026, QuACK team. +"""Host-side helpers for SM120 NVFP4 blockscaled GEMM.""" + +from __future__ import annotations + +from typing import Tuple + +import torch + + +_FP4_E2M1FN_VALUES = ( + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +) + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def round_up(a: int, b: int) -> int: + return ceil_div(a, b) * b + + +def _validate_cuda_tensor(tensor: torch.Tensor, name: str) -> None: + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + + +def validate_sm120_nvfp4_ab_storage( + packed_tensor: torch.Tensor, + *, + logical_k: int, + major_extent: int, + batch_extent: int, +) -> None: + _validate_cuda_tensor(packed_tensor, "packed_tensor") + if packed_tensor.dtype != torch.float4_e2m1fn_x2: + raise TypeError(f"packed_tensor must be torch.float4_e2m1fn_x2, got {packed_tensor.dtype}") + if logical_k <= 0 or logical_k % 128 != 0: + raise ValueError(f"logical_k must be a positive multiple of 128, got {logical_k}") + expected_shape = (major_extent, logical_k // 2, batch_extent) + if tuple(packed_tensor.shape) != expected_shape: + raise ValueError( + f"packed_tensor shape must be {expected_shape}, got {tuple(packed_tensor.shape)}" + ) + expected_stride = (logical_k // 2, 1, major_extent * logical_k // 2) + stride = tuple(packed_tensor.stride()) + stride_ok = stride == expected_stride or ( + batch_extent == 1 and stride[:2] == expected_stride[:2] + ) + if not stride_ok: + raise ValueError( + "packed_tensor must be K-major contiguous packed storage with stride " + f"{expected_stride}, got {stride}" + ) + if packed_tensor.data_ptr() % 32 != 0: + raise ValueError("SM120 NVFP4 A/B base pointer must be 32B aligned") + + +def validate_sm120_nvfp4_d_storage( + d: torch.Tensor, + *, + m: int, + n: int, + l: int, +) -> None: + _validate_cuda_tensor(d, "D") + if d.dtype != torch.bfloat16: + raise TypeError(f"D must be torch.bfloat16, got {d.dtype}") + expected_shape = (m, n, l) + if tuple(d.shape) != expected_shape: + raise ValueError(f"D shape must be {expected_shape}, got {tuple(d.shape)}") + expected_stride = (n, 1, m * n) + stride = tuple(d.stride()) + stride_ok = stride == expected_stride or (l == 1 and stride[:2] == expected_stride[:2]) + if not stride_ok: + raise ValueError(f"D must be N-major with stride {expected_stride}, got {stride}") + if d.data_ptr() % 16 != 0: + raise ValueError("SM120 NVFP4 D base pointer must be at least 16B aligned") + + +def _check_major_tile(major_extent: int, major_tile: int, tile_major: int) -> int: + major_offset = major_tile * tile_major + if major_tile < 0 or major_offset + tile_major > major_extent: + raise ValueError(f"major_tile={major_tile} is outside major_extent={major_extent}") + return major_offset + + +def sm120_nvfp4_scale_pages(logical_k: int, sf_vec_size: int = 16) -> tuple[int, int, int]: + if sf_vec_size != 16: + raise ValueError(f"SM120 NVFP4 requires sf_vec_size=16, got {sf_vec_size}") + if logical_k <= 0 or logical_k % 128 != 0: + raise ValueError(f"logical_k must be a positive multiple of 128, got {logical_k}") + logical_scale_cols = ceil_div(logical_k, sf_vec_size) + physical_scale_cols = round_up(logical_scale_cols, 16) + physical_scale_pages = max(physical_scale_cols // 16, 2) + return logical_scale_cols, physical_scale_cols, physical_scale_pages + + +def sm120_nvfp4_scale_physical_offset( + major: int, + scale_col: int, + major_extent: int, +) -> int: + physical_major_extent = max(major_extent, 128) + payload_idx = scale_col * physical_major_extent + major + payload_chunk = payload_idx // 16 + payload_byte = payload_idx % 16 + physical_chunk = payload_chunk ^ ((payload_chunk >> 3) & 0x7) + return physical_chunk * 16 + payload_byte + + +def sm120_nvfp4_scale_interleaved_size( + logical_k: int, + major_extent: int, + batch_extent: int, +) -> tuple[int, int, int]: + if logical_k <= 0 or logical_k % 16 != 0: + raise ValueError(f"logical_k must be a positive multiple of 16, got {logical_k}") + if major_extent % 128 != 0: + raise ValueError(f"major_extent must be a multiple of 128, got {major_extent}") + if batch_extent <= 0: + raise ValueError(f"batch_extent must be positive, got {batch_extent}") + logical_cols = logical_k // 16 + if logical_cols % 4 != 0: + raise ValueError(f"logical_k / 16 must be a multiple of 4, got {logical_cols}") + major_tiles = major_extent // 128 + scale_tiles = ceil_div(logical_cols, 4) + return logical_cols, scale_tiles, major_tiles * scale_tiles * 512 * batch_extent + + +def sm120_nvfp4_scale_interleaved_offset( + major: int, + scale_col: int, + *, + logical_k: int, + major_extent: int, + batch_idx: int = 0, +) -> int: + logical_cols, scale_tiles, _ = sm120_nvfp4_scale_interleaved_size( + logical_k, major_extent, batch_idx + 1 + ) + if scale_col < 0 or scale_col >= logical_cols: + raise ValueError(f"scale_col={scale_col} is outside logical_cols={logical_cols}") + if major < 0 or major >= major_extent: + raise ValueError(f"major={major} is outside major_extent={major_extent}") + major_tiles = major_extent // 128 + major_tile = major // 128 + major_in_tile = major - major_tile * 128 + major_row = major_in_tile % 32 + major_quad = major_in_tile // 32 + scale_tile = scale_col // 4 + scale_quad = scale_col - scale_tile * 4 + l_stride = scale_tiles * major_tiles * 512 + return ( + batch_idx * l_stride + + scale_tile * major_tiles * 512 + + major_tile * 512 + + major_row * 16 + + major_quad * 4 + + scale_quad + ) + + +def validate_sm120_nvfp4_scale_storage( + scale_tensor: torch.Tensor, + *, + logical_k: int, + major_extent: int, + batch_extent: int, +) -> tuple[int, int, int]: + _validate_cuda_tensor(scale_tensor, "scale_tensor") + if scale_tensor.dtype != torch.float8_e4m3fn: + raise TypeError(f"scale_tensor must be torch.float8_e4m3fn, got {scale_tensor.dtype}") + logical_cols, physical_cols, pages = sm120_nvfp4_scale_pages(logical_k) + _logical_cols, _scale_tiles, interleaved_size = sm120_nvfp4_scale_interleaved_size( + logical_k, major_extent, batch_extent + ) + if scale_tensor.ndim != 1: + raise ValueError("SM120 NVFP4 scale storage must be compact 1D interleaved FP8 storage") + if scale_tensor.numel() != interleaved_size: + raise ValueError( + "interleaved scale_tensor storage must have " + f"{interleaved_size} elements, got {scale_tensor.numel()}" + ) + if scale_tensor.stride() != (1,): + raise ValueError( + f"interleaved scale_tensor must be contiguous, got {scale_tensor.stride()}" + ) + if scale_tensor.data_ptr() % 16 != 0: + raise ValueError("SM120 NVFP4 scale base pointer must be 16B aligned") + return logical_cols, physical_cols, pages + + +def copy_sm120_nvfp4_scale_blocks_to_storage( + scale_tensor: torch.Tensor, + block_values: torch.Tensor, + *, + logical_k: int, +) -> None: + if not block_values.is_cuda: + raise ValueError("block_values must be a CUDA tensor") + major_extent, logical_cols, batch_extent = block_values.shape + expected_cols = ceil_div(logical_k, 16) + if logical_cols != expected_cols: + raise ValueError(f"block_values shape[1] must be {expected_cols}, got {logical_cols}") + validate_sm120_nvfp4_scale_storage( + scale_tensor, + logical_k=logical_k, + major_extent=major_extent, + batch_extent=batch_extent, + ) + ref_u8 = block_values.to(torch.float8_e4m3fn).view(torch.uint8) + storage_u8 = scale_tensor.view(torch.uint8) + storage_u8.zero_() + _logical_cols, scale_tiles, _storage_size = sm120_nvfp4_scale_interleaved_size( + logical_k, major_extent, batch_extent + ) + major_tiles = major_extent // 128 + storage_view = storage_u8.view(batch_extent, scale_tiles, major_tiles, 32, 4, 4) + for batch_idx in range(batch_extent): + for col in range(logical_cols): + scale_tile = col // 4 + scale_quad = col - scale_tile * 4 + for major_tile in range(major_tiles): + major_start = major_tile * 128 + src = ref_u8[major_start : major_start + 128, col, batch_idx] + storage_view[batch_idx, scale_tile, major_tile, :, :, scale_quad].copy_( + src.view(4, 32).transpose(0, 1) + ) + + +def make_sm120_nvfp4_ab_metadata_tensor( + *, device, logical_k: int = 128, major_extent: int = 128, batch_extent: int = 2 +) -> torch.Tensor: + """Create rank-preserving dummy metadata for SM120 A/B TMA atom construction.""" + return torch.empty( + (logical_k, major_extent, batch_extent), + dtype=torch.uint8, + device=device, + ) + + +def make_sm120_nvfp4_scale_metadata_tensor(*, device) -> torch.Tensor: + """Create rank-preserving dummy metadata for SM120 scale TMA atom construction.""" + return torch.empty((1024,), dtype=torch.uint8, device=device) + + +def create_sm120_nvfp4_ab_tensor( + l: int, major: int, k: int, *, fill_byte: int | None = None +) -> torch.Tensor: + """Create SM120 packed-K NVFP4 A/B storage shaped ``(major, k / 2, l)``.""" + if l <= 0: + raise ValueError(f"l must be positive, got {l}") + if major <= 0: + raise ValueError(f"major must be positive, got {major}") + if k <= 0 or k % 128 != 0: + raise ValueError(f"k must be a positive multiple of 128, got {k}") + storage = torch.empty((l, major, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda") + packed = storage.permute(1, 2, 0) + if fill_byte is not None: + if fill_byte < 0 or fill_byte > 255: + raise ValueError(f"fill_byte must fit in uint8, got {fill_byte}") + packed.view(torch.uint8).fill_(fill_byte) + return packed + + +def create_sm120_nvfp4_tensorfill_like_ab_tensor( + l: int, major: int, k: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create bounded random non-zero FP4 A/B data close to 79a TensorFillRandomUniform. + + 79a uses TensorFillRandomUniform with the FP4 input scope [-2, 2]. Keep the + same value range but avoid FP4 zero codes so performance checks cannot turn + into accidental zero-skipping checks. + """ + packed = create_sm120_nvfp4_ab_tensor(l, major, k) + magnitudes = torch.randint(1, 5, (major, k, l), device="cuda", dtype=torch.uint8) + signs = torch.randint(0, 2, (major, k, l), device="cuda", dtype=torch.uint8) << 3 + codes = magnitudes | signs + packed.view(torch.uint8).copy_(codes[:, 0::2, :] | (codes[:, 1::2, :] << 4)) + table = torch.tensor(_FP4_E2M1FN_VALUES, device="cuda", dtype=torch.float32) + return table[codes.long()], packed + + +def create_sm120_nvfp4_scale_tensor(l: int, mn: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Create random SM120 interleaved scale storage and expanded FP32 reference.""" + logical_cols, _scale_tiles, storage_size = sm120_nvfp4_scale_interleaved_size(k, mn, l) + ref_blocks = torch.randint(1, 4, (mn, logical_cols, l), device="cuda").float() + storage = torch.zeros(storage_size, dtype=torch.float8_e4m3fn, device="cuda") + copy_sm120_nvfp4_scale_blocks_to_storage(storage, ref_blocks, logical_k=k) + ref = ( + ref_blocks.permute(0, 2, 1) + .unsqueeze(-1) + .expand(mn, l, logical_cols, 16) + .reshape(mn, l, logical_cols * 16) + .permute(0, 2, 1) + )[:, :k, :] + return ref, storage + + +def create_sm120_nvfp4_tensorfill_like_scale_tensor( + l: int, mn: int, k: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Create bounded random positive non-zero E4M3 scale data.""" + logical_cols, _scale_tiles, storage_size = sm120_nvfp4_scale_interleaved_size(k, mn, l) + choices = torch.tensor([0.5, 1.0], device="cuda", dtype=torch.float32) + indices = torch.randint(0, choices.numel(), (mn, logical_cols, l), device="cuda") + ref_blocks = choices[indices] + storage = torch.zeros(storage_size, dtype=torch.float8_e4m3fn, device="cuda") + copy_sm120_nvfp4_scale_blocks_to_storage(storage, ref_blocks, logical_k=k) + ref = ( + ref_blocks.permute(0, 2, 1) + .unsqueeze(-1) + .expand(mn, l, logical_cols, 16) + .reshape(mn, l, logical_cols * 16) + .permute(0, 2, 1) + )[:, :k, :] + return ref, storage diff --git a/quack/sm120_pipeline.py b/quack/sm120_pipeline.py new file mode 100644 index 00000000..0c0a1844 --- /dev/null +++ b/quack/sm120_pipeline.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""SM120 CTA-local TMA pipeline helpers.""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import cutlass.cute as cute +from cutlass._mlir import ir +from cutlass.cutlass_dsl import Boolean, Int32, dsl_user_op, if_generate +from cutlass.cute.typing import Pointer +from cutlass.pipeline import CooperativeGroup, PipelineState +from cutlass.pipeline.sm90 import PipelineTmaAsync + + +@dataclass(frozen=True) +class PipelineTmaWarpMma(PipelineTmaAsync): + """SM120 CTA-local TMA pipeline for warp-level MMA consumers.""" + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: Optional[Pointer] = None, + cta_layout_vmnk=None, + tidx: Optional[Int32] = None, + mcast_mode_mn: Tuple[int, int] = (1, 1), + defer_sync: bool = False, + ) -> "PipelineTmaWarpMma": + base = PipelineTmaAsync.create( + num_stages=num_stages, + producer_group=producer_group, + consumer_group=consumer_group, + tx_count=tx_count, + barrier_storage=barrier_storage, + cta_layout_vmnk=cta_layout_vmnk, + tidx=tidx, + mcast_mode_mn=mcast_mode_mn, + defer_sync=defer_sync, + ) + return PipelineTmaWarpMma( + base.sync_object_full, + base.sync_object_empty, + base.num_stages, + base.producer_mask, + base.consumer_mask, + base.is_signalling_thread, + ) + + @dsl_user_op + def producer_acquire_already_elected( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + """Acquire a TMA load stage from inside an existing ``elect_one`` block.""" + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + cute.arch.mbarrier_arrive_and_expect_tx( + self.producer_get_barrier(state, loc=loc, ip=ip), + self.sync_object_full.tx_count, + loc=loc, + ip=ip, + ) + + +__all__ = [ + "PipelineTmaWarpMma", +] diff --git a/quack/sm120_utils.py b/quack/sm120_utils.py new file mode 100644 index 00000000..3cafe470 --- /dev/null +++ b/quack/sm120_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, QuACK team. +"""Small public SM120 helper facade. + +The SM120 NVFP4 kernel implementation uses ``quack._sm120_nvfp4_utils`` directly. +Keep this module limited to stable inspection helpers used by tests and callers. +""" + +from quack import _sm120_nvfp4_utils as _sm120 + + +def get_ab_tma_tx_bytes( + tile_mn: int = 128, + tile_k: int = 128, + *, + smem_format: str = "packed", +) -> int: + return _sm120.mxf4nvf4_ab_tma_tx_bytes(tile_mn, tile_k, smem_format=smem_format) + + +def get_scale_tma_tx_bytes( + tile_mn: int = 128, + tile_k: int = 128, + sf_vec_size: int = _sm120.MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + return _sm120.mxf4nvf4_scale_tma_tx_bytes(tile_mn, tile_k, sf_vec_size) + + +def get_full_tma_tx_bytes( + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = _sm120.MXF4NVF4_SCALE_VEC_SIZE, + *, + ab_smem_format: str = "packed", +) -> int: + return _sm120.mxf4nvf4_full_tma_tx_bytes( + tile_m, + tile_n, + tile_k, + sf_vec_size, + ab_smem_format=ab_smem_format, + ) + + +__all__ = [ + "get_ab_tma_tx_bytes", + "get_full_tma_tx_bytes", + "get_scale_tma_tx_bytes", +] diff --git a/scripts/run_sm120_nvfp4_bench.sh b/scripts/run_sm120_nvfp4_bench.sh new file mode 100755 index 00000000..743eaced --- /dev/null +++ b/scripts/run_sm120_nvfp4_bench.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +export CUTE_DSL_ARCH="${CUTE_DSL_ARCH:-sm_120a}" + +python benchmarks/benchmark_gemm.py \ + --mnkl "${MNKL:-4096,4096,4096,1}" \ + --tile_shape_mnk "${TILE:-128,128,128}" \ + --cluster_shape_mnk 1,1,1 \ + --ab_dtype Float4E2M1FN \ + --sf_dtype Float8E4M3FN \ + --sf_vec_size 16 \ + --d_dtype BFloat16 \ + --sm120_nvfp4_path "${SM120_NVFP4_PATH:-validated}" \ + --warmup_iterations "${WARMUP:-5}" \ + --iterations "${ITERS:-10}" \ + --skip_ref_check diff --git a/tests/test_gemm_sm120_nvfp4_blockscaled.py b/tests/test_gemm_sm120_nvfp4_blockscaled.py new file mode 100644 index 00000000..ab07147a --- /dev/null +++ b/tests/test_gemm_sm120_nvfp4_blockscaled.py @@ -0,0 +1,168 @@ +import pytest +import torch + +import cutlass + +from quack.blockscaled_gemm_utils import blockscaled_gemm_reference +from quack.blockscaled_gemm_utils import compile_blockscaled_gemm_tvm_ffi +from quack.sm120_blockscaled_utils import ( + create_sm120_nvfp4_ab_tensor, + create_sm120_nvfp4_scale_tensor, + create_sm120_nvfp4_tensorfill_like_ab_tensor, + create_sm120_nvfp4_tensorfill_like_scale_tensor, + validate_sm120_nvfp4_scale_storage, +) + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + if torch.cuda.get_device_capability(0)[0] != 12: + pytest.skip("SM120 required") + + +def _make_d(m: int, n: int, l: int, dtype=torch.bfloat16) -> torch.Tensor: + return torch.empty((l, m, n), device="cuda", dtype=dtype).permute(1, 2, 0) + + +def _make_problem(m: int, n: int, k: int, l: int): + a_ref, a = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, m, k) + b_ref, b = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, n, k) + sfa_ref, sfa = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, m, k) + sfb_ref, sfb = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, n, k) + d = _make_d(m, n, l) + return a_ref, b_ref, sfa_ref, sfb_ref, a, b, d, sfa, sfb + + +def _compile_runner( + a, + b, + d, + sfa, + sfb, + *, + keep_ptx: bool = False, + sm120_nvfp4_path: str = "validated", + ab_dtype=cutlass.Float4E2M1FN, + sf_dtype=cutlass.Float8E4M3FN, + sf_vec_size: int = 16, + d_dtype=cutlass.BFloat16, + mma_tiler_mn=(128, 128), + cluster_shape_mn=(1, 1), + varlen_m: bool = False, + varlen_k: bool = False, +): + return compile_blockscaled_gemm_tvm_ffi( + ab_dtype, + sf_dtype, + sf_vec_size, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + a, + b, + d, + sfa, + sfb, + keep_ptx=keep_ptx, + sm120_nvfp4_path=sm120_nvfp4_path, + varlen_m=varlen_m, + varlen_k=varlen_k, + ) + + +@pytest.mark.parametrize("m,n,k,l", [(128, 128, 128, 1), (256, 256, 256, 1)]) +def test_sm120_nvfp4_validated_matches_tensorfill_like_reference(m, n, k, l): + _skip_if_not_sm120() + torch.manual_seed(20260525 + m + n + k) + a_ref, b_ref, sfa_ref, sfb_ref, a, b, d, sfa, sfb = _make_problem(m, n, k, l) + + assert torch.all(a_ref != 0) + assert torch.all(b_ref != 0) + assert torch.all(sfa_ref != 0) + assert torch.all(sfb_ref != 0) + + runner = _compile_runner(a, b, d, sfa, sfb) + d.zero_() + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + torch.testing.assert_close(d.float(), ref, atol=0.25, rtol=2e-2) + + +def test_sm120_nvfp4_fast_small_shape_runs_and_matches_reference(): + _skip_if_not_sm120() + torch.manual_seed(20260526) + m = n = k = 128 + l = 1 + a_ref, b_ref, sfa_ref, sfb_ref, a, b, d, sfa, sfb = _make_problem(m, n, k, l) + + runner = _compile_runner(a, b, d, sfa, sfb, sm120_nvfp4_path="fast") + d.zero_() + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + torch.testing.assert_close(d.float(), ref, atol=0.25, rtol=2e-2) + + +def test_sm120_nvfp4_validated_and_fast_ptx_paths(tmp_path, monkeypatch): + _skip_if_not_sm120() + monkeypatch.chdir(tmp_path) + m = n = k = 128 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + d = _make_d(m, n, l) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + + validated = _compile_runner(a, b, d, sfa, sfb, keep_ptx=True) + fast = _compile_runner(a, b, d, sfa, sfb, keep_ptx=True, sm120_nvfp4_path="fast") + validated_ptx = validated.compiled.__ptx__ + fast_ptx = fast.compiled.__ptx__ + assert validated_ptx is not None + assert fast_ptx is not None + + mma = "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X" + assert validated_ptx.count(mma) == 64 + assert fast_ptx.count(mma) == 64 + assert validated_ptx.count("cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier") == 2 + assert validated_ptx.count("cp.async.bulk.tensor.3d.shared::cta.global.tile.mbarrier") == 2 + assert fast_ptx.count("cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier") == 2 + assert fast_ptx.count("cp.async.bulk.tensor.3d.shared::cta.global.tile.mbarrier") == 2 + + assert validated_ptx.count("st.global.b") == 128 + assert fast_ptx.count("st.global.b") == 0 + assert "cp.async.bulk.tensor.3d.global.shared::cta" in fast_ptx + + +def test_sm120_nvfp4_rejects_unsupported_user_contracts(): + _skip_if_not_sm120() + m = n = k = 128 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + d = _make_d(m, n, l) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + + unsupported = "SM120 blockscaled GEMM currently supports only NVFP4|SM120 NVFP4" + with pytest.raises((RuntimeError, ValueError), match=unsupported): + _compile_runner(a, b, d, sfa, sfb, mma_tiler_mn=(64, 128)) + with pytest.raises((RuntimeError, ValueError), match=unsupported): + _compile_runner(a, b, d, sfa, sfb, cluster_shape_mn=(2, 1)) + with pytest.raises((RuntimeError, ValueError), match=unsupported): + _compile_runner(a, b, d, sfa, sfb, ab_dtype=cutlass.Float8E4M3FN) + with pytest.raises((RuntimeError, ValueError), match=unsupported): + _compile_runner(a, b, _make_d(m, n, l, torch.float16), sfa, sfb, d_dtype=cutlass.Float16) + with pytest.raises((RuntimeError, ValueError), match=unsupported): + _compile_runner(a, b, d, sfa, sfb, varlen_m=True) + + _logical_cols, _physical_cols, pages = validate_sm120_nvfp4_scale_storage( + sfa, logical_k=k, major_extent=m, batch_extent=l + ) + legacy_rank4_sfa = torch.empty((m, 16, pages, l), device="cuda", dtype=torch.float8_e4m3fn) + with pytest.raises(ValueError, match="compact 1D interleaved FP8"): + _compile_runner(a, b, d, legacy_rank4_sfa, sfb) diff --git a/tests/test_gemm_sm120_nvfp4_correctness.py b/tests/test_gemm_sm120_nvfp4_correctness.py new file mode 100644 index 00000000..d50b40f6 --- /dev/null +++ b/tests/test_gemm_sm120_nvfp4_correctness.py @@ -0,0 +1,207 @@ +import pytest +import torch + +import cutlass + +from quack.blockscaled_gemm_utils import blockscaled_gemm_reference +from quack.blockscaled_gemm_utils import compile_blockscaled_gemm_tvm_ffi +from quack.sm120_blockscaled_utils import ( + copy_sm120_nvfp4_scale_blocks_to_storage, + create_sm120_nvfp4_ab_tensor, + create_sm120_nvfp4_scale_tensor, + create_sm120_nvfp4_tensorfill_like_ab_tensor, + create_sm120_nvfp4_tensorfill_like_scale_tensor, +) + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + if torch.cuda.get_device_capability(0)[0] != 12: + pytest.skip("SM120 required") + + +def _make_d(m: int, n: int, l: int) -> torch.Tensor: + return torch.empty((l, m, n), device="cuda", dtype=torch.bfloat16).permute(1, 2, 0) + + +def _compile_runner(a, b, d, sfa, sfb): + return compile_blockscaled_gemm_tvm_ffi( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128), + (1, 1), + a, + b, + d, + sfa, + sfb, + ) + + +def _store_scale_blocks(storage, blocks, k): + copy_sm120_nvfp4_scale_blocks_to_storage(storage, blocks, logical_k=k) + + +def _expand_scale_blocks(blocks, k): + major, logical_cols, l = blocks.shape + return ( + blocks.permute(0, 2, 1) + .unsqueeze(-1) + .expand(major, l, logical_cols, 16) + .reshape(major, l, logical_cols * 16) + .permute(0, 2, 1) + )[:, :k, :] + + +def test_sm120_nvfp4_single_cta_uniform_and_k64_scale_split(): + _skip_if_not_sm120() + m = n = k = 128 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + d = _make_d(m, n, l) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + runner = _compile_runner(a, b, d, sfa, sfb) + + d.zero_() + _store_scale_blocks(sfa, torch.ones((m, k // 16, l), device="cuda"), k) + _store_scale_blocks(sfb, torch.ones((n, k // 16, l), device="cuda"), k) + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + torch.testing.assert_close(d.float(), torch.full_like(d.float(), 128.0)) + + d.zero_() + sfa_blocks = torch.ones((m, k // 16, l), device="cuda") + sfa_blocks[:, 4:8, :].fill_(2.0) + _store_scale_blocks(sfa, sfa_blocks, k) + _store_scale_blocks(sfb, torch.ones((n, k // 16, l), device="cuda"), k) + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + torch.testing.assert_close(d.float(), torch.full_like(d.float(), 192.0)) + assert not hasattr(runner, "descriptor_cache") + + +def test_sm120_nvfp4_k384_scale_page_crossing(): + _skip_if_not_sm120() + m = n = 128 + k = 384 + l = 2 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + d = _make_d(m, n, l) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + runner = _compile_runner(a, b, d, sfa, sfb) + + d.zero_() + sfa_blocks = torch.ones((m, k // 16, l), device="cuda") + sfa_blocks[:, 8:16, 1].fill_(2.0) + _store_scale_blocks(sfa, sfa_blocks, k) + _store_scale_blocks(sfb, torch.ones((n, k // 16, l), device="cuda"), k) + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + expected = torch.full_like(d.float(), 384.0) + expected[:, :, 1].fill_(512.0) + torch.testing.assert_close(d.float(), expected) + + +def test_sm120_nvfp4_nonzero_multi_tile_scale_layout_matches_reference(): + _skip_if_not_sm120() + torch.manual_seed(120) + m = n = 256 + k = 256 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + a_ref = torch.ones((m, k, l), device="cuda") + b_ref = torch.ones((n, k, l), device="cuda") + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + sfa_blocks = torch.ones((m, k // 16, l), device="cuda") + sfb_blocks = torch.ones((n, k // 16, l), device="cuda") + sfa_blocks[128:, 0:8, :].fill_(2.0) + sfa_blocks[:, 8:16, :].fill_(3.0) + sfb_blocks[128:, 0:8, :].fill_(2.0) + sfb_blocks[:, 8:16, :].fill_(4.0) + _store_scale_blocks(sfa, sfa_blocks, k) + _store_scale_blocks(sfb, sfb_blocks, k) + sfa_ref = _expand_scale_blocks(sfa_blocks, k) + sfb_ref = _expand_scale_blocks(sfb_blocks, k) + d = _make_d(m, n, l) + + assert torch.all(a_ref != 0) + assert torch.all(b_ref != 0) + assert torch.all(sfa_ref != 0) + assert torch.all(sfb_ref != 0) + + runner = _compile_runner(a, b, d, sfa, sfb) + d.zero_() + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + torch.testing.assert_close(d.float(), ref.to(torch.bfloat16).float()) + + +def test_sm120_nvfp4_row_random_nonzero_multi_tile_matches_reference(): + _skip_if_not_sm120() + torch.manual_seed(20260522) + m = n = 256 + k = 256 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + a_ref = torch.ones((m, k, l), device="cuda") + b_ref = torch.ones((n, k, l), device="cuda") + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + sfa_blocks = torch.randint(1, 4, (m, k // 16, l), device="cuda").float() + sfb_blocks = torch.randint(1, 4, (n, k // 16, l), device="cuda").float() + _store_scale_blocks(sfa, sfa_blocks, k) + _store_scale_blocks(sfb, sfb_blocks, k) + sfa_ref = _expand_scale_blocks(sfa_blocks, k) + sfb_ref = _expand_scale_blocks(sfb_blocks, k) + d = _make_d(m, n, l) + + assert torch.all(a_ref != 0) + assert torch.all(b_ref != 0) + assert torch.all(sfa_ref != 0) + assert torch.all(sfb_ref != 0) + + runner = _compile_runner(a, b, d, sfa, sfb) + d.zero_() + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + torch.testing.assert_close(d.float(), ref.to(torch.bfloat16).float()) + + +def test_sm120_nvfp4_tensorfill_like_6x6_tiles_matches_reference(): + _skip_if_not_sm120() + torch.manual_seed(20260524) + m = n = k = 768 + l = 1 + a_ref, a = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, m, k) + b_ref, b = create_sm120_nvfp4_tensorfill_like_ab_tensor(l, n, k) + sfa_ref, sfa = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, m, k) + sfb_ref, sfb = create_sm120_nvfp4_tensorfill_like_scale_tensor(l, n, k) + d = _make_d(m, n, l) + + assert torch.all(a_ref != 0) + assert torch.all(b_ref != 0) + assert torch.all(sfa_ref != 0) + assert torch.all(sfb_ref != 0) + + runner = _compile_runner(a, b, d, sfa, sfb) + d.zero_() + runner(a, b, d, sfa, sfb) + torch.cuda.synchronize() + + ref = blockscaled_gemm_reference(a_ref, b_ref, sfa_ref, sfb_ref) + torch.testing.assert_close(d.float(), ref, atol=0.25, rtol=2e-2) diff --git a/tests/test_gemm_sm120_nvfp4_ptx.py b/tests/test_gemm_sm120_nvfp4_ptx.py new file mode 100644 index 00000000..e1e32743 --- /dev/null +++ b/tests/test_gemm_sm120_nvfp4_ptx.py @@ -0,0 +1,58 @@ +import pytest +import torch + +import cutlass + +from quack.blockscaled_gemm_utils import compile_blockscaled_gemm_tvm_ffi +from quack.sm120_blockscaled_utils import create_sm120_nvfp4_ab_tensor +from quack.sm120_blockscaled_utils import create_sm120_nvfp4_scale_tensor + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + if torch.cuda.get_device_capability(0)[0] != 12: + pytest.skip("SM120 required") + + +def test_sm120_nvfp4_ptx_contains_compact_tma_mainloop(tmp_path, monkeypatch): + _skip_if_not_sm120() + monkeypatch.chdir(tmp_path) + m = n = k = 128 + l = 1 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + b = create_sm120_nvfp4_ab_tensor(l, n, k, fill_byte=0x22) + d = torch.empty((l, m, n), device="cuda", dtype=torch.bfloat16).permute(1, 2, 0) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + _, sfb = create_sm120_nvfp4_scale_tensor(l, n, k) + + runner = compile_blockscaled_gemm_tvm_ffi( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128), + (1, 1), + a, + b, + d, + sfa, + sfb, + keep_ptx=True, + ) + ptx = runner.compiled.__ptx__ + assert ptx is not None + + assert ptx.count("cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier") == 2 + assert ptx.count("cp.async.bulk.tensor.3d.shared::cta.global.tile.mbarrier") == 2 + assert ptx.count("cp.async.bulk.tensor.3d.global.shared::cta.tile") == 0 + assert ptx.count("st.global.b") == 128 + assert ptx.count("ldmatrix.sync.aligned.m8n8.x4.shared.b16") == 16 + assert ( + ptx.count("mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X") + == 64 + ) + assert "tcgen05" not in ptx + assert "shared::cluster" not in ptx + assert ".multicast" not in ptx + assert "fence.proxy.tensormap::generic.acquire.gpu" not in ptx diff --git a/tests/test_gemm_sm120_nvfp4_validation.py b/tests/test_gemm_sm120_nvfp4_validation.py new file mode 100644 index 00000000..42e5a4d8 --- /dev/null +++ b/tests/test_gemm_sm120_nvfp4_validation.py @@ -0,0 +1,147 @@ +from pathlib import Path + +import pytest +import torch + +import cutlass + +from quack.gemm_sm120 import GemmSm120 +from quack.sm120_blockscaled_utils import ( + create_sm120_nvfp4_ab_tensor, + create_sm120_nvfp4_scale_tensor, + validate_sm120_nvfp4_ab_storage, + validate_sm120_nvfp4_d_storage, + validate_sm120_nvfp4_scale_storage, +) + + +def _skip_if_not_sm120(): + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + if torch.cuda.get_device_capability(0)[0] != 12: + pytest.skip("SM120 required") + + +def test_sm120_nvfp4_facade_and_config_validation(): + import quack.sm120_utils as sm120_utils + + assert sm120_utils.get_ab_tma_tx_bytes() == 8192 + assert sm120_utils.get_scale_tma_tx_bytes() == 1024 + assert sm120_utils.get_full_tma_tx_bytes() == 18432 + + valid = ( + cutlass.Float4E2M1FN, + cutlass.Float8E4M3FN, + 16, + cutlass.BFloat16, + (128, 128, 128), + (1, 1), + 128, + 128, + 128, + 1, + "k", + "k", + "n", + ) + assert GemmSm120.can_implement_blockscaled(*valid) + assert not GemmSm120.can_implement_blockscaled(*valid[:4], (64, 64, 128), *valid[5:]) + assert not GemmSm120.can_implement_blockscaled(*valid[:4], (128, 64, 128), *valid[5:]) + assert not GemmSm120.can_implement_blockscaled(*valid[:5], (2, 1), *valid[6:]) + assert not GemmSm120.can_implement_blockscaled(*valid[:10], "m", *valid[11:]) + assert not GemmSm120.can_implement_blockscaled(*valid[:11], "n", *valid[12:]) + assert not GemmSm120.can_implement_blockscaled(*valid[:12], "m") + + with pytest.raises(ValueError, match="\\(128,128,128\\)"): + GemmSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (64, 64, 128), + (1, 1, 1), + pingpong=False, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + + +def test_sm120_dense_pingpong_constructor_still_works(): + GemmSm120( + cutlass.Float32, + cutlass.BFloat16, + (128, 128, 64), + (1, 1, 1), + pingpong=True, + ) + + +def test_sm120_nvfp4_path_policy_selects_scheduler_and_epilogue(): + validated = GemmSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 128), + (1, 1, 1), + pingpong=True, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + ) + assert validated.direct_global_store + assert validated.direct_cute_static_scheduler + + fast = GemmSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 128), + (1, 1, 1), + pingpong=True, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + sm120_nvfp4_path="fast", + ) + assert not fast.direct_global_store + assert not fast.direct_cute_static_scheduler + + with pytest.raises(ValueError, match="validated.*fast"): + GemmSm120( + cutlass.Float32, + cutlass.Float4E2M1FN, + (128, 128, 128), + (1, 1, 1), + pingpong=True, + sf_vec_size=16, + sf_dtype=cutlass.Float8E4M3FN, + sm120_nvfp4_path="unknown", + ) + + +def test_sm120_nvfp4_source_has_no_experimental_env_matrix(): + source = (Path(__file__).parents[1] / "quack/gemm_sm120.py").read_text() + + assert "os.environ" not in source + assert "QUACK_SM120_NVFP4" not in source + assert "blockscaled_kernel_legacy" not in source + assert "get_native_tma_desc_addr" not in source + + +def test_sm120_nvfp4_storage_validation(): + _skip_if_not_sm120() + m = n = k = 128 + l = 2 + a = create_sm120_nvfp4_ab_tensor(l, m, k, fill_byte=0x22) + d = torch.empty((l, m, n), device="cuda", dtype=torch.bfloat16).permute(1, 2, 0) + _, sfa = create_sm120_nvfp4_scale_tensor(l, m, k) + + validate_sm120_nvfp4_ab_storage(a, logical_k=k, major_extent=m, batch_extent=l) + validate_sm120_nvfp4_d_storage(d, m=m, n=n, l=l) + _logical_cols, _physical_cols, pages = validate_sm120_nvfp4_scale_storage( + sfa, logical_k=k, major_extent=m, batch_extent=l + ) + + with pytest.raises(ValueError, match="shape"): + validate_sm120_nvfp4_ab_storage( + a.transpose(0, 1), logical_k=k, major_extent=m, batch_extent=l + ) + legacy_physical_scale = torch.empty((m, 16, pages, l), device="cuda", dtype=torch.float8_e4m3fn) + with pytest.raises(ValueError, match="compact 1D interleaved FP8"): + validate_sm120_nvfp4_scale_storage( + legacy_physical_scale, logical_k=k, major_extent=m, batch_extent=l + )