Skip to content

Add Sm120 blockscaled FP4 GEMM path#127

Draft
alecco wants to merge 2 commits into
Dao-AILab:mainfrom
alecco:sm120-blockscaled
Draft

Add Sm120 blockscaled FP4 GEMM path#127
alecco wants to merge 2 commits into
Dao-AILab:mainfrom
alecco:sm120-blockscaled

Conversation

@alecco
Copy link
Copy Markdown
Contributor

@alecco alecco commented Apr 30, 2026

Requires NVIDIA/cutlass#3185

Author: Alecco (& Codex) for Ologan

Summary

This PR adds the first guarded SM120 blockscaled GEMM path using the existing blockscaled GEMM interface, rather than adding a separate SM120-specific frontend.

This is a first implementation not optimized. Optimization will be done in a subsequent PR.

Supported SM120 scope:

  • A/B: Float4E2M1FN FP4 operands, packed as torch.float4_e2m1fn_x2
  • NVFP4 scales: Float8E4M3FN, sf_vec_size=16
  • MXFP4 scales: Float8E8M0FNU, sf_vec_size=32
  • Accumulator: Float32
  • Output: BFloat16
  • Shapes: M % 128 == 0, N % 128 == 0, K % 64 == 0, L == 1
  • Tile/cluster: tile_shape_mnk=(128,128,64), cluster_shape_mnk=(1,1,1)

Unsupported for this first path:

  • C / beta
  • varlen
  • gather_A
  • grouped/sparse kernels
  • multicast clusters
  • output FP4 quantization
  • generic autodispatch
  • SM120 FP8/INT8 blockscaled paths

Implementation Notes

SM120 uses a padded row-major physical scale layout:

logical_scale_cols = ceil(K / sf_vec_size)
physical_scale_cols = round_up(logical_scale_cols, 16)

SFA shape = (M, physical_scale_cols, 1)
SFB shape = (N, physical_scale_cols, 1)

Compact scale rows such as (M, 4) are rejected before launch because SM120 TMA on compact row-major scale pages with row stride below 16 bytes traps. Padding columns are
ignored by the kernel.

For K > 64, K64 partials are accumulated through an FP32 shared-memory scratch tile and D is converted to BF16 only once on the final K tile. This preserves FP32
accumulation semantics instead of rounding through BF16 between K tiles.

The SM120 FFI compile path still calls the JIT-only blockscaled entry directly, after host-side validation, because it compiles logical fake CuTe tensor views over packed
torch storage. The class-level blockscaled_call keeps host validation for logical class-call tensors.

The benchmark path can launch the SM120 blockscaled kernel and build padded scale tensors, but cuBLAS comparison is skipped because the benchmark currently builds QuACK’s
padded row-major scale layout, not the cuBLAS/PyTorch scaled_mm scale layout.

Tests

Validated with:

python -m compileall quack tests benchmarks
ruff check quack tests benchmarks
CUTE_DSL_ARCH=sm_120a pytest -q tests/test_gemm_blockscaled.py -s -rs
pytest -q 'tests/test_linear.py::test_gemm[960-736-1504-input_dtype0-k-k-m-False]' -s -rs

Current results on SM120:

  • tests/test_gemm_blockscaled.py: 22 passed, 52 skipped
  • Dense GEMM smoke: 1 passed

The SM120 tests cover:

  • supported/rejected can_implement_blockscaled cases
  • padded scale layout validation
  • compact scale rejection
  • direct class-call validation
  • scale-sensitive runtime correctness
  • K=320 scale-page crossing
  • K=384 regression for FP32 accumulation across K64 tiles
  • non-constant FP4 K-lane patterns for A/B with scale padding poison present

Notes

This PR does not claim full arbitrary FP4 operand-pattern coverage yet. Broader row-and-K-varying FP4 correctness, cuBLAS-compatible SM120 scale conversion, and additional
blockscaled features should be follow-up work.

Agent disclosure

This work was created with OpenAI Codex CLI agent.

@alecco alecco changed the title [Sm120] add blockscaled FP4 GEMM path Add Sm120 blockscaled FP4 GEMM path Apr 30, 2026
agent added 2 commits April 30, 2026 15:22
Add the first guarded SM120 blockscaled GEMM path to the existing blockscaled
GEMM interface instead of adding a separate SM120-specific frontend.  The new
path is intentionally narrow and mirrors the SM100 blockscaled entry points
where the existing abstractions fit.

Supported SM120 scope in this commit:
- A/B are Float4E2M1FN FP4 operands stored as packed torch.float4_e2m1fn_x2.
- Scale tensors are byte-sized FP8 scale factors:
  - NVFP4: Float8E4M3FN with sf_vec_size=16.
  - MXFP4: Float8E8M0FNU with sf_vec_size=32.
- Accumulation is Float32 and D is BFloat16.
- M and N must be multiples of 128, K must be a multiple of 64, L must be 1.
- tile_shape_mnk is fixed to 128x128x64 and cluster_shape_mnk is fixed to 1x1x1.
- C/beta, varlen, gather_A, sparse/grouped kernels, multicast clusters, output
  FP4 quantization, and generic autodispatch remain unsupported.

The SM120 scale layout is deliberately different from the compact SM100-style
scale layout.  CuTeDSL TMA copies for compact row-major scale pages such as
(M, 4) trap on SM120 because the row stride is smaller than the TMA-friendly
16-byte granularity.  The SM120 path therefore requires padded physical scale
pages:

  logical_scale_cols = ceil(K / sf_vec_size)
  physical_scale_cols = round_up(logical_scale_cols, 16)
  SFA shape = (M, physical_scale_cols, 1)
  SFB shape = (N, physical_scale_cols, 1)

Only logical columns are consumed; padding columns are ignored.  The helper used
by tests and benchmarks now validates K divisibility, sf_vec_size, and the
matching scale dtype so invalid tensors are rejected before they can reach TMA.
The compile-time entry also rejects compact SM120 scale tensors before launch.

The SM120 kernel uses non-multicast TMA to stage packed A/B bytes and padded
SFA/SFB pages, expands packed FP4 bytes into the padded Int8 shared-memory shape
required by SM120 FP4 ldmatrix helpers, and issues the CUTLASS DSL tuple-MMA
blockscaled path.  It keeps the proven selector-zero scale packet mapping:
SFA provider rows are group + 8 * (tid & 1), and SFB provider columns are group.
A/B FP4 ldmatrix still uses the local SM120 helpers instead of generic cute.copy
because source and destination element widths differ.

For K > 64, this commit keeps Float32 accumulation across K64 tiles for each
16x8 atom.  Intermediate K64 partial sums are written to an FP32 shared-memory
scratch tile, later K64 partials add that FP32 value back into registers, and D
is converted to BF16 only once on the final K64 tile.  This avoids the previous
BF16-chained partial accumulation behavior and preserves the advertised Float32
accumulation semantics for the supported K-multiple cases.

The public class-level blockscaled call now runs the SM120 validation path before
launching the JIT kernel.  The TVM/FFI compile helper still calls a JIT-only
internal entry because it creates logical fake CuTe tensors over packed torch
storage before compilation; the host validation path documents and rejects raw
packed-K class calls with a clear error.

Tests are consolidated under tests/test_gemm_blockscaled.py.  The SM120 coverage
checks:
- supported and rejected can_implement_blockscaled cases;
- direct class-call validation for missing scales, C/beta, and packed-vs-logical
  K misuse;
- padded scale layout sizing for K=64, 128, 256, 384 and MXFP4 examples;
- scale-helper validation for K, sf_vec_size, and scale dtype;
- compact scale tensors rejected before launch;
- SM120 scale-sensitive runtime correctness for K=64, multi-CTA K=128, and
  K=320 crossing a 16-column scale page;
- a K=384 regression that would fail if K64 partial sums were rounded through
  BF16 between tiles;
- non-constant FP4 K-lane patterns for both A and B with padded-scale poison
  present beyond the logical scale columns.

The benchmark smoke path can create the padded SM120 scale tensors and launch the
explicit SM120 blockscaled path, but numerical correctness is covered by the
pytest references rather than benchmark timing output.
Update the benchmark module examples to match the current CLI. The blockscaled path is selected by --sf_dtype and/or --sf_vec_size, so remove the stale --blockscaled flag and other old example-only flags that are not parsed by the script.\n\nUse BF16 output in the FP4 SM120 examples because that is the supported SM120 benchmark path. The benchmark remains a launch/timing utility; correctness is checked against the local reference, and cuBLAS comparison remains skipped for SM120 padded row-major scale tensors.
@alecco
Copy link
Copy Markdown
Contributor Author

alecco commented Apr 30, 2026

Fixed benchmark BF16 checks and docstring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant