Skip to content

[Cute,Sm80] Fix forward GQA correctness — pack_gqa flag without pack_…#2528

Open
kbkarthikeyan wants to merge 1 commit into
Dao-AILab:mainfrom
kbkarthikeyan:fix/sm80-fwd-pack-gqa-mask-blockinfo
Open

[Cute,Sm80] Fix forward GQA correctness — pack_gqa flag without pack_…#2528
kbkarthikeyan wants to merge 1 commit into
Dao-AILab:mainfrom
kbkarthikeyan:fix/sm80-fwd-pack-gqa-mask-blockinfo

Conversation

@kbkarthikeyan
Copy link
Copy Markdown

…gqa_layout call

FlashAttentionForwardSm80 set qhead_per_kvhead_packgqa = qhead_per_kvhead on BlockInfo and AttentionMask when pack_gqa=True, but it never actually calls pack_gqa_layout to fold qhead into the seqlen dimension (only flash_fwd_sm100.py and flash_fwd_mla_sm100.py do).

The downstream code in BlockInfo.get_n_block_min_max and AttentionMask.apply_mask then divides row indices by the ratio, treating un-packed rows as packed:

  • wrong causal column limit per row (row 0 = 0/r = 0 happens to be right; all other rows wrong by factor r)
  • wrong K-block iteration range past m_block >= 1

Fix: SM80 forward never packs, so pass 1 (constexpr) at both call sites.

Reproduces and is fixed on RTX PRO 6000 Blackwell Max-Q (sm_120a, which inherits from FlashAttentionForwardSm80).

…_gqa_layout

FlashAttentionForwardSm80 and FlashAttentionBackwardSm80 set
qhead_per_kvhead_packgqa = qhead_per_kvhead and use a pack_gqa shortcut
on head_idx_kv when pack_gqa=True, but neither class ever calls
pack_gqa_layout (only the SM100 paths do). Downstream code then treats
un-packed rows as packed — wrong causal column limit, wrong K-block range,
wrong head index for K/V/dK/dV reads.

Fix: both classes never pack, so always set qhead_per_kvhead_packgqa=1
and always compute head_idx_kv = head_idx // qhead_per_kvhead.

Symptom on SM80-class GPUs (notably SM120 / RTX PRO 6000 Blackwell):
  fwd GQA causal: rel diff 0.77-1.18 vs SDPA -> < 0.005
  bwd GQA all D<=256: dq rel 9-470, NaN at >=5x scale -> < 0.01, no NaN

Verified on Qwen3-1.7B, Qwen3.5-0.8B, Mistral 7B at 1x/5x/10x logit scale.
@kbkarthikeyan kbkarthikeyan force-pushed the fix/sm80-fwd-pack-gqa-mask-blockinfo branch from f0c8a8c to bcb170d Compare April 30, 2026 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant