[WIP] DeepSeek V4: packed request-aware chunked prefill layer#625
[WIP] DeepSeek V4: packed request-aware chunked prefill layer#625zhaozhaozz wants to merge 1 commit into
Conversation
Add prefill_chunked_layer.py, a Qwen-style packed prefill variant of
prefill_layer.py. prefill_layer_core expands the batch and sequence
dimensions internally: it loops over requests and fixed-T token tiles,
gathers tile-local [T, ...] inputs from the packed buffers, and calls the
existing fixed-T child kernels (prefill_attention_{swa,hca,csa} and moe)
per tile with a unique, gap-free moe_epoch, scattering valid rows back into
the packed x_next. config.py and the child kernels are untouched; all
batch-dependent dynamic shapes (pl.dynamic) are local to this file.
Coordinate system, shared by JIT and golden: packed token buffers use
global packed rows, cache/state/tables are request-local, the sparse-index
overlay is tile-local, and position_ids are absolute. The host-side packed
metadata and golden reuse each child build_*_tensor_specs(start_pos,
num_tokens) per tile, so a batch=1 / chunk=T case reproduces the original
single-tile golden bit-exactly.
Validated on a2a3 (2-card EP2) against golden: single-tile (chunk_len <= T)
for swa/hca/csa, and multi-request packing with multiple sequential MoE
collective calls reusing the shared windows.
Multi-tile within one request (chunk_len > T) is still WIP: the cross-tile
request-local cache read-after-write is not yet ordered across the pl.range
tile loop. The children mutate the cache pl.Out in place rather than
returning it, and a loop-carried cache slice cannot be threaded through
pl.range init_values into a fixed-shape inline parameter (it loses inferred
tensor metadata). Fixing this needs a fresh static-shape scratch buffer
threaded across tiles with two-step materialization back to the packed
cache.
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
1 similar comment
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a packed, request-aware chunked prefill single-layer implementation for DeepSeek-V4 with MoE EP2, including the core JIT-compiled functions, host-side launchers, tensor spec builders, and golden reference tests. The feedback highlights a critical issue regarding dynamic dimension bindings: the same dynamic dimensions are reused for both state blocks and state block tables, which actually have different physical and logical sizes. To prevent runtime shape mismatches or compilation errors, separate dynamic dimensions should be defined and bound for the state block tables across the core and host-side functions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN") | ||
| PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN") | ||
| PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN") |
There was a problem hiding this comment.
The dynamic dimensions for the state block tables should be defined separately from the state block dimensions. HCA_STATE_BLOCK_NUM (physical blocks) and HCA_STATE_MAX_BLOCKS (logical blocks in the table) are different constants, meaning their packed runtime dimensions (batch * HCA_STATE_BLOCK_NUM vs batch * HCA_STATE_MAX_BLOCKS) will differ. Using the same dynamic dimension PREFILL_HCA_STATE_BLOCKS_DYN for both will cause a runtime shape mismatch or compilation error during dynamic dimension binding.
Please define separate dynamic dimensions for the state block tables.
| PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN") | |
| PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN") | |
| PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN") | |
| PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN") | |
| PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN") | |
| PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN") | |
| PREFILL_HCA_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_TABLE_DYN") | |
| PREFILL_CSA_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_TABLE_DYN") | |
| PREFILL_INNER_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_TABLE_DYN") |
| [PREFILL_HCA_STATE_BLOCKS_DYN, HCA_STATE_BLOCK_SIZE, HCA_MAIN_OUT_DIM], | ||
| pl.FP32, | ||
| ], | ||
| hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Use the newly defined PREFILL_HCA_STATE_TABLE_DYN dynamic dimension for hca_compress_state_block_table to avoid shape mismatch with hca_cmp_kv_state.
| hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32], | |
| hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_TABLE_DYN], pl.INT32], |
| csa_cmp_norm_w: pl.Tensor[[HEAD_DIM], pl.BF16], | ||
| csa_cmp_kv_state: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32], | ||
| csa_cmp_score_state: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32], | ||
| csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Use the newly defined PREFILL_CSA_STATE_TABLE_DYN dynamic dimension for csa_compress_state_block_table to avoid shape mismatch with csa_cmp_kv_state.
| csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32], | |
| csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_TABLE_DYN], pl.INT32], |
| csa_inner_norm_w: pl.Tensor[[IDX_HEAD_DIM], pl.BF16], | ||
| csa_inner_kv_state: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32], | ||
| csa_inner_score_state: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32], | ||
| csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Use the newly defined PREFILL_INNER_STATE_TABLE_DYN dynamic dimension for csa_inner_compress_state_block_table to avoid shape mismatch with csa_inner_kv_state.
| csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32], | |
| csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_TABLE_DYN], pl.INT32], |
| hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_BLOCKS_DYN) | ||
| csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | ||
| csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | ||
| csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | ||
| csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | ||
| csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | ||
| csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) |
There was a problem hiding this comment.
Update the dynamic bindings in prefill_layer_core to use the correct table dynamic dimensions.
| hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_BLOCKS_DYN) | |
| csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | |
| csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | |
| csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | |
| csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | |
| csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | |
| csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | |
| hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_TABLE_DYN) | |
| csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | |
| csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN) | |
| csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_TABLE_DYN) | |
| csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | |
| csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN) | |
| csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_TABLE_DYN) |
| [N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN, HCA_STATE_BLOCK_SIZE, HCA_MAIN_OUT_DIM], | ||
| pl.FP32, | ||
| ], | ||
| hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Update the hca_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_HCA_STATE_TABLE_DYN.
| hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32], | |
| hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_TABLE_DYN], pl.INT32], |
| csa_cmp_norm_w: pl.Tensor[[N_RANKS, HEAD_DIM], pl.BF16], | ||
| csa_cmp_kv_state: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32], | ||
| csa_cmp_score_state: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32], | ||
| csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Update the csa_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_CSA_STATE_TABLE_DYN.
| csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32], | |
| csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_TABLE_DYN], pl.INT32], |
| csa_inner_norm_w: pl.Tensor[[N_RANKS, IDX_HEAD_DIM], pl.BF16], | ||
| csa_inner_kv_state: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32], | ||
| csa_inner_score_state: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32], | ||
| csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32], |
There was a problem hiding this comment.
Update the csa_inner_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_INNER_STATE_TABLE_DYN.
| csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32], | |
| csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_TABLE_DYN], pl.INT32], |
Implements the DeepSeek V4 prefill refactor to Qwen-style packed,
request-aware scheduling (issue #591): one fixed-
Tchild-kernel tile at atime, scheduled over a packed batch.
What this adds
New file
models/deepseek/v4/prefill_chunked_layer.py, a packed variant ofprefill_layer.py.prefill_layer_coreexpands the batch and sequencedimensions internally:
config.pyand the child kernels (prefill_attention_*,moe,compressor/indexer) are unchanged. All batch-dependent dynamic shapes
(
pl.dynamic) are local to this file.moe_epochis the global execution ordinal of the MoE call(
chunk_tile_offsets[request_id] + tile_id + 1, the exclusive prefix sumof tile counts). It must be gap-free because the EP done-windows are
monotonic counters —
request_id * MAX_TILES + tile_id + 1would deadlock.sum(chunk_lens); cache/state/tables areuser_batch-concatenated request-local slices.Coordinate system (shared by JIT and golden)
chunk_offsets[r] + tile_id*T + t)WIN + tfor the current tile)position_ids: absoluteThe host-side packed metadata and the golden both reuse each child
build_*_tensor_specs(start_pos=tile_ctx, num_tokens=valid)per tile, whichalready encodes the absolute-position ring / overlay / compressed / state
formulas. As a result a
batch=1 / chunk_len=Tcase reproduces the originalsingle-tile
golden_prefill_layerbit-exactly (max abs diff 0 forswa/hca/csa), which anchors the request/tile machinery, child dispatch, MoE
collective and scatter.
Validation (a2a3, 2-card EP2, vs golden)
chunk_lens=128— SWA (L0) / HCA (L9) / CSA (L8)chunk_lens=128,128(multiple sequential MoE calls reusing the windows)chunk_lens=256/128,256(chunk_len > T)Host-side golden is also validated in pure Python (single-tile bit-exact vs
the original; multi-tile reference runs with cross-tile cache/state
continuity).
Run example:
Note:
a2a3simis not a faithful platform for these EP2 kernels (theoriginal
prefill_layer.pyalso fails ona2a3sim); correctness must bechecked on real
a2a3.Known limitation (the WIP part)
Multi-tile within a single request (
chunk_len > T) is not yet correct: thecross-tile request-local cache read-after-write is not ordered across the
pl.rangetile loop. Tile N writes the ring/compressed/state cache in place;tile N+1 must read it, but the dependency is not enforced, so tile N+1's
early tokens race on stale KV (error concentrates there; ~9.6% over the 1%
rel-diff bar).
Attempts that did not fix it:
pl.scope()and useauto_scope=True— still ~10%.pl.range(init_values=(...))— fails to compile:a loop-carried slice passed to a fixed-shape inline parameter loses
inferred tensor metadata (
missing inferred tensor metadata for parameter).Planned fix: copy each request's cache/state into static-shape scratch
buffers (
pl.create_tensor, which keep metadata), thread those through thetile loop via
init_values, and materialize them back to the packed cacheafter the request (fresh-scratch + two-step materialization).
Dependencies
pl.dynamicdims;@pl.jit.host+pl.dynamicotherwise raises NameErrorin the generated host orchestration).
(
task-submit --device auto --device-num 2, or an explicit healthy pair);the program's
--deviceoverrides-dto the locked cards.Coverage TODO
chunk_len > Tfor swa/hca/csa.chunk_start > 0) and partial tiles(
chunk_lennot a multiple ofT).