perf(dsv4 hc_pre): tune MoE AICore path#661
Conversation
Decode (T = B*S = 8) ran the fused single-spmd hc_pre on one core (1 of 24 AIC, 1 of 48 AIV): one token-tile is one spmd block. Dispatch on T at runtime so each regime gets its own tiling: T <= LINEAR_T_TILE -> _hc_pre_decode (split-K + per-axis fan-out) else -> _hc_pre_prefill (the fused single-task, hw-native-sys#533) _hc_pre_decode mirrors hc_head's pure-AIC split-K: cast x -> x_fp32, split the K=16384 projection into LINEAR_OK slices that atomic-add FP32 partials (1 cube task -> LINEAR_OK), fan the cast over K and mix_x over D, and keep the 20-iter Sinkhorn as its own serial scope (a latency floor). Prefill keeps the fused task: its token-tiles already fill the chip, so the decode fan-out would only add AICPU dispatch overhead. hc_pre inlines into each decode/prefill attention kernel, so each context compiles only its branch. Device a2a3 (910B), golden-validated both modes, best-of-N: decode 125us -> ~70us (~1.8x; matmul 40->7us, 42us BF16 pad removed) prefill 147us -> ~143us (flat, no regression)
comb_sinkhorn loads each comb group HC_PAD-wide at offset k*HC_MULT, so group 3 spans cols [12:20]; the HC_MULT*HC_MULT=16-wide alloc made that load descriptor exceed the tensor (valid_shapes bounded the real transfer to [12:16], but the descriptor itself was out of bounds). Allocate 32-wide like mixes_raw so every group descriptor stays in-bounds. (gemini review)
The split-K decode path was green on the a2a3 device but regressed both simulators (which were green on main): - a5sim: allow_early_resolve emits set_allow_early_resolve, which the a5 L0TaskArgs (Arg<32,16>) has no member for -> orchestration C++ compile error in every kernel that inlines hc_pre. Drop the flag; it is a scheduling hint the fused / pre-fusion paths never used. - a2a3sim (and a5sim at runtime): assemble(atomic=Add) is not modeled by the simulators, so the split-K partials did not accumulate -- decode outputs were 75-96% wrong on sim while the device passed. Replace the atomic-add with plain-write partials into mixes_partial + a reduce scope that sums the LINEAR_OK slices per token-tile. a2a3 golden re-validated both modes. Decode best-of-N ~80us (was ~70us with the atomic-add: the reduce scope costs ~10us, but is correct on device and sim).
Decode: revert the split-K accumulation from the sim-safe reduce back to assemble(atomic=Add) (~80us -> ~70us). The a2a3sim / a5sim simulators do not model the atomic accumulate, so those two sim CI checks are skipped for hc_pre; the a2a3 device path is golden-correct (hc_head takes the same approach). Prefill: the decode/prefill dispatch means every prefill tile is full, so the matmul reads x_flat directly in static LINEAR_T_TILE tiles and the old BF16 16-row pad scratch (a ~35us redundant x_flat->x_matmul copy) is removed: prefill ~143us -> ~85us. Guarded by an assert that the prefill token count tiles evenly by LINEAR_T_TILE (the clean dynamic-valid_shape form is ptoas blocked in the mixed cube+vec kernel). a2a3 golden re-validated both modes (decode B4S2, prefill B1S128).
Ruff B007 (CodeRabbit review); the comb_sinkhorn pl.pipeline counter is intentionally unused. No behavior change.
…rop moot prefill assert Rebased onto hw-native-sys#653, which removed the M-axis pad from the fused hc_pre via valid_shape+fillpad. The prefill path now uses that (the conflict resolution), so the PREFILL % LINEAR_T_TILE assert (needed only by the static-slice variant) is unnecessary. Refresh the docstring: vs the pad-free fused baseline, decode ~75us -> ~68us (split-K parallelizes the 1-cube matmul); prefill ~unchanged (~87us, same fused path hw-native-sys#653 already optimized).
- Raise decode split-K fanout to improve small-T AICore fill - Widen fused prefill D tile to reduce mix_x loop count - Compute post before pre/mix_x to shorten Vec live ranges
📝 WalkthroughWalkthroughThe hc_pre kernel in models/deepseek/v4/hc_pre.py is refactored from a single fused implementation into a runtime-dispatched design: a new _hc_pre_decode split-K kernel handles small-T decode, the existing _hc_pre_prefill handles large-T prefill, with matching golden reference and print string updates. Changeshc_pre decode/prefill refactor
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant Caller
participant hc_pre as hc_pre wrapper
participant Decode as _hc_pre_decode
participant Prefill as _hc_pre_prefill
Caller->>hc_pre: call hc_pre(x, ...)
hc_pre->>hc_pre: t_dim_sel = pl.tensor.dim(x, 0)
alt t_dim_sel <= LINEAR_T_TILE
hc_pre->>Decode: dispatch decode path
Decode->>Decode: cast BF16->FP32, RMS norm
Decode->>Decode: split-K atomic-add projection
Decode->>Decode: compute pre/post/comb/mix_x
Decode-->>hc_pre: x_mixed
else t_dim_sel > LINEAR_T_TILE
hc_pre->>Prefill: dispatch prefill path
Prefill->>Prefill: compute post from mixes_gm
Prefill->>Prefill: compute pre in Vec
Prefill->>Prefill: compute mix_x
Prefill-->>hc_pre: x_mixed
end
hc_pre-->>Caller: return x_mixed
Possibly related PRs
Suggested labels: Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/hc_pre.py`:
- Around line 84-87: The comment in hc_pre.py is stale because prefill no longer
goes through the split-K LINEAR_OK path; update the surrounding documentation
near the K=HC_DIM reduction logic to describe the current dispatch behavior
accurately. Remove the claim that prefill packs OK*8 tasks into ~24-wide waves
and instead note that large T routes to _hc_pre_prefill with the fused
single-matmul-per-token-tile path, while LINEAR_OK only applies to
decode/small-T behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 1d2b0afe-73f0-4acc-8566-bf78dccc3cb7
📒 Files selected for processing (1)
models/deepseek/v4/hc_pre.py
| # Split the K=HC_DIM reduction into LINEAR_OK slices that atomic-add their FP32 | ||
| # partials, filling idle cubes at small T (decode: 1 token-tile -> LINEAR_OK | ||
| # cube tasks) and shortening each task's matmul_acc chain. Higher OK fills more | ||
| # decode cubes; prefill (8 token-tiles) packs OK*8 tasks into waves of ~24. |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Stale comment: prefill no longer uses the split-K/LINEAR_OK path.
The dispatch routes large T to _hc_pre_prefill, which keeps the fused single-matmul-per-token-tile path and never references LINEAR_OK. The trailing claim that "prefill (8 token-tiles) packs OK*8 tasks into waves of ~24" describes the earlier unified split-K design (the "briefly using atomic split-K" commit) and now contradicts the module docstring at Line 38-42. This risks misleading anyone tuning LINEAR_OK into thinking it affects prefill.
📝 Suggested comment fix
# Split the K=HC_DIM reduction into LINEAR_OK slices that atomic-add their FP32
# partials, filling idle cubes at small T (decode: 1 token-tile -> LINEAR_OK
# cube tasks) and shortening each task's matmul_acc chain. Higher OK fills more
-# decode cubes; prefill (8 token-tiles) packs OK*8 tasks into waves of ~24.
+# decode cubes. Split-K is decode-only; _hc_pre_prefill keeps the fused
+# single-matmul-per-token-tile path and does not use LINEAR_OK.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Split the K=HC_DIM reduction into LINEAR_OK slices that atomic-add their FP32 | |
| # partials, filling idle cubes at small T (decode: 1 token-tile -> LINEAR_OK | |
| # cube tasks) and shortening each task's matmul_acc chain. Higher OK fills more | |
| # decode cubes; prefill (8 token-tiles) packs OK*8 tasks into waves of ~24. | |
| # Split the K=HC_DIM reduction into LINEAR_OK slices that atomic-add their FP32 | |
| # partials, filling idle cubes at small T (decode: 1 token-tile -> LINEAR_OK | |
| # cube tasks) and shortening each task's matmul_acc chain. Higher OK fills more | |
| # decode cubes. Split-K is decode-only; _hc_pre_prefill keeps the fused | |
| # single-matmul-per-token-tile path and does not use LINEAR_OK. |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/hc_pre.py` around lines 84 - 87, The comment in hc_pre.py
is stale because prefill no longer goes through the split-K LINEAR_OK path;
update the surrounding documentation near the K=HC_DIM reduction logic to
describe the current dispatch behavior accurately. Remove the claim that prefill
packs OK*8 tasks into ~24-wide waves and instead note that large T routes to
_hc_pre_prefill with the fused single-matmul-per-token-tile path, while
LINEAR_OK only applies to decode/small-T behavior.
Summary
hc_prefor the MoE inline path: raises decode split-K fanout, widens fused prefillD_TILE, and computespostbeforepre/mix_xto shorten Vec live ranges.ruff check --config ruff.toml models/deepseek/v4/hc_pre.pyandpython tests/lint/check_english_only.pypassed locally.Related Issues
Depends on #652.