[Cute,Fwd,Sm120] Fix three forward-pass correctness bugs (use_tma_O, smem-store atom, pack_gqa)#2553
Open
jganbar wants to merge 3 commits into
Open
[Cute,Fwd,Sm120] Fix three forward-pass correctness bugs (use_tma_O, smem-store atom, pack_gqa)#2553jganbar wants to merge 3 commits into
jganbar wants to merge 3 commits into
Conversation
`FlashAttentionForwardSm80.__call__` sets
self.use_tma_O = self.arch >= Arch.sm_90
but the base class kernel never constructs a TMA atom for O — it
passes None as `tma_atom_O` to `self.epilogue` (line ~1069). The check
exists because the file once intended to support a Hopper-style TMA-O
path that was never wired up here.
On Hopper / SM_100 hardware this is dead code because those archs use
their own forward classes (`FlashAttentionForwardSm90` /
`FlashAttentionForwardSm100`) with their own `__call__`. But
`FlashAttentionForwardSm120` inherits from this class, and
`FlashAttentionForwardBase.__init__` reads `self.arch` from the DSL,
which is `Arch.sm_120` on consumer Blackwell. The epilogue then takes
the TMA-output branch and crashes inside
`quack.copy_utils.tma_get_copy_fn` -> `cpasync.tma_partition` with
`AttributeError: 'NoneType' object has no attribute '_trait'`.
Static `arch = 80` on `FlashAttentionForwardSm120` was intended to
prevent this but is overwritten by `__init__`.
Force `use_tma_O = False` here; SM90 and SM100 are unaffected because
they have their own `__call__`.
Reproduced on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`FlashAttentionForwardSm80.epilogue` chose the rmem->smem store atom
via `get_smem_store_atom(self.arch.major*10 + self.arch.minor, ...)`,
which returns
- `CopyUniversalOp` for arch < 90 (or non-16-bit data), and
- `StMatrix8x8x16bOp(num_matrices=4)` (Hopper `stmatrix`) for
arch >= 90 on 16-bit data.
`stmatrix` is hardware-paired with WGMMA's output register layout. The
SM80 base class uses `mma.sync.aligned.m16n8k16` whose output register
layout is *not* what `stmatrix` consumes. With WGMMA-output the atom
permutes bytes from a fixed pattern of threads/registers; feeding it
SM80-MMA-output silently scrambles values across nearby register
lanes during the store.
On native SM_80 hardware this branch never fires because the DSL arch
is sm_80 < sm_90. The bug only surfaces when this class is reused on
SM_120 via `FlashAttentionForwardSm120`, where `self.arch` is read
from the DSL as `sm_120` and the >= 90 branch picks `stmatrix`.
Symptom: the kernel completes without error and returns the correct
output shape and a roughly correct output norm (each scrambled value
is replaced by a same-magnitude neighbour), but element-wise diffs vs
fp32 SDPA are 0.5-1.2 (non-causal) and 3.4-3.9 (causal), versus
SDPA-bf16's own ~0.003 and ~0.008. Determinism still holds and error
scales linearly with input magnitude — the precision/permutation
signature, not a logic bug.
Fix: pass a fixed `80` to `get_smem_store_atom` here so the SM80 base
class always takes the universal-copy path, matching its actual MMA
output layout.
Verified on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128):
240/240 correctness configs pass against fp32 SDPA reference across
{fp16, bf16} x {causal, full} x B in {1,2} x S in {128..4096} x
{MHA, GQA, MQA} x D in {64, 128}, with max abs diff matching
SDPA-bf16's own (~0.012 worst case, ~0.0027 mean).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`pack_gqa.compute_ptr` calls
utils.elem_pointer(tensor, ((h_idx, m_idx),))
with a tensor whose layout is supposed to keep a composite
`(qhead_per_kvhead, seqlen_q)` first mode (created by
`pack_gqa_layout`). The slice `mO[None, 0]` that lands in
`compute_ptr` is meant to preserve that compositeness so the rank-2
coord matches.
On SM_120 with `cuTeDSL==4.4.2` the slice collapses the composite
mode into a rank-1 layout. `cute.crd2idx` then refuses the rank-2
coord and raises at trace time with
unable to compute crd2idx with
'!cute.layout<"(?):(?{i64 div=8})">'
and '!cute.coord<"((?,?))">'
resulting in a `ValueError: Operation creation failed` before the
kernel can run. Every default-policy GQA / MQA shape on consumer
Blackwell hits this.
The non-packed GQA path is numerically identical (pack_gqa is a
perf-only optimization for the GQA Q-load / O-store), so flipping
the auto-default to False on SM_120 makes GQA / MQA work out of the
box while a deeper cuTeDSL fix is investigated. Explicit
`pack_gqa=True` from the caller is still honoured.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
thad0ctor
added a commit
to thad0ctor/flash-attention
that referenced
this pull request
May 24, 2026
Phase 4-B2 re-benchmark caught a real regression: all 12 qwen2.5-7b benchmark cells failed with 'failed to perform a valid division of ((7,?),?) by [128:1;128:1]' at flash_fwd.py:920 (cute.local_tile of the packed Q layout by blkQ_shape=(128, 128)). Root cause: qwen2.5-7b has 28 Q heads / 4 KV heads = 7-way GQA. pack_gqa_layout produces a composite-mode-0 Q layout ((qhead_per_kvhead, seqlen_q), head_dim); cute.local_tile must divide this by (tile_m, tile_hdim). For the division to succeed at the qhead_per_kvhead boundary, tile_m must be divisible by qhead_per_kvhead. SM120 uses tile_m=128 unconditionally, which divides 1/2/4/8/16-way GQA (llama3 32q/8kv, mistral 32q/8kv, hd64-small-gqa 16q/4kv) but does NOT divide 7-way GQA (28q/4kv). Phase 4-T (commit 56a8639) lifted the SM120 pack_gqa=False default override that PR Dao-AILab#2553 had as a stopgap, and that exposed this latent qwen2.5-style failure that the original PRs never exercised. Fix: auto-downgrade pack_gqa=False on SM120 when qhead_per_kvhead does not divide 128. Other arches (SM80/SM90/SM100) choose tile_m differently and are not affected; this gate is SM120-only. Verified: qwen2.5-7b (28/4) now works (norm 97.58); llama3-8b (32/8) still uses pack_gqa=True (norm 106.19); regression smokes pass. A future tile-tuning pass (Phase 5) could pick tile_m that's a multiple of qhead_per_kvhead per-shape and re-enable pack_gqa for these unusual ratios; the auto-downgrade keeps qwen-family models working in the meantime.
thad0ctor
added a commit
to thad0ctor/flash-attention
that referenced
this pull request
May 25, 2026
The hardcoded get_smem_store_atom(80, ...) introduced by upstream PR Dao-AILab#2553 (commit bc67a9c) lives on FlashAttentionForwardBase, not on FlashAttentionForwardSm80. FlashAttentionForwardSm90 inherits this method without override and calls self.epilogue(), so SM90 forward was silently switched from the WGMMA-paired stmatrix path to the universal copy — a perf regression on Hopper. Mirror the arch-gating pattern from flash_bwd_postprocess.py's dQ store atom: force 80 only when arch//10 in {8, 12}; SM90 keeps stmatrix (matches its WGMMA register layout). self.arch in flash_fwd.py is an Arch enum (vs flash_bwd_postprocess where it's an int), so unpack via .major*10+.minor. Phase 1 SM120 smokes still pass; SM90 forward perf restored.
thad0ctor
added a commit
to thad0ctor/flash-attention
that referenced
this pull request
May 25, 2026
… + tile tuning Makes the three cherry-picked upstream SM120 PRs (Dao-AILab#2553, Dao-AILab#2349, Dao-AILab#2389) actually usable end-to-end on consumer Blackwell (RTX 5090, RTX PRO 6000 Blackwell). The upstream PRs alone leave SM120 forward dispatcher-buggy and backward broken; this commit adds the integration glue, backward support, real paged-KV + pack_gqa implementations, a subprocess-isolated per-shape tile lookup, and the test coverage to back it up. # Dispatcher fixes (SM120 forward + backward couldn't compile or run end-to-end # without these) - Initialize dQ_single_wg in the SM120 backward setup (was unbound) - Keep softmax_scale non-None for SM80/SM120 backward dK epilogue (inline log2 computation like SM90 does) - atomic_add_fp32: adopt the new keyword-only nvvm.atomicrmw signature in nvidia-cutlass-dsl >= 4.x - Drop unsupported is_split_kv kwarg on the SM120 forward path - Pass split_idx=0, num_splits=1, seqlen_info=seqlen at the SM80/SM120 get_total_block_count call site (arity mismatch fix) - Rename vec_size -> score_vec_size on the SM120 TMA forward (typo in upstream Dao-AILab#2349; was AttributeError on softcap/learnable_sink/score_mod) - Auto-downgrade pack_gqa=False when qhead_per_kvhead doesn't divide tile_m=128 (qwen2.5-7b's 7-way GQA otherwise fails cute.local_tile division) - Auto-downgrade pack_gqa=False when paged-KV is used (cross-feature interaction with PagedKVManager's K/V indexing) - Route head_dim > head_dim_v to the non-TMA SM120 path: bisection showed the hang lives in FlashAttentionForwardSm120Tma, not in the SM80-base mainloop as upstream diagnosed. Non-TMA can_implement accepts d > dv; the TMA path still rejects it so the dispatcher falls through. d > dv shapes now work (verified bitwise-identical to SDPA on the minimum repro). - Route SM120 through the shared _validate_head_dims helper (invalid head_dim was reaching the kernel and faulting with cudaErrorMisalignedAddress) - Clamp the cu_seqlens[batch_idx+1] read in SeqlenInfoQK.create so SM80/SM120 over-launched varlen tiles don't fault on a non-resident page - arch-gate FlashAttentionForwardBase.epilogue smem store atom: SM80/SM120 force the universal copy, SM90 keeps WGMMA-paired stmatrix (upstream PR Dao-AILab#2553's bc67a9c unconditionally forced 80, which silently switched SM90 forward through the universal-copy path) - Include sm120_num_stages in the forward compile cache key (different ns values with the same tile would otherwise share a key and the second call would reuse the first-compiled kernel) - Document why deterministic backward can't be lifted on SM120 (the SM80 base kernel itself lacks the dQ_semaphore code path; a feature gap shared with SM80) # SM120-specific kernel work - Real paged-KV forward via PagedKVManager on the SM80-base kernel, supported through head_dim <= 128. A paged-specific tile override (128, 128, ns=1) gates on page_table is not None and head_dim <= 128 so PagedKVManager's tile_n >= num_threads invariant holds. SMEM math fits: 48 KB at d=64, 72 KB at d=96, 96 KB at d=128 (cap 99 KB). - Real pack_gqa=True support: rewrite PackGQA.compute_ptr to compute the flat offset arithmetically from stride[0][0] and stride[0][1] rather than cute.crd2idx (which cuTeDSL 4.4-4.5 collapses through trailing slices). Call pack_gqa_layout in the SM80-base forward so packed Q is actually materialized (was missing — would have produced wrong output even after the crd2idx workaround). - Backward postprocess dQ smem-store atom: force universal copy on SM80/SM120 (same class of bug as the upstream Dao-AILab#2553 forward fix but in the dQ postprocess; left silent rmem->smem scrambling otherwise). Permanent regression test with a white-box source-inspection guard against reintroduction. - New D > 128 SM120 tile bracket (64, 64, ns=1) that fits the 99 KB SMEM cap for head_dim=256. # Forward tile selection (per-shape lookup) The SM120 forward dispatch now consults a tile + num_stages lookup keyed on (head_dim, qhead_per_kvhead, seqlen, causal). Shapes outside the lookup fall back to the head_dim-only brackets that match the pre-tuning defaults. The lookup was built from a subprocess-isolated sweep: each (cell, candidate) pair is measured in a fresh python process so JIT-cache pollution can't bias the rankings (a single-process sweep silently reuses compiled kernels across candidates with subtly different shapes). The top-3 candidates per cell get a reproducibility re-measurement; variance > 10% excludes a candidate. A candidate ships only when its mean TFLOPS beats the baseline tile by >= 2%; otherwise the cell falls back to baseline. # Test coverage added - tests/cute/test_paged_kv_sm120.py (38 cases): paged-KV correctness across page_size {16, 64, 256}, identity / permuted / shared page tables, GQA + MQA, d in {64, 96, 128}; expected NotImplementedError for d in {192, 256}; expected correctness (now, not rejection) for the paged + d > dv + varlen cross-feature combination. - tests/cute/test_flash_attn_bwd_sm120_postprocess.py (10 cases): backward dQ postprocess regression suite, combines numeric vs fp32-SDPA comparison with a white-box source-inspection guard against the buggy literal pattern. - tests/cute/test_flash_attn_sm120_dgtdv.py (11 cases): regression test for the Bug E d > dv non-TMA routing. 8 kernel-launch parametrizations plus 3 unit probes (TMA rejection, non-TMA acceptance, SMEM constraint). All kernel tests carry pytest-timeout(30) with --timeout-method=signal so a future TMA gate widening that re-introduces the GPU hang fails as a timeout instead of wedging the GPU. # What this is NOT - Real paged-KV at head_dim > 128: rejected at dispatch with a clear NotImplementedError. Lifting would require either a refactor of PagedKVManager (per-thread page-table fragment > 0 at tile_n < num_threads) or a separate kernel; the 99 KB SMEM cap precludes the simple (128, 128, ns=1) approach used for d <= 128. - Real fix for the TMA path d > dv hang: the kernel-level root cause needs cuda-gdb or instrumented bisection; the routing fix makes user-visible shapes correct today, but the TMA kernel itself is still latent-broken for d > dv. The can_implement gate ensures the TMA path is never selected for d > dv. - Deterministic backward on SM120: asserts off because the SM80 base kernel itself lacks the dQ_semaphore code path. Lift would need a feature port from SM90 into the SM80 base; out of scope here.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three independent bugs prevent FA4's forward kernel from producing correct output on consumer Blackwell (SM_120) GPUs as of
flash-attn-4==4.0.0b12. All three reproduce onmainat the current HEAD.use_tma_Owas True on SM_120 but the SM80 base class never wires a TMA atom for O — it passesNonetoself.epilogue, and the epilogue crashes insidequack.copy_utils.tma_get_copy_fn.stmatrix(chosen byget_smem_store_atomfor anyarch >= 90), but the SM80 base class'smma.sync.aligned.m16n8k16output register layout is not compatible with it. The kernel completed, returned the correct shape, and the output norm matched SDPA — but element-wise values were wrong by 0.5–3.9 (bf16/fp16, vs SDPA-bf16's own ~0.003).pack_gqa.compute_ptrtripscute.crd2idxbecause cuTeDSL 4.4.2 collapses the composite(qhead_per_kvhead, seqlen_q)mode inmO[None, 0]to a rank-1 layout. Pragmatic mitigation here: defaultpack_gqa=Falseon SM_120 until the deeper cuTeDSL issue is fixed; explicitpack_gqa=Trueis still honoured.After this PR, the SM_120 forward kernel matches PyTorch SDPA-bf16's own numerical accuracy across a 240-shape grid (240/240 pass, max abs diffs bit-identical to SDPA's on 219/240 shapes and within
3.5e-5on the rest) and is ~1.15× faster on average (range 0.74×–1.95×).Bugs and fixes at a glance
Each commit has a focused message with the full root-cause writeup; this is the summary.
flash_fwd.py:652AttributeError: 'NoneType' object has no attribute '_trait'at trace timeself.use_tma_O = Falseunconditionally in the SM80 base classflash_fwd.py:347get_smem_store_atom(80, ...)so the SM80 base class always uses the universal copy, not the SM90 stmatrixinterface.py:430ValueError: Operation creation failed(unable to compute crd2idx) on every default-policy GQA/MQA shapepack_gqa=Falseon SM_120Bug 2 (the stmatrix mismatch) is the headline fix — silent numerical wrongness is the worst kind of bug to ship and this one would propagate into training loss / eval quality for any model using FA4 on consumer Blackwell.
How the headline bug was found
The error profile was very specific and ruled out most hypotheses up-front:
That pattern points at a memory transfer using the wrong layout. Grepping
self.archfor arch-conditional branches in the SM80 base class surfaces exactly one candidate:get_smem_store_atomin the epilogue. The branch picks the SM90stmatrixpath for any arch ≥ 90, butstmatrixis hardware-paired with WGMMA's output register layout, not SM80 MMA's. Patch is one literal arg change.Test plan
Hardware: NVIDIA GeForce RTX 5090, capability (12, 0), driver 590.48.01, CUDA 12.8.
Software: Python 3.11.15,
torch==2.10.0+cu128,nvidia-cutlass-dsl==4.4.2, this branch installed editable.Correctness (240 shapes × 3 backends = 720 evaluations, all pass)
FA4's per-shape
max_abs_diffvs the fp32 SDPA reference is bit-identical to SDPA-bf16's own on 219/240 shapes; the remaining 21 differ by ≤ 3.5e-5 (FP reordering noise).Before this PR the analogous run was 240 SDPA pass, 80 FA4 "pass" (MHA only — actually producing element-wise diffs of 0.5–3.9 against fp32 reference, flagged
failby the harness; the kernel had reported success), 160 FA4 error (GQA/MQA, thepack_gqacute crash).Performance
Same harness, 10 warmup + 50 measured iters per cell,
torch.cuda.Eventtiming, hard per-process memory cap so we don't disturb other GPU tenants.sdpa_bshd(PyTorch eager + repeat_interleave for GQA)fa4(this PR)FA4 speedup vs
sdpa_bshdacross 240 shapes: mean 1.15×, median 1.07×, max 1.95×, min 0.74×. Biggest FA4 wins are GQA at small S (SDPA pays forrepeat_interleave); the few shapes where FA4 still loses are S=512 D=128 causal MHA, which is a tile-size tuning opportunity orthogonal to this PR.Minimal reproducer for the silent-wrongness bug
Limitations / out of scope
flash_bwd_postprocess.py:537callsget_smem_store_atom(self.arch.major*10+self.arch.minor, ...)and almost certainly needs the analogous fix when going throughFlashAttentionBackwardSm120. A backward-correctness check (autograd +grad.allclose(grad_ref)) should follow as a separate PR.pack_gqa=Falseon SM_120; it does not fixpack_gqa.compute_ptror the cuTeDSL behaviour itself. A proper fix lives in eitherpack_gqa.compute_ptr(compute the flat index directly, without relying on the slice preserving compositeness) or in cuTeDSL.🤖 Generated with Claude Code