perf(dsv4 prefill sparse-attn): per-group o_proj decoupling + col_expand_add/A_K_TILE (-15%)#649
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughUpdates DeepSeek-V4 Prefill Sparse Attention Refactor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the prefill sparse attention kernel to decouple the grouped INT8 output projection per-group, mirroring the decode sparse attention pipeline. Key changes include updating tile sizes, optimizing bias addition via pl.col_expand_add, establishing explicit task dependencies with pl.spmd contexts, and rewriting the projection, quantization, and activation steps under a manual scope to overlap execution. The golden reference implementation is also updated to match. The review feedback suggests optimizing the pipelined loops in both proj_a_mm and proj_b_mm by avoiding first-iteration peeling and instead using conditional checks inside the pl.pipeline loops to allow better load overlap on CANN/Ascend hardware.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
models/deepseek/v4/prefill_sparse_attn.py (1)
84-88: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick winAdd the missing
A_K_TILEdivisibility guard.
proj_a_mmpeels0:A_K_TILEand then coversO_GROUP_IN // A_K_TILEchunks, so a non-divisibleO_GROUP_INwould leave a K-tail uncovered. Add this next to the other tile guards:Proposed guard
assert T % QUANT_TOKEN_TILE == 0 +assert O_GROUP_IN % A_K_TILE == 0 assert D % PROJ_B_MM_N_TILE == 0 and D % PROJ_B_D_CHUNK == 0 and PROJ_B_D_CHUNK % PROJ_B_MM_N_TILE == 0Based on learnings, DeepSeek v4 module-level correctness guards using
assertare intentional.Also applies to: 356-363
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/prefill_sparse_attn.py` around lines 84 - 88, Add the missing divisibility assert for A_K_TILE in the module-level guard block around the existing tensor tile checks in prefill_sparse_attn.py. The issue is that proj_a_mm peels 0:A_K_TILE and then iterates over O_GROUP_IN // A_K_TILE, so O_GROUP_IN must be divisible by A_K_TILE to avoid an uncovered K-tail; place this next to the other assert guards and keep the same correctness-check style used throughout the DeepSeek v4 module.Source: Learnings
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/deepseek/v4/prefill_sparse_attn.py`:
- Around line 84-88: Add the missing divisibility assert for A_K_TILE in the
module-level guard block around the existing tensor tile checks in
prefill_sparse_attn.py. The issue is that proj_a_mm peels 0:A_K_TILE and then
iterates over O_GROUP_IN // A_K_TILE, so O_GROUP_IN must be divisible by
A_K_TILE to avoid an uncovered K-tail; place this next to the other assert
guards and keep the same correctness-check style used throughout the DeepSeek v4
module.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: c985c0a0-a5cc-405b-b6ef-63052c4ea130
📒 Files selected for processing (1)
models/deepseek/v4/prefill_sparse_attn.py
17a1435 to
1042228
Compare
… -7.1%) Mirror two numerically-safe decode sparse-attn idioms into prefill: - qk_pv: fold the per-block softmax bias into pl.col_expand_add instead of col_expand into a dead pl.full(0) base + a separate add. Bit-exact; removes two vector ops and a [QK_M_TILE, PREFILL_ATTN_TILE] FP32 dead UB tile in the attn_qk_softmax band. - proj_a: A_K_TILE 128 -> 256. With b_trans, wo_a is K-contiguous, so K*2B = 512B fills a full a2a3 L2 cache line (128 wastes half). O_GROUP_IN = 4096, so the cube K-frag loop just halves (32 -> 16 frags); output is identical. Real-NPU a2a3 swimlane sweep: proj_a_aic 29.06 -> 26.99us (-7.1%) at 256, 512 returns to 29.05 (no gain). All values golden PASS (ratio_allclose).
1042228 to
c13a93d
Compare
…pan) Mirror decode's per-group output-projection pipeline into prefill, replacing the 3-spmd proj_a -> quant -> proj_b form -- whose per-row GLOBAL amax barrier between proj_a and proj_b serialized cube and vector -- with a manual_scope 4-segment pipeline: - proj_a_mm (pure cube): BF16 grouped GEMM -> o_r. - quant (pure vector): PER-GROUP amax (one per O_LORA group, not the full O_GROUPS*O_LORA-row reduction) + INT8 quant -> act_scale_dq[G, T]. The per-group reduction removes the proj_a<->proj_b barrier. - proj_b_mm (pure cube): INT8 GEMM -> INT32 partials[:, g*D+n], deps on quant[g] only, so group g's proj_b overlaps proj_a/quant of later groups. - proj_b_act (pure vector): sum the O_GROUPS partials by their per-group act scales, apply the per-channel weight scale -> BF16. Fine-grained deps via pl.manual_scope + pl.array(TASK_ID); merge_norm/rope become with-form spmd so proj_a can depend on the o_packed producers. Golden switched to per-group amax to match the kernel. Real-NPU a2a3 swimlane: golden PASS; makespan 787us vs 903-968us baseline. The decoupling moves the dominant proj_b vector band off the critical path (proj_b_aiv ~112us fused dequant -> proj_b_act ~29us) and lets proj_b cube overlap later groups -- the T=128 e2e win the per-row barrier forbade.
c13a93d to
6dd3c70
Compare
Summary
Mirror two decode
sparse_attno_proj idioms into prefillsparse_attn, bringing prefill's grouped output projection in line with decode. Two commits, smaller-then-bigger:1.
col_expand_add+A_K_TILE256 — numerically safeqk_pv: fold the per-block softmax bias viapl.col_expand_addinstead ofcol_expandonto a deadpl.full(0)base + a separateadd. Bit-exact; drops two vector ops and a dead UB tile in the softmax band.proj_a:A_K_TILE128 -> 256. Underb_transthewo_aoperand is K-contiguous, soK*2B = 512Bfills a full a2a3 L2 line (128 wasted half);O_GROUP_IN = 4096, so the cube K-frag loop just halves (32 -> 16). Output identical.2. Per-group o_proj decoupling — mirrors decode
proj_a -> quant -> proj_bform, whose per-row GLOBAL amax barrier serialized cube and vector, with decode'spl.manual_scope4-segment pipeline:proj_a_mm(pure cube) ->quant(PER-GROUP amax, no global barrier) ->proj_b_mm(pure cube INT32 partials) ->proj_b_act(pure vector dequant+sum), with fine-grainedpl.array(TASK_ID)deps.merge_norm/ropebecome with-form spmd soproj_acan depend on theo_packedproducers.Validation (real NPU, a2a3)
ratio_allclose) for both commits.proj_b_aiv~112 us fused dequant ->proj_b_act~29 us) and lets proj_b cube overlap later groups' proj_a/quant.A_K_TILEswept 128/256/512:proj_a_aic29.06 -> 26.99 us (-7.1%) at 256; 512 returns to 29.05 -- a clean cache-line optimum.Reproduce:
Golden change -- why, and precision
The per-group commit changes the INT8 activation quant of
o_rfrom per-row-global amax to per-group amax, so the golden was updated to model the same scheme. This aligns prefill with decode, which already uses per-group (kernel and golden) -- before this PR prefill was the outlier (per-row-global). Precision is not regressed: per-group is a finer activation quant, so it is closer to the official DeepSeek V4 o_proj path (act_quant(block_size=128), finer still). Measured INT8 quant error vs full precision: per-row-global 0.9-1.2% -> per-group ~0.78% (official-granularity per-128-block ~0.64%).Risks / notes
[T, O_GROUPS*D]INT32partialsintermediate (~16 MB at T=128); L2-swimlane profiling needs a largerPTO2_RING_HEAPthan the default.