Skip to content

[Cute,Fwd,Sm120] Fix three forward-pass correctness bugs (use_tma_O, smem-store atom, pack_gqa)#2553

Open
jganbar wants to merge 3 commits into
Dao-AILab:mainfrom
jganbar:sm120-fwd-correctness
Open

[Cute,Fwd,Sm120] Fix three forward-pass correctness bugs (use_tma_O, smem-store atom, pack_gqa)#2553
jganbar wants to merge 3 commits into
Dao-AILab:mainfrom
jganbar:sm120-fwd-correctness

Conversation

@jganbar
Copy link
Copy Markdown

@jganbar jganbar commented May 11, 2026

Summary

Three independent bugs prevent FA4's forward kernel from producing correct output on consumer Blackwell (SM_120) GPUs as of flash-attn-4==4.0.0b12. All three reproduce on main at the current HEAD.

  • Compile crash in the epilogue. use_tma_O was True on SM_120 but the SM80 base class never wires a TMA atom for O — it passes None to self.epilogue, and the epilogue crashes inside quack.copy_utils.tma_get_copy_fn.
  • Silent numerical wrongness. The rmem→smem store atom in the epilogue was the SM90 stmatrix (chosen by get_smem_store_atom for any arch >= 90), but the SM80 base class's mma.sync.aligned.m16n8k16 output register layout is not compatible with it. The kernel completed, returned the correct shape, and the output norm matched SDPA — but element-wise values were wrong by 0.5–3.9 (bf16/fp16, vs SDPA-bf16's own ~0.003).
  • Compile crash on every default-policy GQA/MQA shape. pack_gqa.compute_ptr trips cute.crd2idx because cuTeDSL 4.4.2 collapses the composite (qhead_per_kvhead, seqlen_q) mode in mO[None, 0] to a rank-1 layout. Pragmatic mitigation here: default pack_gqa=False on SM_120 until the deeper cuTeDSL issue is fixed; explicit pack_gqa=True is still honoured.

After this PR, the SM_120 forward kernel matches PyTorch SDPA-bf16's own numerical accuracy across a 240-shape grid (240/240 pass, max abs diffs bit-identical to SDPA's on 219/240 shapes and within 3.5e-5 on the rest) and is ~1.15× faster on average (range 0.74×–1.95×).

Bugs and fixes at a glance

Each commit has a focused message with the full root-cause writeup; this is the summary.

# File Symptom Fix
1 flash_fwd.py:652 AttributeError: 'NoneType' object has no attribute '_trait' at trace time self.use_tma_O = False unconditionally in the SM80 base class
2 flash_fwd.py:347 Kernel runs to completion, returns plausible output, but max abs diff vs fp32 reference is 0.5–3.9 (vs ~0.003 for SDPA-bf16) Force get_smem_store_atom(80, ...) so the SM80 base class always uses the universal copy, not the SM90 stmatrix
3 interface.py:430 ValueError: Operation creation failed (unable to compute crd2idx) on every default-policy GQA/MQA shape Default pack_gqa=False on SM_120

Bug 2 (the stmatrix mismatch) is the headline fix — silent numerical wrongness is the worst kind of bug to ship and this one would propagate into training loss / eval quality for any model using FA4 on consumer Blackwell.

How the headline bug was found

The error profile was very specific and ruled out most hypotheses up-front:

  • Bitwise-deterministic across reruns → not a race / scheduler issue.
  • Uniform across all M-tiles (not concentrated in m_block=0) → not a kernel prologue / initialisation issue.
  • Scales linearly with input magnitude → precision / permutation signature, not a logic bug.
  • Output norm preserved to ~3 significant digits → values are reshuffled with their same-magnitude neighbours, not zeroed.

That pattern points at a memory transfer using the wrong layout. Grepping self.arch for arch-conditional branches in the SM80 base class surfaces exactly one candidate: get_smem_store_atom in the epilogue. The branch picks the SM90 stmatrix path for any arch ≥ 90, but stmatrix is hardware-paired with WGMMA's output register layout, not SM80 MMA's. Patch is one literal arg change.

Test plan

Hardware: NVIDIA GeForce RTX 5090, capability (12, 0), driver 590.48.01, CUDA 12.8.

Software: Python 3.11.15, torch==2.10.0+cu128, nvidia-cutlass-dsl==4.4.2, this branch installed editable.

Correctness (240 shapes × 3 backends = 720 evaluations, all pass)

