feat(deepseek/v4): gather-free split-half RoPE for the decode path (draft, decode-only)#564
Conversation
…ecode-only) Convert decode RoPE from interleaved (GPT-J, gather-based) to split-half (GPT-NeoX, gather-free): forward in qkv_proj_rope + decode compressors, inverse in sparse_attn hca/swa/csa. Removes the j^1 swap-gather, j>>1 dup-gather, and the rope_cs pre-pass; the rotation partner is now a contiguous lo/hi half-slice. Bit-exact on a2a3 (HCA/SWA standalone+composed+layer, CSA standalone+composed). L2 swimlane: rope compute -70.5%, HCA attention-module wall-clock -9.6%; all VGATHERs eliminated. Decode-only: qkv_proj_rope is shared with prefill, which is left interleaved (latently half-converted) -- prefill conversion is a tracked follow-up. A deployment-time offline weight-permutation (q-proj/k_pe/wo_a rope columns) is required for real checkpoints; synthetic tests need none. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the Rotary Position Embedding (RoPE) implementation across DeepSeek v4 decode attention, compressor, and projection modules from an interleaved layout to a split-half (NeoX) layout. This change eliminates the need for in-kernel index generation and gather operations by utilizing half-width, pre-sliced FP32 cosine and sine tables. The review feedback highlights opportunities to optimize performance and reduce redundancy in decode_attention_csa.py and decode_attention_hca.py by reusing already sliced and cast row tensors instead of re-slicing and re-casting the base frequency tables.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| rope_cos_half_t = pl.assemble( | ||
| rope_cos_half_t, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0]) | ||
| rope_sin_half_t = pl.assemble( | ||
| rope_sin_half_t, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0]) |
There was a problem hiding this comment.
Instead of performing redundant pl.slice and pl.cast operations on freqs_cos and freqs_sin again, you can slice cos_row and sin_row directly, as they are already sliced and cast to pl.FP32 on lines 203-204. This avoids unnecessary overhead and is consistent with the implementation in decode_attention_swa.py.
| rope_cos_half_t = pl.assemble( | |
| rope_cos_half_t, pl.cast(pl.slice(freqs_cos, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0]) | |
| rope_sin_half_t = pl.assemble( | |
| rope_sin_half_t, pl.cast(pl.slice(freqs_sin, [1, HALF_ROPE], [pos_b, 0]), target_type=pl.FP32), [t, 0]) | |
| rope_cos_half_t = pl.assemble( | |
| rope_cos_half_t, cos_row[0 : 1, 0 : HALF_ROPE], [t, 0]) | |
| rope_sin_half_t = pl.assemble( | |
| rope_sin_half_t, sin_row[0 : 1, 0 : HALF_ROPE], [t, 0]) |
| rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast( | ||
| freqs_cos[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32) | ||
| rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast( | ||
| freqs_sin[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32) |
There was a problem hiding this comment.
Instead of performing redundant slicing and casting on freqs_cos and freqs_sin again, you can slice step_cos_row and step_sin_row directly, as they are already sliced and cast to pl.FP32 on lines 165-166. This avoids unnecessary overhead.
| rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast( | |
| freqs_cos[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32) | |
| rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = pl.cast( | |
| freqs_sin[pos_b : pos_b + 1, 0 : ROPE_HEAD_DIM // 2], target_type=pl.FP32) | |
| rope_cos_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = step_cos_row[0 : 1, 0 : ROPE_HEAD_DIM // 2] | |
| rope_sin_half_t[t : t + 1, 0 : ROPE_HEAD_DIM // 2] = step_sin_row[0 : 1, 0 : ROPE_HEAD_DIM // 2] |
… main Cherry-pick of c5aa9c5 (feat: gather-free split-half RoPE for the decode path) resolved onto main (ea299a1). 5 of 10 files auto-merged; 5 conflicted (both sides fully rewrote the same RoPE block). Resolution: - All 5 conflicts: adopt the PR split-half (NeoX) side, discard main's interleaved (GPT-J) side. - decode_compressor_ratio{4,128}: restore `gamma_rope = pl.cast(..., pl.FP32)` on the split-half branch. hw-native-sys#568 flipped norm_w FP32->BF16 and added the cast at every apply site; the PR (branched pre-hw-native-sys#568) dropped it. Without the cast the per-column gamma fold would run in BF16, asymmetric with the NOPE branch and the float golden rmsnorm. - decode_sparse_attn{,_hca,_swa}: drop the PR's `pl.cast(r_tile, FP32)`. The PR added it when attn_rope_stage was BF16; hw-native-sys#568 made the stage FP32, so the cast is an FP32->FP32 identity that the pypto op registry rejects (trace-time failure). Read the already-FP32 stage directly, matching main's interleaved version. Validated on a2a3sim (golden): decode_compressor_ratio4/128 PASS, qkv_proj_rope PASS, decode_sparse_attn PASS, decode_sparse_attn_hca PASS. decode_sparse_attn_swa is flaky (4/6) but identically so on clean main (FAIL PASS FAIL PASS PASS PASS on both) -- pre-existing unseeded-input tolerance flakiness from hw-native-sys#563, not a merge regression. Decode-only, inheriting the PR's caveat: qkv_proj_rope is shared with prefill, whose sparse_attn/compressors stay interleaved, so prefill is latently half-converted. Landing requires the prefill split-half follow-up + offline weight permutation for real checkpoints.
…570) ## Summary Converts the DeepSeek-V4 RoPE path from the interleaved (GPT-J, gather-based) layout to split-half (GPT-NeoX, gather-free) for **both decode and prefill**, so the whole chain is layout-consistent. The rotation partner of lane `k` becomes lane `k+HALF` — a contiguous `lo=[:HALF]`/`hi=[HALF:]` slice instead of a `j^1` swap-gather + `j>>1` cos/sin dup-gather. Every RoPE rotation becomes contiguous slices + plain FMAs, with no cross-lane op and no in-kernel `rope_cs` dup-gather pre-pass. - **Decode** (commit 1): adopts the split-half conversion across `qkv_proj_rope` (forward, shared), `decode_compressor_ratio{4,128}` (forward), and `decode_sparse_attn{,_hca,_swa}` (inverse), with the callers feeding half-width FP32 `rope_cos_half`/`rope_sin_half`. Two precision details vs current `main`: the compressor `gamma_rope` is kept cast to FP32 (since `norm_w` is BF16), and the inverse-rope stage (already FP32) is read directly — no FP32→FP32 identity cast. - **Prefill** (commit 2): mirrors the same conversion onto `prefill_compressor_ratio{4,128}` (forward) and `prefill_sparse_attn` (inverse), with `prefill_attention_{csa,hca,swa}` building and passing the half-width tables. This removes the latent "half-converted" hazard from converting the shared `qkv_proj_rope` without converting prefill's own compressors/inverse. Rebased on top of #569 — its per-head NOPE fix and the `prefill_sparse_attn_padded_indices` removal are preserved. Forward: `out_lo = x_lo*cos − x_hi*sin`, `out_hi = x_lo*sin + x_hi*cos`. Inverse (conjugate): `out_lo = x_lo*cos + x_hi*sin`, `out_hi = x_hi*cos − x_lo*sin`. The lightning **indexer** (`{decode,prefill}_indexer{,_compressor}`) is intentionally left interleaved: it is a self-contained RoPE subsystem (own query/KV rope from the freqs tables) that feeds sparse attention only integer top-k indices, so it is decoupled from the main path. ### Why On-device profiling showed the per-element gather (`j^1` swap + `j>>1` dup) — not the arithmetic — is the dominant RoPE cost. The earlier interleaved L2 swimlane on the HCA attention module measured rope compute 2970 → 877 µs (**−70.5%**) and module wall-clock −9.6% from this change. ### Validation (a2a3sim, golden, all PASS) `decode_compressor_ratio4/128`, `decode_sparse_attn{,_hca,_swa}`, `qkv_proj_rope`, `decode_attention_{csa,hca,swa}`, `decode_layer`; `prefill_compressor_ratio4/128`, `prefill_sparse_attn`, `prefill_attention_{csa,hca,swa}`, `prefill_layer`. The `*_sparse_attn_swa` standalone test is occasionally flaky (~4/6) but identically so on `main` (unseeded fixtures from #563 against a tight tolerance) — not introduced by this change. ### Caveat Real checkpoints need an offline interleaved→split-half permutation of the trained `q-proj`/`k_pe`/`wo_a` rope columns to stay bit-identical to the trained model. Synthetic tests need none. Not yet validated on-device or end-to-end by the serving system. ## Related Issues - Builds on / supersedes #564 (decode-only draft) by completing the prefill conversion. - Rebased onto #569 (prefill_sparse_attn NOPE fix), which it preserves.
Summary
Draft / direction proposal — decode-only, not yet validated by the serving system.
Converts the DeepSeek-V4 decode RoPE path from the interleaved (GPT-J, gather-based) layout to split-half (GPT-NeoX, gather-free). The rotation partner of lane
kbecomes lanek+HALF— a contiguouslo=[:HALF]/hi=[HALF:]slice instead of aj^1swap-gather.What
qkv_proj_ropeq+kv,decode_compressor_ratio{4,128}):out_lo = x_lo*cos − x_hi*sin,out_hi = x_lo*sin + x_hi*cos. The compressors fold the per-column gamma per-half (gamma does not commute with the rotation).decode_sparse_attn{,_hca,_swa}):out_lo = x_lo*cos + x_hi*sin,out_hi = x_hi*cos − x_lo*sin(the conjugate of forward).decode_attention_{hca,swa,csa}) feed half-width FP32rope_cos_half/rope_sin_half;decode_layer_dp_epdrops the interleaved-table plumbing.rope_csdup-gather pre-pass and allVGATHERs from the RoPE kernels. Net −10 lines.Why
The interleaved layout forces a per-element gather (
j^1swap +j>>1dup) that on-device profiling showed to be the dominant RoPE cost (an earlier mask-pattern alternative measured +61%, confirming the gather — not the arithmetic — is the bottleneck). Split-half makes every rotation a contiguous-slice + plain FMAs, with no cross-lane op.Validation (a2a3, on-device)
--layer-id 3; SWA standalone + composed +--layer-id 0; CSA standalone + composed; compressors 128 + 4 standalone. The tight standaloneratio_allclose(1e-4, 1/128)checks discriminate sign errors (a negative-control flipped sign produced 1.59 max-abs error).q_head_rope_fused−77%, inverserope−71%,kv_rope_fused−69%,rmsnorm_rope−38%,rope_cseliminated. Non-rope kernels within ±2% noise.Caveats (why this is a draft)
qkv_proj_ropeis shared with prefill; prefill's own sparse_attn / compressors stay interleaved, so prefill is latently half-converted — its kernel==golden tests still pass, but the output is wrong vs the true model. Prefill conversion (~10 files, same pattern) is a tracked follow-up.wo_arope columns (interleaved → split-half) to stay bit-identical to the trained model. Synthetic tests need none; this is documented but not included here.decode_layer_dp_ep --layer-id 10(CSA, 2-device) fails with AICore507018— butmainfails identically at the same point, so it is pre-existing and unrelated to this change.🤖 Generated with Claude Code