Skip to content

DSA bwd SM100 fix dropped dKV gradients when topk width > total_S_kv#298

Open
Jie-Fang wants to merge 1 commit into
NVIDIA:developfrom
Jie-Fang:fix/dsa-bwd-dkv-topk-gt-skv
Open

DSA bwd SM100 fix dropped dKV gradients when topk width > total_S_kv#298
Jie-Fang wants to merge 1 commit into
NVIDIA:developfrom
Jie-Fang:fix/dsa-bwd-dkv-topk-gt-skv

Conversation

@Jie-Fang

@Jie-Fang Jie-Fang commented Jun 10, 2026

Copy link
Copy Markdown

(global_row_idx) against max_seqlen_kv. A column position >= total_S_kv is not invalid -- with a non-compact topk_idxs layout (-1 sentinels, width > total_S_kv) valid indices can sit at any column. Entries past column total_S_kv were silently treated as -1 and their dKV contributions dropped, while dQ (whose load path correctly judges validity by the index value) stayed correct. With a [window | compressed] layout this zeroes the entire original-KV region of dkv bit-exactly. Drop the position-vs-seqlen comparison; the < topk bound plus the topk_idx >= 0 sentinel check in the store helpers already match the load-side and FlashMLA-forward semantics. Remove the now-unused max_seqlen_kv parameter from reduce_dKV.
Also fix the test reference _make_topk_mask: without topk_length it clamped -1 sentinels to index 0, spuriously marking KV row 0 as attended, which corrupted out/lse/gradient references for non-compact inputs.
Verified on B200: topk width 1024 > S_kv 256 now gives cos_sim(dkv) 0.9996 (was 0.498); wide non-compact layouts pass FP32 autograd checks; fe_api/dsa pytest suite passes (16 tests). Co-Authored-By: Claude Fable 5 noreply@anthropic.com

Summary by CodeRabbit

  • Refactor
    • Optimized sparse attention backward computation to improve kernel efficiency by streamlining index validation logic during gradient reduction operations.

(global_row_idx) against max_seqlen_kv. A column position >= total_S_kv
is not invalid -- with a non-compact topk_idxs layout (-1 sentinels,
width > total_S_kv) valid indices can sit at any column. Entries past
column total_S_kv were silently treated as -1 and their dKV
contributions dropped, while dQ (whose load path correctly judges
validity by the index value) stayed correct. With a [window | compressed]
layout this zeroes the entire original-KV region of dkv bit-exactly.
Drop the position-vs-seqlen comparison; the < topk bound plus the
topk_idx >= 0 sentinel check in the store helpers already match the
load-side and FlashMLA-forward semantics. Remove the now-unused
max_seqlen_kv parameter from reduce_dKV.
Also fix the test reference _make_topk_mask: without topk_length it
clamped -1 sentinels to index 0, spuriously marking KV row 0 as
attended, which corrupted out/lse/gradient references for non-compact
inputs.
Verified on B200: topk width 1024 > S_kv 256 now gives cos_sim(dkv)
0.9996 (was 0.498); wide non-compact layouts pass FP32 autograd
checks; fe_api/dsa pytest suite passes (16 tests).
Co-Authored-By: Claude Fable 5 noreply@anthropic.com
@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR removes the max_seqlen_kv parameter from the reduce_dKV kernel, eliminating redundant sequence-length bounds checking. The kernel now relies solely on the topk bound for validity conditions in top-k index preloading. The test reference implementation is updated to use boolean indexing to match the simplified logic.

Changes

Parameter removal and validity simplification

Layer / File(s) Summary
Kernel signature and parameter removal
python/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm100.py
reduce_dKV kernel signature drops max_seqlen_kv: Int32 parameter; bwd method call site is updated to not pass this argument; internal kernel computation of cur_seqlen_kv is removed.
Top-k preload validity checks
python/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm100.py
Validity conditions in main dKV0/dKV1 preload and dKV4 (64-wide tail) preload paths are simplified from global_row_idx < topk and global_row_idx < cur_seqlen_kv to global_row_idx < topk.
Test reference implementation update
test/python/fe_api/dsa/dsa_reference.py
_make_topk_mask now uses boolean indexing on valid entries (row_idx[valid], topk_idxs[valid]) instead of clamping invalid indices.

🎯 2 (Simple) | ⏱️ ~12 minutes

🐰 A kernel parameter hops away,
Max_seqlen_kv had its day,
Now topk alone stands guard so bright,
Validity checks are lean and tight,
Sparse attention takes its rightful flight! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and specifically identifies the main bug fix: removing a problematic seqlen comparison in reduce_dKV that dropped dKV gradients when topk width exceeds total_S_kv, which aligns with the core changes and objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Git: Failed to clone repository. Please run the @coderabbitai full review command to re-trigger a full review. If the issue persists, set path_filters to include or exclude specific files.


Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
test/python/fe_api/dsa/dsa_reference.py (1)

35-36: 💤 Low value

Dead code: invalid_idx is unused.

The variable invalid_idx is assigned but never read. This appears to be leftover from the old clamping approach and can be removed.

♻️ Suggested cleanup
 if topk_length is not None:
     positions = torch.arange(topk, device=topk_idxs.device).unsqueeze(0).expand(t, -1)
     invalid = positions >= topk_length.unsqueeze(1)
-    invalid_idx = topk_idxs.clone()
-    invalid_idx[invalid] = 0
     # Recompute mask only where valid
     mask = torch.zeros(t, s_kv, dtype=torch.bool, device=topk_idxs.device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/dsa_reference.py` around lines 35 - 36, Remove the
dead assignment to invalid_idx which is never used: delete the lines creating
invalid_idx from topk_idxs.clone() and setting invalid_idx[invalid] = 0; instead
rely on the existing logic that uses topk_idxs and invalid directly (references:
invalid_idx, topk_idxs.clone(), invalid) so no other changes are needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 35-36: Remove the dead assignment to invalid_idx which is never
used: delete the lines creating invalid_idx from topk_idxs.clone() and setting
invalid_idx[invalid] = 0; instead rely on the existing logic that uses topk_idxs
and invalid directly (references: invalid_idx, topk_idxs.clone(), invalid) so no
other changes are needed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 54d553c8-e190-489a-a53e-f6fb3b35185d

📥 Commits

Reviewing files that changed from the base of the PR and between 65f40b9 and 233f2c7.

📒 Files selected for processing (2)
  • python/cudnn/deepseek_sparse_attention/sparse_attention_backward/dsa_bwd_sm100.py
  • test/python/fe_api/dsa/dsa_reference.py

@vedaanta vedaanta requested review from saltyminty and vedaanta June 11, 2026 00:57
@Anerudhan

Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants