Add Sm120 blockscaled FP4 GEMM path#127
Draft
alecco wants to merge 2 commits into
Draft
Conversation
added 2 commits
April 30, 2026 15:22
Add the first guarded SM120 blockscaled GEMM path to the existing blockscaled GEMM interface instead of adding a separate SM120-specific frontend. The new path is intentionally narrow and mirrors the SM100 blockscaled entry points where the existing abstractions fit. Supported SM120 scope in this commit: - A/B are Float4E2M1FN FP4 operands stored as packed torch.float4_e2m1fn_x2. - Scale tensors are byte-sized FP8 scale factors: - NVFP4: Float8E4M3FN with sf_vec_size=16. - MXFP4: Float8E8M0FNU with sf_vec_size=32. - Accumulation is Float32 and D is BFloat16. - M and N must be multiples of 128, K must be a multiple of 64, L must be 1. - tile_shape_mnk is fixed to 128x128x64 and cluster_shape_mnk is fixed to 1x1x1. - C/beta, varlen, gather_A, sparse/grouped kernels, multicast clusters, output FP4 quantization, and generic autodispatch remain unsupported. The SM120 scale layout is deliberately different from the compact SM100-style scale layout. CuTeDSL TMA copies for compact row-major scale pages such as (M, 4) trap on SM120 because the row stride is smaller than the TMA-friendly 16-byte granularity. The SM120 path therefore requires padded physical scale pages: logical_scale_cols = ceil(K / sf_vec_size) physical_scale_cols = round_up(logical_scale_cols, 16) SFA shape = (M, physical_scale_cols, 1) SFB shape = (N, physical_scale_cols, 1) Only logical columns are consumed; padding columns are ignored. The helper used by tests and benchmarks now validates K divisibility, sf_vec_size, and the matching scale dtype so invalid tensors are rejected before they can reach TMA. The compile-time entry also rejects compact SM120 scale tensors before launch. The SM120 kernel uses non-multicast TMA to stage packed A/B bytes and padded SFA/SFB pages, expands packed FP4 bytes into the padded Int8 shared-memory shape required by SM120 FP4 ldmatrix helpers, and issues the CUTLASS DSL tuple-MMA blockscaled path. It keeps the proven selector-zero scale packet mapping: SFA provider rows are group + 8 * (tid & 1), and SFB provider columns are group. A/B FP4 ldmatrix still uses the local SM120 helpers instead of generic cute.copy because source and destination element widths differ. For K > 64, this commit keeps Float32 accumulation across K64 tiles for each 16x8 atom. Intermediate K64 partial sums are written to an FP32 shared-memory scratch tile, later K64 partials add that FP32 value back into registers, and D is converted to BF16 only once on the final K64 tile. This avoids the previous BF16-chained partial accumulation behavior and preserves the advertised Float32 accumulation semantics for the supported K-multiple cases. The public class-level blockscaled call now runs the SM120 validation path before launching the JIT kernel. The TVM/FFI compile helper still calls a JIT-only internal entry because it creates logical fake CuTe tensors over packed torch storage before compilation; the host validation path documents and rejects raw packed-K class calls with a clear error. Tests are consolidated under tests/test_gemm_blockscaled.py. The SM120 coverage checks: - supported and rejected can_implement_blockscaled cases; - direct class-call validation for missing scales, C/beta, and packed-vs-logical K misuse; - padded scale layout sizing for K=64, 128, 256, 384 and MXFP4 examples; - scale-helper validation for K, sf_vec_size, and scale dtype; - compact scale tensors rejected before launch; - SM120 scale-sensitive runtime correctness for K=64, multi-CTA K=128, and K=320 crossing a 16-column scale page; - a K=384 regression that would fail if K64 partial sums were rounded through BF16 between tiles; - non-constant FP4 K-lane patterns for both A and B with padded-scale poison present beyond the logical scale columns. The benchmark smoke path can create the padded SM120 scale tensors and launch the explicit SM120 blockscaled path, but numerical correctness is covered by the pytest references rather than benchmark timing output.
Update the benchmark module examples to match the current CLI. The blockscaled path is selected by --sf_dtype and/or --sf_vec_size, so remove the stale --blockscaled flag and other old example-only flags that are not parsed by the script.\n\nUse BF16 output in the FP4 SM120 examples because that is the supported SM120 benchmark path. The benchmark remains a launch/timing utility; correctness is checked against the local reference, and cuBLAS comparison remains skipped for SM120 padded row-major scale tensors.
d2272ee to
09dbc28
Compare
Contributor
Author
|
Fixed benchmark BF16 checks and docstring. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Requires NVIDIA/cutlass#3185
Author: Alecco (& Codex) for Ologan
Summary
This PR adds the first guarded SM120 blockscaled GEMM path using the existing blockscaled GEMM interface, rather than adding a separate SM120-specific frontend.
This is a first implementation not optimized. Optimization will be done in a subsequent PR.
Supported SM120 scope:
Unsupported for this first path:
Implementation Notes
SM120 uses a padded row-major physical scale layout:
logical_scale_cols = ceil(K / sf_vec_size)
physical_scale_cols = round_up(logical_scale_cols, 16)
SFA shape = (M, physical_scale_cols, 1)
SFB shape = (N, physical_scale_cols, 1)
Compact scale rows such as (M, 4) are rejected before launch because SM120 TMA on compact row-major scale pages with row stride below 16 bytes traps. Padding columns are
ignored by the kernel.
For K > 64, K64 partials are accumulated through an FP32 shared-memory scratch tile and D is converted to BF16 only once on the final K tile. This preserves FP32
accumulation semantics instead of rounding through BF16 between K tiles.
The SM120 FFI compile path still calls the JIT-only blockscaled entry directly, after host-side validation, because it compiles logical fake CuTe tensor views over packed
torch storage. The class-level blockscaled_call keeps host validation for logical class-call tensors.
The benchmark path can launch the SM120 blockscaled kernel and build padded scale tensors, but cuBLAS comparison is skipped because the benchmark currently builds QuACK’s
padded row-major scale layout, not the cuBLAS/PyTorch scaled_mm scale layout.
Tests
Validated with:
python -m compileall quack tests benchmarks
ruff check quack tests benchmarks
CUTE_DSL_ARCH=sm_120a pytest -q tests/test_gemm_blockscaled.py -s -rs
pytest -q 'tests/test_linear.py::test_gemm[960-736-1504-input_dtype0-k-k-m-False]' -s -rs
Current results on SM120:
The SM120 tests cover:
Notes
This PR does not claim full arbitrary FP4 operand-pattern coverage yet. Broader row-and-K-varying FP4 correctness, cuBLAS-compatible SM120 scale conversion, and additional
blockscaled features should be follow-up work.
Agent disclosure
This work was created with OpenAI Codex CLI agent.