dtypes    : fp16, bf16
causal    : True, False
B         : 1, 2
S         : 128, 512, 1024, 2048, 4096
heads     : MHA(Hq=Hkv=16), GQA(Hq=16,Hkv=8), MQA(Hq=16,Hkv=1)
D         : 64, 128
backends  : sdpa_bshd (reference), fa4 (default pack_gqa), fa4_no_packgqa
reference : PyTorch SDPA in fp32 on the same inputs

status by backend:
  sdpa_bshd       240 / 240 pass
  fa4             240 / 240 pass
  fa4_no_packgqa  240 / 240 pass

FA4's per-shape max_abs_diff vs the fp32 SDPA reference is bit-identical to SDPA-bf16's own on 219/240 shapes; the remaining 21 differ by ≤ 3.5e-5 (FP reordering noise).

Before this PR the analogous run was 240 SDPA pass, 80 FA4 "pass" (MHA only — actually producing element-wise diffs of 0.5–3.9 against fp32 reference, flagged fail by the harness; the kernel had reported success), 160 FA4 error (GQA/MQA, the pack_gqa cute crash).

Performance

Same harness, 10 warmup + 50 measured iters per cell, torch.cuda.Event timing, hard per-process memory cap so we don't disturb other GPU tenants.

backend mean TFLOPs/s peak
sdpa_bshd (PyTorch eager + repeat_interleave for GQA) 87.9 193.7
fa4 (this PR) 94.8 196.2

FA4 speedup vs sdpa_bshd across 240 shapes: mean 1.15×, median 1.07×, max 1.95×, min 0.74×. Biggest FA4 wins are GQA at small S (SDPA pays for repeat_interleave); the few shapes where FA4 still loses are S=512 D=128 causal MHA, which is a tile-size tuning opportunity orthogonal to this PR.

Minimal reproducer for the silent-wrongness bug

import torch
import torch.nn.functional as F
from flash_attn.cute.interface import flash_attn_func

torch.manual_seed(0)
B, S, H, D = 1, 256, 4, 64
dt = torch.bfloat16
q = torch.randn(B, S, H, D, device="cuda", dtype=dt)
k = torch.randn(B, S, H, D, device="cuda", dtype=dt)
v = torch.randn(B, S, H, D, device="cuda", dtype=dt)

out = flash_attn_func(q, k, v, causal=False)
if isinstance(out, tuple):
    out = out[0]

ref = F.scaled_dot_product_attention(
    q.float().transpose(1, 2),
    k.float().transpose(1, 2),
    v.float().transpose(1, 2),
).transpose(1, 2).contiguous()

print(f"max abs diff: {(out.float() - ref).abs().max().item():.4f}")
# before this PR: ~0.61   (bf16 elementwise tolerance should be ~0.003)
# after this PR:  ~0.002

Limitations / out of scope

  • Backward pass. This PR is forward-only. flash_bwd_postprocess.py:537 calls get_smem_store_atom(self.arch.major*10+self.arch.minor, ...) and almost certainly needs the analogous fix when going through FlashAttentionBackwardSm120. A backward-correctness check (autograd + grad.allclose(grad_ref)) should follow as a separate PR.
  • The cuTeDSL composite-mode flattening (deeper root cause behind commit 3). Commit 3 only flips the default to pack_gqa=False on SM_120; it does not fix pack_gqa.compute_ptr or the cuTeDSL behaviour itself. A proper fix lives in either pack_gqa.compute_ptr (compute the flat index directly, without relying on the slice preserving compositeness) or in cuTeDSL.
  • No new tile-size tuning for SM_120. The few shapes where FA4 loses to SDPA after this PR are tile-tuning candidates but out of scope.

🤖 Generated with Claude Code

jganbar and others added 3 commits May 11, 2026 18:04
`FlashAttentionForwardSm80.__call__` sets

    self.use_tma_O = self.arch >= Arch.sm_90

but the base class kernel never constructs a TMA atom for O — it
passes None as `tma_atom_O` to `self.epilogue` (line ~1069). The check
exists because the file once intended to support a Hopper-style TMA-O
path that was never wired up here.

On Hopper / SM_100 hardware this is dead code because those archs use
their own forward classes (`FlashAttentionForwardSm90` /
`FlashAttentionForwardSm100`) with their own `__call__`. But
`FlashAttentionForwardSm120` inherits from this class, and
`FlashAttentionForwardBase.__init__` reads `self.arch` from the DSL,
which is `Arch.sm_120` on consumer Blackwell. The epilogue then takes
the TMA-output branch and crashes inside
`quack.copy_utils.tma_get_copy_fn` -> `cpasync.tma_partition` with
`AttributeError: 'NoneType' object has no attribute '_trait'`.

Static `arch = 80` on `FlashAttentionForwardSm120` was intended to
prevent this but is overwritten by `__init__`.

Force `use_tma_O = False` here; SM90 and SM100 are unaffected because
they have their own `__call__`.

Reproduced on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`FlashAttentionForwardSm80.epilogue` chose the rmem->smem store atom
via `get_smem_store_atom(self.arch.major*10 + self.arch.minor, ...)`,
which returns

  - `CopyUniversalOp` for arch < 90 (or non-16-bit data), and
  - `StMatrix8x8x16bOp(num_matrices=4)` (Hopper `stmatrix`) for
    arch >= 90 on 16-bit data.

`stmatrix` is hardware-paired with WGMMA's output register layout. The
SM80 base class uses `mma.sync.aligned.m16n8k16` whose output register
layout is *not* what `stmatrix` consumes. With WGMMA-output the atom
permutes bytes from a fixed pattern of threads/registers; feeding it
SM80-MMA-output silently scrambles values across nearby register
lanes during the store.

On native SM_80 hardware this branch never fires because the DSL arch
is sm_80 < sm_90. The bug only surfaces when this class is reused on
SM_120 via `FlashAttentionForwardSm120`, where `self.arch` is read
from the DSL as `sm_120` and the >= 90 branch picks `stmatrix`.

Symptom: the kernel completes without error and returns the correct
output shape and a roughly correct output norm (each scrambled value
is replaced by a same-magnitude neighbour), but element-wise diffs vs
fp32 SDPA are 0.5-1.2 (non-causal) and 3.4-3.9 (causal), versus
SDPA-bf16's own ~0.003 and ~0.008. Determinism still holds and error
scales linearly with input magnitude — the precision/permutation
signature, not a logic bug.

Fix: pass a fixed `80` to `get_smem_store_atom` here so the SM80 base
class always takes the universal-copy path, matching its actual MMA
output layout.

Verified on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128):
240/240 correctness configs pass against fp32 SDPA reference across
{fp16, bf16} x {causal, full} x B in {1,2} x S in {128..4096} x
{MHA, GQA, MQA} x D in {64, 128}, with max abs diff matching
SDPA-bf16's own (~0.012 worst case, ~0.0027 mean).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`pack_gqa.compute_ptr` calls

    utils.elem_pointer(tensor, ((h_idx, m_idx),))

with a tensor whose layout is supposed to keep a composite
`(qhead_per_kvhead, seqlen_q)` first mode (created by
`pack_gqa_layout`). The slice `mO[None, 0]` that lands in
`compute_ptr` is meant to preserve that compositeness so the rank-2
coord matches.

On SM_120 with `cuTeDSL==4.4.2` the slice collapses the composite
mode into a rank-1 layout. `cute.crd2idx` then refuses the rank-2
coord and raises at trace time with

    unable to compute crd2idx with
      '!cute.layout<"(?):(?{i64 div=8})">'
      and  '!cute.coord<"((?,?))">'

resulting in a `ValueError: Operation creation failed` before the
kernel can run. Every default-policy GQA / MQA shape on consumer
Blackwell hits this.

The non-packed GQA path is numerically identical (pack_gqa is a
perf-only optimization for the GQA Q-load / O-store), so flipping
the auto-default to False on SM_120 makes GQA / MQA work out of the
box while a deeper cuTeDSL fix is investigated. Explicit
`pack_gqa=True` from the caller is still honoured.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
thad0ctor added a commit to thad0ctor/flash-attention that referenced this pull request May 24, 2026
Phase 4-B2 re-benchmark caught a real regression: all 12 qwen2.5-7b
benchmark cells failed with 'failed to perform a valid division of
((7,?),?) by [128:1;128:1]' at flash_fwd.py:920 (cute.local_tile of
the packed Q layout by blkQ_shape=(128, 128)).

Root cause: qwen2.5-7b has 28 Q heads / 4 KV heads = 7-way GQA.
pack_gqa_layout produces a composite-mode-0 Q layout
((qhead_per_kvhead, seqlen_q), head_dim); cute.local_tile must divide
this by (tile_m, tile_hdim). For the division to succeed at the
qhead_per_kvhead boundary, tile_m must be divisible by
qhead_per_kvhead. SM120 uses tile_m=128 unconditionally, which divides
1/2/4/8/16-way GQA (llama3 32q/8kv, mistral 32q/8kv, hd64-small-gqa
16q/4kv) but does NOT divide 7-way GQA (28q/4kv).

Phase 4-T (commit 56a8639) lifted the SM120 pack_gqa=False default
override that PR Dao-AILab#2553 had as a stopgap, and that exposed this latent
qwen2.5-style failure that the original PRs never exercised.

Fix: auto-downgrade pack_gqa=False on SM120 when qhead_per_kvhead does
not divide 128. Other arches (SM80/SM90/SM100) choose tile_m
differently and are not affected; this gate is SM120-only. Verified:
qwen2.5-7b (28/4) now works (norm 97.58); llama3-8b (32/8) still uses
pack_gqa=True (norm 106.19); regression smokes pass.

A future tile-tuning pass (Phase 5) could pick tile_m that's a
multiple of qhead_per_kvhead per-shape and re-enable pack_gqa for
these unusual ratios; the auto-downgrade keeps qwen-family models
working in the meantime.
thad0ctor added a commit to thad0ctor/flash-attention that referenced this pull request May 25, 2026
The hardcoded get_smem_store_atom(80, ...) introduced by upstream PR
Dao-AILab#2553 (commit bc67a9c) lives on FlashAttentionForwardBase, not on
FlashAttentionForwardSm80. FlashAttentionForwardSm90 inherits this
method without override and calls self.epilogue(), so SM90 forward
was silently switched from the WGMMA-paired stmatrix path to the
universal copy — a perf regression on Hopper.

Mirror the arch-gating pattern from flash_bwd_postprocess.py's dQ
store atom: force 80 only when arch//10 in {8, 12}; SM90 keeps
stmatrix (matches its WGMMA register layout).

self.arch in flash_fwd.py is an Arch enum (vs flash_bwd_postprocess
where it's an int), so unpack via .major*10+.minor.

Phase 1 SM120 smokes still pass; SM90 forward perf restored.
thad0ctor added a commit to thad0ctor/flash-attention that referenced this pull request May 25, 2026
… + tile tuning

Makes the three cherry-picked upstream SM120 PRs (Dao-AILab#2553, Dao-AILab#2349, Dao-AILab#2389) actually
usable end-to-end on consumer Blackwell (RTX 5090, RTX PRO 6000 Blackwell). The
upstream PRs alone leave SM120 forward dispatcher-buggy and backward broken;
this commit adds the integration glue, backward support, real paged-KV +
pack_gqa implementations, a subprocess-isolated per-shape tile lookup, and the
test coverage to back it up.

# Dispatcher fixes (SM120 forward + backward couldn't compile or run end-to-end
# without these)

- Initialize dQ_single_wg in the SM120 backward setup (was unbound)
- Keep softmax_scale non-None for SM80/SM120 backward dK epilogue (inline log2
  computation like SM90 does)
- atomic_add_fp32: adopt the new keyword-only nvvm.atomicrmw signature in
  nvidia-cutlass-dsl >= 4.x
- Drop unsupported is_split_kv kwarg on the SM120 forward path
- Pass split_idx=0, num_splits=1, seqlen_info=seqlen at the SM80/SM120
  get_total_block_count call site (arity mismatch fix)
- Rename vec_size -> score_vec_size on the SM120 TMA forward (typo in upstream
  Dao-AILab#2349; was AttributeError on softcap/learnable_sink/score_mod)
- Auto-downgrade pack_gqa=False when qhead_per_kvhead doesn't divide
  tile_m=128 (qwen2.5-7b's 7-way GQA otherwise fails cute.local_tile division)
- Auto-downgrade pack_gqa=False when paged-KV is used (cross-feature
  interaction with PagedKVManager's K/V indexing)
- Route head_dim > head_dim_v to the non-TMA SM120 path: bisection showed the
  hang lives in FlashAttentionForwardSm120Tma, not in the SM80-base mainloop
  as upstream diagnosed. Non-TMA can_implement accepts d > dv; the TMA path
  still rejects it so the dispatcher falls through. d > dv shapes now work
  (verified bitwise-identical to SDPA on the minimum repro).
- Route SM120 through the shared _validate_head_dims helper (invalid head_dim
  was reaching the kernel and faulting with cudaErrorMisalignedAddress)
- Clamp the cu_seqlens[batch_idx+1] read in SeqlenInfoQK.create so SM80/SM120
  over-launched varlen tiles don't fault on a non-resident page
- arch-gate FlashAttentionForwardBase.epilogue smem store atom: SM80/SM120
  force the universal copy, SM90 keeps WGMMA-paired stmatrix (upstream PR
  Dao-AILab#2553's bc67a9c unconditionally forced 80, which silently switched SM90
  forward through the universal-copy path)
- Include sm120_num_stages in the forward compile cache key (different ns
  values with the same tile would otherwise share a key and the second call
  would reuse the first-compiled kernel)
- Document why deterministic backward can't be lifted on SM120 (the SM80
  base kernel itself lacks the dQ_semaphore code path; a feature gap shared
  with SM80)

# SM120-specific kernel work

- Real paged-KV forward via PagedKVManager on the SM80-base kernel,
  supported through head_dim <= 128. A paged-specific tile override
  (128, 128, ns=1) gates on page_table is not None and head_dim <= 128 so
  PagedKVManager's tile_n >= num_threads invariant holds. SMEM math fits:
  48 KB at d=64, 72 KB at d=96, 96 KB at d=128 (cap 99 KB).
- Real pack_gqa=True support: rewrite PackGQA.compute_ptr to compute the
  flat offset arithmetically from stride[0][0] and stride[0][1] rather than
  cute.crd2idx (which cuTeDSL 4.4-4.5 collapses through trailing slices).
  Call pack_gqa_layout in the SM80-base forward so packed Q is actually
  materialized (was missing — would have produced wrong output even after
  the crd2idx workaround).
- Backward postprocess dQ smem-store atom: force universal copy on SM80/SM120
  (same class of bug as the upstream Dao-AILab#2553 forward fix but in the dQ
  postprocess; left silent rmem->smem scrambling otherwise). Permanent
  regression test with a white-box source-inspection guard against
  reintroduction.
- New D > 128 SM120 tile bracket (64, 64, ns=1) that fits the 99 KB SMEM cap
  for head_dim=256.

# Forward tile selection (per-shape lookup)

The SM120 forward dispatch now consults a tile + num_stages lookup keyed on
(head_dim, qhead_per_kvhead, seqlen, causal). Shapes outside the lookup fall
back to the head_dim-only brackets that match the pre-tuning defaults.

The lookup was built from a subprocess-isolated sweep: each (cell, candidate)
pair is measured in a fresh python process so JIT-cache pollution can't bias
the rankings (a single-process sweep silently reuses compiled kernels across
candidates with subtly different shapes). The top-3 candidates per cell get a
reproducibility re-measurement; variance > 10% excludes a candidate. A
candidate ships only when its mean TFLOPS beats the baseline tile by >= 2%;
otherwise the cell falls back to baseline.

# Test coverage added

- tests/cute/test_paged_kv_sm120.py (38 cases): paged-KV correctness across
  page_size {16, 64, 256}, identity / permuted / shared page tables, GQA + MQA,
  d in {64, 96, 128}; expected NotImplementedError for d in {192, 256};
  expected correctness (now, not rejection) for the paged + d > dv + varlen
  cross-feature combination.
- tests/cute/test_flash_attn_bwd_sm120_postprocess.py (10 cases): backward dQ
  postprocess regression suite, combines numeric vs fp32-SDPA comparison with
  a white-box source-inspection guard against the buggy literal pattern.
- tests/cute/test_flash_attn_sm120_dgtdv.py (11 cases): regression test for
  the Bug E d > dv non-TMA routing. 8 kernel-launch parametrizations plus
  3 unit probes (TMA rejection, non-TMA acceptance, SMEM constraint). All
  kernel tests carry pytest-timeout(30) with --timeout-method=signal so a
  future TMA gate widening that re-introduces the GPU hang fails as a timeout
  instead of wedging the GPU.

# What this is NOT

- Real paged-KV at head_dim > 128: rejected at dispatch with a clear
  NotImplementedError. Lifting would require either a refactor of
  PagedKVManager (per-thread page-table fragment > 0 at tile_n < num_threads)
  or a separate kernel; the 99 KB SMEM cap precludes the simple (128, 128, ns=1)
  approach used for d <= 128.
- Real fix for the TMA path d > dv hang: the kernel-level root cause needs
  cuda-gdb or instrumented bisection; the routing fix makes user-visible
  shapes correct today, but the TMA kernel itself is still latent-broken for
  d > dv. The can_implement gate ensures the TMA path is never selected for
  d > dv.
- Deterministic backward on SM120: asserts off because the SM80 base kernel
  itself lacks the dQ_semaphore code path. Lift would need a feature port
  from SM90 into the SM80 base; out of scope here.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant