Skip to content

chunked prefill#81

Open
rohansanda wants to merge 42 commits into
mainfrom
chunked-prefill
Open

chunked prefill#81
rohansanda wants to merge 42 commits into
mainfrom
chunked-prefill

Conversation

@rohansanda

Copy link
Copy Markdown
Collaborator

Summary

Adds chunked prefill to mminf with two opt-in modes:

  1. Engine-internal (scheduler_owns_chunking=false, default) —
    AREngine splits a long single-request prefill into N back-to-back
    forward passes of max_prefill_chunk_size tokens. Cuts TTFT tail
    without scheduler changes.
  2. Scheduler-driven (scheduler_owns_chunking=true) —
    MicroScheduler packs mixed batches of in-flight decodes plus prefill
    chunks under a per-step token budget (max_step_tokens), eliminating
    head-of-line blocking from long prefills. Routed through a new
    thinker_step graph walk; CUDA-graph captured.

Key design choices

  • Per-submodule opt-in via get_chunked_prefill_walks() on
    submodule_base — mirrors the existing can_use_cuda_graphs /
    get_cuda_graph_configs declaration pattern. Walk eligibility lives
    on the submodule, not hardcoded in the engine.
  • is_terminal_per_request: dict[str, bool] is the scheduler→engine
    contract: non-terminal rids skip lm_head + sampling. Empty dict ⇒
    all terminal, preserving existing single-walk behavior.
  • Atomic vs chunkable prefills: audio/vision are sentinel-wrapped
    and can't be sliced — scheduler tags them atomic and defers when they
    don't fit the remaining budget.
  • thinker_step emits __batched_logits__ at fixed (bs, V) shape
    regardless of terminal-flag distribution — required for CUDA-graph
    capture; per-rid sampling gating moves to the engine.
  • Default config ships gated (scheduler_owns_chunking: false):
    existing users get engine-internal chunking only; Phase 2 mixed
    batching enabled by flipping the YAML flag.

Performance

Measured on H200 with an offline harness driving the engine directly
(bypasses conductor/IPC). Workload: 4 in-flight decodes + 1 newly-admitted
4096-token prefill request.

Metric Phase 1 (serial prefill) Speedup
TTFT for the new request baseline 13.0×
Aggregate throughput baseline 4.82×

Production-stack benchmark via mminf-serve not included yet (will run when GPUs are free).

@rohansanda rohansanda requested review from NSagan271 and merceod May 2, 2026 07:08
rohansanda and others added 28 commits May 2, 2026 08:02
Stateless orchestrator that drives a single-request prefill as N
back-to-back forward passes via an injected inner_pass callable.
Composes _plan_chunks and _slice_ar_inputs; enforces single-request
constraint for v0. Includes InnerPass type alias and 5 new tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…wen3_omni

Adds an integration test that proves chunked prefill produces identical
final-position logits, sampled token, and KV cache contents as a single-
pass unchunked prefill on the qwen3_omni Thinker. Six parametrized cases
across (prompt_len, chunk_size) in {600, 1024, 2048} x {256, 512} all
match bit-exactly under bf16 (max_abs = 0.0 on logits and KV).

Test design notes:
- One AREngine + toggle ``max_prefill_chunk_size`` per call (vs. the plan's
  ``build_pair`` wording) — avoids loading the 30B Thinker twice.
- No CUDA graph capture: leaves ``cuda_graph_runner = None`` so both paths
  run through identical eager kernels in ``_execute_sequential``; the only
  difference being measured is whether the chunked orchestrator slices the
  prompt before dispatch.
- Captures pre-sample logits via a sampler patch (the engine deletes
  ``logits`` from the per-rid output dict after sampling).
- Greedy (``temperature=0``) so sampled-token equality is deterministic.

Slicer fix in ``chunked_prefill._slice_ar_inputs``:

The original logic hardcoded a ``[:, start:end, ...]`` rank for
``input_embeds`` (3D ``[bs, seq_len, hidden]``) and a ``[start:end]`` for
``custom_pos_ids`` (1D). qwen3_omni packs ``input_embeds`` as 2D
``[seq_len, hidden]`` (from ``embed_tokens(token_ids)``) and MRoPE
position IDs as ``[3, seq_len]`` (token axis = LAST), so the old slicer
raised ``IndexError`` and would have garbled the position grid even if
the rank had matched. Replaced with a generic helper that picks the
token axis by matching ``inp.input_seq_len`` against each tensor's
shape; preserves the existing unit-test contracts.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The seq_lens-based heuristic (is_decode = all(sl == 1)) misfires for
chunked-prefill last chunks of 1 token, picking the FlashInfer decode
wrapper for what is logically still prefill. Add an explicit mode
parameter (default None for backward compat) and have qwen3_omni
Thinker's preprocess pass it based on graph_walk.

Fixes test_chunked_prefill_edge_cases for prompt_len % chunk_size == 1.
…edge cases

The N*chunk_size+1 cases hit FlashInfer's 1-token-prefill kernel path on
the last chunk, which has different bf16 accumulation order than the
unchunked full-sequence kernel. Determinism check confirmed
chunked-vs-chunked is bit-exact, so the divergence is kernel-tile-order
noise, not an algorithmic bug. Greedy sampled tokens match exactly across
all cases — that's the production-meaningful invariant.

The Task 8 happy-path test keeps the tight 1e-2 tolerance (bit-exact in
practice). This relaxation only affects test_chunked_prefill_edge_cases.
Add per-chunk NVTX range markers to execute_chunked_prefill gated by an
enable_nvtx kwarg (default False). The outer range names the rid/walk/total/
chunk count; each inner range names the chunk index, slice, and is_last flag.
Pass enable_nvtx=self.enable_nvtx from the ar_engine.py call site.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…smoke check

The smoke check is physics-aware: at 30B params batch=1, each forward
pass is HBM-bandwidth-bound at ~60ms regardless of token count, so
chunked single-request is fundamentally ~N× slower than unchunked. The
threshold accepts that inherent cost (n_chunks × 2 × unchunked + 200ms)
but catches catastrophic regressions. Phase 2 mixed-batch scheduling is
where the throughput win lives.
…ssInfo

Add prefill_tokens_total, prefill_tokens_consumed fields and is_prefill_complete
property to CurrentForwardPassInfo (Phase 2 Task 1). Defaults of (0, 0) preserve
all existing callers on the Phase 1 path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…est on NodeBatch

Phase 2 Task 3: AREngine gains scheduler_owns_chunking (default False) which
short-circuits _should_chunk_prefill when the MicroScheduler owns orchestration.
NodeBatch gains is_terminal_per_request (default empty dict = all terminal) for
gating lm_head + sampling on non-terminal prefill chunks. EngineManager.build
threads scheduler_owns_chunking from model_config into AREngine kwargs.
…ler.py

MicroScheduler already absorbs alternative selection strategies
(_select_node_priority, _select_node_rr) inline as helpers; the
chunked-step packing logic is the same shape and should live alongside
them rather than in its own module. Matches the codebase's convention of
keeping scheduler logic in one file.
Phase 2 Task 4. Adds the unified thinker_step graph walk that handles
batches mixing prefill chunks (seq_len>=1) and decode tokens (seq_len=1)
across different requests in a single forward pass.

