SM120 blockscaled: add opt-in packed-LDSM performance path (2.3 to 103.4 TFLOP/s)#128
Draft
alecco wants to merge 4 commits into
Draft
SM120 blockscaled: add opt-in packed-LDSM performance path (2.3 to 103.4 TFLOP/s)#128alecco wants to merge 4 commits into
alecco wants to merge 4 commits into
Conversation
added 2 commits
April 30, 2026 15:22
Add the first guarded SM120 blockscaled GEMM path to the existing blockscaled GEMM interface instead of adding a separate SM120-specific frontend. The new path is intentionally narrow and mirrors the SM100 blockscaled entry points where the existing abstractions fit. Supported SM120 scope in this commit: - A/B are Float4E2M1FN FP4 operands stored as packed torch.float4_e2m1fn_x2. - Scale tensors are byte-sized FP8 scale factors: - NVFP4: Float8E4M3FN with sf_vec_size=16. - MXFP4: Float8E8M0FNU with sf_vec_size=32. - Accumulation is Float32 and D is BFloat16. - M and N must be multiples of 128, K must be a multiple of 64, L must be 1. - tile_shape_mnk is fixed to 128x128x64 and cluster_shape_mnk is fixed to 1x1x1. - C/beta, varlen, gather_A, sparse/grouped kernels, multicast clusters, output FP4 quantization, and generic autodispatch remain unsupported. The SM120 scale layout is deliberately different from the compact SM100-style scale layout. CuTeDSL TMA copies for compact row-major scale pages such as (M, 4) trap on SM120 because the row stride is smaller than the TMA-friendly 16-byte granularity. The SM120 path therefore requires padded physical scale pages: logical_scale_cols = ceil(K / sf_vec_size) physical_scale_cols = round_up(logical_scale_cols, 16) SFA shape = (M, physical_scale_cols, 1) SFB shape = (N, physical_scale_cols, 1) Only logical columns are consumed; padding columns are ignored. The helper used by tests and benchmarks now validates K divisibility, sf_vec_size, and the matching scale dtype so invalid tensors are rejected before they can reach TMA. The compile-time entry also rejects compact SM120 scale tensors before launch. The SM120 kernel uses non-multicast TMA to stage packed A/B bytes and padded SFA/SFB pages, expands packed FP4 bytes into the padded Int8 shared-memory shape required by SM120 FP4 ldmatrix helpers, and issues the CUTLASS DSL tuple-MMA blockscaled path. It keeps the proven selector-zero scale packet mapping: SFA provider rows are group + 8 * (tid & 1), and SFB provider columns are group. A/B FP4 ldmatrix still uses the local SM120 helpers instead of generic cute.copy because source and destination element widths differ. For K > 64, this commit keeps Float32 accumulation across K64 tiles for each 16x8 atom. Intermediate K64 partial sums are written to an FP32 shared-memory scratch tile, later K64 partials add that FP32 value back into registers, and D is converted to BF16 only once on the final K64 tile. This avoids the previous BF16-chained partial accumulation behavior and preserves the advertised Float32 accumulation semantics for the supported K-multiple cases. The public class-level blockscaled call now runs the SM120 validation path before launching the JIT kernel. The TVM/FFI compile helper still calls a JIT-only internal entry because it creates logical fake CuTe tensors over packed torch storage before compilation; the host validation path documents and rejects raw packed-K class calls with a clear error. Tests are consolidated under tests/test_gemm_blockscaled.py. The SM120 coverage checks: - supported and rejected can_implement_blockscaled cases; - direct class-call validation for missing scales, C/beta, and packed-vs-logical K misuse; - padded scale layout sizing for K=64, 128, 256, 384 and MXFP4 examples; - scale-helper validation for K, sf_vec_size, and scale dtype; - compact scale tensors rejected before launch; - SM120 scale-sensitive runtime correctness for K=64, multi-CTA K=128, and K=320 crossing a 16-column scale page; - a K=384 regression that would fail if K64 partial sums were rounded through BF16 between tiles; - non-constant FP4 K-lane patterns for both A and B with padded-scale poison present beyond the logical scale columns. The benchmark smoke path can create the padded SM120 scale tensors and launch the explicit SM120 blockscaled path, but numerical correctness is covered by the pytest references rather than benchmark timing output.
Update the benchmark module examples to match the current CLI. The blockscaled path is selected by --sf_dtype and/or --sf_vec_size, so remove the stale --blockscaled flag and other old example-only flags that are not parsed by the script.\n\nUse BF16 output in the FP4 SM120 examples because that is the supported SM120 benchmark path. The benchmark remains a launch/timing utility; correctness is checked against the local reference, and cuBLAS comparison remains skipped for SM120 padded row-major scale tensors.
added 2 commits
May 1, 2026 16:04
Add an opt-in SM120 blockscaled path that uses packed subbyte shared-memory fragments with ordinary m8n8.x4 ldmatrix instead of the original byte-expanded b4x16_p64 unpack path. The new path is gated by QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 so the existing correctness-first path remains available while the packed path is reviewed and tuned. The direction follows the local CUTLASS GeForce reference in examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu: CUTLASS uses a packed subbyte shared-memory consumer with m8n8 ldmatrix for the SM120 NVFP4 path rather than QuACK's earlier padded b4x16_p64 unpack route. This commit ports the relevant consumer-side idea while deliberately keeping QuACK's narrower, already-proven per-atom TMA producer instead of importing the full CUTLASS mainloop. The implementation keeps the scope narrow: Float4E2M1FN A/B, Float8E4M3FN NVFP4 or Float8E8M0FNU MXFP4 scales, BF16 D, C None, beta 0, cluster 1x1x1, and tile-aligned SM120 shapes. can_implement_blockscaled and direct class-call validation now share the same advertised tile contract: the existing 128x128x64 correctness path, 64x64x64, and 64x64x128 only when the packed path is explicitly enabled. Other mixed 64/128 tile shapes are rejected before launch. A globally set QUACK_SM120_BLOCKSCALED_PACKED_LDSM no longer hijacks the existing 128x128x64 fallback path. Packed mode is only activated for the two packed-supported tiles, so the env var can be left set while callers still compile or validate the correctness-first 128x128x64 path. For 64x64 tiles, four consumer warps each own one 16-row band and keep eight 16x8 FP32 accumulators live across all K tiles before storing BF16 once. This preserves Float32 accumulation semantics without the generic sAcc scratch traffic used by the wider fallback path. The 64x64x128 path also avoids allocating the full sAcc scratch because it uses the same register-resident accumulator path. This PR intentionally ships only one runtime knob: QUACK_SM120_BLOCKSCALED_PACKED_LDSM. Earlier profiling-only controls for stage count, consumer-warp count, and sync mode are omitted from the production PR surface so reviewers do not have to reason about untested scheduling variants. Full-tile and grouped TMA are deliberately not included here. Local experiments showed the current CuTe DSL subbyte/swizzled full-tile TMA layout either fails legalization, degenerates into many tiny TMA sites, or times out for nested raw FP4 layouts. This commit keeps the production path on the proven per-atom TMA mechanism and uses tile_K=128 to amortize producer/barrier overhead while leaving the full-tile layout issue for a separate repro/upstream track. Tests cover capability-contract negatives for unsupported K64 and K128 tile shapes, direct class-call rejection of unadvertised shapes, global packed-env compatibility with the 128x128x64 fallback path, packed K64/K128 NVFP4 correctness, packed K64/K128 MXFP4 correctness, asymmetric FP4 data, scale-page crossing with poisoned scale padding, a K128 scale-offset regression where the first and second K64 halves use different scales, pre-launch rejection for tile_K=128 without the packed path, and PTX regression checks requiring m8n8.x4.shared.b16 plus m16n8k64 mxf4nvf4 while rejecting b4x16_p64, m8n16, multicast, and shared::cluster.
Document the narrow performance path added for SM120 blockscaled GEMM: opt-in packed LDSM, supported FP4 blockscaled formats, the recommended correctness gate, the benchmark command, and the current 64x64x128 tile target. The note records the reviewer-relevant experiment outcome without carrying experimental code into the PR. The previous correctness-first path was useful for proving tuple MMA and scale handling, but it used the b4x16_p64 unpack route and was shared-load bound. This PR follows the CUTLASS 79a GeForce NVFP4 example's packed LDSM direction while keeping QuACK's smaller per-atom TMA producer. Full-tile/grouped TMA remains deliberately out of scope because local CuTe DSL layout-lowering experiments either failed legalization, generated many tiny TMA sites, hung, or timed out. That work belongs in a separate minimal repro/upstream track. Update the benchmark module docstring with the SM120 packed-LDSM NVFP4 command and make the benchmark CLI report expected configuration errors without a Python traceback. The benchmark wording stays conservative: numbers are local RTX 5060 workstation timing signals with reference checking skipped, while numerical validation is covered by pytest.
bcf6ab7 to
a1696b5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Requires NVIDIA/cutlass#3185
Author: Alecco (& Codex) for Ologan
Summary
This PR builds on top of #127 and adds the first SM120 blockscaled performance path for FP4 GEMM.
The previous PR established the correctness-first SM120 blockscaled path:
This PR keeps that scope narrow, but replaces the slow correctness-first A/B shared-memory consumer path with an opt-in packed-LDSM path.
What Changes
Adds an opt-in SM120 packed-LDSM path behind:
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1
Supported scope:
For tile_K=128, logical K must be divisible by 128 and QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 must be set.
A globally set QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 does not force packed mode for every SM120 blockscaled tile. The existing 128x128x64 path still uses the correctness-
first fallback path.
The only new runtime knob in this PR is:
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1
No stage-count, consumer-warp-count, or sync-mode tuning knobs are exposed.
Why
The correctness-first path expanded compact FP4 A/B bytes into the padded SM120 .b4x16_p64 ldmatrix shared-memory format. That path was useful for proving correctness, but
it was very slow.
This PR adds the packed shared-memory path instead:
ldmatrix.sync.aligned.m8n8.x4.shared.b16
mma.sync.aligned.m16n8k64.kind::mxf4nvf4
The direction follows the local CUTLASS GeForce reference:
examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu
The new tests assert that the packed path does not regress back to:
For 64x64 tiles, four consumer warps each own one 16-row band and keep eight 16x8 FP32 accumulators live across K before storing BF16 once. This preserves FP32
accumulation semantics while avoiding the generic FP32 shared-memory scratch path.
Why 64x64x128
The packed 64x64x64 path removed the original shared-load bottleneck, but it was still producer/barrier heavy. 64x64x128 keeps the same accumulator ownership and
correctness surface while doubling K work per producer/barrier cycle.
Benchmark runs on an RTX 5060 workstation, with reference checking skipped because pytest owns numerical validation:
base sm120-blockscaled, correctness-first path:
4096x4096x4096 NVFP4 -> BF16, 128x128x64: 60.571 ms, 2.3 TFLOP/s
this PR, packed-LDSM path:
4096x4096x4096 NVFP4 -> BF16, 64x64x64: 2.988 ms, 46.0 TFLOP/s
4096x4096x4096 NVFP4 -> BF16, 64x64x128: 1.329 ms, 103.4 TFLOP/s
These numbers are meant as a local performance signal, not a stable benchmark guarantee.
What Is Deliberately Not Included
This PR does not add full-tile or grouped A/B TMA into the final packed/swizzled shared-memory layout.
That route was investigated separately, but the current CuTe DSL layout-lowering path is not clean enough for production:
So this PR keeps the proven per-atom TMA mechanism and uses tile_K=128 as the low-risk amortization step. Full/grouped TMA should be handled as a separate minimal repro /
upstream issue track.
Tests
Ran:
python -m compileall
quack/gemm_sm120.py
quack/blockscaled_gemm_utils.py
tests/test_gemm_blockscaled.py
benchmarks/benchmark_gemm.py
ruff check
quack/gemm_sm120.py
quack/blockscaled_gemm_utils.py
tests/test_gemm_blockscaled.py
benchmarks/benchmark_gemm.py
CUTE_DSL_ARCH=sm_120a pytest -q
tests/test_gemm_blockscaled.py
-k sm120_blockscaled
-n 16
-s -rs
Result:
30 passed
Also ran the documented packed-path selector:
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a pytest -q
tests/test_gemm_blockscaled.py
-k "sm120 and packed_ldsm"
-n 16
-s -rs
Result:
10 passed
Also ran focused regression tests for the global-env fallback and PTX helper path:
CUTE_DSL_ARCH=sm_120a pytest -q
tests/test_gemm_blockscaled.py::test_sm120_blockscaled_packed_env_keeps_128x128x64_class_call
tests/test_gemm_blockscaled.py::test_sm120_blockscaled_packed_ldsm_ptx_regression
-s -rs
Result:
2 passed
Small reference-checked benchmark smoke:
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a
python benchmarks/benchmark_gemm.py
--mnkl 1024,1024,1024,1
--tile_shape_mnk 64,64,128
--cluster_shape_mnk 1,1,1
--ab_dtype Float4E2M1FN
--sf_dtype Float8E4M3FN
--sf_vec_size 16
--d_dtype BFloat16
--warmup_iterations 2
--iterations 5
Result:
Ref check PASSED
PASS
Large benchmark command used for the timing snapshot:
QUACK_SM120_BLOCKSCALED_PACKED_LDSM=1 CUTE_DSL_ARCH=sm_120a
python benchmarks/benchmark_gemm.py
--mnkl 4096,4096,4096,1
--tile_shape_mnk 64,64,128
--cluster_shape_mnk 1,1,1
--ab_dtype Float4E2M1FN
--sf_dtype Float8E4M3FN
--sf_vec_size 16
--d_dtype BFloat16
--warmup_iterations 5
--iterations 30
--skip_ref_check
Agent disclosure
This work was created with OpenAI Codex CLI agent.