Revert "feat(deepseek/v4): gather-free split-half RoPE for decode + prefill"#576
Revert "feat(deepseek/v4): gather-free split-half RoPE for decode + prefill"#576zhangqi-chen wants to merge 1 commit into
Conversation
…refill (…" This reverts commit cdb64e0.
|
Caution Review failedPull request was closed or merged during review No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (16)
💤 Files with no reviewable changes (1)
📝 WalkthroughWalkthroughReplaces NeoX split-half RoPE rotation with an interleaved (A3 swap-gather) formulation across all DeepSeek-V4 decode and prefill modules. Half-width FP32 ChangesDeepSeek-V4 RoPE Split-Half → Interleaved Refactor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request transitions the RoPE (Rotary Position Embedding) implementation from a split-half (NeoX) layout to an interleaved swap-gather (CANN A3 rotate_interleaved) layout across various decode, prefill, compressor, and sparse attention modules. This change eliminates the need for half-width unsigned inverse-RoPE tables by precomputing head-invariant interleaved cos and signed sin tables in-kernel. The review feedback focuses on micro-optimizations within the newly introduced in-kernel index and sign generation logic, specifically recommending the elimination of redundant cast operations (e.g., casting from INT32 to FP32 and back to INT32) and simplifying sign calculations to reduce compiled instruction counts.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for rope_dup_idx. We can optimize this by casting to INT32 first to get rope_dup_idx, and then casting that to FP32 to get rope_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1 | |
| rope_dup_idx = pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| rope_dup_f = pl.cast(rope_dup_idx, target_type=pl.FP32) |
| rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for rope_dup_idx. We can optimize this by casting to INT32 first to get rope_dup_idx, and then casting that to FP32 to get rope_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| rope_dup_f = pl.cast(pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| rope_dup_idx = pl.cast(rope_dup_f, target_type=pl.INT32) # j>>1 | |
| rope_dup_idx = pl.cast(pl.mul(rope_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| rope_dup_f = pl.cast(rope_dup_idx, target_type=pl.FP32) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | |
| cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | ||
| cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2 | ||
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate) |
There was a problem hiding this comment.
We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate) | |
| cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...] (conjugate) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | |
| cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | ||
| cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2 | ||
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate) |
There was a problem hiding this comment.
We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] (conjugate) | |
| cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...] (conjugate) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for cs_dup_idx. We can optimize this by casting to INT32 first to get cs_dup_idx, and then casting that to FP32 to get cs_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | |
| cs_dup_idx = pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| cs_dup_f = pl.cast(cs_dup_idx, target_type=pl.FP32) |
| cs_dup_f = pl.cast(pl.cast(pl.mul(cs_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| cs_dup_idx = pl.cast(cs_dup_f, target_type=pl.INT32) # j>>1 | ||
| cs_lane = pl.sub(cs_col, pl.mul(cs_dup_f, 2.0)) # j%2 | ||
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] |
There was a problem hiding this comment.
We can simplify the computation of cs_sign by avoiding the pl.neg operation. pl.sub(1.0, pl.mul(cs_lane, 2.0)) is mathematically equivalent to pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) but saves one instruction in the compiled kernel.
| cs_sign = pl.neg(pl.sub(pl.mul(cs_lane, 2.0), 1.0)) # [+1,-1,...] | |
| cs_sign = pl.sub(1.0, pl.mul(cs_lane, 2.0)) # [+1,-1,...] |
| q_dup_f = pl.cast(pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| q_dup_idx = pl.cast(q_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for q_dup_idx. We can optimize this by casting to INT32 first to get q_dup_idx, and then casting that to FP32 to get q_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| q_dup_f = pl.cast(pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| q_dup_idx = pl.cast(q_dup_f, target_type=pl.INT32) # j>>1 | |
| q_dup_idx = pl.cast(pl.mul(q_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| q_dup_f = pl.cast(q_dup_idx, target_type=pl.FP32) |
| kv_dup_f = pl.cast(pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | ||
| kv_dup_idx = pl.cast(kv_dup_f, target_type=pl.INT32) # j>>1 |
There was a problem hiding this comment.
The current implementation performs a redundant cast from INT32 back to INT32 via FP32 for kv_dup_idx. We can optimize this by casting to INT32 first to get kv_dup_idx, and then casting that to FP32 to get kv_dup_f. This avoids one redundant cast instruction in the compiled kernel.
| kv_dup_f = pl.cast(pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc"), target_type=pl.FP32) | |
| kv_dup_idx = pl.cast(kv_dup_f, target_type=pl.INT32) # j>>1 | |
| kv_dup_idx = pl.cast(pl.mul(kv_col, 0.5), target_type=pl.INT32, mode="trunc") | |
| kv_dup_f = pl.cast(kv_dup_idx, target_type=pl.FP32) |
Reverts #570