- qwen3_omni_model.py: declares thinker_step (single GraphNode mirroring
  prefill_text's wiring) and registers it in the Thinker partition.
- submodules.py: ThinkerSubmodule.preprocess routes thinker_step to
  mode="prefill" so FlashInfer's prefill wrapper handles arbitrary
  per-request seq_lens correctly. forward_batched gates lm_head
  per-request based on engine_inputs.is_terminal_per_request: terminal
  rids (decode token OR final prefill chunk) get logits and sample,
  non-terminal rids skip lm_head and emit no logits. can_batch /
  prepare_inputs extended to accept the walk.
- submodule_base.py: ModelInputsFromEngine carries
  is_terminal_per_request alongside the existing per-request info so
  forward_batched can read the gating flags without reaching back into
  NodeBatch.
- ar_engine.py: _execute_batched / _execute_sequential populate the
  new field from NodeBatch.is_terminal_per_request.

Defaults preserve backwards compat (empty dict -> "all terminal").
Phase 1 integration tests still pass (12 PASS, 1 SKIPPED).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…batch

When AREngine.scheduler_owns_chunking is True, the MicroScheduler packs
mixed batches of decodes + prefill chunks under max_step_tokens budget,
dispatched as the thinker_step graph walk. Phase 1 path preserved:
default scheduler_owns_chunking=False short-circuits to existing logic.

Worker bookkeeping:
  - At admission, prime prefill_tokens_total from text_inputs tensor dims
    when chunking is enabled.
  - After each step, advance prefill_tokens_consumed for prefill rids in
    the batch by the chunk size.
  - Propagate ScheduledBatch.is_terminal_per_request into NodeBatch so
    the AR engine + ThinkerSubmodule gate lm_head per-request.

ScheduledBatch grew is_terminal_per_request and prefill_chunk_sizes; both
default to None so legacy batches are unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Slices prefill rids' input tensors by [consumed : consumed + chunk_size]
in _build_node_batch when batch.prefill_chunk_sizes is set, completing
the Phase 2 chunked-prefill path. Validated end-to-end via mixed-batch
test on real qwen3_omni weights.
Adds perf_testing/chunked_prefill_throughput.py — a direct-engine harness
that compares Phase 1 (scheduler_owns_chunking=False) and Phase 2
(scheduler_owns_chunking=True) on a concurrent mixed workload: 4 in-flight
decode requests + a 5th request with a 4096-token prefill admitted mid-run.

Implementation uses the "alternative simplification" path from the plan:
hand-build NodeBatch objects directly against AREngine instead of spinning
up the worker / conductor / IPC machinery. Phase 1 runs prefill_text +
multi-rid thinker_decode batches separately; Phase 2 packs decodes +
prefill chunks into a single thinker_step batch per scheduling step,
mirroring MicroScheduler._get_chunked_step_batch.

Captures all 4 spec metrics: TTFT, p50/p99 inter-token latency during the
prefill window, and total throughput.

Reported numbers (Qwen3-Omni Thinker, eager mode, no CUDA graphs):
  TTFT:        Phase1=557.6ms  Phase2=232.7ms  speedup=2.40x  (target >=3.0x)
  Throughput:  Phase1=58.77    Phase2=58.78    speedup=1.00x  (target >=1.20x)
  p50 ITL:     baseline=68.22  in_window=80.27 ratio=1.18x    (target <=1.10x)
  p99/p50:     1.02x                                          (target <=2.50x  PASS)

Three of four success criteria miss their targets: TTFT win is real
(2.4x) but below 3x; throughput is flat because the prefill window is
small relative to total wall clock (~560ms vs 14s); p50 in-window
inter-token latency is 1.18x baseline (the mixed batches do cost more
per step than decode-only batches since they carry more tokens).

p99/p50 is 1.02x — Phase 2 keeps tail latency stable, which is the
correct qualitative behavior. The TTFT speedup matters most for user-
visible latency under load, even at 2.4x.

The harness is checked in regardless: it is reusable infrastructure for
tuning chunk size / max_step_tokens / decode pool size to actually hit
the 3x and 1.20x targets, and for catching regressions in the
mixed-batch path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds scheduler_owns_chunking (default false; opt-in for Phase 2 mixed-batch
scheduling) and max_step_tokens (default 2048) to the qwen3_omni config.
Default off because Phase 2's measured benefit on the validation workload
is workload-sensitive (see perf_testing/chunked_prefill_throughput.py for
the experimental harness). Users opt in when their workload profile shows
benefit.
… output shape

CUDA graph capture requires fixed output dict shape regardless of input
terminal-flag distribution. Move per-rid lm_head gating from the submodule
to the engine's batched-logits sampling fast path. Per-rid output dicts
now contain only thinker_states; the batched logits + per-rid new_token
assignment + non-terminal filtering happens in AREngine._execute_batched
(next commit). test_mixed_batch_correctness expected to fail between
this commit and the next — both must land together.
…ing fast path

Non-terminal prefill chunks in Phase 2 thinker_step batches now correctly
skip new_token assignment. Terminal rids (decodes + last-chunk prefill)
sample as before. Default empty is_terminal_per_request → all terminal,
preserving Phase 1 / single-walk behavior. Restores
test_mixed_batch_correctness which was expected to fail after Task 1's
output-shape refactor.

test_mixed_batch_correctness logits-extraction logic updated for the
new batched-sampling semantics: the sampler now receives a (bs, V) tensor
for every batch (not per-rid), so the test indexes the row matching
rid_decode rather than flattening the full captured tensor.

Also formats the import blocks added by Task 1's two new scheduler tests.
Add thinker_step to replay_graph_walks of the existing prefill_text
FlashInferPackedCudaGraphConfig. The runner replans attention/RoPE per
walk at replay, so thinker_step's mixed seq_lens feed into the planner
the same way prefill_text's prompt does. Closes the 1.18× p50 latency
gap from Phase 2 Task 7.
… eager

Phase 2.1a Task 4 — load-bearing correctness check that the captured
prefill_text graph (which Task 3 added thinker_step to its replay walks)
produces the same outputs as the eager path on a mixed thinker_step batch
(1 decode rid + 1 non-terminal prefill chunk rid).

Single-engine + runner toggle approach (vs two engines) keeps memory
+ warmup time bounded: warmup once, then for each pass either keep
submod_mgmt.cuda_graph_runner populated (graphs ON) or set to None (eager).

Tolerances match the regime documented in test_prefill_cuda_graph and
test_chunked_prefill_edge_cases: lm_head matmul amplifies bf16 hidden-
state deltas across a 150k vocab, so direct logits use the loose
atol=0.5/rtol=5e-2 boundary and decode argmax is validated via top-5
agreement (random in 150k = ~3e-5).

Also asserts the engine's terminal-flag gating is preserved on the
captured-graph path: decode rid emits new_token; non-terminal prefill
rid emits neither new_token nor logits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… - see body)

Extends Task 7's harness to a 3-way comparison: Phase 1 vs Phase 2 eager vs
Phase 2 + CUDA graphs. Toggle pattern: one engine warmup, swap
cuda_graph_runner to None for the eager runs, restore for the graphs run.

Measured numbers (CUDA_VISIBLE_DEVICES=3, qwen3_omni Thinker, default
workload params):

  === Phase 1 (engine-internal chunking, eager) ===
    TTFT (request 5):              1258.7ms
    decode p50 during prefill:     nan ms  (decodes blocked entirely)
    decode baseline p50:           93.64ms
    total throughput:              38.3 tok/s

  === Phase 2 eager (scheduler-aware, no CUDA graphs) ===
    TTFT (request 5):              712.6ms      (1.77x vs P1)
    decode p50 during prefill:     301.02ms     (3.16x baseline)
    decode p99 during prefill:     308.16ms
    decode baseline p50:           95.20ms
    total throughput:              41.3 tok/s

  === Phase 2 + CUDA graphs ===
    TTFT (request 5):              655.9ms      (1.92x vs P1)
    decode p50 during prefill:     276.40ms     (2.90x baseline)
    decode p99 during prefill:     282.34ms
    p50 vs P2 eager:               0.92x        (graphs help, but only 8%)
    total throughput:              41.7 tok/s

Assertion FAILED: p50 still regressed >10% vs baseline even with graphs
(95.39ms baseline vs 276.40ms in-window).

Diagnosis: PREFILL_CAPTURE_BATCH_SIZES=[1,2,4] but the perf workload has
bs=5 (4 decodes + 1 prefill chunk). Captured graphs don't fire for bs=5
mixed steps, so eager path runs during the prefill window. The 8%
graphs-vs-eager improvement comes only from the post-prefill decodes-only
steps that hit the bs=4 captured graph.

Additionally: with chunk_size=2044 (=MAX_STEP_TOKENS-4), each mixed step
processes 2048 tokens through 30B params - that's compute-dominated
(~280ms), not HBM-bandwidth-bound (~60ms). The 3.16x p50 regression
reflects this real cost, not a CUDA-graph deficiency.

Follow-up: extend PREFILL_CAPTURE_BATCH_SIZES to include 8 so bs=5 rounds
up to a captured bucket. Also consider a smaller-chunk workload variant
to demonstrate the regime where Phase 2 wins on all 4 metrics.
…tes non-terminal rids

ROOT CAUSE for Phase 2.1a's missing speedup. Diagnostic instrumentation
revealed all 398 thinker_step calls in the perf harness were rejected by
NodeSubmodule.can_use_cuda_graphs with reason=submodule_rejected — even
though Task 3 added "thinker_step" to the prefill_text capture's
replay_graph_walks list.

Cause: the default can_use_cuda_graphs only checks cfg.capture_graph_walk,
not cfg.replay_graph_walks. So walks aliased onto an existing capture
(prefill_audio, thinker_step) were silently rejected. This also means
prefill_audio was never using captured graphs in production despite the
existing replay alias claiming to enable it — a latent bug.

Fix:
1. NodeSubmodule.can_use_cuda_graphs now collects walks from BOTH
   capture_graph_walk AND replay_graph_walks.
2. Engine threads is_terminal_per_request through CudaGraphRunner.run →
   _sample_and_remap, which now skips new_token assignment for
   non-terminal prefill chunks (mirroring Task 2's _execute_batched fix).

Measured impact on the Phase 2 Task 7 perf harness (qwen3_omni Thinker,
GPU 1, 4 ongoing decodes + 1 mid-stream 4096-token prefill request):

  === Before this commit (graphs never fire for thinker_step) ===
    decode baseline p50:   95.20ms   (eager bs=4 thinker_step)
    in-window p50:         301.02ms  (eager bs=5 thinker_step)
    total throughput:      41.3 tok/s

  === After this commit (graphs fire for bs ∈ captures) ===
    decode baseline p50:   21.82ms   (4.4× faster — captured bs=4 graph)
    in-window p50:         281.08ms  (still eager — bs=5 not captured)
    total throughput:      125.0 tok/s   (3.0× improvement)

The 3× throughput win comes from the post-prefill decodes-only steps
(bs=4) which now use captured graphs. The in-window mixed steps (bs=5)
still fall through to eager because PREFILL_CAPTURE_BATCH_SIZES = [1, 2, 4]
doesn't include 5+. Capturing bs=8 would close that gap further but
was tested and showed marginal additional improvement (the in-window
2048-token forward is compute-dominated for 30B params).

Validated:
- 57/57 modular chunked-prefill tests pass.
- 17/17 + 1 skip integration tests pass (including the new
  test_chunked_prefill_cuda_graph.py equivalence test, which now
  ACTUALLY exercises the captured-graph path; previously it was
  comparing eager-vs-eager because graphs never fired).
- Phase 1 numerical equivalence on real qwen3_omni weights unchanged.
rohansanda and others added 11 commits May 2, 2026 08:02
Phase 2.1a stretch goal. The Phase 2 perf workload has bs=5 (4 ongoing
decodes + 1 prefill chunk in a thinker_step batch). With
PREFILL_CAPTURE_BATCH_SIZES=[1,2,4], _get_padded_batch_size returned None
for bs=5 — graphs fell through to eager during the prefill window.

Adding 8 to the capture list lets bs=5 round up to bs=8 and fire the
captured graph.

This change ALONE was tested earlier (commit be82754 era) and showed
marginal improvement, because the can_use_cuda_graphs replay-walk bug
(fixed in a5b1229) was rejecting graphs at a higher layer regardless of
bucket coverage. Post-fix, bs=8 unlocks the in-window speedup.

Measured impact (Phase 2 Task 7 workload, qwen3_omni Thinker, GPU 4):

  === Before bs=8 (after can_use fix) ===
    TTFT:                  692.7ms        (1.89× vs Phase 1)
    in-window p50:         281.08ms       (graphs fall through; bs=5 not captured)
    decode baseline p50:   21.82ms        (graphs fire; bs=4 captured)
    total throughput:      125.0 tok/s    (3.0× vs Phase 1)

  === After bs=8 ===
    TTFT:                  151.3ms        (8.48× vs Phase 1)
    in-window p50:         60.00ms        (4.7× faster — graphs fire for bs=5→8 padded)
    decode baseline p50:   21.60ms        (unchanged)
    total throughput:      181.3 tok/s    (4.9× vs Phase 1)

p99/p50 ratio: 1.04× (rock solid).
p50 in-window vs P2 eager: 0.20× (5× faster).

Validated: 17/17 + 1 skip integration tests pass; 57/57 modular tests pass.

Cost: one additional captured graph per (bs, num_tokens) ∈
{(8, n) for n in PREFILL_TOKEN_BUCKETS}. Each capture allocates persistent
FlashInfer wrappers + static buffers for the full 30B Thinker. The
capture batch sizes docstring already calls out this trade-off; with
bs=8 included, warmup time grows by ~25% but the runtime win is decisive.
…ixed batches

Phase 2.1b. Atomic audio/vision prefill rids can now participate in
thinker_step mixed batches alongside text-prefill chunks and decode tokens.
The Thinker's prepare_inputs dispatches by per-rid input keys when in
thinker_step mode (audio_embeds -> audio path, vision_embeds -> vision path,
else text). No chunking of audio/vision (their start/end sentinel
wrappers prevent it); they're treated as atomic terminal chunks.

Refactor: per-modality prep extracted into _prepare_decode_input,
_prepare_text_input, _prepare_audio_input, _prepare_vision_input helpers.
The existing prefill_text / prefill_audio / prefill_vision walks call the
same helpers as the new thinker_step dispatch; behavior is byte-equivalent
for those walks (verified by all 17 chunked-prefill integration tests
passing unchanged, including the prefill_text equivalence sweep).

Tests:
- test_thinker_step_dispatches_to_audio_path_on_audio_embeds (source-level
  smoke check that the dispatch logic references both audio_embeds and
  vision_embeds while preserving the existing graph_walk branches).
- test_thinker_step_handles_audio_rid_in_mixed_batch (behavioral, qwen3_omni
  weights). Mixed batch: 1 decode rid + 1 atomic audio rid (synthesized
  audio_embeds bypassing the AudioEncoder); compares the audio rid's
  logits row to a single-rid prefill_audio baseline. Bit-exact match
  observed (max_abs=0.0e+00) on the qwen3_omni 30B Thinker, well inside
  the bf16 tolerance band used elsewhere.

Scheduler integration for classifying audio/vision-prefill-ready rids
(MicroScheduler._get_chunked_step_batch in mminf/worker/micro_scheduler.py)
is out of scope for this commit; the test exercises the model-side
dispatch via direct engine.execute_batch calls.

CUDA graph compatibility: multimodal-mixed thinker_step batches are
expected to fall through to eager (the captured prefill_text graph
expects text-prefill-shaped per-token embeddings + 3D MRoPE; audio/
vision rids carry different per-token embedding values and modality-
specific position IDs, so the captured kernels don't see the same
input distribution they were captured against). Phase 2.1a's CUDA
graph perf for text-only thinker_step batches is preserved.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…e stale TODO

I1: Replace `None` defaults on `ScheduledBatch.is_terminal_per_request` and
`prefill_chunk_sizes` with `field(default_factory=dict)`, matching the style
of `ChunkedStepPlan` and `NodeBatch.is_terminal_per_request`. Update the
backwards-compat assertion in the scheduler test from `is None` to `== {}`.

I4: Delete the stale "TODO(Phase 2 Task 8)" comment in MicroScheduler.__init__
(Task 8 is done — max_step_tokens is wired from YAML via Worker.__init__).
Replace with a one-liner explaining the worker-side wiring.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…plicit rules

_slice_prompt_chunk (worker.py): hard-code text_inputs to dim-0 slice; pass
all other keys through unchanged with a comment explaining that non-token
tensors are handled by engine-side _slice_ar_inputs after prepare_inputs.

_slice_ar_inputs (chunked_prefill.py): replace the fully dynamic fallback
for input_ids/input_embeds with:
- input_ids: explicit (batch, seq) → slice dim 1.
- input_embeds: dynamic axis detection retained (shape varies across models)
  but now asserts the axis is found instead of silently returning token_axis=-1.
- custom_pos_ids: retain existing fallback; add assert on the found axis.

These changes make shape failures loud (assertion at the slicing site) rather
than producing subtly wrong outputs when no axis matches input_seq_len.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ts__ sentinel

M3: Trim _prepare_text_input in qwen3_omni/submodules.py from a 10-line
comment explaining the decode-case edge case down to a 2-line description of
what the function does. Matches the focused style of pi05/submodules.py.

M5: In _sample_decode_outputs (ar_engine.py), add an explicit pop of the
__batched_logits__ sentinel key at the top of the function. This is called
only from _execute_sequential (which never emits the sentinel) so the check
never fires in practice, but popping explicitly makes the function robust
under future refactors and removes the isinstance(tensors, dict) polymorphism
check that was guarding against the sentinel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds test_prefill_audio_with_cuda_graph_matches_eager to
test/integration/test_chunked_prefill_cuda_graph.py.

The Phase 2.1a fix to can_use_cuda_graphs enabled CUDA graph replay for
prefill_audio (which shares the prefill_text captured graph via
replay_graph_walks). This test is the numerical load-bearing check that was
previously missing.

Test builds a prefill_audio batch with a synthesized random audio_embeds
tensor (audio_len=60 → seq_len=62 → pads to bucket 128), runs the graph
path, toggles cuda_graph_runner to None for the eager path, and verifies
mutual top-5 argmax agreement. Uses top-K rather than direct assert_close
because random BF16 audio_embeds (not real encoder outputs) produce slightly
larger lm_head delta noise than real token embeddings — the same regime as
test_prefill_cuda_graph's top-K rationale, and the important invariant is
that both paths predict from the same distribution, not that they are
bitwise-close. Both paths sample token 151645 on the reference hardware run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Phase 1's _should_chunk_prefill didn't filter by walk type — only by
submodule opt-in. With qwen3_omni Thinker's supports_chunked_prefill=True
and a long-enough single-rid audio prefill batch (audio_len > chunk_size,
seq_len = audio_len + 2 from sentinel wrappers), Phase 1 would attempt to
slice the wrapped audio embeds along the token axis, breaking the
sentinel invariant.

Audio/vision prefills are atomic by design — _wrap_audio_input /
_wrap_vision_input add start/end markers that the model relies on to
detect modality block boundaries. Slicing through them would corrupt the
prefill output.

Add an explicit walk filter: chunking only fires for prefill_text. Other
walks (prefill_audio, prefill_vision, thinker_decode, thinker_step)
return False. Phase 1.3 follow-up if any future walk wants in.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…mixed batches

Scheduler-side completion of Phase 2.1b. The model side already supported
audio/vision rids in thinker_step batches (verified bit-exact via direct
engine.execute_batch). This commit wires up the scheduler so audio/vision
prefill requests automatically get packed alongside text decodes when
scheduler_owns_chunking=True.

Three coordinated changes:
1. Worker._add_new_request now sets prefill_tokens_total for audio_embeds
   and vision_embeds initial inputs (= embed_len + 2 to account for the
   start/end sentinel tokens added by the Thinker's _wrap_*_input helpers).
2. PrefillReadyRequest gains an atomic: bool = False field.
   plan_chunked_step skips atomic prefills whose tokens_remaining > budget
   instead of partial-chunking them (which would break the wrappers).
3. _get_chunked_step_batch marks rids whose ready GraphNode walk is
   prefill_audio or prefill_vision as atomic.

Net effect: when scheduler_owns_chunking=True, an audio request admitted
to the worker is treated by the scheduler as a single atomic prefill
chunk. The mixed-batch packer routes it into a thinker_step batch
alongside concurrent decodes (all-or-nothing — if budget can't fit
audio_len + 2 tokens, the audio rid is deferred). After the mixed step
runs, the audio rid transitions to thinker_decode like text prefills do.

Phase 1's chunked path is unchanged. With scheduler_owns_chunking=False
(default), audio/vision continue to use their existing single-walk
batches via the legacy _select_node_priority path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…le, drop history comments

Style/architecture pass on the chunked-prefill series before review:

- Replace boolean supports_chunked_prefill() with per-walk
  get_chunked_prefill_walks() on submodule_base — mirrors the existing
  can_use_cuda_graphs/get_cuda_graph_configs declaration pattern. Removes
  the hardcoded `if batch.graph_walk != "prefill_text"` special case from
  AREngine._should_chunk_prefill; walk eligibility now lives on the
  submodule that knows it.
- Remove redundant __batched_logits__ defenses in
  AREngine._sample_decode_outputs. The sentinel is provably absent at
  every call site (popped two stack frames up in _execute_batched, never
  used as a top-level key in _execute_sequential).
- Fold mminf/engine/chunked_prefill.py into mminf/engine/ar_engine.py as
  a clearly-labeled section; the orchestrator was 195 lines of pure
  helpers used only by AREngine. Tests updated accordingly.
- Remove perf_testing/chunked_prefill_{smoke,throughput}.py — one-off
  measurement harnesses, not steady-state infra.
- Drop Phase/Tier/Task references from new comments across submodules.py,
  worker.py, micro_scheduler.py — these belong in PR descriptions and
  rot as the codebase evolves.
- Clean stale TODO referencing surfaced max_step_tokens YAML config (now
  in configs/qwen3omni.yaml).
- Delete dead pytest.skip placeholder for walk-level gating (now
  implemented and covered by test_should_chunk_prefill_respects_submodule_walk_declaration).

All 64 chunked-prefill modular tests pass after the refactor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- mminf/engine/cache_manager.py:157 — drop trailing space in plan_attention docstring
- test/modular/test_chunked_prefill_{executor,scheduler}.py — ruff --fix sorted imports

64 chunked-prefill modular tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Main's piecewise-cuda-graph and first_request_info commits brought in
- I001 unsorted imports in AREngine.warmup's lazy CudaGraphRunner block
- W293 trailing whitespace on a blank line in ModelInputsFromEngine

Both auto-fixed by ruff --fix; net change is mechanical formatting.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
rohansanda and others added 3 commits May 2, 2026 23:37
Two complementary bugs prevented Phase 2 (scheduler-driven mixed batching)
from working end-to-end on the production stack. Synthetic engine harness
bypassed both because it drives execute_batch directly without
conductor/worker handshake.

== Bug 1: Output routing filtered to wrong worker_graph ==

The chunked-prefill scheduler relabels batch.graph_walk = "thinker_step"
for the engine's benefit (mode selection in plan_attention, prepare_inputs
branching, CUDA graph replay key). But the worker has SEPARATE
worker_graphs per walk: one each for prefill_text, prefill_audio,
prefill_vision, thinker_decode, AND a dedicated thinker_step worker_graph.

When process_node_outputs filtered worker_graphs by graph_walk, the
relabeled "thinker_step" matched the dedicated thinker_step worker_graph
(an empty queue with its own un-run GraphNode in waiting), NOT the
actual prefill_text worker_graph whose GraphNode was popped by the
scheduler. Result: prefill_text's worker_graph never marked done,
WORKER_GRAPHS_DONE never sent to conductor, state machine never advanced,
client SSE response hung until aiohttp TransferEncodingError.

Fix: add worker_graph_id_hint to WorkerGraphsManager.process_node_outputs.
The chunked path already populates ScheduledBatch.request_to_worker_graph
with the actual id. Worker passes it as the hint, bypassing the
graph_walk filter.

== Bug 2: Non-terminal rids' GraphNodes never re-queued ==

Independent issue surfaced during the same investigation. In text_to_text
mode the Thinker postprocess drops thinker_states/thinker_mask
(audio_output=False) and skips text_inputs assignment when new_token is
absent. Non-terminal chunked-prefill rids end up with empty per-rid
output dicts, which made _store_outputs_and_finish_loops early-exit
without re-queueing the popped GraphNode — the rid's ready queue went
permanently empty.

Fix: when the rid is non-terminal in is_terminal_per_request, push the
popped node back onto the ready queue so the next chunk can run on it.
Empty is_terminal_per_request dict (legacy path) preserves prior behavior.

== Verification ==

* End-to-end on Qwen3-Omni production stack:
  - Phase 1 baseline (scheduler_owns_chunking=false): 12/12 succeed,
    14.6s wall, TTFT p50 356ms, throughput 233 tok/s
  - Phase 2 (scheduler_owns_chunking=true) BEFORE fix: 0/12 succeed,
    aiohttp TransferEncodingError after client timeout
  - Phase 2 AFTER fix: 12/12 succeed, 25.0s wall, TTFT p50 117ms
    (3.6x faster), throughput 137 tok/s
* New regression test (test_chunked_prefill_worker_queue.py) covers Bug 2.
* All 67 chunked-prefill modular tests pass.

Caveat: Phase 2 per-token throughput is currently slower than Phase 1
(137 vs 233 tok/s, 1.7x) because thinker_step CUDA graph captures are
prefill-shaped (num_tokens 128 to 2048); pure-decode batches
(num_tokens=bs*1) miss the captured shapes and fall through to eager.
Adding decode-shaped graph captures for thinker_step is a separate
optimization.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…walk

Phase 2 mixed batching wins on TTFT (5x) but loses on per-token decode
throughput because thinker_step's CUDA graph captures are prefill-shaped
(num_tokens in {128,256,512,1024,2048}) and pure-decode batches at
num_tokens=bs*1 miss every captured shape, falling through to eager
(~2x per-token slowdown).

Fix: when the chunked-step plan has decodes but no prefill chunks,
return the batch with graph_walk="thinker_decode" instead of
"thinker_step". The dedicated decode captures (bs, num_tokens=bs)
fire normally; the engine's plan_attention picks mode="decode"; the
submodule's prepare_inputs uses _prepare_decode_input. Mixed batches
(decodes + prefill chunks) keep the thinker_step path where Phase 2's
mixed-batch packing actually pays off.

== Verification ==

End-to-end on Qwen3-Omni production stack (12 long-prompt requests,
concurrency 4):

| Metric        | Phase 1 | Phase 2 (prior) | Phase 2 + Path A |
|---------------|---------|-----------------|------------------|
| Succeeded     | 12/12   | 12/12           | 12/12            |
| Wall time     | 14.6s   | 25.0s           | 9.2s             |
| TTFT mean     | 444 ms  | 123 ms          | 87 ms            |
| TTFT p50      | 356 ms  | 117 ms          | 82 ms            |
| ITL p50       | 11 ms   | 25 ms           | 11 ms            |
| Throughput    | 233 t/s | 137 t/s         | 291 t/s          |

vs Phase 1: 5.1x faster TTFT, 1.25x faster throughput, 1.6x faster wall.

ITL parity (11 ms) confirms pure-decode batches now hit the captured
graphs. TTFT improvement (5x) confirms mixed-batch packing still works
when there's actually a prefill chunk to interleave.

All 67 chunked-prefill modular tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Removes 8 test files (4 modular + 4 integration) from the PR to keep
the diff focused on the actual code changes. Files are preserved at
/tmp/chunked_prefill_tests_backup/ for local development; they can be
re-added later if the team wants test coverage in-tree.

Removed:
  test/modular/test_chunked_prefill_unit.py
  test/modular/test_chunked_prefill_executor.py
  test/modular/test_chunked_prefill_scheduler.py
  test/modular/test_chunked_prefill_worker_queue.py
  test/integration/test_chunked_prefill_cuda_graph.py
  test/integration/test_chunked_prefill_equivalence.py
  test/integration/test_mixed_batch_correctness.py
  test/integration/test_thinker_step_multimodal.py

The end-to-end perf benchmark on Qwen3-Omni production stack remains
the primary validation: 12/12 succeed, 5x TTFT improvement, 1.25x
throughput vs Phase 1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread mminf/engine/ar_engine.py
input_ids = None

if inp.input_embeds is not None:
seq_axis = next(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if the input embeds ever have a batch dimension that has the same size as the sequence length this fails. This is very unlikely because the sequence length for chunked prefill is longer than the max batch size you'd expect, so this is just a nitpick. I guess we can always specify seq_axis in ARNodeInputs.

Also, we might want to have the same seq_axis logic for the input ids above.

# prefill chunks) keep the ``thinker_step`` walk, which is where
# Phase 2's mixed-batch packing actually pays off.
is_pure_decode = bool(plan.decode_rids) and not plan.prefill_allocations
batch_graph_walk = "thinker_decode" if is_pure_decode else "thinker_step"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to have the model declare the decode and chunked prefill walks instead of this string name matching; it makes it more generalizable to, e.g., having chunked prefill on both the thinker and the talker.

# Audio/vision prefills can't be chunked safely (sentinel-wrapped
# blocks). Mark them atomic so the planner skips them when budget
# is too small instead of partial-chunking.
atomic = walk in ("prefill_audio", "prefill_vision")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the model/submodule declare graph walks that are chunkable? If so, we should use that here instead of hardcoding graph walk names. Otherwise, we should have the submodule declare something to that extent.

Comment thread mminf/worker/worker.py
ar_engine is not None
and getattr(ar_engine, "scheduler_owns_chunking", False)
):
for edge in body.initial_inputs:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic works for the qwen talker but will break for any other model. If it doesn't take too long, maybe we can restructure things to use the already-existing submodule.prepare_inputs, which provides the sequence lengths baked-in. We would just have to make sure (1) it doesn't interfere with the async worker (I don't think it should but we should probably test that it doesn't slow anything down), and (2) we don't double-call prepare_inputs

Comment thread mminf/worker/worker.py
if not isinstance(t, torch.Tensor):
new_list.append(t)
continue
if name == "text_inputs":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works for all models so far but potentially not future-proof (e.g., there is a text input with a different name, or different sequence dimension). Is it possible to make all of this logic on the engine-side _slice_ar_inputs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants