Align DSA indexer kernels and fix dense score-grad clipping#297
Align DSA indexer kernels and fix dense score-grad clipping#297jiayus-nvidia wants to merge 2 commits into
Conversation
📝 WalkthroughWalkthroughThis PR refactors sparse attention backward and forward kernels across multiple GPU architectures. Changes include: adding ChangesBackward API and Grid Reordering
Dense SM90 Backward Score-Grad DSL Implementation
Forward SM90 Kernel Simplification
Test Reference Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
test/python/fe_api/dsa/dsa_reference.py (1)
573-596: 💤 Low valueReturn type annotation is now incorrect.
When
return_scores=True, this function returnsTuple[torch.Tensor, torch.Tensor], but the annotation on line 580 declares only-> torch.Tensor.📝 Suggested fix
+from typing import Optional, Tuple, Union ... def _dense_indexer_predict_distribution( q_indexer: torch.Tensor, # (B, S_q, H, D) k_indexer: torch.Tensor, # (B, S_k, D) weights: torch.Tensor, # (B, S_q, H) sm_scale: float, ratio: int, return_scores: bool = False, -) -> torch.Tensor: +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test/python/fe_api/dsa/dsa_reference.py` around lines 573 - 596, The return type annotation for _dense_indexer_predict_distribution is wrong: when return_scores=True it returns a tuple (predict, scores). Update the function annotation to reflect both possibilities (e.g. -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) and add the necessary typing import(s) (Union and/or Tuple) at top of the file; ensure the annotation matches the function's behavior controlled by the return_scores parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 573-596: The return type annotation for
_dense_indexer_predict_distribution is wrong: when return_scores=True it returns
a tuple (predict, scores). Update the function annotation to reflect both
possibilities (e.g. -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]])
and add the necessary typing import(s) (Union and/or Tuple) at top of the file;
ensure the annotation matches the function's behavior controlled by the
return_scores parameter.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b0cf009f-5d93-476c-b793-e3309b594591
📒 Files selected for processing (8)
python/cudnn/deepseek_sparse_attention/indexer_backward/api.pypython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm100.pypython/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.pypython/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.pytest/python/fe_api/dsa/dsa_reference.py
💤 Files with no reviewable changes (1)
- python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py
|
@cudnn-ci-bot run |
Summary
(seqlen, batch)ordering.topk_indices_globalplumbing so backward can handle local per-batch top-k indices by default.seqlen_k.Summary by CodeRabbit
New Features
topk_indices_globalparameter for flexible top-k indexing controlclean_logitsoption for improved masking configurationPerformance Optimizations