Skip to content

Align DSA indexer kernels and fix dense score-grad clipping#297

Open
jiayus-nvidia wants to merge 2 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/sm100-scoregrad-log-clip-mask
Open

Align DSA indexer kernels and fix dense score-grad clipping#297
jiayus-nvidia wants to merge 2 commits into
NVIDIA:developfrom
jiayus-nvidia:jiayus/sm100-scoregrad-log-clip-mask

Conversation

@jiayus-nvidia

@jiayus-nvidia jiayus-nvidia commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Align SM90/SM100 indexer backward launch grids to use (seqlen, batch) ordering.
  • Add topk_indices_global plumbing so backward can handle local per-batch top-k indices by default.
  • Rework SM90 indexer forward to write reduced logits directly to global memory, removing the separate score-store path.
  • Fix dense score-grad clipping to use log-domain mask checks, and apply the same behavior to SM90 dense score-grad.
  • Keep SM90 dense GEMM compile caching independent of runtime seqlen_k.
  • Update the dense DSA reference to backprop through logits with an explicit grad signal matching clipped-log KL behavior.

Summary by CodeRabbit

  • New Features

    • Added topk_indices_global parameter for flexible top-k indexing control
    • Introduced clean_logits option for improved masking configuration
  • Performance Optimizations

    • Optimized kernel grid scheduling for better memory efficiency
    • Replaced torch-based computation with specialized kernel implementation
    • Streamlined forward indexer pipeline

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors sparse attention backward and forward kernels across multiple GPU architectures. Changes include: adding topk_indices_global flag to the backward API; systematically reordering CUDA grids from (batch, seqlen) to (seqlen, batch) across SM90/SM100 kernels; introducing a new CuTe-DSL score-grad kernel for dense SM90 backward; simplifying forward SM90 to use direct global-memory score output; and updating test references to match new kernel semantics.

Changes

Backward API and Grid Reordering

Layer / File(s) Summary
Backward API topk_indices_global support
python/cudnn/deepseek_sparse_attention/indexer_backward/api.py
IndexerBackward.__init__ and indexer_backward_wrapper accept and propagate topk_indices_global flag through initialization, compilation, and cache keying to support both local and global top-k indexing semantics.
Dense SM100 backward grid reordering and clipped-log masking
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
Dense SM100 kernels swap CUDA grid from (batch_size, seqlen) to (seqlen, batch_size) with matching block_idx interpretation swaps; clipped-log masking refactored to derive log_clip_mask from score_minus_lse >= CLIP_LOG_MIN comparison in both Phase 1 and Phase 2.
Sparse SM100/SM90 backward grid reordering
python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm100.py, python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py
Sparse backward kernels reorder grids from (batch_size, seqlen, 1) to (seqlen, batch_size, 1) with corresponding updates to kernel-side block_idx interpretation for both GEMM and score-grad kernels.

Dense SM90 Backward Score-Grad DSL Implementation

Layer / File(s) Summary
ScoreGradDenseSm90 CuTe-DSL kernel and helper
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
Introduces _dense_seqlen_info helper and ScoreGradDenseSm90 CuTe-DSL kernel class that replaces torch-based score-grad computation, applying clip masking and reduction to both dense and varlen cases.
Score-grad and GEMM compilation caching
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
Adds separate compile caches for score-grad and GEMM stages with distinct keying strategies; implements _ensure_compiled_score_grad and _run_score_grad_only to execute the DSL kernel before GEMM.
Dense SM90 backward execution flow
python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
Reworked _run function invokes DSL score-grad kernel first to produce grad_signal in idx_scores_raw, then executes GEMM stage with dIndexK dtype handling wrapped in stream contexts.

Forward SM90 Kernel Simplification

Layer / File(s) Summary
Forward kernel interface and cache update
python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py, python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
Removes use_tma_store from SM90 forward compilation cache key and kernel construction, eliminating backend-specific storage-mode specialization.
IndexerForwardSm90 clean_logits parameter
python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
Constructor introduces clean_logits: bool = True parameter to control invalid-logits masking; shared-storage layout reorganized to remove sScore shared-memory region.
Forward kernel epilogue global-memory store
python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
Replaces shared-memory score staging epilogue with new global-memory store that directly writes reduced scores to mOut with bottom-right causal masking and clean_logits logic; consumer warp-group flow simplified by eliminating score TMA/barrier synchronization.

Test Reference Implementation

Layer / File(s) Summary
Reference predict and scores retrieval
test/python/fe_api/dsa/dsa_reference.py
_dense_indexer_predict_distribution extended with optional return_scores flag to return either predict distribution or (predict, scores) tuple for flexible gradient computation.
Reference gradient computation refactor
test/python/fe_api/dsa/dsa_reference.py
Dense backward reference refactored to compute grad_signal from clipped targets and predict distribution, avoiding explicit scalar loss computation and using analytic gradients for backpropagation through scores.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

mod-frontend, orig-nv-eng, cat-enhancements

Suggested reviewers

  • saltyminty
  • Anerudhan

Poem

🐰 Grids dance a new quadrille,
From batch-first to seqlen's sway,
Score-grads bloom in DSL's quill,
Epilogues write to global arrays today,
Clean logits gleam, the forward way!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: grid alignment (SM90/SM100 indexer kernels), topk_indices_global support, SM90 forward rework, and dense score-grad clipping fixes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
test/python/fe_api/dsa/dsa_reference.py (1)

573-596: 💤 Low value

Return type annotation is now incorrect.

When return_scores=True, this function returns Tuple[torch.Tensor, torch.Tensor], but the annotation on line 580 declares only -> torch.Tensor.

📝 Suggested fix
+from typing import Optional, Tuple, Union
...
 def _dense_indexer_predict_distribution(
     q_indexer: torch.Tensor,  # (B, S_q, H, D)
     k_indexer: torch.Tensor,  # (B, S_k, D)
     weights: torch.Tensor,  # (B, S_q, H)
     sm_scale: float,
     ratio: int,
     return_scores: bool = False,
-) -> torch.Tensor:
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test/python/fe_api/dsa/dsa_reference.py` around lines 573 - 596, The return
type annotation for _dense_indexer_predict_distribution is wrong: when
return_scores=True it returns a tuple (predict, scores). Update the function
annotation to reflect both possibilities (e.g. -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor]]) and add the necessary typing import(s)
(Union and/or Tuple) at top of the file; ensure the annotation matches the
function's behavior controlled by the return_scores parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@test/python/fe_api/dsa/dsa_reference.py`:
- Around line 573-596: The return type annotation for
_dense_indexer_predict_distribution is wrong: when return_scores=True it returns
a tuple (predict, scores). Update the function annotation to reflect both
possibilities (e.g. -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]])
and add the necessary typing import(s) (Union and/or Tuple) at top of the file;
ensure the annotation matches the function's behavior controlled by the
return_scores parameter.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b0cf009f-5d93-476c-b793-e3309b594591

📥 Commits

Reviewing files that changed from the base of the PR and between 65f40b9 and cfebc97.

📒 Files selected for processing (8)
  • python/cudnn/deepseek_sparse_attention/indexer_backward/api.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/dense_indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm100.py
  • python/cudnn/deepseek_sparse_attention/indexer_backward/indexer_backward_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py
  • python/cudnn/deepseek_sparse_attention/indexer_forward/indexer_fwd_sm90.py
  • test/python/fe_api/dsa/dsa_reference.py
💤 Files with no reviewable changes (1)
  • python/cudnn/deepseek_sparse_attention/indexer_forward/_interface_sm90.py

@Anerudhan

Copy link
Copy Markdown
Collaborator

@cudnn-ci-bot run

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants