Skip to content

perf(dsv4 decode): retile csa/indexer/expert_shared/hc_post/moe/qkv#672

Open
zhangqi-chen wants to merge 2 commits into
hw-native-sys:mainfrom
zhangqi-chen:csa-perf
Open

perf(dsv4 decode): retile csa/indexer/expert_shared/hc_post/moe/qkv#672
zhangqi-chen wants to merge 2 commits into
hw-native-sys:mainfrom
zhangqi-chen:csa-perf

Conversation

@zhangqi-chen

Copy link
Copy Markdown
Collaborator

Summary

  • attention_csa: split the sparse-index build into a 1-token/core SPMD overlay pass plus a single T-wide compressed-slot scope
  • decode_indexer: fan qr_rope over ROPE_SPMD_TILE-row blocks and fold the rope sign into sin_il (multiply by ±1 is exact)
  • expert_shared: decompose into separate gate/up/act/requant/w2 scopes with larger 256-wide matmul tiles and pipelined vector stages
  • hc_post: T_TILE 8 → 4 to double decode SPMD blocks
  • moe.combine: per-token SPMD reduction
  • qkv_proj_rope: load each q head once, reuse the resident tile for the inv_rms reduction and NOPE/RoPE writeback
  • switch --enable-l2-swimlane to an int level (0/1/2) across these scripts

All changes are numerically equivalent to the prior form; the only non-bit-exact reordering is qkv_proj_rope's inv_rms sum-of-squares, bounded at ≤1 BF16 ULP and within the existing golden tolerance (rtol=1/128).

Fan out and retile several dsv4 decode kernels for higher core occupancy,
all numerically equivalent to the prior form:

- attention_csa: split the sparse-index build into a 1-token/core SPMD
  overlay pass plus a single T-wide compressed-slot scope.
- decode_indexer: fan qr_rope over ROPE_SPMD_TILE-row blocks and fold the
  rope sign into sin_il (x by +/-1 is exact).
- expert_shared: decompose into separate gate/up/act/requant/w2 scopes with
  larger 256-wide matmul tiles and pipelined vector stages.
- hc_post: T_TILE 8 -> 4 to double decode SPMD blocks.
- moe.combine: per-token SPMD reduction.
- qkv_proj_rope: load each q head once, reuse the resident tile for the
  inv_rms reduction and NOPE/RoPE writeback.

Also switch --enable-l2-swimlane to an int level (0/1/2) across these scripts.
@coderabbitai

coderabbitai Bot commented Jul 2, 2026

Copy link
Copy Markdown

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 0ec11749-0429-43f2-ba71-952985c5c109

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR refactors tiling and SPMD loop structures across several DeepSeek v4 kernels: CSA sparse-index/validity-mask construction, indexer ROPE rotation, shared-expert matmul/quant pipeline, hc_post tiling, MoE combine reduction, and Q-head RMSNorm computation. Several test CLI harnesses also change --enable-l2-swimlane from a boolean flag to an integer option.

Changes

DeepSeek v4 kernel tiling and compute refactors

Layer / File(s) Summary
CSA sparse index and validity mask refactor
models/deepseek/v4/decode_attention_csa.py
CSA_TOPK_TOKEN_TILE is replaced by CSA_WB_TOKEN_TILE; cmp_sparse_indices/valid_block_mask construction moves from tiled SPMD loops to per-token pl.spmd(T) with vectorized masked-copy logic.
Indexer ROPE tiling and rotation math
models/deepseek/v4/decode_indexer.py
New ROPE_SPMD_TILE drives coarser qr_rope blocks and recomputed batch_idx; sign is folded into sin_il_signed; CLI flag becomes an integer with choices.
Shared-expert kernel M-tile restructuring
models/deepseek/v4/expert_shared.py
Tiling constants replaced; expert_shared restructured into a single pl.parallel(N_MTILES) loop covering gate/up matmul, dequant+SwiGLU, INT8 requant, and w2 down-projection dequant into a padded output.
hc_post tiling and CLI default changes
models/deepseek/v4/hc_post.py
T_TILE reduced from 8 to 4; --mode default changes to decode; --enable-l2-swimlane becomes an integer option.
MoE combine loop simplification
models/deepseek/v4/moe.py
combine_reduce switches from chunked T // 4 + inner loop to direct pl.spmd(T) per-token loop.
Q-head RMSNorm and NOPE writeback refactor
models/deepseek/v4/qkv_proj_rope.py
RMSNorm pass-1 loads NOPE/RoPE columns once with a two-term squared-sum instead of chunked accumulation; --enable-l2-swimlane becomes an integer option.

Estimated code review effort: 4 (Complex) | ~60 minutes

Possibly related PRs

Suggested labels: enhancement

Poem

A rabbit hopped through tiles anew,
Fewer loops, more work to do,
Sparse indices dance in spmd(T),
SwiGLU glows in quantized glee,
Flags now count from zero to two—
Hop hop hooray, the kernels flew! 🐇✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title is concise and accurately summarizes the main retile/performance changes across the dsv4 decode kernels.
Description check ✅ Passed The description matches the kernel tiling and CLI changes and is clearly aligned with the changeset.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors and optimizes several DeepSeek v4 model components, including decode attention CSA, decode indexer, expert shared, HC post, MoE, and QKV projection. Key optimizations include restructuring SPMD loops, vectorizing operations, folding sign multiplication in RoPE, and refactoring the expert shared kernel to use parallel M-tiles with pipelined matmuls. Feedback suggests further optimizing the decode attention CSA by replacing a loop over T with vectorized slice assignment, and simplifying the decode indexer by removing a single-iteration loop while adding an explicit assertion for the tile size invariant.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +269 to +276
for c_t0 in pl.range(T):
pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32))
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
for c_dt in pl.range(CSA_TOPK_TOKEN_TILE):
c_t = topk_t0 + c_dt
if c_t < T:
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_t, c_sb], c_valid)
for c_dt in pl.range(T):
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_dt, c_sb], c_valid)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since valid_block_mask and c_blk_valid are tensors, we can avoid the loop over T (which is a small static constant) by using vectorized slice assignment. This is more idiomatic in PyPTO and avoids loop overhead.

Suggested change
for c_t0 in pl.range(T):
pl.write(valid_block_mask, [c_t0, 0], pl.cast(1, pl.INT32))
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
for c_dt in pl.range(CSA_TOPK_TOKEN_TILE):
c_t = topk_t0 + c_dt
if c_t < T:
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_t, c_sb], c_valid)
for c_dt in pl.range(T):
c_valid = pl.cast(pl.read(c_blk_valid, [c_dt, 0]), target_type=pl.INT32)
pl.write(valid_block_mask, [c_dt, c_sb], c_valid)
valid_block_mask[0 : T, 0 : 1] = pl.full([T, 1], dtype=pl.INT32, value=1)
for c_sb in pl.range(1, SPARSE_BLOCKS):
c_s0 = (c_sb - 1) * ATTN_K_TILE
c_blk_valid = pl.row_max(c_mask[:, c_s0 : c_s0 + ATTN_K_TILE])
valid_block_mask[0 : T, c_sb : c_sb + 1] = pl.cast(c_blk_valid, target_type=pl.INT32)

Comment on lines +143 to 148
for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE):
r0 = o0 + ro
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il))
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since ROPE_SPMD_TILE and ROPE_ROW_TILE are both hardcoded to 32, the loop for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE) will always execute for exactly one iteration (ro = 0). We can simplify the code and eliminate the loop overhead by removing the loop entirely. However, when assuming single-tile coverage, we must make this invariant explicit with an assertion (e.g., assert ROPE_SPMD_TILE == ROPE_ROW_TILE) to prevent silent correctness issues if configurations change in the future.

Suggested change
for ro in pl.range(0, ROPE_SPMD_TILE, ROPE_ROW_TILE):
r0 = o0 + ro
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(pl.mul(qr_swapped, rope_sign), sin_il))
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")
assert ROPE_SPMD_TILE == ROPE_ROW_TILE, "ROPE_SPMD_TILE must match ROPE_ROW_TILE for single-tile coverage"
r0 = o0
qr_rope_slice = qr_proj_flat[r0 : r0 + ROPE_ROW_TILE, IDX_NOPE_HEAD_DIM : IDX_HEAD_DIM]
qr_swapped = pl.gather(qr_rope_slice, dim=-1, index=rope_swap_idx)
rope_rot = pl.add(pl.mul(qr_rope_slice, cos_il), pl.mul(qr_swapped, sin_il_signed))
qr_rope_out[r0 : r0 + ROPE_ROW_TILE, :] = pl.cast(rope_rot, target_type=pl.BF16, mode="rint")
References
  1. When assuming a single-tile coverage (e.g., T_PAD <= MM_ROW_TILE) in a kernel, make this invariant explicit with a module-level assertion (e.g., assert T_PAD == MM_ROW_TILE) to prevent silent correctness issues if configurations change.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
models/deepseek/v4/decode_attention_csa.py (1)

239-248: 🚀 Performance & Scalability | 🔵 Trivial | ⚡ Quick win

Hoist the overlay-position reads out of the topk_k loop.

topk_overlay_pos (and thus topk_overlay_pos % WIN) depends only on topk_os, not on topk_k, yet it is re-read from position_ids for every one of the WIN window slots. That is WIN * S scalar reads per SPMD token where S distinct values would suffice. Precomputing the S overlay positions (or their % WIN) once per token, before the topk_k loop, removes the redundant pl.reads on this per-token-per-core hot path without changing results.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/decode_attention_csa.py` around lines 239 - 248, The
nested loop in `decode_attention_csa.py` repeatedly reads `position_ids` inside
the `topk_k` loop even though `topk_overlay_pos` only depends on `topk_os`.
Refactor the logic around `cmp_sparse_indices` so the overlay positions (or
their `% WIN` values) are precomputed once per token before iterating over
`topk_k`, then reuse those cached values in the inner comparison. Keep the
behavior of `topk_out` and `pl.write` unchanged while removing the redundant
`pl.read` calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/expert_shared.py`:
- Around line 52-54: The W2 dequant path in expert_shared.py currently only
handles full 512-channel tiles, so non-4096 hidden sizes like 7168 leave the
tail of y_i32 uninitialized before copying into sh. Update the logic around the
W2 dequant loop (the code using D_OUT_TILE_ACT, W2_INNER, and W2_ACT_INNER) to
iterate by D_OUT_TILE_ACT directly and add explicit remainder handling for any
final partial block before writing sh_pad.

---

Nitpick comments:
In `@models/deepseek/v4/decode_attention_csa.py`:
- Around line 239-248: The nested loop in `decode_attention_csa.py` repeatedly
reads `position_ids` inside the `topk_k` loop even though `topk_overlay_pos`
only depends on `topk_os`. Refactor the logic around `cmp_sparse_indices` so the
overlay positions (or their `% WIN` values) are precomputed once per token
before iterating over `topk_k`, then reuse those cached values in the inner
comparison. Keep the behavior of `topk_out` and `pl.write` unchanged while
removing the redundant `pl.read` calls.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 7734b03b-9e27-4c48-90c4-b6d2377dfbe0

📥 Commits

Reviewing files that changed from the base of the PR and between 1b9a997 and bc8979f.

📒 Files selected for processing (6)
  • models/deepseek/v4/decode_attention_csa.py
  • models/deepseek/v4/decode_indexer.py
  • models/deepseek/v4/expert_shared.py
  • models/deepseek/v4/hc_post.py
  • models/deepseek/v4/moe.py
  • models/deepseek/v4/qkv_proj_rope.py

Comment on lines +52 to +54
D_OUT_TILE_ACT = 512
W2_INNER = 4
W2_ACT_INNER = 8

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="models/deepseek/v4/expert_shared.py"

echo "== outline =="
ast-grep outline "$file" --view expanded || true

echo
echo "== relevant lines =="
nl -ba "$file" | sed -n '1,240p'

echo
echo "== searches =="
rg -n "D_OUT_TILE_ACT|W2_INNER|W2_ACT_INNER|assert .*4096|D % D_OUT_TILE_ACT|spmd\\(D //" models/deepseek/v4 -n

Repository: hw-native-sys/pypto-lib

Length of output: 1114


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file="models/deepseek/v4/expert_shared.py"

sed -n '150,190p' "$file"
echo
rg -n "assert .*D|D %|hidden_size|moe_intermediate_size|swiglu_limit" models/deepseek/v4 -n

Repository: hw-native-sys/pypto-lib

Length of output: 11370


Cover the W2 dequant tail for non-4096 hidden sizes
models/deepseek/v4/expert_shared.py:169-177 only processes D // (8 * 512) blocks, so hidden_size=7168 leaves the last 3072 channels in y_i32 unwritten before the copy to sh. Iterate by D_OUT_TILE_ACT directly, or otherwise handle the remainder before writing sh_pad.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/expert_shared.py` around lines 52 - 54, The W2 dequant
path in expert_shared.py currently only handles full 512-channel tiles, so
non-4096 hidden sizes like 7168 leave the tail of y_i32 uninitialized before
copying into sh. Update the logic around the W2 dequant loop (the code using
D_OUT_TILE_ACT, W2_INNER, and W2_ACT_INNER) to iterate by D_OUT_TILE_ACT
directly and add explicit remainder handling for any final partial block before
writing sh_pad.

Source: Learnings

- Set allow_early_resolve=True on the csa compressed-slots scope,
  sparse_attn / sparse_attn_hca build_valid scopes, and the qkv
  q_head_rms_nope_rope scope.
- decode_attention_hca: wire --runtime-dir / --golden-data through to
  run_jit for reproducible golden-data replay.
- decode_layer / prefill_layer: refresh the real-weight x_next
  over-threshold fraction annotations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant