Skip to content

[WIP] DeepSeek V4: packed request-aware chunked prefill layer#625

Draft
zhaozhaozz wants to merge 1 commit into
hw-native-sys:mainfrom
zhaozhaozz:feat/dsv4-prefill-chunked-packed
Draft

[WIP] DeepSeek V4: packed request-aware chunked prefill layer#625
zhaozhaozz wants to merge 1 commit into
hw-native-sys:mainfrom
zhaozhaozz:feat/dsv4-prefill-chunked-packed

Conversation

@zhaozhaozz

Copy link
Copy Markdown
Contributor

Status: WIP / Draft. Single-tile (all attention kinds) and
multi-request are validated on real NPU; multi-tile within one request
(chunk_len > T) is not yet correct (see Known limitation). Opening early
for visibility and review of the approach.

Implements the DeepSeek V4 prefill refactor to Qwen-style packed,
request-aware scheduling (issue #591): one fixed-T child-kernel tile at a
time, scheduled over a packed batch.

What this adds

New file models/deepseek/v4/prefill_chunked_layer.py, a packed variant of
prefill_layer.py. prefill_layer_core expands the batch and sequence
dimensions internally:

l3_prefill_layer (rank loop)
  -> prefill_layer_core (packed, per rank)
       for request_id in pl.range(user_batch):
         for tile_id in pl.range(ceil(chunk_len / T)):
           gather tile-local [T, ...] inputs from the packed buffers
           call prefill_attention_{swa,hca,csa} directly   (fixed T)
           call moe directly                                (fixed T)
           scatter valid rows back into packed x_next
  • config.py and the child kernels (prefill_attention_*, moe,
    compressor/indexer) are unchanged. All batch-dependent dynamic shapes
    (pl.dynamic) are local to this file.
  • moe_epoch is the global execution ordinal of the MoE call
    (chunk_tile_offsets[request_id] + tile_id + 1, the exclusive prefix sum
    of tile counts). It must be gap-free because the EP done-windows are
    monotonic counters — request_id * MAX_TILES + tile_id + 1 would deadlock.
  • Token buffers are packed to sum(chunk_lens); cache/state/tables are
    user_batch-concatenated request-local slices.

Coordinate system (shared by JIT and golden)

  • packed token buffers: global packed row (chunk_offsets[r] + tile_id*T + t)
  • cache / state / tables: request-local
  • sparse-index overlay: tile-local (WIN + t for the current tile)
  • position_ids: absolute

The host-side packed metadata and the golden both reuse each child
build_*_tensor_specs(start_pos=tile_ctx, num_tokens=valid) per tile, which
already encodes the absolute-position ring / overlay / compressed / state
formulas. As a result a batch=1 / chunk_len=T case reproduces the original
single-tile golden_prefill_layer bit-exactly (max abs diff 0 for
swa/hca/csa), which anchors the request/tile machinery, child dispatch, MoE
collective and scatter.

Validation (a2a3, 2-card EP2, vs golden)

Case Result
single-tile chunk_lens=128 — SWA (L0) / HCA (L9) / CSA (L8) PASS
multi-request chunk_lens=128,128 (multiple sequential MoE calls reusing the windows) PASS
multi-tile chunk_lens=256 / 128,256 (chunk_len > T) FAIL ~9.6% (see below)

Host-side golden is also validated in pure Python (single-tile bit-exact vs
the original; multi-tile reference runs with cross-tile cache/state
continuity).

Run example:

python models/deepseek/v4/prefill_chunked_layer.py -p a2a3 --layer-id 0 --chunk-lens 128 -d <2 devices>

Note: a2a3sim is not a faithful platform for these EP2 kernels (the
original prefill_layer.py also fails on a2a3sim); correctness must be
checked on real a2a3.

Known limitation (the WIP part)

Multi-tile within a single request (chunk_len > T) is not yet correct: the
cross-tile request-local cache read-after-write is not ordered across the
pl.range tile loop. Tile N writes the ring/compressed/state cache in place;
tile N+1 must read it, but the dependency is not enforced, so tile N+1's
early tokens race on stale KV (error concentrates there; ~9.6% over the 1%
rel-diff bar).

Attempts that did not fix it:

  1. Drop the per-tile pl.scope() and use auto_scope=True — still ~10%.
  2. Re-slice the cache inside the tile loop — still ~9.6%.
  3. Thread the cache via pl.range(init_values=(...)) — fails to compile:
    a loop-carried slice passed to a fixed-shape inline parameter loses
    inferred tensor metadata (missing inferred tensor metadata for parameter).

Planned fix: copy each request's cache/state into static-shape scratch
buffers (pl.create_tensor, which keep metadata), thread those through the
tile loop via init_values, and materialize them back to the packed cache
after the request (fresh-scratch + two-step materialization).

Dependencies

Coverage TODO

  • Fix and validate chunk_len > T for swa/hca/csa.
  • Optional history (chunk_start > 0) and partial tiles
    (chunk_len not a multiple of T).

Add prefill_chunked_layer.py, a Qwen-style packed prefill variant of
prefill_layer.py. prefill_layer_core expands the batch and sequence
dimensions internally: it loops over requests and fixed-T token tiles,
gathers tile-local [T, ...] inputs from the packed buffers, and calls the
existing fixed-T child kernels (prefill_attention_{swa,hca,csa} and moe)
per tile with a unique, gap-free moe_epoch, scattering valid rows back into
the packed x_next. config.py and the child kernels are untouched; all
batch-dependent dynamic shapes (pl.dynamic) are local to this file.

Coordinate system, shared by JIT and golden: packed token buffers use
global packed rows, cache/state/tables are request-local, the sparse-index
overlay is tile-local, and position_ids are absolute. The host-side packed
metadata and golden reuse each child build_*_tensor_specs(start_pos,
num_tokens) per tile, so a batch=1 / chunk=T case reproduces the original
single-tile golden bit-exactly.

Validated on a2a3 (2-card EP2) against golden: single-tile (chunk_len <= T)
for swa/hca/csa, and multi-request packing with multiple sequential MoE
collective calls reusing the shared windows.

Multi-tile within one request (chunk_len > T) is still WIP: the cross-tile
request-local cache read-after-write is not yet ordered across the pl.range
tile loop. The children mutate the cache pl.Out in place rather than
returning it, and a loop-carried cache slice cannot be threaded through
pl.range init_values into a fixed-shape inline parameter (it loses inferred
tensor metadata). Fixing this needs a fresh static-shape scratch buffer
threaded across tiles with two-step materialization back to the packed
cache.
@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: dc2b97b8-8d83-439e-90ea-1377472b0af1

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

1 similar comment
@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: dc2b97b8-8d83-439e-90ea-1377472b0af1

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a packed, request-aware chunked prefill single-layer implementation for DeepSeek-V4 with MoE EP2, including the core JIT-compiled functions, host-side launchers, tensor spec builders, and golden reference tests. The feedback highlights a critical issue regarding dynamic dimension bindings: the same dynamic dimensions are reused for both state blocks and state block tables, which actually have different physical and logical sizes. To prevent runtime shape mismatches or compilation errors, separate dynamic dimensions should be defined and bound for the state block tables across the core and host-side functions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +145 to +147
PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN")
PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN")
PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dynamic dimensions for the state block tables should be defined separately from the state block dimensions. HCA_STATE_BLOCK_NUM (physical blocks) and HCA_STATE_MAX_BLOCKS (logical blocks in the table) are different constants, meaning their packed runtime dimensions (batch * HCA_STATE_BLOCK_NUM vs batch * HCA_STATE_MAX_BLOCKS) will differ. Using the same dynamic dimension PREFILL_HCA_STATE_BLOCKS_DYN for both will cause a runtime shape mismatch or compilation error during dynamic dimension binding.

Please define separate dynamic dimensions for the state block tables.

Suggested change
PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN")
PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN")
PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN")
PREFILL_HCA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_BLOCKS_DYN")
PREFILL_CSA_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_BLOCKS_DYN")
PREFILL_INNER_STATE_BLOCKS_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_BLOCKS_DYN")
PREFILL_HCA_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_HCA_STATE_TABLE_DYN")
PREFILL_CSA_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_CSA_STATE_TABLE_DYN")
PREFILL_INNER_STATE_TABLE_DYN = pl.dynamic("DEEPSEEK_PREFILL_INNER_STATE_TABLE_DYN")

[PREFILL_HCA_STATE_BLOCKS_DYN, HCA_STATE_BLOCK_SIZE, HCA_MAIN_OUT_DIM],
pl.FP32,
],
hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the newly defined PREFILL_HCA_STATE_TABLE_DYN dynamic dimension for hca_compress_state_block_table to avoid shape mismatch with hca_cmp_kv_state.

Suggested change
hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32],
hca_compress_state_block_table: pl.Tensor[[PREFILL_HCA_STATE_TABLE_DYN], pl.INT32],

csa_cmp_norm_w: pl.Tensor[[HEAD_DIM], pl.BF16],
csa_cmp_kv_state: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32],
csa_cmp_score_state: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32],
csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the newly defined PREFILL_CSA_STATE_TABLE_DYN dynamic dimension for csa_compress_state_block_table to avoid shape mismatch with csa_cmp_kv_state.

Suggested change
csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32],
csa_compress_state_block_table: pl.Tensor[[PREFILL_CSA_STATE_TABLE_DYN], pl.INT32],

csa_inner_norm_w: pl.Tensor[[IDX_HEAD_DIM], pl.BF16],
csa_inner_kv_state: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32],
csa_inner_score_state: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32],
csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use the newly defined PREFILL_INNER_STATE_TABLE_DYN dynamic dimension for csa_inner_compress_state_block_table to avoid shape mismatch with csa_inner_kv_state.

Suggested change
csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32],
csa_inner_compress_state_block_table: pl.Tensor[[PREFILL_INNER_STATE_TABLE_DYN], pl.INT32],

Comment on lines +278 to +284
hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_BLOCKS_DYN)
csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the dynamic bindings in prefill_layer_core to use the correct table dynamic dimensions.

Suggested change
hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_BLOCKS_DYN)
csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
hca_compress_state_block_table.bind_dynamic(0, PREFILL_HCA_STATE_TABLE_DYN)
csa_cmp_kv_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_cmp_score_state.bind_dynamic(0, PREFILL_CSA_STATE_BLOCKS_DYN)
csa_compress_state_block_table.bind_dynamic(0, PREFILL_CSA_STATE_TABLE_DYN)
csa_inner_kv_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_score_state.bind_dynamic(0, PREFILL_INNER_STATE_BLOCKS_DYN)
csa_inner_compress_state_block_table.bind_dynamic(0, PREFILL_INNER_STATE_TABLE_DYN)

[N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN, HCA_STATE_BLOCK_SIZE, HCA_MAIN_OUT_DIM],
pl.FP32,
],
hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the hca_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_HCA_STATE_TABLE_DYN.

Suggested change
hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_BLOCKS_DYN], pl.INT32],
hca_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_HCA_STATE_TABLE_DYN], pl.INT32],

csa_cmp_norm_w: pl.Tensor[[N_RANKS, HEAD_DIM], pl.BF16],
csa_cmp_kv_state: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32],
csa_cmp_score_state: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN, CSA_STATE_BLOCK_SIZE, CSA_MAIN_OUT_DIM], pl.FP32],
csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the csa_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_CSA_STATE_TABLE_DYN.

Suggested change
csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_BLOCKS_DYN], pl.INT32],
csa_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_CSA_STATE_TABLE_DYN], pl.INT32],

csa_inner_norm_w: pl.Tensor[[N_RANKS, IDX_HEAD_DIM], pl.BF16],
csa_inner_kv_state: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32],
csa_inner_score_state: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN, INNER_STATE_BLOCK_SIZE, INNER_OUT_DIM], pl.FP32],
csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the csa_inner_compress_state_block_table parameter in l3_prefill_layer to use PREFILL_INNER_STATE_TABLE_DYN.

Suggested change
csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_BLOCKS_DYN], pl.INT32],
csa_inner_compress_state_block_table: pl.Tensor[[N_RANKS, PREFILL_INNER_STATE_TABLE_DYN], pl.INT32],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant