FA4 consumer Blackwell (sm_120) integration: forward + backward + dispatcher fixes#1
FA4 consumer Blackwell (sm_120) integration: forward + backward + dispatcher fixes#1thad0ctor wants to merge 16 commits into
Conversation
`FlashAttentionForwardSm80.__call__` sets
self.use_tma_O = self.arch >= Arch.sm_90
but the base class kernel never constructs a TMA atom for O — it
passes None as `tma_atom_O` to `self.epilogue` (line ~1069). The check
exists because the file once intended to support a Hopper-style TMA-O
path that was never wired up here.
On Hopper / SM_100 hardware this is dead code because those archs use
their own forward classes (`FlashAttentionForwardSm90` /
`FlashAttentionForwardSm100`) with their own `__call__`. But
`FlashAttentionForwardSm120` inherits from this class, and
`FlashAttentionForwardBase.__init__` reads `self.arch` from the DSL,
which is `Arch.sm_120` on consumer Blackwell. The epilogue then takes
the TMA-output branch and crashes inside
`quack.copy_utils.tma_get_copy_fn` -> `cpasync.tma_partition` with
`AttributeError: 'NoneType' object has no attribute '_trait'`.
Static `arch = 80` on `FlashAttentionForwardSm120` was intended to
prevent this but is overwritten by `__init__`.
Force `use_tma_O = False` here; SM90 and SM100 are unaffected because
they have their own `__call__`.
Reproduced on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`FlashAttentionForwardSm80.epilogue` chose the rmem->smem store atom
via `get_smem_store_atom(self.arch.major*10 + self.arch.minor, ...)`,
which returns
- `CopyUniversalOp` for arch < 90 (or non-16-bit data), and
- `StMatrix8x8x16bOp(num_matrices=4)` (Hopper `stmatrix`) for
arch >= 90 on 16-bit data.
`stmatrix` is hardware-paired with WGMMA's output register layout. The
SM80 base class uses `mma.sync.aligned.m16n8k16` whose output register
layout is *not* what `stmatrix` consumes. With WGMMA-output the atom
permutes bytes from a fixed pattern of threads/registers; feeding it
SM80-MMA-output silently scrambles values across nearby register
lanes during the store.
On native SM_80 hardware this branch never fires because the DSL arch
is sm_80 < sm_90. The bug only surfaces when this class is reused on
SM_120 via `FlashAttentionForwardSm120`, where `self.arch` is read
from the DSL as `sm_120` and the >= 90 branch picks `stmatrix`.
Symptom: the kernel completes without error and returns the correct
output shape and a roughly correct output norm (each scrambled value
is replaced by a same-magnitude neighbour), but element-wise diffs vs
fp32 SDPA are 0.5-1.2 (non-causal) and 3.4-3.9 (causal), versus
SDPA-bf16's own ~0.003 and ~0.008. Determinism still holds and error
scales linearly with input magnitude — the precision/permutation
signature, not a logic bug.
Fix: pass a fixed `80` to `get_smem_store_atom` here so the SM80 base
class always takes the universal-copy path, matching its actual MMA
output layout.
Verified on RTX 5090 (SM_120, cuTeDSL 4.4.2, torch 2.10.0+cu128):
240/240 correctness configs pass against fp32 SDPA reference across
{fp16, bf16} x {causal, full} x B in {1,2} x S in {128..4096} x
{MHA, GQA, MQA} x D in {64, 128}, with max abs diff matching
SDPA-bf16's own (~0.012 worst case, ~0.0027 mean).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`pack_gqa.compute_ptr` calls
utils.elem_pointer(tensor, ((h_idx, m_idx),))
with a tensor whose layout is supposed to keep a composite
`(qhead_per_kvhead, seqlen_q)` first mode (created by
`pack_gqa_layout`). The slice `mO[None, 0]` that lands in
`compute_ptr` is meant to preserve that compositeness so the rank-2
coord matches.
On SM_120 with `cuTeDSL==4.4.2` the slice collapses the composite
mode into a rank-1 layout. `cute.crd2idx` then refuses the rank-2
coord and raises at trace time with
unable to compute crd2idx with
'!cute.layout<"(?):(?{i64 div=8})">'
and '!cute.coord<"((?,?))">'
resulting in a `ValueError: Operation creation failed` before the
kernel can run. Every default-policy GQA / MQA shape on consumer
Blackwell hits this.
The non-packed GQA path is numerically identical (pack_gqa is a
perf-only optimization for the GQA Q-load / O-store), so flipping
the auto-default to False on SM_120 makes GQA / MQA work out of the
box while a deeper cuTeDSL fix is investigated. Explicit
`pack_gqa=True` from the caller is still honoured.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add FlashAttentionForwardSm120Tma class that uses TMA (cp.async.bulk) for Q/K/V loads with 1 DMA warp + 4 MMA warps, enabling producer/consumer overlap via PipelineTmaAsync with mbarrier synchronization. Key design: - TMA-compatible SMEM swizzle: Swizzle(B, 4, 3) instead of (B, 3, 3) - KV double-buffering (kv_stages=2), 160 threads (5 warps), 99KB SMEM - All pipeline operations inlined in the mainloop (not delegated to a separate @cute.jit method), which avoids CuTe DSL compiler hangs when pipeline states flow through method boundaries - is_first=False with pre-reset softmax state eliminates the need for a compile-time is_first flag in the single-loop mainloop - Dispatch: TMA default for SM120 non-paged, non-varlen. Falls back to CpAsync for paged KV and varlen (TMA addressing constraints). Validated on SM121a (DGX Spark): - 8/8 configs pass: non-causal + causal, B=1/2, Sq=64/128/256, Sk=128/256/512, H=4/8, D=128 - All diffs 0.002-0.008 vs reference Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Override self.arch = Arch.sm_80 after parent __init__ to prevent base class code paths from seeing the runtime arch (12.x) and enabling SM90+ features. The parent __init__ overwrites the class-level arch=80 attribute with the actual GPU arch. This was found by @2imi9 in Dao-AILab#2420 for the CpAsync kernel — same bug applies here. Add can_implement() check before TMA dispatch in interface.py so that configs exceeding SM120's 99KB SMEM (e.g. hdim=192 with kv_stages=2) fall back to the CpAsync kernel instead of failing at instantiation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
The base FlashAttentionForwardSm80.__call__ and FlashAttentionForwardSm100.__call__ both keep `stream` as the final parameter, with a comment: "Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI)". cute.compile binds arguments positionally against the compile_args list in interface.py, which ends with `current_stream`. The TMA kernel had `stream` at position 7 (right after softmax_scale). On this branch alone the kernel still works as advertised, but the mismatch breaks when composed with PRs that append further positional arguments to the compile path — most visibly when combined with Dao-AILab#2348's paged-KV plumbing or Dao-AILab#2439's dropout seeds, where the extra positions push `current_stream` onto a parameter that no longer exists or has the wrong type. Aligning with the base-class convention is mechanical and preserves correctness in isolation: Validation on SM121a (DGX Spark GB10), causal ∈ {False, True}, dtype ∈ {bf16, fp16}, B=1 S=256 H=4 D=64: causal=False bf16: max_diff=0.0020 PASS causal=False fp16: max_diff=0.0002 PASS causal=True bf16: max_diff=0.0078 PASS causal=True fp16: max_diff=0.0010 PASS Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Block-sparse attention processes only the KV blocks specified by block_sparse_tensors rather than the full KV sequence. Two block types are supported: mask_blocks (partially masked, apply mask_mod per element) and full_blocks (fully unmasked, skip masking entirely). Design follows the same mma_one_n_block callback pattern as SM90/SM100. The SM80 base class gets a new mma_one_n_block_bs method (load K, load V, wait, GEMM QK, score_mod, mask, softmax, GEMM PV) and a corresponding run_block_sparse_mainloop_sm80 utility in block_sparse_utils.py that iterates mask blocks then full blocks, mirroring consume_block_sparse_loads. Key implementation details: - run_block_sparse_mainloop_sm80: iterate mask_blocks first (highest n), then full_blocks. First full block always gets mask_seqlen=True since it may be at a higher n position than any mask block. - mma_one_n_block_bs: no pipeline overlap (block address unknown ahead of time), load K then V with separate cp_async_wait_group(1)/wait_group(0). - SM120 inherits SM80 base class and gets block sparsity for free. - FlashAttentionForwardSm120.__init__ forces self.arch = Arch.sm_80 to prevent the SM80 epilogue from using TMA-O (which would crash on SM121a since tma_atom_O is None in this kernel variant). - SM120: num_splits clamped to 1 in interface.py (no SplitKV support yet). - Block sparsity assert removed from SM120 interface path. Validated on SM121a (DGX Spark GB10): - test_block_sparsity.py: 4621 passed, 40 skipped - causal/non-causal, various head dims and sequence lengths Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace inline blocksparse_tensors[0]/[2] index access with the get_total_block_count() utility from block_sparse_utils.py. This keeps variable naming consistent with the rest of the block-sparse codebase (which unpacks by name, not by index position). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an SM120 TMA forward path, SM80 block‑sparse mainloop and per‑block compute, updates SM120/SM80 dispatch and epilogue/backward glue, fixes PackGQA pointer math and seqlen safety, updates atomic API usage, and adds SM120 regression tests for paged‑KV and backward postprocess. ChangesBlock-sparse mainloop for SM80/SM120
SM120 TMA forward kernel with warp specialization
SM120 dispatch, interface, and forward-sm120 changes
SM80 forward kernel adjustments and paged‑KV notes
SM80/SM120 hardening, backward, and utilities
PackGQA and seqlen safety fixes
SM120-focused regression tests
🎯 4 (Complex) | ⏱️ ~60 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
flash_attn/cute/block_sparse_utils.py (1)
708-758: ⚡ Quick winVarlen path not supported – direct 4D indexing duplicates helper logic.
run_block_sparse_mainloop_sm80hard-codes the non-varlen 4D indexing pattern (lines 750-758) instead of usingget_curr_blocksparse_tensorswhich handles both varlen (2D) and non-varlen (4D) layouts. Other consumers likeconsume_block_sparse_loads(line 403) properly delegate to the helper.If varlen + block-sparse on SM80/SM120 is intended to be supported later, the function should accept
seqlen_infoand use the existing helper. If not supported, consider adding a compile-time assertion.♻️ Suggested refactor to use helper
`@cute.jit` def run_block_sparse_mainloop_sm80( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, mma_one_n_block, mask_fn, mask_mod, fastdiv_mods, + seqlen_info: SeqlenInfoQK, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): ... - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors - m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - - curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] - - if const_expr(full_block_cnt is not None): - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] - else: - curr_full_block_cnt = Int32(0) - curr_full_block_idx = None + ( + curr_mask_block_cnt, + curr_mask_block_idx, + curr_full_block_cnt, + curr_full_block_idx, + ) = get_curr_blocksparse_tensors( + batch_idx, head_idx, m_block_sparse, blocksparse_tensors, seqlen_info + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/block_sparse_utils.py` around lines 708 - 758, run_block_sparse_mainloop_sm80 duplicates non-varlen 4D indexing for mask/full block tensors instead of using the shared helper; either update run_block_sparse_mainloop_sm80 to accept seqlen_info and call get_curr_blocksparse_tensors(...) (the same helper used by consume_block_sparse_loads) so it correctly handles both varlen (2D) and non-varlen (4D) layouts, or add a compile-time assertion in run_block_sparse_mainloop_sm80 that varlen layouts are not supported; locate the logic around curr_mask_block_cnt/curr_mask_block_idx and curr_full_block_cnt/curr_full_block_idx and replace it with the helper call (or assert) accordingly.flash_attn/cute/flash_fwd_sm120_tma.py (3)
566-566: 💤 Low valueRemove unused variable
n_block.This variable is computed but never used in the kernel. The loop at line 784 computes
cur_n_block = n_block_max - n_tile - 1directly.🧹 Suggested cleanup
n_block_min, n_block_max = block_info.get_n_block_min_max( seqlen, m_block, split_idx, num_splits ) - n_block = cutlass.max(n_block_max - 1, 0)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/flash_fwd_sm120_tma.py` at line 566, Remove the unused local variable n_block (the expression n_block = cutlass.max(n_block_max - 1, 0)) since it is never referenced later; update the surrounding code to rely on the existing computation cur_n_block = n_block_max - n_tile - 1 (and any uses of n_block) so only n_block_max, n_tile, and cur_n_block remain; ensure no other code paths reference n_block before deleting the assignment and related dead-code.
23-44: 💤 Low valueRemove unused imports.
Several imports are flagged by static analysis as unused:
Constexpr(line 23) — used ascutlass.Constexprinstead of the bare namePackGQA(line 37)NamedBarrierFwd(line 38)FastDivmodDivisor(line 44)🧹 Suggested cleanup
-from cutlass import Constexpr, Float32, Int32, const_expr +from cutlass import Float32, Int32, const_expr-from flash_attn.cute.pack_gqa import PackGQA -from flash_attn.cute.named_barrier import NamedBarrierFwd-from cutlass.cute import FastDivmodDivisor🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/flash_fwd_sm120_tma.py` around lines 23 - 44, Remove the unused imports to clean up the module: delete the bare imports Constexpr, PackGQA, NamedBarrierFwd, and FastDivmodDivisor from the top import block; ensure code that currently refers to cutlass.Constexpr continues to use the qualified name (cutlass.Constexpr) instead of the removed bare Constexpr, and verify there are no references to PackGQA, NamedBarrierFwd, or FastDivmodDivisor elsewhere in this file (e.g., search for PackGQA, NamedBarrierFwd, FastDivmodDivisor) before removing them to avoid breaking references.
914-1003: 💤 Low value
mma_one_n_blockmethod is unused.This method is defined but never called. The kernel at lines 784–837 has the same logic inlined directly, with a comment at lines 780–783 explaining this is intentional to avoid CuTe DSL compiler hangs when pipeline states flow through method boundaries.
If this is dead code from development, consider removing it to reduce maintenance burden. If it's intended for future use, consider adding a comment indicating that.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/flash_fwd_sm120_tma.py` around lines 914 - 1003, The method mma_one_n_block is defined but never used (its logic is inlined inside the kernel to avoid CuTe DSL compiler hangs); either remove this dead function to reduce maintenance or keep it but add a clear comment above mma_one_n_block stating it is intentionally unused and kept for reference/future use (mention the kernel that inlines the logic and the compiler-hang rationale), so future readers know why it remains in the codebase.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flash_attn/cute/flash_fwd.py`:
- Around line 1384-1396: The call to apply_score_mod is passing seqlen as the
positional argument where softmax_scale belongs, causing an argument-order bug;
update the call in the const_expr(score_mod is not None) branch to pass
softmax_scale and seqlen by keyword (e.g., softmax_scale=softmax.softmax_scale,
seqlen=seqlen) or otherwise ensure softmax_scale is the 7th positional and
seqlen the 8th; reference the apply_score_mod invocation that currently uses
mma_params.thr_mma_qk, batch_idx, head_idx, m_block, acc_S, n_block, seqlen,
softmax_scale=... and mirror the correct ordering used in compute_one_n_block.
In `@flash_attn/cute/interface.py`:
- Around line 954-960: The local reassignment of is_varlen near the TMA kernel
selection overrides the earlier definition that includes seqused_q/seqused_k;
remove the redefinition (the lines that set is_varlen = cu_seqlens_q is not None
or cu_seqlens_k is not None) and use the previously computed is_varlen when
evaluating use_tma_sm120 (alongside page_table and use_block_sparsity) so
callers providing seqused_q/seqused_k are correctly treated as varlen.
---
Nitpick comments:
In `@flash_attn/cute/block_sparse_utils.py`:
- Around line 708-758: run_block_sparse_mainloop_sm80 duplicates non-varlen 4D
indexing for mask/full block tensors instead of using the shared helper; either
update run_block_sparse_mainloop_sm80 to accept seqlen_info and call
get_curr_blocksparse_tensors(...) (the same helper used by
consume_block_sparse_loads) so it correctly handles both varlen (2D) and
non-varlen (4D) layouts, or add a compile-time assertion in
run_block_sparse_mainloop_sm80 that varlen layouts are not supported; locate the
logic around curr_mask_block_cnt/curr_mask_block_idx and
curr_full_block_cnt/curr_full_block_idx and replace it with the helper call (or
assert) accordingly.
In `@flash_attn/cute/flash_fwd_sm120_tma.py`:
- Line 566: Remove the unused local variable n_block (the expression n_block =
cutlass.max(n_block_max - 1, 0)) since it is never referenced later; update the
surrounding code to rely on the existing computation cur_n_block = n_block_max -
n_tile - 1 (and any uses of n_block) so only n_block_max, n_tile, and
cur_n_block remain; ensure no other code paths reference n_block before deleting
the assignment and related dead-code.
- Around line 23-44: Remove the unused imports to clean up the module: delete
the bare imports Constexpr, PackGQA, NamedBarrierFwd, and FastDivmodDivisor from
the top import block; ensure code that currently refers to cutlass.Constexpr
continues to use the qualified name (cutlass.Constexpr) instead of the removed
bare Constexpr, and verify there are no references to PackGQA, NamedBarrierFwd,
or FastDivmodDivisor elsewhere in this file (e.g., search for PackGQA,
NamedBarrierFwd, FastDivmodDivisor) before removing them to avoid breaking
references.
- Around line 914-1003: The method mma_one_n_block is defined but never used
(its logic is inlined inside the kernel to avoid CuTe DSL compiler hangs);
either remove this dead function to reduce maintenance or keep it but add a
clear comment above mma_one_n_block stating it is intentionally unused and kept
for reference/future use (mention the kernel that inlines the logic and the
compiler-hang rationale), so future readers know why it remains in the codebase.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9c771396-cc50-48ee-b2a8-55805d806e5e
📒 Files selected for processing (7)
flash_attn/cute/block_sparse_utils.pyflash_attn/cute/flash_bwd.pyflash_attn/cute/flash_fwd.pyflash_attn/cute/flash_fwd_sm120.pyflash_attn/cute/flash_fwd_sm120_tma.pyflash_attn/cute/interface.pyflash_attn/cute/utils.py
…shadowing 1. flash_fwd.py:1392 — apply_score_mod() was called with seqlen as the 7th positional argument, where the signature places softmax_scale. The original call then passed softmax_scale=... as a keyword, which would have raised 'multiple values for argument softmax_scale' under strict Python, or silently bound seqlen as the softmax_scale under CuTeDSL's relaxed semantics (producing wildly wrong attention scores in the score_mod path). Fix by passing both softmax_scale and seqlen by keyword to match the correct call pattern in compute_one_n_block at lines 1254-1266. 2. interface.py:955 — the SM120 dispatch block redefined is_varlen with a narrower check (cu_seqlens_q/cu_seqlens_k only), shadowing the outer-scope is_varlen defined at lines 626-631 which correctly includes seqused_q and seqused_k. A caller passing seqused_q/seqused_k without cu_seqlens would have been silently routed to the TMA kernel, which does not support varlen, producing wrong output or a crash. Remove the local redefinition so the outer is_varlen is used; add a comment so this doesn't get re-introduced. Both flagged by CodeRabbit on PR #1. Regression smokes pass: fwd: max abs diff 0.003906 vs SDPA, bwd: max abs diff 0.007812.
flash_fwd_sm120_tma.py: - Remove unused imports: Constexpr (the file uses cutlass.Constexpr via qualified name), PackGQA, NamedBarrierFwd, FastDivmodDivisor. - Remove unused local n_block at line 566 — the kernel's compute loop computes cur_n_block = n_block_max - n_tile - 1 directly and never references n_block. - Add comment above mma_one_n_block explaining why the method is kept despite being unused: its logic is inlined into kernel() (around lines 784-837) to avoid a CuTe DSL compiler hang when k_pipeline/v_pipeline consumer states flow through a method boundary. Kept as a reference template. block_sparse_utils.py: - run_block_sparse_mainloop_sm80 hard-codes non-varlen 4D indexing into the blocksparse_tensors NamedTuple. Add a docstring note documenting why this is intentionally narrow (the SM120 dispatcher in interface.py does not enable varlen + block-sparse together) and what would need to change to lift it (route through get_curr_blocksparse_tensors with seqlen_info threaded in). Regression smokes pass: fwd max abs diff 0.003906, bwd 0.007812 vs SDPA.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/cute/test_flash_attn_bwd_sm120_postprocess.py`:
- Around line 71-73: The inline continuation comment for the variable S in the
test (around the assignment "S = 1024") is misaligned and triggers Flake8
E114/E116; update the comment so its indentation matches the wrapped statement's
indentation (align the comment with the start of the "S = 1024" line or the
wrapped continuation indent) to satisfy linting rules, touching the
test_flash_attn_bwd_sm120_postprocess.py section containing S = 1024 and the
following comment.
In `@tests/cute/test_paged_kv_sm120.py`:
- Around line 36-38: Remove the unused imports causing Flake8 F401 in
tests/cute/test_paged_kv_sm120.py by deleting the import lines for math and os
(leave the necessary import for sys if used elsewhere); ensure only required
modules are imported so the top-of-file import block no longer includes math or
os.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b8475edd-852f-436c-8947-c425cf37be30
📒 Files selected for processing (10)
flash_attn/cute/block_sparse_utils.pyflash_attn/cute/flash_bwd_postprocess.pyflash_attn/cute/flash_fwd.pyflash_attn/cute/flash_fwd_sm120.pyflash_attn/cute/flash_fwd_sm120_tma.pyflash_attn/cute/interface.pyflash_attn/cute/pack_gqa.pyflash_attn/cute/seqlen_info.pytests/cute/test_flash_attn_bwd_sm120_postprocess.pytests/cute/test_paged_kv_sm120.py
CodeRabbit's second review on PR #1 flagged two lint issues introduced by the Phase 4-S bwd-postprocess test and the Phase 4-R paged-KV test: * tests/cute/test_flash_attn_bwd_sm120_postprocess.py:72 The wrapped comment on the second line of `S = 1024 # ...` was indented to align with the value (column 15), which trips E114 ("indentation is not a multiple of x (comment)") and E116 ("unexpected indentation (comment)"). Convert it to a normal 4-space-indented comment block above the assignment. * tests/cute/test_paged_kv_sm120.py:36-37 `import math` and `import os` were left over from an earlier draft of the regression test and are not referenced anywhere in the file (F401). Drop both. No functional change to either test. flake8 with --select=E114,E116,F401 is now clean on both files.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flash_attn/cute/interface.py (1)
1105-1105:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPhase 5c tuned
num_stagesvalues are never applied to SM120 kernels.The
sm120_num_stagesvariable is computed at line 585 from the_SM120_TILE_LOOKUP(e.g.,(64, 1, 2048, 1): (64, 64, 2)setsnum_stages=2), but line 1105 hardcodesnum_stages=1instead of using the tuned value.This means the Phase 5c performance tuning for SM120 is ineffective.
🐛 Proposed fix
fa_fwd = FlashAttentionForwardSm120( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, - num_stages=1, + num_stages=sm120_num_stages, num_threads=num_threads, Q_in_regs=False, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/interface.py` at line 1105, The code hardcodes num_stages=1 where SM120 kernels should use the tuned sm120_num_stages computed from _SM120_TILE_LOOKUP; update the call/site that currently passes num_stages=1 to instead pass the variable sm120_num_stages (the value computed at line ~585) so Phase 5c SM120 tuning is applied, ensuring any downstream uses (function/class where num_stages is supplied) accept and propagate that variable rather than the literal 1.
🧹 Nitpick comments (1)
flash_attn/cute/interface.py (1)
899-901: 💤 Low valueComment is misleading — Phase 5c lookup does not run for SM80.
The comment claims
num_stagesis set by "Phase 5c per-shape lookup," but the_SM120_TILE_LOOKUPlogic (lines 535–592) is guarded byif arch // 10 == 12:, so it never executes for SM80. For SM80,sm120_num_stagesstays at the default value of 1.Consider updating the comment to reflect the actual behavior:
- # num_stages set by Phase 5c per-shape lookup above (defaults - # to 1; bumped to 2 for shapes where K/V pipelining wins). + # SM80 uses 1 stage; SM120 overrides this in its own branch. num_stages=sm120_num_stages,🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/interface.py` around lines 899 - 901, The comment is misleading about Phase 5c affecting num_stages; update the comment near the num_stages=num_stages=sm120_num_stages assignment to say that the Phase 5c per-shape lookup (_SM120_TILE_LOOKUP) runs only when arch // 10 == 12 (SM120), so for SM80 sm120_num_stages remains the default 1; mention that num_stages is therefore conditionally adjusted only for SM120 and defaults to 1 on SM80 to avoid confusion.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flash_attn/cute/interface.py`:
- Line 1105: The code hardcodes num_stages=1 where SM120 kernels should use the
tuned sm120_num_stages computed from _SM120_TILE_LOOKUP; update the call/site
that currently passes num_stages=1 to instead pass the variable sm120_num_stages
(the value computed at line ~585) so Phase 5c SM120 tuning is applied, ensuring
any downstream uses (function/class where num_stages is supplied) accept and
propagate that variable rather than the literal 1.
---
Nitpick comments:
In `@flash_attn/cute/interface.py`:
- Around line 899-901: The comment is misleading about Phase 5c affecting
num_stages; update the comment near the num_stages=num_stages=sm120_num_stages
assignment to say that the Phase 5c per-shape lookup (_SM120_TILE_LOOKUP) runs
only when arch // 10 == 12 (SM120), so for SM80 sm120_num_stages remains the
default 1; mention that num_stages is therefore conditionally adjusted only for
SM120 and defaults to 1 on SM80 to avoid confusion.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a5f532f7-b808-4704-ac78-59bf83310c26
📒 Files selected for processing (3)
flash_attn/cute/interface.pytests/cute/test_flash_attn_bwd_sm120_postprocess.pytests/cute/test_paged_kv_sm120.py
💤 Files with no reviewable changes (1)
- tests/cute/test_paged_kv_sm120.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flash_attn/cute/interface.py`:
- Around line 1104-1107: The forward JIT cache key (compile_key) currently omits
sm120_num_stages while the SM120 kernel selection uses it
(num_stages=sm120_num_stages), leading to cache collisions across shapes; update
the compile_key construction (the same place that builds the forward compile
cache key) to include sm120_num_stages so that compile_key distinguishes
variants by num_stages and prevents reusing a compiled kernel with the wrong
sm120_num_stages.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8d4bc294-813d-427f-b7d8-eb453f985493
📒 Files selected for processing (1)
flash_attn/cute/interface.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flash_attn/cute/interface.py (1)
1028-1042:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
can_implementcheck uses hardcodednum_stages=1instead ofsm120_num_stages.The
can_implementcall at line 1033 passesnum_stages=1, but the kernel is instantiated at line 1053 withnum_stages=sm120_num_stages. Sincesm120_num_stagescan be 2 from the Phase 5c lookup table (e.g.,(64, 1, 2048, 1): (64, 64, 2)), the SMEM check validates for 1-stage but the kernel runs with 2 stages — potentially overflowing the 99 KB SMEM cap.🐛 Proposed fix
assert FlashAttentionForwardSm120.can_implement( dtype, head_dim, head_dim_v, tile_m, tile_n, - num_stages=1, num_threads=num_threads, is_causal=causal, + num_stages=sm120_num_stages, num_threads=num_threads, is_causal=causal, Q_in_regs=False, ), (🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/interface.py` around lines 1028 - 1042, The SM 12.0 validation wrongly hardcodes num_stages=1 when calling FlashAttentionForwardSm120.can_implement, which can mis-validate SMEM usage vs the actual kernel instantiation using sm120_num_stages; update the can_implement call to pass num_stages=sm120_num_stages (preserving the other parameters like dtype, head_dim, head_dim_v, tile_m, tile_n, num_threads, is_causal, Q_in_regs) so the SMEM/hang/divisibility checks reflect the real kernel configuration used later when creating the FlashAttentionForwardSm120 kernel.
🧹 Nitpick comments (2)
flash_attn/cute/flash_fwd.py (1)
1700-1701: 💤 Low valueDead variables:
smem_pipe_readandsmem_pipe_writeare unused.These variables are initialized but never read. Since
num_stages == 1is enforced by the assertion at line 1655, all SMEM accesses use hardcoded index0(e.g.,sK[None, None, 0]). Consider removing them.♻️ Proposed fix to remove dead variables
- smem_pipe_read = Int32(0) - smem_pipe_write = Int32(0) nb = n_block🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/flash_fwd.py` around lines 1700 - 1701, Remove the dead variables smem_pipe_read and smem_pipe_write which are initialized but never used; since the assertion enforcing num_stages == 1 (see assertion around num_stages) makes all SMEM accesses use the fixed index 0 (e.g., sK[None, None, 0]), delete the Int32(0) declarations for smem_pipe_read and smem_pipe_write and any related unused references so there is no unused state left in the flash_fwd.py top-level scope.flash_attn/cute/block_sparse_utils.py (1)
788-811: 💤 Low valueConsider explicitly passing
mask_mod=Nonefor full blocks for consistency.In
consume_block_sparse_loads(lines 470, 480), when transitioning from mask blocks to full blocks,mask_mod=Noneis explicitly passed. Here,mask_modis omitted for full blocks, relying on a default value inapply_mask. While this likely works ifapply_maskdefaultsmask_mod=None, explicit passing would be clearer and consistent with the SM90/SM100 consumer path.♻️ Proposed fix for explicit mask_mod=None
if const_expr(full_block_cnt is not None): if curr_full_block_cnt > 0: n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: mma_one_n_block( n_block=n_block, - mask_fn=partial(mask_fn, mask_seqlen=True), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), is_first_n_block=True, ) else: mma_one_n_block( n_block=n_block, - mask_fn=partial(mask_fn, mask_seqlen=True), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), is_first_n_block=False, ) for j in cutlass.range(1, curr_full_block_cnt): n_block = curr_full_block_idx[curr_full_block_cnt - 1 - j] mma_one_n_block( n_block=n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), is_first_n_block=False, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flash_attn/cute/block_sparse_utils.py` around lines 788 - 811, The full-block handling in the block-sparse consumer omits an explicit mask_mod, relying on apply_mask's default; change the two mma_one_n_block calls in the full-block section so they pass mask_mod=None (i.e., the first call with mask_fn=partial(mask_fn, mask_seqlen=True), is_first_n_block=True/False should also include mask_mod=None, and the subsequent calls using mask_seqlen=False should likewise include mask_mod=None) to match the consume_block_sparse_loads behavior and the SM90/SM100 path and make intent explicit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flash_attn/cute/interface.py`:
- Around line 1028-1042: The SM 12.0 validation wrongly hardcodes num_stages=1
when calling FlashAttentionForwardSm120.can_implement, which can mis-validate
SMEM usage vs the actual kernel instantiation using sm120_num_stages; update the
can_implement call to pass num_stages=sm120_num_stages (preserving the other
parameters like dtype, head_dim, head_dim_v, tile_m, tile_n, num_threads,
is_causal, Q_in_regs) so the SMEM/hang/divisibility checks reflect the real
kernel configuration used later when creating the FlashAttentionForwardSm120
kernel.
---
Nitpick comments:
In `@flash_attn/cute/block_sparse_utils.py`:
- Around line 788-811: The full-block handling in the block-sparse consumer
omits an explicit mask_mod, relying on apply_mask's default; change the two
mma_one_n_block calls in the full-block section so they pass mask_mod=None
(i.e., the first call with mask_fn=partial(mask_fn, mask_seqlen=True),
is_first_n_block=True/False should also include mask_mod=None, and the
subsequent calls using mask_seqlen=False should likewise include mask_mod=None)
to match the consume_block_sparse_loads behavior and the SM90/SM100 path and
make intent explicit.
In `@flash_attn/cute/flash_fwd.py`:
- Around line 1700-1701: Remove the dead variables smem_pipe_read and
smem_pipe_write which are initialized but never used; since the assertion
enforcing num_stages == 1 (see assertion around num_stages) makes all SMEM
accesses use the fixed index 0 (e.g., sK[None, None, 0]), delete the Int32(0)
declarations for smem_pipe_read and smem_pipe_write and any related unused
references so there is no unused state left in the flash_fwd.py top-level scope.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4f343344-27a3-4116-b112-e6513be015c2
📒 Files selected for processing (9)
flash_attn/cute/block_sparse_utils.pyflash_attn/cute/flash_bwd_postprocess.pyflash_attn/cute/flash_fwd.pyflash_attn/cute/flash_fwd_sm120.pyflash_attn/cute/flash_fwd_sm120_tma.pyflash_attn/cute/interface.pyflash_attn/cute/pack_gqa.pyflash_attn/cute/seqlen_info.pyflash_attn/cute/utils.py
… + tile tuning Makes the three cherry-picked upstream SM120 PRs (Dao-AILab#2553, Dao-AILab#2349, Dao-AILab#2389) actually usable end-to-end on consumer Blackwell (RTX 5090, RTX PRO 6000 Blackwell). The upstream PRs alone leave SM120 forward dispatcher-buggy and backward broken; this commit adds the integration glue, backward support, real paged-KV + pack_gqa implementations, a subprocess-isolated per-shape tile lookup, and the test coverage to back it up. # Dispatcher fixes (SM120 forward + backward couldn't compile or run end-to-end # without these) - Initialize dQ_single_wg in the SM120 backward setup (was unbound) - Keep softmax_scale non-None for SM80/SM120 backward dK epilogue (inline log2 computation like SM90 does) - atomic_add_fp32: adopt the new keyword-only nvvm.atomicrmw signature in nvidia-cutlass-dsl >= 4.x - Drop unsupported is_split_kv kwarg on the SM120 forward path - Pass split_idx=0, num_splits=1, seqlen_info=seqlen at the SM80/SM120 get_total_block_count call site (arity mismatch fix) - Rename vec_size -> score_vec_size on the SM120 TMA forward (typo in upstream Dao-AILab#2349; was AttributeError on softcap/learnable_sink/score_mod) - Auto-downgrade pack_gqa=False when qhead_per_kvhead doesn't divide tile_m=128 (qwen2.5-7b's 7-way GQA otherwise fails cute.local_tile division) - Auto-downgrade pack_gqa=False when paged-KV is used (cross-feature interaction with PagedKVManager's K/V indexing) - Route head_dim > head_dim_v to the non-TMA SM120 path: bisection showed the hang lives in FlashAttentionForwardSm120Tma, not in the SM80-base mainloop as upstream diagnosed. Non-TMA can_implement accepts d > dv; the TMA path still rejects it so the dispatcher falls through. d > dv shapes now work (verified bitwise-identical to SDPA on the minimum repro). - Route SM120 through the shared _validate_head_dims helper (invalid head_dim was reaching the kernel and faulting with cudaErrorMisalignedAddress) - Clamp the cu_seqlens[batch_idx+1] read in SeqlenInfoQK.create so SM80/SM120 over-launched varlen tiles don't fault on a non-resident page - arch-gate FlashAttentionForwardBase.epilogue smem store atom: SM80/SM120 force the universal copy, SM90 keeps WGMMA-paired stmatrix (upstream PR Dao-AILab#2553's bc67a9c unconditionally forced 80, which silently switched SM90 forward through the universal-copy path) - Include sm120_num_stages in the forward compile cache key (different ns values with the same tile would otherwise share a key and the second call would reuse the first-compiled kernel) - Document why deterministic backward can't be lifted on SM120 (the SM80 base kernel itself lacks the dQ_semaphore code path; a feature gap shared with SM80) # SM120-specific kernel work - Real paged-KV forward via PagedKVManager on the SM80-base kernel, supported through head_dim <= 128. A paged-specific tile override (128, 128, ns=1) gates on page_table is not None and head_dim <= 128 so PagedKVManager's tile_n >= num_threads invariant holds. SMEM math fits: 48 KB at d=64, 72 KB at d=96, 96 KB at d=128 (cap 99 KB). - Real pack_gqa=True support: rewrite PackGQA.compute_ptr to compute the flat offset arithmetically from stride[0][0] and stride[0][1] rather than cute.crd2idx (which cuTeDSL 4.4-4.5 collapses through trailing slices). Call pack_gqa_layout in the SM80-base forward so packed Q is actually materialized (was missing — would have produced wrong output even after the crd2idx workaround). - Backward postprocess dQ smem-store atom: force universal copy on SM80/SM120 (same class of bug as the upstream Dao-AILab#2553 forward fix but in the dQ postprocess; left silent rmem->smem scrambling otherwise). Permanent regression test with a white-box source-inspection guard against reintroduction. - New D > 128 SM120 tile bracket (64, 64, ns=1) that fits the 99 KB SMEM cap for head_dim=256. # Forward tile selection (per-shape lookup) The SM120 forward dispatch now consults a tile + num_stages lookup keyed on (head_dim, qhead_per_kvhead, seqlen, causal). Shapes outside the lookup fall back to the head_dim-only brackets that match the pre-tuning defaults. The lookup was built from a subprocess-isolated sweep: each (cell, candidate) pair is measured in a fresh python process so JIT-cache pollution can't bias the rankings (a single-process sweep silently reuses compiled kernels across candidates with subtly different shapes). The top-3 candidates per cell get a reproducibility re-measurement; variance > 10% excludes a candidate. A candidate ships only when its mean TFLOPS beats the baseline tile by >= 2%; otherwise the cell falls back to baseline. # Test coverage added - tests/cute/test_paged_kv_sm120.py (38 cases): paged-KV correctness across page_size {16, 64, 256}, identity / permuted / shared page tables, GQA + MQA, d in {64, 96, 128}; expected NotImplementedError for d in {192, 256}; expected correctness (now, not rejection) for the paged + d > dv + varlen cross-feature combination. - tests/cute/test_flash_attn_bwd_sm120_postprocess.py (10 cases): backward dQ postprocess regression suite, combines numeric vs fp32-SDPA comparison with a white-box source-inspection guard against the buggy literal pattern. - tests/cute/test_flash_attn_sm120_dgtdv.py (11 cases): regression test for the Bug E d > dv non-TMA routing. 8 kernel-launch parametrizations plus 3 unit probes (TMA rejection, non-TMA acceptance, SMEM constraint). All kernel tests carry pytest-timeout(30) with --timeout-method=signal so a future TMA gate widening that re-introduces the GPU hang fails as a timeout instead of wedging the GPU. # What this is NOT - Real paged-KV at head_dim > 128: rejected at dispatch with a clear NotImplementedError. Lifting would require either a refactor of PagedKVManager (per-thread page-table fragment > 0 at tile_n < num_threads) or a separate kernel; the 99 KB SMEM cap precludes the simple (128, 128, ns=1) approach used for d <= 128. - Real fix for the TMA path d > dv hang: the kernel-level root cause needs cuda-gdb or instrumented bisection; the routing fix makes user-visible shapes correct today, but the TMA kernel itself is still latent-broken for d > dv. The can_implement gate ensures the TMA path is never selected for d > dv. - Deterministic backward on SM120: asserts off because the SM80 base kernel itself lacks the dQ_semaphore code path. Lift would need a feature port from SM90 into the SM80 base; out of scope here.
53b1ad5 to
cee3b54
Compare
|
@coderabbitai review Force-pushed: squashed all thad0ctor commits into one (
Please do a full review on the squashed diff. |
|
✅ Actions performedFull review triggered. |
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
Independent verification of squashed tip ( Re-ran the full test plan from a worktree-isolated agent on a fresh branch off
GPU pinned by UUID to the RTX 5090 ( |
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
FA4 (this PR) vs FA2 2.8.3 benchmark — RTX 5090, sm_120, bf16, 160 paired cells + 8 FA4-only cells Head-to-head on the same RTX 5090, same torch 2.11.0+cu130. Fresh subprocess per cell, 3 warmup + 10 timed CUDA-event iterations, median latency, Dao FLOPs convention. All 168 cells returned Forward (40 paired cells):
Backward (40 paired cells):
FA4-only shapes (no FA2 reference):
These shapes either hung the GPU ( Full per-cell table now in the PR body under "Performance vs FlashAttention 2". |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/cute/test_flash_attn_bwd_sm120_postprocess.py`:
- Around line 151-156: The test's assertion in
tests/cute/test_flash_attn_bwd_sm120_postprocess.py is brittle because it checks
an exact multi-line string for "get_smem_store_atom(\n
self.arch,"; update the guard to use a regex that ignores whitespace (e.g., use
re.search with a pattern like r"get_smem_store_atom\s*\(\s*self\.arch\s*,") so
it fails if self.arch is still passed as the first positional argument but
survives formatting changes; locate the assertion referencing
flash_bwd_postprocess.FlashAttentionBackwardPostprocess and replace the string
containment check with a whitespace-agnostic regex match that ensures self.arch
is not the first positional parameter (consider also allowing store_atom_arch to
appear).
In `@tests/cute/test_paged_kv_sm120.py`:
- Around line 223-230: The tests use Python's built-in hash(page_table_pattern)
for seeding (in test_page_table_patterns and the other paged-case test that
calls _run_paged_case), which is randomized per process; replace that with a
deterministic hash function or mapping (e.g., compute a stable integer from
page_table_pattern using hashlib.sha256 or zlib.crc32 and then mask to 16 bits)
and pass that stable seed into _run_paged_case so runs are reproducible; update
both occurrences where seed=hash(page_table_pattern) & 0xFFFF to use the
deterministic conversion instead.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 40870c62-2a0d-4695-a0c5-d30a2d82106c
📒 Files selected for processing (13)
flash_attn/cute/block_sparse_utils.pyflash_attn/cute/flash_bwd.pyflash_attn/cute/flash_bwd_postprocess.pyflash_attn/cute/flash_fwd.pyflash_attn/cute/flash_fwd_sm120.pyflash_attn/cute/flash_fwd_sm120_tma.pyflash_attn/cute/interface.pyflash_attn/cute/pack_gqa.pyflash_attn/cute/seqlen_info.pyflash_attn/cute/utils.pytests/cute/test_flash_attn_bwd_sm120_postprocess.pytests/cute/test_flash_attn_sm120_dgtdv.pytests/cute/test_paged_kv_sm120.py
CodeRabbit (PR #1) flagged the exact-string membership check on `get_smem_store_atom(\n self.arch,` as brittle: any reformat of the source line would silently disarm the regression guard even if `self.arch` were still passed as the first positional arg. Switch the assertion to `re.search(r"get_smem_store_atom\(\s*self\.arch\s*,", src)` so the guard fails on the buggy pattern regardless of whitespace or line-wrap formatting. Closes-Upstream-Comment: #1
CodeRabbit (PR #1) flagged that `seed=hash(page_table_pattern) & 0xFFFF` is non-reproducible across runs: Python's builtin `hash(str)` is randomized per process via PYTHONHASHSEED, so a tolerance-band failure on one run cannot be repro'd from the recorded seed. Replace both occurrences (test_page_table_patterns and test_d_gt64_page_table_patterns) with a fixed PATTERN_SEEDS mapping ("identity"->101, "permuted"->202, "shared"->303). The second test keeps its `d * 1000 + seed` derivation so d=96 and d=128 sweeps remain seed-disjoint. All 9 affected parametrized cases pass on RTX 5090 (sm_120). Closes-Upstream-Comment: #1
Add a subprocess-isolated benchmark harness for tuning the FA4 CuTe
backward kernel tile sizes on consumer Blackwell (sm_120) GPUs.
benchmarks/sm120_bwd_tuning/
measure_one_bwd.py - one cell + one tile config, prints JSON
bench_master_bwd.py - orchestrator: sweep + repro + analyze + validate
README.md - usage, env vars, output schema
For each (preset, seqlen, causal) cell, spawns a fresh Python process per
(tile_m, tile_n, num_stages) candidate so the CuTe JIT cache and CUDA
context start cold for every measurement.
The measurement script monkey-patches the SM120 branch of
flash_attn.cute.interface._flash_attn_bwd so the hard-coded tile
constants can be overridden per run without touching kernel sources.
No source changes to flash_attn/. Pure tooling; the harness is opt-in
via env vars (FA_BENCH_GPU_UUID, FA_BENCH_OUT_DIR, FA_BENCH_PYTHON,
FA_BENCH_CUTE_OVERRIDE) and falls back to inheriting CUDA_VISIBLE_DEVICES
and sys.executable when unset.
The 0.05 threshold rejected legitimate bf16 noise on backward gradients at longer sequences (causal=1, sl=2048 cells with dV diff=0.0625 -- ~8 bf16 ULPs at magnitude 1.0), causing the sweep to mis-flag the working baseline as a numerical_fail. 0.1 absolute is roughly 12 bf16 ULPs and still catches genuine corruption: when an MMA-unsupported tile (e.g. tile_n=32 on the SM80-base backward) compiles but produces garbage, the dK/dV diff is in the 0.4-0.75 range -- well past the 0.1 cutoff.
|
Update — four new commits on
Phase 14 backward tile sweep — negative result, by design: A subprocess-isolated sweep across the same 5 model presets × 4 seqlens × {causal, non-causal} = 40 cells used by the Phase 13 FA4-vs-FA2 bench. Acceptance gate: 0 cells regress >2% AND geomean ≥ 1.02×.
Root cause (documented in So no dQ atomic audit (W2): also negative. The existing docstrings already accurately frame the SM80-only Verified: PR body's "Performance vs FlashAttention 2" section updated with this negative finding. |
Phase 16c sweep + Phase 17C tightened paired validation (RTX 5090,
n_measure=30, interleaved trials) confirm ns=1 wins on d=64 backward.
ns=2 was inherited from the SM80-base default but the async pipeline
overhead exceeds the latency-hiding benefit at the small d=64 tile
size on consumer Blackwell.
Paired validation across 19 d=64 cells (5 model presets x 4 seqlens x
{causal, non-causal} from the qhead_per_kvhead = {1,4,7} matrix):
geomean ratio ns=2/ns=1 = 1.0558x
0 cells regress >2%
Code change is minimal: the SM120 backward branch now uses ns=1 for
all head_dim (the d>64 branch already used ns=1, so this is purely a
flip of the d<=64 case from 2 to 1). No effect on forward, non-sm_120
arches, or any other parameter.
Validation:
- E2E phase11_e2e/e2e.py: 34/34 pass
- tests/cute/test_paged_kv_sm120.py: 38/38 pass
- tests/cute/test_flash_attn_sm120_dgtdv.py: 11/11 pass
- tests/cute/test_flash_attn_bwd_sm120_postprocess.py: 10/10 pass
Phase 16a NCU profile attributed the 0.93x FA4/FA2 backward geomean to parallelism: FA4 SM120 backward ran 4 warps/SM, FA2 runs 8 warps/SM, both clamped to 1 block by ~82 KB SMEM. The 9pp compute-throughput gap (80.12% vs 88.57%) was the dominant explanation. This commit repartitions the SM120 backward kernel to 256-thread / 8-warp blocks at the SAME SMEM footprint. The bug surfaced by the prior 17A-config attempt (causal-path dQ/dK/dV errors of 5-20 vs the 0.004 noise floor) was traced to the R2P bitmask fast-path in flash_attn/cute/mask.py:r2p_bitmask_below + sm90_col_to_r2p_idx assuming the standard SM80/SM90 per-thread column pattern (col-pairs at stride 8). With AtomLayoutSdP = (4, 2, 1) the SM120 256-thread configuration has 2 N-warps and the per-thread cols interleave at stride 16 instead, so the R2P bitmask kept cells beyond the causal boundary. Fixed by adding an optional r2p_compatible field to AttentionMask (default True; preserves all SM80/SM90/SM100 behaviour) and gating the 4-warps-per-tile detection to sm_120 only via a new `arch = 120` marker on FlashAttentionBackwardSm120. The SM80 path is unaffected (no `arch` attribute -> getattr falls back to 80). Also fixes a postprocess invariant: the dq_accum / dk_accum / dv_accum byte buffers are written by the main kernel via a thread-major partition whose stride is num_threads; the postprocess reader must use the same num_threads or per-thread element->address mapping diverges. SM120 branch now mirrors num_threads_post_dQ/dKV = 256. SM80, SM90, SM100 paths untouched. NCU after on mistral-7b sl=4096 c=1 bwd: theoretical occupancy 8.33% -> 16.67% (matches FA2), compute throughput 80.12% -> 87.55% (1pp short of FA2's 88.57%). Phase 13 40-cell backward paired bench: geomean ~1.05x; the 6 flagged regressions are all in the documented small-seqlen / GQA bench-noise floor (10% CV) and flicker across re-runs. Validation: E2E 34/34, sm120 pytest 59/59, forward unchanged within bench noise.
…layout
Phase 17D-lite-v3: switch gmem_tiled_copy_dQaccum to a 128-bit copy atom
with val_layout=4 on SM120 only, so each thread owns 4 contiguous fp32 in
gdQaccum. With the post-17A-config 256-thread / 8-warp partition this is
the prerequisite for emitting red.global.add.v4.f32 in the dQ accumulator
write loop, cutting the atomic instruction count by 4x. The dK/dV GQA
atomic-add path (qhead_per_kvhead > 1) reuses the same V=4 copy and the
same v4 atomic helper.
The MMA m16n8k16 C-fragment for thread t holds 4 fp32 (c0,c1,c2,c3) in
2-contig col pairs at row offsets {r, r+8}. retile(acc_dQ) flattens
((2,2),1,N):(...) -> ((4,1),1,N):(...) without reordering registers — the
register-storage flat order is identity in both layouts. So writing the
4 per-atom registers to 4 contiguous gmem positions yields a consistent
flat encoding as long as the postprocess reads back in the same order;
flash_bwd_postprocess.py is updated to use num_s2r_copy_elems=4 for
SM120 to satisfy that invariant.
Architectural gating:
- SM80 path: untouched. The dQaccum gmem copy and atomic loop both
remain V=1 / scalar atomic.
- SM90/SM100: unaffected (separate files, not subclasses of
FlashAttentionBackwardSm80).
- SM120: V=4 / v4 atomic, gated via getattr(self, "arch", 80) == 120.
Adds utils.atomic_add_fp32_v4 as a v4 inline-asm wrapper (mirrors the
existing copy_utils.atomic_add_fp32x4 but lives next to atomic_add_fp32
so the backward kernel can import both from the same module).
Validation (RTX 5090 / sm_120, bf16):
- phase11_e2e/e2e.py: 34/34 pass
- tests/cute/test_flash_attn_bwd_sm120_postprocess.py: 8/8 pass
- tests/cute/test_flash_attn_sm120_dgtdv.py: 11/11 pass
- tests/cute/test_paged_kv_sm120.py: 40/40 pass (2 skipped)
- Smoke (mistral-7b/qwen2.5-7b sl=1024 c=1, llama2-7b sl=1024 c=0):
max diff vs SDPA <= 0.05 on all 3 cells.
|
Phase 17 backward optimization landed — please re-review the 3 new kernel commits + 2 test-file commits since the squash base Summary of changes since the last review:
Result on 40-cell Phase 13 backward bench (RTX 5090, bf16): FA4 / FA2 = 1.017× geomean (was 0.93×). Cells where FA4 ≥ FA2: 29 / 40 (was 10/40). Peak FA4 TFLOPS: 180.8 (was 167.0). MHA cells now beat FA2 by 6.4%; GQA cells at par. All changes arch-gated to sm_120 only — E2E @coderabbitai review |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Get the list of commits since cee3b54 on the PR branch
gh pr view 1 --repo thad0ctor/flash-attention --json headRefName,headRefOid,baseRefNameLength of output: 312 🏁 Script executed: #!/bin/bash
# Get the diff for commits since cee3b54
gh api repos/thad0ctor/flash-attention/compare/cee3b54...sm120-integrate \
--jq '.commits[] | {sha: .sha, message: .commit.message}' 2>/dev/null | head -100Length of output: 8884 🏁 Script executed: #!/bin/bash
# Get the files changed in the new commits
gh api repos/thad0ctor/flash-attention/compare/cee3b54...sm120-integrate \
--jq '.files[] | {filename: .filename, additions: .additions, deletions: .deletions, patch: .patch}' 2>/dev/nullLength of output: 62616
Test fixes (e06bc15, 4670859) ✅Both prior-review items are cleanly addressed:
Bench harness (a0dd865, 47aa883) — MAJOR: monkey-patch needle broken
needle = (
" if arch // 10 == 12:\n"
" # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps).\n"
" m_block_size = 64\n"
" n_block_size = 64\n"
" if head_dim <= 64:\n" # <-- removed by 4d59090
" num_stages_Q = 2\n" # <-- removed by 4d59090
" num_stages_dO = 2\n" # <-- removed by 4d59090
" else:\n" # <-- removed by 4d59090
" num_stages_Q = 1\n"
" num_stages_dO = 1\n"
)But commit The fix is straightforward — update the needle to match the post-4d59090 source: needle = (
" if arch // 10 == 12:\n"
" # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps).\n"
" m_block_size = 64\n"
" n_block_size = 64\n"
" # num_stages=1 across all head_dim on consumer Blackwell.\n"
...
" num_stages_Q = 1\n"
" num_stages_dO = 1\n"
)(This is bench-only tooling, so it doesn't affect correctness of the kernel, but the harness is currently non-functional against this codebase.) Phase 17C:
|
| Area | Status | Notes |
|---|---|---|
| e06bc15 test regex fix | ✅ | Correctly addresses prior CR MAJOR |
| 4670859 PATTERN_SEEDS | ✅ | Correctly addresses prior CR MINOR |
| a0dd865/47aa883 bench harness | needle mismatch vs 4d59090 → RuntimeError on import |
|
| 4d59090 ns=1 flip | ✅ | Well-validated, clean |
| 362a65a 8-warp + R2P fix | ✅ | R2P logic correct, mask.py change minimal |
| 55ab672 v4 atomic | ✅ | PTX correct, alignment guaranteed structurally |
The only blocking item before merging is the bench harness needle — since it's out-of-runtime-path it won't affect users, but if you plan to use or share the sweep tooling it needs the needle updated to match the current interface.py. All kernel-path changes look correct.
🐇 (solid set of changes overall)=^._.^=
Integrates three open upstream PRs adding FA4 forward support for consumer Blackwell (sm_120) and adds the dispatcher fixes required to make sm_120 forward AND backward actually run end-to-end. Also tunes the SM120 forward tile selection per shape from a subprocess-isolated sweep.
What's here
Upstream PRs cherry-picked
pack_gqa=Falseon consumer Blackwell)get_total_block_countutilityWhat this PR adds on top of Dao-AILab#2553 + Dao-AILab#2349 + Dao-AILab#2389
The three upstream PRs are forward-only correctness/feature fixes. Even with all three merged, SM120 forward still doesn't work end-to-end (latent dispatcher and kernel bugs surface immediately) and backward hangs or asserts. This PR adds the integration glue + backward support + tile tuning + real paged-KV/pack_gqa that you don't get from cherry-picking the upstream PRs alone:
apply_score_modarg-order bug,vec_sizetypo,is_split_kvkwarg, etc.)dQ_single_wgunbound,softmax_scaleNone→Float,atomic_add_fp32signature mismatch)pack_gqa=Trueon SM120crd2idxcomposite-mode collapse atpack_gqa.py:139)pack_gqa=Falsevia flat-offset arithmetic inPackGQA.compute_ptrcute.local_tile("failed to perform a valid division of((7,?),?)by[128:1;128:1]")mPageTablekwarg but ignores it)head_dim=128; 38-case regression suitehead_dimbracket (conservative: 128×128 for d≤64, 128×64 for d≤128)(head_dim, qhead_per_kvhead, seqlen, causal)tuned via subprocess-isolated sweep (60 cells × 5 candidates, top-3 reproducibility re-pass, strict ≥2% paired-validation acceptance, +num_stages override). Geomean 1.06× over the head_dim-only default, peak win +31% (hd64-small-gqa sl=2048 causal)head_dim > head_dim_v(e.g. MLA absorbed)(64, 64, ns=1)tile bracketcudaErrorIllegalAddress(out-of-boundscu_seqlensread on over-launched grid tiles)test_block_sparsity.pyAttributeErroronvec_size(typo in Dao-AILab#2349's TMA forward)cudaErrorMisalignedAddressfrom deviceAssertionErroris_varlenthat includesseqused_q/kapply_score_modarg orderCherry-picking Dao-AILab#2553/Dao-AILab#2349/Dao-AILab#2389 alone gives you SM120 forward kernels that compile but fail at first dispatch. This PR is what makes them actually usable end-to-end for training and inference.
Performance vs PyTorch SDPA (RTX 5090, sm_120, bf16, batch=1, forward, this PR's tuned tip)
Paired validation: tuned FA4 averaged over multiple bench runs vs the best of (cuDNN, Flash-2) SDPA backends, measured via a standard per-cell harness (5 warmup + 30 iters,
torch.cuda.Eventmedian timing).Aggregate across 60 cells (5 models × 6 seqlens × 2 causal):
Notable: the GQA causal/mid-seqlen band where 7B-class training and decode spend most of their time consistently shows FA4 +10-30% over best-SDPA. The cells where FA4 loses concentrate at sl=16384 head_dim=128 (cuDNN's long-seqlen path is well-tuned there). PyTorch 2.11's SDPA also has a documented kernel-selector pothole at sl=2048-4096 head_dim=128 causal+GQA where it drops to 12-50 TFLOPS; FA4 doesn't.
Dispatcher fixes for SM120
dQ_single_wgin the SM120 backward setup (was missing → UnboundLocalError)softmax_scalenon-None for SM80/SM120 backward dK epilogue (inline log2 computation like SM90 does)atomic_add_fp32: adopt the new keyword-onlyatomicrmwsignature innvidia-cutlass-dsl >= 4.xis_split_kvkwarg on the SM120 forward pathsplit_idx=0, num_splits=1, seqlen_info=seqlenat the SM80/SM120get_total_block_countcall site (arity mismatch fix)vec_size→score_vec_sizeon the SM120 TMA forward (was a typo, unblocks softcap/learnable_sink/score_mod)pack_gqa=Falseon SM120 whenqhead_per_kvheaddoesn't dividetile_m=128(qwen2.5-7b's 7-way GQA fail), and when paged-KV is enabled (path interaction)head_dim > head_dim_vto the SM120 non-TMA path: bisection showed the hang lives inFlashAttentionForwardSm120Tma, not in the SM80-base mainloop as previously diagnosed. The non-TMAcan_implementnow accepts d > dv; the TMAcan_implementstill rejects it, so the dispatcher routes those shapes to the SM80-base kernel which handles them correctly (verified bitwise-identical to SDPA on the minimum repro and within bf16 noise across larger shapes)_validate_head_dimshelper (invalid head_dim was reaching the kernel and faulting withcudaErrorMisalignedAddress)cu_seqlens[batch_idx+1]read inSeqlenInfoQK.createso SM80/SM120 over-launched varlen tiles don't fault on a non-resident pagedQ_semaphorecode path; it's a feature gap shared with SM80, not an SM120 SMEM limit)SM120-specific kernel work
head_dim = 128using a paged-specific(128, 128, ns=1)tile (validated for d in {64, 96, 128} against SDPA on reconstructed K/V; max abs diff ~0.004 bf16).head_dim > 128rejected at dispatch.pack_gqa=Truesupport: rewritePackGQA.compute_ptrto compute the flat offset arithmetically (avoid the cuTeDSLcrd2idxcollapse on composite mode 0); callpack_gqa_layoutin the SM80-base forwardD > 128tile bracket(64, 64)for SM120 that fits the 99 KB SMEM cap (was overflowing); wireFlashAttentionForwardSm120.can_implementas a dispatch safety netPer-shape SM120 forward tile +
num_stageslookupThe SM120 forward dispatch consults a tile +
num_stageslookup keyed on(head_dim, qhead_per_kvhead, seqlen, causal). Shapes outside the lookup fall back to the head_dim-only brackets that match the pre-tuning defaults.The lookup was built from a subprocess-isolated sweep: each
(cell, candidate)pair runs in a freshpythonprocess so the JIT compile cache cannot silently bias rankings across candidates. The top-3 candidates per cell get a reproducibility re-measurement; variance > 10% excludes the candidate. A candidate replaces baseline only when its mean TFLOPS beats the baseline tile by ≥ 2% on paired bench runs; otherwise the cell falls back to baseline.Verified
E2E suite on RTX 5090, sm_120, bf16 — 34 / 34 pass through the public
flash_attn.cute.flash_attn_func/flash_attn_varlen_funcAPI, max abs diff vs PyTorch SDPA math backend (with KV repeat-interleave for GQA reference):head_dim > head_dim_v(previously hung the GPU):(B, S, H, d, dv)∈ {(1,64,1,128,64), (2,2048,4,128,64), (1,1024,8,128,64)}, causal + non-causalpack_gqa=Truevspack_gqa=Falseon 4-way GQA (llama-style)pack_gqaauto-downgradehead_dim=128viaflash_attn_varlen_func(page_size=64, permuted page table)cudaErrorIllegalAddress)head_dim=256forward (new tile bracket)Additional pytest coverage in
tests/cute/:test_paged_kv_sm120.py: 38 cases pass (page_size in {16, 64, 256}, identity / permuted / shared page tables, GQA + MQA,head_dim ∈ {64, 96, 128}, plus expected-rejection forhead_dim ∈ {192, 256}and a cross-feature paged +d > dv+ varlen correctness case)test_flash_attn_sm120_dgtdv.py: 11 cases pass (head_dim > head_dim_vrouting, withpytest-timeout(30)so a future TMA-gate widening that re-introduces the kernel hang fails as a timeout rather than wedging the GPU)test_flash_attn_bwd_sm120_postprocess.py: 10 cases pass (backward dQ postprocess; combines a numeric vs-fp32-SDPA comparison with a white-box source-inspection guard against re-introduction of the buggy literal pattern)test_block_sparsity.py: 624 of the ~4900 collected cases run as a representative slice on sm_120 (all 624 pass)Tile lookup paired validation (RTX 5090, multiple baseline runs + multiple tuned runs averaged, same bench harness):
hd64-small-gqa sl=2048causal +31%,sl=4096causal +21%; mid-seqlen causal d=128 GQA shapes (llama3 / mistral / qwen) consistently +13–15%Performance vs FlashAttention 2 (RTX 5090, sm_120, bf16)
Head-to-head bench against
flash_attn 2.8.3(Dao-AILab official wheel) on the same RTX 5090, same torch 2.11.0 + cu130 build. Each cell run in a fresh subprocess (defeat JIT cache pollution), 3 untimed warmups + 10 timed CUDA-event iterations, median latency reported, Dao FLOPs convention (4*B*H_q*S*S*Dnon-causal, halved for causal, 2.5× for backward). Matrix: 5 model presets (llama3-8b, mistral-7b, qwen2.5-7b, llama2-7b, mixtral-8x7b) ×sl ∈ {1024, 2048, 4096, 8192}×causal ∈ {False, True}×direction ∈ {fwd, bwd}= 160 paired cells, all returningstatus: ok.Forward (40 paired cells):
Largest forward wins:
qwen2.5-7b sl=1024 causal1.13×,mistral-7b sl=20481.10×,mixtral-8x7b sl=20481.08×. Largest forward regression:llama2-7b sl=8192 causal0.91× (MHA tail).Backward (40 paired cells, timing backward pass only):
Backward is now at or above FA2 parity on geomean across the 40-cell matrix. MHA (
llama2-7b) cells run 1.064× geomean (FA4 beats FA2 by 6.4%), GQA cells (qwen2.5-7b 7-way + llama3-8b/mistral-7b/mixtral-8x7b 4-way) at 1.005× geomean. Top wins:llama2-7b sl=1024 causal1.159×,llama3-8b sl=1024 causal1.143×,mistral-7b sl=8192 causal1.129×. Largest remaining regression:qwen2.5-7b sl=1024 non-causal0.803×.Phase 17 backward optimization — three shipped levers (separate commits on top of the squash base, not coalesced):
cute: flip SM120 backward d<=64 default num_stages 2->1— Phase 16c discovery + tightened paired-validation bench (n_measure=30, interleaved trials) confirmedns=1wins on 31/40 d=64 backward cells.ns=2was inherited from the SM80-base default but the async pipeline overhead exceeds the latency-hiding benefit at small tile size on consumer Blackwell. Arch-gated toarch==120+head_dim<=64. +5.6% geomean on the d=64 sub-sweep. (Note: the Phase 13 head-to-head matrix is all d=128 so this commit doesn't appear in the numbers above.)cute: SM120 backward repartition 4 warps -> 8 warps per block— NCU profile attributed the original 0.93× gap to parallelism: FA4 ran 4 warps/SM vs FA2's 8 warps/SM, both clamped to 1 block by ~82 KB SMEM. This commit repartitions the SM120 backward kernel to 256-thread / 8-warp blocks at the SAME SMEM footprint. The non-trivial bug surfaced by a first attempt (num_threads=128→256works on non-causal but breaks all causal shapes with dQ/dK/dV errors of 5-20 vs the 0.004 baseline) was traced to the R2P bitmask fast-path inflash_attn/cute/mask.py:r2p_bitmask_below+sm90_col_to_r2p_idxassuming the standard SM80/SM90 per-thread column pattern (col-pairs at stride 8). WithAtomLayoutSdP = (4,2,1)the SM120 256-thread configuration has 2 N-warps and the per-thread cols interleave at stride 16, so the R2P bitmask kept cells beyond the causal boundary. Fix: newr2p_compatible=Falsemarker onFlashAttentionBackwardSm120. Arch-gated to sm_120 only. NCU after onmistral-7b sl=4096 c=1 bwd: theoretical occupancy 8.33% → 16.67% (now matches FA2), compute throughput 80.12% → 87.55% (within 1pp of FA2's 88.57%).cute: SM120 backward v4 atomic dQ/dK/dV via 4-contig per-thread gmem layout— Phase 16a NCU profile showed FA4 emitted 192REDG.E.ADD.F32per CTA vs FA2's 32 (the SM120 SASS uses per-elementred.global.add.f32). After the 256-thread repartition, restructuringgmem_tiled_copy_dQaccumfromval_layout=(1)toval_layout=(4)with a 128-bit copy atom gives each thread 4 contiguous fp32 ingdQaccum. The MMA m16n8k16 C-fragment is flat-compact in register storage, soretile()toval_layout=(4)doesn't reorder data —red.global.add.v4.f32lands at the correct 4 gmem positions. Wired for the dQ atomic loop AND the GQA dK/dV epilogue. Arch-gated to sm_120. Bench shows +1.76× over the prior tip on the 20-cell focused subset.The original Phase 13 message about the backward tile tuning being a "negative result at d=128" still holds for the per-shape tile lookup approach explored in Phase 14 (only
(64,64,ns=1)fits at d=128 + SMEM cap). The bench harness from that work is preserved atbenchmarks/sm120_bwd_tuning/. The real gains came from the three commits above — restructuring the kernel itself, not picking among tile permutations.FA4-only shapes (no FA2 reference — FA2 doesn't support these):
d=128, dv=64(Bug-E shape, B=2, S=4096, H=32)d=128, dv=64d=192, dv=128(B=2, S=2048, H=32)d=192, dv=128d=256, dv=128(B=2, S=2048, H=32)d=256, dv=128head_dim=128,page_size=64(B=4, S=2048, H=8, permuted)head_dim=128,page_size=64These cells previously hung the GPU (
d > dvBug E) or were not implemented (paged-KV beyond the FA3 path); this PR makes them functional on sm_120 with no FA2 equivalent.Raw data:
bench/results.jsonl, per-cell subprocess logs preserved for replay.Not in scope
dQ_semaphorecode path (shared gap with SM80; lifting it on SM120 alone is not a small change).head_dim > 128— SM120's 99 KB SMEM cap can't fit thetile_n >= 128PagedKVManager requirement plus d > 128 (would need 144 KB at d=192, 192 KB at d=256). Use SM100 / SM90 for those shapes.Behavior changes from upstream (silent argument downgrades)
These auto-downgrades happen with no warning emitted. They keep callers correct but may surprise users who explicitly opt into the downgraded feature:
pack_gqa=True+page_table != None→pack_gqasilently set toFalse. Reason: paged-KV'snum_head_kvaccounting conflicts with the packed mQ layout. Affects all arches that go through the SM120/SM80-base path.pack_gqa=True+num_splits != 1+ non-varlen →pack_gqasilently set toFalse. Pre-existing upstream behavior, no change.pack_gqa=Trueon SM120 whenqhead_per_kvheaddoes not dividetile_m=128(e.g. qwen2.5-7b's 7-way GQA) → silently downgraded so thecute.local_tiledivision succeeds. SM120-only.Test plan
Run on a sm_120 GPU (RTX 5090 / RTX PRO 6000 Blackwell). All test files in
tests/cute/gate themselves ontorch.cuda.get_device_capability() == (12, 0)and skip cleanly on other arches. Invokepytestfrom a directory other than the repo root so the namespaceflash_attnpackage isn't shadowed by the local FA2__init__.py.tests/cute/test_paged_kv_sm120.py— paged-KV correctness throughhead_dim = 128, multiple page sizes / patterns, GQA + MQA, expected rejection for d > 128, plus a cross-feature paged + d > dv + varlen case (38 cases)tests/cute/test_flash_attn_sm120_dgtdv.py—head_dim > head_dim_vnon-TMA routing regression suite, withpytest-timeout(30)so a future TMA-gate widening fails as a timeout instead of wedging the GPU (11 cases)tests/cute/test_flash_attn_bwd_sm120_postprocess.py— backward dQ postprocess regression: numeric vs fp32-SDPA comparison + white-box source-inspection guard against re-introduction of the buggy literal pattern (10 cases)tests/cute/test_block_sparsity.py— representative slice (~624 of ~4900 collected cases run on sm_120; all 624 expected to pass)flash_attn_funccall vsF.scaled_dot_product_attentionwithSDPBackend.MATH, max abs diff < 0.05requires_grad=True, compare dq/dk/dv against SDPA reference, max abs diff < 0.05cc @coderabbitai
Summary by CodeRabbit