perf(dsv4 decode): retile csa/indexer/expert_shared/hc_post/moe/qkv#672
perf(dsv4 decode): retile csa/indexer/expert_shared/hc_post/moe/qkv#672zhangqi-chen wants to merge 2 commits into
Conversation
Fan out and retile several dsv4 decode kernels for higher core occupancy, all numerically equivalent to the prior form: - attention_csa: split the sparse-index build into a 1-token/core SPMD overlay pass plus a single T-wide compressed-slot scope. - decode_indexer: fan qr_rope over ROPE_SPMD_TILE-row blocks and fold the rope sign into sin_il (x by +/-1 is exact). - expert_shared: decompose into separate gate/up/act/requant/w2 scopes with larger 256-wide matmul tiles and pipelined vector stages. - hc_post: T_TILE 8 -> 4 to double decode SPMD blocks. - moe.combine: per-token SPMD reduction. - qkv_proj_rope: load each q head once, reuse the resident tile for the inv_rms reduction and NOPE/RoPE writeback. Also switch --enable-l2-swimlane to an int level (0/1/2) across these scripts.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR refactors tiling and SPMD loop structures across several DeepSeek v4 kernels: CSA sparse-index/validity-mask construction, indexer ROPE rotation, shared-expert matmul/quant pipeline, hc_post tiling, MoE combine reduction, and Q-head RMSNorm computation. Several test CLI harnesses also change --enable-l2-swimlane from a boolean flag to an integer option. ChangesDeepSeek v4 kernel tiling and compute refactors
Estimated code review effort: 4 (Complex) | ~60 minutes Possibly related PRs
Suggested labels: Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors and optimizes several DeepSeek v4 model components, including decode attention CSA, decode indexer, expert shared, HC post, MoE, and QKV projection. Key optimizations include restructuring SPMD loops, vectorizing operations, folding sign multiplication in RoPE, and refactoring the expert shared kernel to use parallel M-tiles with pipelined matmuls. Feedback suggests further optimizing the decode attention CSA by replacing a loop over T with vectorized slice assignment, and simplifying the decode indexer by removing a single-iteration loop while adding an explicit assertion for the tile size invariant.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for c_t0 in pl.range(T): | ||
| pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32)) | ||
| for c_sb in pl.range(1, SPARSE_BLOCKS): | ||
| c_s0 = (c_sb - 1) * ATTN_K_TILE | ||
| c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE]) | ||
| for c_dt in pl.range(CSA_TOPK_TOKEN_TILE): | ||
| c_t = topk_t0 + c_dt | ||
| if c_t < T: | ||
| c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) | ||
| pl.write(valid_block_mask, [c_t, c_sb], c_valid) | ||
| for c_dt in pl.range(T): | ||
| c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) | ||
| pl.write(valid_block_mask, [c_dt, c_sb], c_valid) |
There was a problem hiding this comment.
Since valid_block_mask and c_blk_valid are tensors, we can avoid the loop over T (which is a small static constant) by using vectorized slice assignment. This is more idiomatic in PyPTO and avoids loop overhead.
| for c_t0 in pl.range(T): | |
| pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32)) | |
| for c_sb in pl.range(1, SPARSE_BLOCKS): | |
| c_s0 = (c_sb - 1) * ATTN_K_TILE | |
| c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE]) | |
| for c_dt in pl.range(CSA_TOPK_TOKEN_TILE): | |
| c_t = topk_t0 + c_dt | |
| if c_t < T: | |
| c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) | |
| pl.write(valid_block_mask, [c_t, c_sb], c_valid) | |
| for c_dt in pl.range(T): | |
| c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32) | |
| pl.write(valid_block_mask, [c_dt, c_sb], c_valid) | |
| valid_block_mask[0 : T, 0 : 1] = pl.full([T, 1], dtype=pl.INT32, value=1) | |
| for c_sb in pl.range(1, SPARSE_BLOCKS): | |
| c_s0 = (c_sb - 1) * ATTN_K_TILE | |
| c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE]) | |
| valid_block_mask[0 : T, c_sb : c_sb + 1] = pl.cast(c_blk_valid, target_type=pl.INT32) |
| for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE): | ||
| r0 = o0 + ro | ||
| qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM] | ||
| qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx) | ||
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il)) | ||
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed)) | ||
| qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint") |
There was a problem hiding this comment.
Since ROPE_SPMD_TILE and ROPE_ROW_TILE are both hardcoded to 32, the loop for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE) will always execute for exactly one iteration (ro = 0). We can simplify the code and eliminate the loop overhead by removing the loop entirely. However, when assuming single-tile coverage, we must make this invariant explicit with an assertion (e.g., assert ROPE_SPMD_TILE == ROPE_ROW_TILE) to prevent silent correctness issues if configurations change in the future.
| for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE): | |
| r0 = o0 + ro | |
| qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM] | |
| qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx) | |
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il)) | |
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed)) | |
| qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint") | |
| assert ROPE_SPMD_TILE == ROPE_ROW_TILE, "ROPE_SPMD_TILE must match ROPE_ROW_TILE for single-tile coverage" | |
| r0 = o0 | |
| qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM] | |
| qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx) | |
| rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed)) | |
| qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint") |
References
- When assuming a single-tile coverage (e.g., T_PAD <= MM_ROW_TILE) in a kernel, make this invariant explicit with a module-level assertion (e.g., assert T_PAD == MM_ROW_TILE) to prevent silent correctness issues if configurations change.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
models/deepseek/v4/decode_attention_csa.py (1)
239-248: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick winHoist the overlay-position reads out of the
topk_kloop.
topk_overlay_pos(and thustopk_overlay_pos % WIN) depends only ontopk_os, not ontopk_k, yet it is re-read fromposition_idsfor every one of theWINwindow slots. That isWIN * Sscalar reads per SPMD token whereSdistinct values would suffice. Precomputing theSoverlay positions (or their% WIN) once per token, before thetopk_kloop, removes the redundantpl.reads on this per-token-per-core hot path without changing results.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/decode_attention_csa.py` around lines 239 - 248, The nested loop in `decode_attention_csa.py` repeatedly reads `position_ids` inside the `topk_k` loop even though `topk_overlay_pos` only depends on `topk_os`. Refactor the logic around `cmp_sparse_indices` so the overlay positions (or their `% WIN` values) are precomputed once per token before iterating over `topk_k`, then reuse those cached values in the inner comparison. Keep the behavior of `topk_out` and `pl.write` unchanged while removing the redundant `pl.read` calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/expert_shared.py`:
- Around line 52-54: The W2 dequant path in expert_shared.py currently only
handles full 512-channel tiles, so non-4096 hidden sizes like 7168 leave the
tail of y_i32 uninitialized before copying into sh. Update the logic around the
W2 dequant loop (the code using D_OUT_TILE_ACT, W2_INNER, and W2_ACT_INNER) to
iterate by D_OUT_TILE_ACT directly and add explicit remainder handling for any
final partial block before writing sh_pad.
---
Nitpick comments:
In `@models/deepseek/v4/decode_attention_csa.py`:
- Around line 239-248: The nested loop in `decode_attention_csa.py` repeatedly
reads `position_ids` inside the `topk_k` loop even though `topk_overlay_pos`
only depends on `topk_os`. Refactor the logic around `cmp_sparse_indices` so the
overlay positions (or their `% WIN` values) are precomputed once per token
before iterating over `topk_k`, then reuse those cached values in the inner
comparison. Keep the behavior of `topk_out` and `pl.write` unchanged while
removing the redundant `pl.read` calls.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 7734b03b-9e27-4c48-90c4-b6d2377dfbe0
📒 Files selected for processing (6)
models/deepseek/v4/decode_attention_csa.pymodels/deepseek/v4/decode_indexer.pymodels/deepseek/v4/expert_shared.pymodels/deepseek/v4/hc_post.pymodels/deepseek/v4/moe.pymodels/deepseek/v4/qkv_proj_rope.py
| D_OUT_TILE_ACT = 512 | ||
| W2_INNER = 4 | ||
| W2_ACT_INNER = 8 |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file="models/deepseek/v4/expert_shared.py"
echo "== outline =="
ast-grep outline "$file" --view expanded || true
echo
echo "== relevant lines =="
nl -ba "$file" | sed -n '1,240p'
echo
echo "== searches =="
rg -n "D_OUT_TILE_ACT|W2_INNER|W2_ACT_INNER|assert .*4096|D % D_OUT_TILE_ACT|spmd\\(D //" models/deepseek/v4 -nRepository: hw-native-sys/pypto-lib
Length of output: 1114
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file="models/deepseek/v4/expert_shared.py"
sed -n '150,190p' "$file"
echo
rg -n "assert .*D|D %|hidden_size|moe_intermediate_size|swiglu_limit" models/deepseek/v4 -nRepository: hw-native-sys/pypto-lib
Length of output: 11370
Cover the W2 dequant tail for non-4096 hidden sizes
models/deepseek/v4/expert_shared.py:169-177 only processes D // (8 * 512) blocks, so hidden_size=7168 leaves the last 3072 channels in y_i32 unwritten before the copy to sh. Iterate by D_OUT_TILE_ACT directly, or otherwise handle the remainder before writing sh_pad.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/expert_shared.py` around lines 52 - 54, The W2 dequant
path in expert_shared.py currently only handles full 512-channel tiles, so
non-4096 hidden sizes like 7168 leave the tail of y_i32 uninitialized before
copying into sh. Update the logic around the W2 dequant loop (the code using
D_OUT_TILE_ACT, W2_INNER, and W2_ACT_INNER) to iterate by D_OUT_TILE_ACT
directly and add explicit remainder handling for any final partial block before
writing sh_pad.
Source: Learnings
- Set allow_early_resolve=True on the csa compressed-slots scope, sparse_attn / sparse_attn_hca build_valid scopes, and the qkv q_head_rms_nope_rope scope. - decode_attention_hca: wire --runtime-dir / --golden-data through to run_jit for reproducible golden-data replay. - decode_layer / prefill_layer: refresh the real-weight x_next over-threshold fraction annotations.
Summary
ROPE_SPMD_TILE-row blocks and fold the rope sign intosin_il(multiply by ±1 is exact)T_TILE8 → 4 to double decode SPMD blocks--enable-l2-swimlaneto an int level (0/1/2) across these scriptsAll changes are numerically equivalent to the prior form; the only non-bit-exact reordering is qkv_proj_rope's inv_rms sum-of-squares, bounded at ≤1 BF16 ULP and within the existing golden tolerance (
rtol=1/128).