Skip to content

Revert "feat(deepseek/v4): gather-free split-half RoPE for decode + prefill"#576

Closed
zhangqi-chen wants to merge 1 commit into
mainfrom
revert-570-research/pr564-splithalf-rope
Closed

Revert "feat(deepseek/v4): gather-free split-half RoPE for decode + prefill"#576
zhangqi-chen wants to merge 1 commit into
mainfrom
revert-570-research/pr564-splithalf-rope

Conversation

@zhangqi-chen

Copy link
Copy Markdown
Collaborator

Reverts #570

@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

Caution

Review failed

Pull request was closed or merged during review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 3c5b20c5-662c-426e-bd6a-bc567bf3dd6f

📥 Commits

Reviewing files that changed from the base of the PR and between cdb64e0 and b4d6d4a.

📒 Files selected for processing (16)
  • models/deepseek/v4/decode_attention_csa.py
  • models/deepseek/v4/decode_attention_hca.py
  • models/deepseek/v4/decode_attention_swa.py
  • models/deepseek/v4/decode_compressor_ratio128.py
  • models/deepseek/v4/decode_compressor_ratio4.py
  • models/deepseek/v4/decode_layer.py
  • models/deepseek/v4/decode_sparse_attn.py
  • models/deepseek/v4/decode_sparse_attn_hca.py
  • models/deepseek/v4/decode_sparse_attn_swa.py
  • models/deepseek/v4/prefill_attention_csa.py
  • models/deepseek/v4/prefill_attention_hca.py
  • models/deepseek/v4/prefill_attention_swa.py
  • models/deepseek/v4/prefill_compressor_ratio128.py
  • models/deepseek/v4/prefill_compressor_ratio4.py
  • models/deepseek/v4/prefill_sparse_attn.py
  • models/deepseek/v4/qkv_proj_rope.py
💤 Files with no reviewable changes (1)
  • models/deepseek/v4/decode_layer.py

📝 Walkthrough

Walkthrough

Replaces NeoX split-half RoPE rotation with an interleaved (A3 swap-gather) formulation across all DeepSeek-V4 decode and prefill modules. Half-width FP32 rope_cos_half/rope_sin_half tables are removed throughout; kernels now accept full-dimension BF16 freqs_cos/freqs_sin. Compressors, sparse-attention kernels, QKV projection, and all attention-layer callers are updated, including golden reference implementations.

Changes

DeepSeek-V4 RoPE Split-Half → Interleaved Refactor

Layer / File(s) Summary
Q/KV projection: interleaved swap-gather rotation
models/deepseek/v4/qkv_proj_rope.py
Replaces the split-half NeoX RoPE rotation for Q and KV with an interleaved A3 swap-gather formulation; in-kernel per-task index/sign tensors replace separate lo/hi writeback regions. Golden apply_rope updated to even/odd pair unflatten + torch.stack(...).flatten(-2).
Compressor kernels: even/odd gather/scatter rotation
models/deepseek/v4/prefill_compressor_ratio4.py, models/deepseek/v4/prefill_compressor_ratio128.py, models/deepseek/v4/decode_compressor_ratio4.py, models/deepseek/v4/decode_compressor_ratio128.py
Replaces split-half gamma_lo/gamma_hi + torch.cat rotation in all four compressor kernels with P0101/P1010 gather/scatter or pairwise unflatten approach; golden references updated to match.
Sparse attention kernels: freqs_cos/sin interface + interleaved inverse-RoPE
models/deepseek/v4/decode_sparse_attn.py, models/deepseek/v4/decode_sparse_attn_hca.py, models/deepseek/v4/decode_sparse_attn_swa.py, models/deepseek/v4/prefill_sparse_attn.py
Changes public signatures of sparse_attn, sparse_attn_hca, sparse_attn_swa, and prefill_sparse_attn to accept freqs_cos/freqs_sin (BF16, ROPE_DIM) instead of rope_cos_half/rope_sin_half (FP32, HALF_ROPE). In-kernel inverse-RoPE is rewritten to precompute head-invariant interleaved cosine/signed-sine and apply via per-task j^1 swap index. Test wrappers, golden references, tensor-spec builders, and new ROPE_TILE/ROPE_INTERLEAVE_TILE constants updated throughout. Adds get_standalone_cmp_valid helper.
Decode attention callers: remove half-table construction
models/deepseek/v4/decode_attention_csa.py, models/deepseek/v4/decode_attention_hca.py, models/deepseek/v4/decode_attention_swa.py, models/deepseek/v4/decode_layer.py
Removes rope_cos_half_t/rope_sin_half_t tensor allocation and per-token filling from all three decode attention paths; sparse_attn_* call sites updated to pass full rope_cos_t/rope_sin_t. Removes stale NeoX explanatory comment from decode_layer.py. Golden references updated.
Prefill attention callers: remove half-table construction
models/deepseek/v4/prefill_attention_csa.py, models/deepseek/v4/prefill_attention_hca.py, models/deepseek/v4/prefill_attention_swa.py
Removes rope_cos_half_t/rope_sin_half_t creation and loop population from all three prefill attention paths; prefill_sparse_attn call sites updated to pass rope_cos_t/rope_sin_t as freqs_cos/freqs_sin. Golden references updated.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

  • hw-native-sys/pypto-lib#480: Both PRs refactor the decode sparse-attention inverse-RoPE into an A3-style swap-gather interleaved form across overlapping files; this PR additionally changes the cosine/sine input contract from half-width tables to full freqs_cos/freqs_sin.
  • hw-native-sys/pypto-lib#525: Both PRs modify decode_sparse_attn.py's inverse-RoPE implementation — this PR changes the kernel interface to freqs_cos/freqs_sin while #525 rewrites the fused rope_cos_il/rope_sin_signed tiling that consumes those tables.
  • hw-native-sys/pypto-lib#532: Both PRs touch qkv_proj_rope.py and the downstream attention modules, with overlapping changes at the RoPE cos/sin input contract and the rotation interface consumed by callers.

Suggested labels

bug

🐇 No more halves to track,
The split-path won't come back!
Even, odd, interleaved in one,
BF16 freqs under the sun. 🌟
Swap and gather, gather and swap —
The rabbit sealed the RoPE gap! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: reverting a feature that implemented gather-free split-half RoPE for DeepSeek V4.
Description check ✅ Passed The description is related to the changeset—it identifies the PR being reverted (#570) and its general purpose (gather-free split-half RoPE for decode + prefill).
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request transitions the RoPE (Rotary Position Embedding) implementation from a split-half (NeoX) layout to an interleaved swap-gather (CANN A3 rotate_interleaved) layout across various decode, prefill, compressor, and sparse attention modules. This change eliminates the need for half-width unsigned inverse-RoPE tables by precomputing head-invariant interleaved cos and signed sin tables in-kernel. The review feedback focuses on micro-optimizations within the newly introduced in-kernel index and sign generation logic, specifically recommending the elimination of redundant cast operations (e.g., casting from INT32 to FP32 and back to INT32) and simplifying sign calculations to reduce compiled instruction counts.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +207 to +208
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for rope_dup_idx. We can optimize this by casting to INT32 first to get rope_dup_idx, and then casting that to FP32 to get rope_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1
rope_dup_idx = pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc")
rope_dup_f = pl.cast(rope_dup_idx, target_type=pl.FP32)

Comment on lines +208 to +209
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for rope_dup_idx. We can optimize this by casting to INT32 first to get rope_dup_idx, and then casting that to FP32 to get rope_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1
rope_dup_idx = pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc")
rope_dup_f = pl.cast(rope_dup_idx, target_type=pl.FP32)

Comment on lines +342 to +343
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc")
cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32)

cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.

Suggested change
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate)
cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...] (conjugate)

Comment on lines +310 to +311
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc")
cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32)

cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.

Suggested change
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate)
cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...] (conjugate)

Comment on lines +262 to +263
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc")
cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32)

cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1
cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.

Suggested change
cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...]
cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...]

Comment on lines +241 to +242
q_dup_f = pl.cast(pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
q_dup_idx = pl.cast(q_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for q_dup_idx. We can optimize this by casting to INT32 first to get q_dup_idx, and then casting that to FP32 to get q_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
q_dup_f = pl.cast(pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
q_dup_idx = pl.cast(q_dup_f, target_type=pl.INT32) # j>>1
q_dup_idx = pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc")
q_dup_f = pl.cast(q_dup_idx, target_type=pl.FP32)

Comment on lines +325 to +326
kv_dup_f = pl.cast(pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
kv_dup_idx = pl.cast(kv_dup_f, target_type=pl.INT32) # j>>1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for kv_dup_idx. We can optimize this by casting to INT32 first to get kv_dup_idx, and then casting that to FP32 to get kv_dup_f. This avoids one redundant cast instruction in the compiled kernel.

Suggested change
kv_dup_f = pl.cast(pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32)
kv_dup_idx = pl.cast(kv_dup_f, target_type=pl.INT32) # j>>1
kv_dup_idx = pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc")
kv_dup_f = pl.cast(kv_dup_idx, target_type=pl.FP32)

@zhangqi-chen zhangqi-chen deleted the revert-570-research/pr564-splithalf-rope branch June 22, 2026 12:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant