[Cute,Sm80] Fix forward GQA correctness — pack_gqa flag without pack_…#2528
Open
kbkarthikeyan wants to merge 1 commit into
Open
[Cute,Sm80] Fix forward GQA correctness — pack_gqa flag without pack_…#2528kbkarthikeyan wants to merge 1 commit into
kbkarthikeyan wants to merge 1 commit into
Conversation
…_gqa_layout FlashAttentionForwardSm80 and FlashAttentionBackwardSm80 set qhead_per_kvhead_packgqa = qhead_per_kvhead and use a pack_gqa shortcut on head_idx_kv when pack_gqa=True, but neither class ever calls pack_gqa_layout (only the SM100 paths do). Downstream code then treats un-packed rows as packed — wrong causal column limit, wrong K-block range, wrong head index for K/V/dK/dV reads. Fix: both classes never pack, so always set qhead_per_kvhead_packgqa=1 and always compute head_idx_kv = head_idx // qhead_per_kvhead. Symptom on SM80-class GPUs (notably SM120 / RTX PRO 6000 Blackwell): fwd GQA causal: rel diff 0.77-1.18 vs SDPA -> < 0.005 bwd GQA all D<=256: dq rel 9-470, NaN at >=5x scale -> < 0.01, no NaN Verified on Qwen3-1.7B, Qwen3.5-0.8B, Mistral 7B at 1x/5x/10x logit scale.
f0c8a8c to
bcb170d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…gqa_layout call
FlashAttentionForwardSm80 set
qhead_per_kvhead_packgqa = qhead_per_kvheadon BlockInfo and AttentionMask whenpack_gqa=True, but it never actually callspack_gqa_layoutto fold qhead into the seqlen dimension (only flash_fwd_sm100.py and flash_fwd_mla_sm100.py do).The downstream code in BlockInfo.get_n_block_min_max and AttentionMask.apply_mask then divides row indices by the ratio, treating un-packed rows as packed:
Fix: SM80 forward never packs, so pass 1 (constexpr) at both call sites.
Reproduces and is fixed on RTX PRO 6000 Blackwell Max-Q (sm_120a, which inherits from FlashAttentionForwardSm80).