DSA bwd SM100 fix dropped dKV gradients when topk width > total_S_kv#298
DSA bwd SM100 fix dropped dKV gradients when topk width > total_S_kv#298Jie-Fang wants to merge 1 commit into
Conversation
(global_row_idx) against max_seqlen_kv. A column position >= total_S_kv is not invalid -- with a non-compact topk_idxs layout (-1 sentinels, width > total_S_kv) valid indices can sit at any column. Entries past column total_S_kv were silently treated as -1 and their dKV contributions dropped, while dQ (whose load path correctly judges validity by the index value) stayed correct. With a [window | compressed] layout this zeroes the entire original-KV region of dkv bit-exactly. Drop the position-vs-seqlen comparison; the < topk bound plus the topk_idx >= 0 sentinel check in the store helpers already match the load-side and FlashMLA-forward semantics. Remove the now-unused max_seqlen_kv parameter from reduce_dKV. Also fix the test reference _make_topk_mask: without topk_length it clamped -1 sentinels to index 0, spuriously marking KV row 0 as attended, which corrupted out/lse/gradient references for non-compact inputs. Verified on B200: topk width 1024 > S_kv 256 now gives cos_sim(dkv) 0.9996 (was 0.498); wide non-compact layouts pass FP32 autograd checks; fe_api/dsa pytest suite passes (16 tests). Co-Authored-By: Claude Fable 5 noreply@anthropic.com
📝 WalkthroughWalkthroughThis PR removes the ChangesParameter removal and validity simplification
🎯 2 (Simple) | ⏱️ ~12 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
test/python/fe_api/dsa/dsa_reference.py (1)
35-36: 💤 Low valueDead code:
invalid_idxis unused.The variable
invalid_idxis assigned but never read. This appears to be leftover from the old clamping approach and can be removed.♻️ Suggested cleanup
if topk_length is not None: positions = torch.arange(topk, device=topk_idxs.device).unsqueeze(0).expand(t, -1) invalid = positions >= topk_length.unsqueeze(1) - invalid_idx = topk_idxs.clone() - invalid_idx[invalid] = 0 # Recompute mask only where valid mask = torch.zeros(t, s_kv, dtype=torch.bool, device=topk_idxs.device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test/python/fe_api/dsa/dsa_reference.py` around lines 35 - 36, Remove the dead assignment to invalid_idx which is never used: delete the lines creating invalid_idx from topk_idxs.clone() and setting invalid_idx[invalid] = 0; instead rely on the existing logic that uses topk_idxs and invalid directly (references: invalid_idx, topk_idxs.clone(), invalid) so no other changes are needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 35-36: Remove the dead assignment to invalid_idx which is never
used: delete the lines creating invalid_idx from topk_idxs.clone() and setting
invalid_idx[invalid] = 0; instead rely on the existing logic that uses topk_idxs
and invalid directly (references: invalid_idx, topk_idxs.clone(), invalid) so no
other changes are needed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 54d553c8-e190-489a-a53e-f6fb3b35185d
📒 Files selected for processing (2)
python/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm100.pytest/python/fe_api/dsa/dsa_reference.py
|
@cudnn-ci-bot run |
(global_row_idx) against max_seqlen_kv. A column position >= total_S_kv is not invalid -- with a non-compact topk_idxs layout (-1 sentinels, width > total_S_kv) valid indices can sit at any column. Entries past column total_S_kv were silently treated as -1 and their dKV contributions dropped, while dQ (whose load path correctly judges validity by the index value) stayed correct. With a [window | compressed] layout this zeroes the entire original-KV region of dkv bit-exactly. Drop the position-vs-seqlen comparison; the < topk bound plus the topk_idx >= 0 sentinel check in the store helpers already match the load-side and FlashMLA-forward semantics. Remove the now-unused max_seqlen_kv parameter from reduce_dKV.
Also fix the test reference _make_topk_mask: without topk_length it clamped -1 sentinels to index 0, spuriously marking KV row 0 as attended, which corrupted out/lse/gradient references for non-compact inputs.
Verified on B200: topk width 1024 > S_kv 256 now gives cos_sim(dkv) 0.9996 (was 0.498); wide non-compact layouts pass FP32 autograd checks; fe_api/dsa pytest suite passes (16 tests). Co-Authored-By: Claude Fable 5 noreply@anthropic.com
Summary by CodeRabbit