diff --git a/configs/qwen3omni.yaml b/configs/qwen3omni.yaml index f18c645e..d2f07f43 100644 --- a/configs/qwen3omni.yaml +++ b/configs/qwen3omni.yaml @@ -1,5 +1,15 @@ model: "qwen3_omni" max_seq_len: 32768 +# Engine: chunked prefill. Splits long prefills into 512-token chunks. +# Set to null (or remove) to disable. Only applies to qwen3_omni Thinker +# (the LLM submodule) — other submodules opt in individually. +max_prefill_chunk_size: 512 +# Phase 2: scheduler-driven chunked prefill. When true, the MicroScheduler +# packs mixed batches (decodes + prefill chunks across requests) up to +# max_step_tokens. When false (default), the engine handles single-request +# chunking internally (Phase 1). +scheduler_owns_chunking: false +max_step_tokens: 2048 node_groups: - node_names: [audio_encoder, vision_encoder, Code2Wav] ranks: [0] diff --git a/mminf/conductor/request_info.py b/mminf/conductor/request_info.py index f8a57e4c..b65ac39a 100644 --- a/mminf/conductor/request_info.py +++ b/mminf/conductor/request_info.py @@ -77,6 +77,18 @@ class CurrentForwardPassInfo: loop_stop_times: dict[str, IterIndexTree] = field(default_factory=dict) dynamic_loop_iter_counts: dict[str, int] = field(default_factory=dict) + # chunked prefill progress. + # Set at request admission; advanced by the MicroScheduler each step + # as chunks complete. Derived `is_prefill_complete` gates the + # prefill→decode transition. Default values (0, 0) mean a request not + # in chunked-prefill mode. + prefill_tokens_total: int = 0 + prefill_tokens_consumed: int = 0 + + @property + def is_prefill_complete(self) -> bool: + return self.prefill_tokens_consumed >= self.prefill_tokens_total + def register_loop_stop(self, loop_name: str): self.dynamic_loop_stop_signals.add(loop_name) diff --git a/mminf/engine/ar_engine.py b/mminf/engine/ar_engine.py index 5cc65d8a..ac56eb8d 100644 --- a/mminf/engine/ar_engine.py +++ b/mminf/engine/ar_engine.py @@ -1,5 +1,6 @@ import logging from dataclasses import asdict, dataclass, field +from typing import Callable import torch @@ -34,6 +35,173 @@ class SubmoduleManagement: cuda_graph_runner: CudaGraphRunner | None = None +# ---------------------------------------------------------------------- +# Chunked-prefill orchestrator. +# +# Splits a single-request prefill batch into back-to-back forward passes +# of ``chunk_size`` tokens. The paged KV cache carries state across chunks +# via ``plan_attention(seq_lens=...)`` — no cache-side changes needed. +# Pure orchestration: stateless, depends only on ARNodeInputs and a +# caller-supplied ``inner_pass`` callable. +# ---------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ChunkSlice: + """One chunk of a single-request prefill, in token-axis coordinates.""" + index: int + start: int + end: int + is_last: bool + + +def _plan_chunks(seq_len: int, chunk_size: int) -> list[ChunkSlice]: + """Cover [0, seq_len) at ``chunk_size`` granularity. Last chunk may be shorter.""" + if seq_len <= 0: + raise ValueError(f"seq_len must be positive, got {seq_len}") + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + + plans: list[ChunkSlice] = [] + n_chunks = (seq_len + chunk_size - 1) // chunk_size + for i in range(n_chunks): + start = i * chunk_size + end = min(start + chunk_size, seq_len) + plans.append( + ChunkSlice(index=i, start=start, end=end, is_last=(i == n_chunks - 1)) + ) + return plans + + +def _slice_ar_inputs(inp: ARNodeInputs, start: int, end: int) -> ARNodeInputs: + """Return a new ARNodeInputs covering token range [start, end). + + Slices token-axis tensors (input_ids, input_embeds, custom_pos_ids). + tensor_inputs and kwargs are passed through by reference — they hold + non-token-axis state (e.g. flags) that the chunked path must not mutate. + + Per-tensor token-axis convention: + - ``input_ids``: shape ``(batch, seq)`` — slice dim 1. + - ``input_embeds``: shape varies by model (``[seq_len, hidden]`` for + qwen3_omni, ``[bs, seq_len, hidden]`` for others) — locate the seq + axis by matching ``inp.input_seq_len``; assert it is found. + - ``custom_pos_ids``: ``inp.input_seq_len`` lives on whichever axis + matches its size. qwen3_omni packs MRoPE as ``[3, seq_len]`` so + the token axis is the LAST one; plain text models use 1D. + """ + chunk_len = end - start + seq_len = inp.input_seq_len + + if inp.input_ids is not None: + input_ids = inp.input_ids[:, start:end] + else: + input_ids = None + + if inp.input_embeds is not None: + seq_axis = next( + (d for d in range(inp.input_embeds.dim()) if inp.input_embeds.shape[d] == seq_len), + None, + ) + assert seq_axis is not None, ( + f"input_embeds shape {tuple(inp.input_embeds.shape)} has no axis " + f"matching input_seq_len={seq_len}" + ) + input_embeds = inp.input_embeds.narrow(seq_axis, start, chunk_len) + else: + input_embeds = None + + def _slice_token(t: torch.Tensor) -> torch.Tensor: + token_axis = next( + (dim for dim in range(t.dim()) if t.shape[dim] == seq_len), + None, + ) + assert token_axis is not None, ( + f"tensor shape {tuple(t.shape)} has no axis matching input_seq_len={seq_len}" + ) + return t.narrow(token_axis, start, chunk_len) + + custom_pos_ids = inp.custom_pos_ids + if isinstance(custom_pos_ids, torch.Tensor): + custom_pos_ids = _slice_token(custom_pos_ids) + elif isinstance(custom_pos_ids, dict): + custom_pos_ids = {k: _slice_token(v) for k, v in custom_pos_ids.items()} + + return ARNodeInputs( + input_seq_len=chunk_len, + input_ids=input_ids, + input_embeds=input_embeds, + custom_pos_ids=custom_pos_ids, + # Aliased (not cloned): downstream must not mutate. + tensor_inputs=inp.tensor_inputs, + kwargs=inp.kwargs, + ) + + +def execute_chunked_prefill( + batch: NodeBatch, + node_inputs: list[ARNodeInputs], + chunk_size: int, + inner_pass: Callable[[NodeBatch, list[ARNodeInputs]], NodeOutput], + *, + enable_nvtx: bool = False, +) -> NodeOutput: + """Drive a single-request prefill as N forward passes of ``chunk_size`` tokens. + + ``inner_pass`` is the engine's existing one-pass dispatch (batched / + sequential / CUDA-graph). It is called once per chunk with a sliced + ARNodeInputs whose ``input_seq_len`` equals the chunk's token count. + The KV-cache manager (read inside ``inner_pass``) carries state across + calls via its existing ``plan_attention(seq_lens=...)`` semantics. + + Only the final chunk's NodeOutput is returned; intermediate outputs + are discarded. This matches the semantics of an unchunked prefill, + where the model produces sampled tokens / final-position logits only + once per request. + """ + if len(batch.request_ids) != 1: + raise ValueError( + f"execute_chunked_prefill requires a single-request batch, " + f"got {len(batch.request_ids)}" + ) + if len(node_inputs) != 1: + raise ValueError( + f"execute_chunked_prefill requires len(node_inputs) == 1, " + f"got {len(node_inputs)}" + ) + + inp = node_inputs[0] + plans = _plan_chunks(seq_len=inp.input_seq_len, chunk_size=chunk_size) + + if enable_nvtx: + range_push( + f"chunked_prefill rid={batch.request_ids[0]} " + f"walk={batch.graph_walk} total={inp.input_seq_len} " + f"chunks={len(plans)}", + synchronize=False, + ) + try: + last_output: NodeOutput | None = None + for plan in plans: + if enable_nvtx: + range_push( + f"chunk {plan.index}/{len(plans) - 1} " + f"[{plan.start}:{plan.end}] last={plan.is_last}", + synchronize=False, + ) + try: + chunk_inputs = [_slice_ar_inputs(inp, plan.start, plan.end)] + last_output = inner_pass(batch, chunk_inputs) + finally: + if enable_nvtx: + range_pop(synchronize=False) + finally: + if enable_nvtx: + range_pop(synchronize=False) + + assert last_output is not None + return last_output + + class AREngine(BaseEngine): """ Autoregressive engine with paged KV cache. @@ -49,6 +217,8 @@ def __init__( self, autocast_dtype=torch.bfloat16, enable_nvtx: bool = False, + max_prefill_chunk_size: int | None = None, + scheduler_owns_chunking: bool = False, ): super().__init__(enable_nvtx=enable_nvtx) @@ -57,6 +227,8 @@ def __init__( self.device = None self.autocast_dtype = autocast_dtype + self.max_prefill_chunk_size = max_prefill_chunk_size + self.scheduler_owns_chunking = scheduler_owns_chunking def engine_type(self) -> EngineType: return EngineType.AR @@ -179,7 +351,9 @@ def _compile_submodules(self) -> None: def warmup(self) -> None: """Compile submodules and capture CUDA graphs.""" from mminf.engine.cuda_graph_runner import ( - CudaGraphRunner, PiecewiseCudaGraphRunner, DEFAULT_AR_CAPTURE_BATCH_SIZES, + DEFAULT_AR_CAPTURE_BATCH_SIZES, + CudaGraphRunner, + PiecewiseCudaGraphRunner, ) for node_name, submodule_mgmt in self.submodule_management.items(): @@ -273,7 +447,8 @@ def _execute_batched( engine_inputs = ModelInputsFromEngine( request_ids=batch.request_ids, per_request_info=batch.per_request_info, - cache_manager=cache_manager + cache_manager=cache_manager, + is_terminal_per_request=batch.is_terminal_per_request, ) if self.enable_nvtx: range_push("ar.batched.preprocess", synchronize=False) @@ -310,8 +485,13 @@ def _execute_batched( sampled = sampler.sample(batch.request_ids, batched_logits) for rid, view in zip(batch.request_ids, sampled.split(1), strict=True): rid_out = batched_output[rid] - rid_out["new_token"] = [view] - del rid_out["logits"] + # skip new_token for non-terminal prefill chunks. Default + # empty is_terminal_per_request → all terminal (single-walk + # batches preserve their existing behavior). + if batch.is_terminal_per_request.get(rid, True): + rid_out["new_token"] = [view] + if "logits" in rid_out: + del rid_out["logits"] output = NodeOutput(per_request_output_tensors=batched_output) else: output = NodeOutput(per_request_output_tensors=batched_output) @@ -348,7 +528,10 @@ def _execute_sequential( per_request_info={ rid: batch.per_request_info[rid] }, - cache_manager=cache_manager + cache_manager=cache_manager, + is_terminal_per_request={ + rid: batch.is_terminal_per_request.get(rid, True) + } if batch.is_terminal_per_request else {}, ) if self.enable_nvtx: @@ -414,6 +597,63 @@ def _can_use_cuda_graph(self, batch: NodeBatch, inputs: list[ARNodeInputs]) -> b requires_cfg=has_cfg, ) + def _should_chunk_prefill( + self, + batch: NodeBatch, + inputs: list[ARNodeInputs], + submodule: ARNodeSubmodule, + ) -> bool: + """Decide whether to route this batch through the chunked-prefill path.""" + if self.scheduler_owns_chunking: + # scheduler is orchestrating chunks. Engine doesn't + # intervene — it just runs whatever (mixed) batch arrives. + return False + if self.max_prefill_chunk_size is None: + return False + if batch.graph_walk not in submodule.get_chunked_prefill_walks(): + return False + if len(batch.request_ids) != 1: + return False + if inputs[0].input_seq_len <= self.max_prefill_chunk_size: + return False + return True + + def _dispatch_one_pass( + self, + batch: NodeBatch, + submodule: ARNodeSubmodule, + node_inputs: list[ARNodeInputs], + allow_cuda_graph: bool = True, + ) -> NodeOutput: + """Run one forward pass via the existing CUDA-graph / batched / sequential priority. + + Extracted so the chunked-prefill orchestrator can call it once per + chunk. ``allow_cuda_graph=False`` is used for chunked-path callers. + """ + if allow_cuda_graph and self._can_use_cuda_graph(batch, node_inputs): + if self.enable_nvtx: + range_push("ar.cuda_graph_path", synchronize=False) + try: + return self._execute_with_cuda_graph(batch, submodule, node_inputs) + finally: + if self.enable_nvtx: + range_pop(synchronize=False) + if submodule.can_batch(batch, node_inputs): + if self.enable_nvtx: + range_push("ar.batched_path", synchronize=False) + try: + return self._execute_batched(batch, submodule, node_inputs) + finally: + if self.enable_nvtx: + range_pop(synchronize=False) + if self.enable_nvtx: + range_push("ar.sequential_path", synchronize=False) + try: + return self._execute_sequential(batch, submodule, node_inputs) + finally: + if self.enable_nvtx: + range_pop(synchronize=False) + def _execute_with_cuda_graph( self, batch: NodeBatch, submodule: ARNodeSubmodule, inputs: list[ARNodeInputs] @@ -441,6 +681,7 @@ def _execute_with_cuda_graph( inputs=inputs, per_request_info=batch.per_request_info, submodule=submodule, + is_terminal_per_request=batch.is_terminal_per_request, ) return NodeOutput(per_request_output_tensors=batched_output) @@ -500,37 +741,26 @@ def execute_batch(self, batch: NodeBatch) -> NodeOutput: ) ) - # Priority: CUDA graph > batched > sequential - if self._can_use_cuda_graph(batch, node_inputs): - if self.enable_nvtx: - range_push("ar.cuda_graph_path", synchronize=False) - try: - output = self._execute_with_cuda_graph( - batch, submodule, node_inputs - ) - finally: - if self.enable_nvtx: - range_pop(synchronize=False) - elif submodule.can_batch(batch, node_inputs): + if self._should_chunk_prefill(batch, node_inputs, submodule): if self.enable_nvtx: - range_push("ar.batched_path", synchronize=False) + range_push("ar.chunked_prefill_path", synchronize=False) try: - output = self._execute_batched( - batch, submodule, node_inputs + output = execute_chunked_prefill( + batch=batch, + node_inputs=node_inputs, + chunk_size=self.max_prefill_chunk_size, + inner_pass=lambda b, ins: self._dispatch_one_pass( + b, submodule, ins, allow_cuda_graph=False + ), + enable_nvtx=self.enable_nvtx, ) finally: if self.enable_nvtx: range_pop(synchronize=False) else: - if self.enable_nvtx: - range_push("ar.sequential_path", synchronize=False) - try: - output = self._execute_sequential( - batch, submodule, node_inputs - ) - finally: - if self.enable_nvtx: - range_pop(synchronize=False) + output = self._dispatch_one_pass( + batch, submodule, node_inputs, allow_cuda_graph=True + ) for rid, info in batch.per_request_info.items(): submodule.postprocess( request_id=rid, diff --git a/mminf/engine/base.py b/mminf/engine/base.py index 700dec0c..1cd5f2ce 100644 --- a/mminf/engine/base.py +++ b/mminf/engine/base.py @@ -31,6 +31,13 @@ class NodeBatch: # unused for now metadata: dict = field(default_factory=dict) + # per-request flag indicating whether this request's slice + # should produce sampled output this step. True for: decode tokens, + # last-chunk prefill (transitions to decode). False for: non-terminal + # prefill chunks (mid-prefill, skip lm_head + sampling). Default empty + # dict means "all terminal" (backwards compat with single-walk batches). + is_terminal_per_request: dict[str, bool] = field(default_factory=dict) + @dataclass class NodeOutput: diff --git a/mminf/engine/cache_manager.py b/mminf/engine/cache_manager.py index 272fd0bd..fbba1cbd 100644 --- a/mminf/engine/cache_manager.py +++ b/mminf/engine/cache_manager.py @@ -134,6 +134,7 @@ def plan_attention( is_causal=True, write_store: bool=True, label: str | None = None, + mode: str | None = None, ): """Pre-compute FlashInfer plan and page positions for a cache label. @@ -151,6 +152,9 @@ def plan_attention( dtype: query data type for FlashInfer. is_causal: whether attention is causal. label: cache label to plan for. If None, uses the current active label. + mode: Optional explicit "prefill" or "decode" hint. When None + (legacy callers), fall back to the seq_lens heuristic + (``all(sl == 1)`` -> decode). """ from mminf.utils.profiler import range_pop, range_push @@ -163,6 +167,7 @@ def plan_attention( is_causal=is_causal, write_store=write_store, label=label, + mode=mode, ) finally: if self.enable_nvtx: @@ -175,6 +180,7 @@ def _plan_attention_impl( is_causal=True, write_store: bool=True, label: str | None = None, + mode: str | None = None, ): from mminf.utils.profiler import range_pop, range_push @@ -267,7 +273,18 @@ def _plan_attention_impl( range_pop(synchronize=False) - is_decode = all([sl == 1 for sl in seq_lens]) + if mode is not None: + if mode not in ("prefill", "decode"): + raise ValueError( + f"plan_attention mode must be 'prefill' or 'decode', got {mode!r}" + ) + is_decode = (mode == "decode") + else: + # Legacy heuristic for callers that don't pass explicit mode. + # Note: unreliable for chunked-prefill last chunks of 1 token + # (the chunk is logically still prefill but every seq_len is 1, + # so the heuristic would incorrectly pick the decode wrapper). + is_decode = all([sl == 1 for sl in seq_lens]) ps = self._plan_states.get(effective_label) if ps is not None and ps.wrapper is not None: wrapper = ps.wrapper diff --git a/mminf/engine/cuda_graph_runner.py b/mminf/engine/cuda_graph_runner.py index 1089a810..bdd65c69 100644 --- a/mminf/engine/cuda_graph_runner.py +++ b/mminf/engine/cuda_graph_runner.py @@ -738,6 +738,7 @@ def run( inputs: list[ARNodeInputs], per_request_info: dict[str, CurrentForwardPassInfo], submodule: ARNodeSubmodule, + is_terminal_per_request: dict[str, bool] | None = None, ) -> dict: """Look up the matching captured graph and dispatch on config type. @@ -770,10 +771,12 @@ def run( if cfg_type == CudaGraphConfigType.BASIC_BATCHED: return self._run_basic_batched( key, graph_data, request_ids, inputs, per_request_info, submodule, + is_terminal_per_request=is_terminal_per_request, ) if cfg_type == CudaGraphConfigType.FLASH_INFER_PACKED: return self._run_flashinfer_packed( key, graph_data, request_ids, inputs, per_request_info, submodule, + is_terminal_per_request=is_terminal_per_request, ) raise ValueError(f"Unknown CudaGraphConfigType: {cfg_type}") @@ -785,6 +788,7 @@ def _run_basic_batched( inputs: list[ARNodeInputs], per_request_info: dict[str, CurrentForwardPassInfo], submodule: ARNodeSubmodule, + is_terminal_per_request: dict[str, bool] | None = None, ) -> dict: """Decode-style replay. Pads real inputs to padded_bs by cloning the capture template, then routes through submodule.preprocess (which re-plans attention @@ -914,6 +918,7 @@ def _run_basic_batched( graph_data=graph_data, submodule=submodule, inputs=inputs, + is_terminal_per_request=is_terminal_per_request, ) if self.enable_nvtx: range_pop(synchronize=False) @@ -940,6 +945,7 @@ def _run_flashinfer_packed( inputs: list[ARNodeInputs], per_request_info: dict[str, CurrentForwardPassInfo], submodule: ARNodeSubmodule, + is_terminal_per_request: dict[str, bool] | None = None, ) -> dict: """Prefill-style replay (vox-serve pattern). @@ -1075,6 +1081,7 @@ def _run_flashinfer_packed( graph_data=graph_data, submodule=submodule, inputs=inputs, + is_terminal_per_request=is_terminal_per_request, ) if self.enable_nvtx: range_pop(synchronize=False) @@ -1197,6 +1204,7 @@ def _sample_and_remap( graph_data: CudaGraphData, submodule: ARNodeSubmodule, inputs: list[ARNodeInputs] | None = None, + is_terminal_per_request: dict[str, bool] | None = None, ) -> dict: """Sample logits + copy non-logit per-rid outputs, remapping dummy → real rids. @@ -1225,8 +1233,11 @@ def _sample_and_remap( # Python reference — no .clone() needed. sampled = self.sampler.sample(request_ids, stacked_logits) sampled_views = sampled.split(1) + # skip new_token assignment for non-terminal prefill chunks. + # Default empty/None is_terminal_per_request → all terminal + terminal = is_terminal_per_request or {} outputs = { - rid: {"new_token": [view]} + rid: ({"new_token": [view]} if terminal.get(rid, True) else {}) for rid, view in zip(request_ids, sampled_views, strict=True) } @@ -1278,8 +1289,11 @@ def _sample_and_remap( if all_logits: stacked_logits = torch.cat(all_logits, dim=0) sampled = self.sampler.sample(request_ids, stacked_logits) + terminal = is_terminal_per_request or {} for i, rid in enumerate(request_ids): - outputs[rid] = {"new_token": [sampled[i:i+1]]} + outputs[rid] = ( + {"new_token": [sampled[i:i+1]]} if terminal.get(rid, True) else {} + ) else: for rid in request_ids: outputs[rid] = {} diff --git a/mminf/model/qwen3_omni/qwen3_omni_model.py b/mminf/model/qwen3_omni/qwen3_omni_model.py index 1f620895..c1394ae3 100644 --- a/mminf/model/qwen3_omni/qwen3_omni_model.py +++ b/mminf/model/qwen3_omni/qwen3_omni_model.py @@ -343,6 +343,41 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: outputs=[], ) + # -- mixed-batch walk: handles both prefill chunks and decode + # tokens of different requests in a single forward pass. The + # ThinkerSubmodule routes attention planning to FlashInfer's + # prefill wrapper (which handles arbitrary per-request seq_lens, + # including seq_len=1) and gates lm_head per-request based on + # ``NodeBatch.is_terminal_per_request`` so non-terminal prefill + # chunks skip sampling. The walk-level wiring mirrors prefill_text: + # a single GraphNode targeting the Thinker that consumes + # ``text_inputs`` and emits the same outputs (new_token + + # streaming thinker_states/thinker_mask) — the difference between + # walks lives entirely inside the submodule's preprocess + + # forward_batched. + thinker_step = GraphNode( + name="Thinker", + input_ids=["text_inputs"], + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name="new_token", + output_modality="text", + persist=True, + ), + StreamingGraphEdge( + next_node="Talker_LLM", + name="thinker_states", + target_partition="Talker", + ), + StreamingGraphEdge( + next_node="Talker_LLM", + name="thinker_mask", + target_partition="Talker", + ), + ], + ) + # -- Talker prefill: receives thinker_states + talker_trigger -- # Dual-input gating: both thinker_states from streaming and # talker_trigger from conductor cross-partition trigger must be @@ -444,6 +479,7 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: "prefill_audio": prefill_audio, "prefill_vision": prefill_vision, "thinker_decode": thinker_decode, + "thinker_step": thinker_step, "talker_prefill": talker_prefill, "talker_last_prefill": talker_last_prefill, "talker_decode": talker_decode, @@ -461,6 +497,7 @@ def get_partitions(self) -> list[PartitionDefinition]: graph_walks={ "prefill_text", "prefill_audio", "prefill_vision", "thinker_decode", + "thinker_step", }, initial_walk="prefill_text", producer_partitions=[], diff --git a/mminf/model/qwen3_omni/submodules.py b/mminf/model/qwen3_omni/submodules.py index 6a9b37d4..eaf55483 100644 --- a/mminf/model/qwen3_omni/submodules.py +++ b/mminf/model/qwen3_omni/submodules.py @@ -229,6 +229,9 @@ class ThinkerSubmodule(ARNodeSubmodule): # Default MRoPE section for head_dim=128: [24, 20, 20] MROPE_SECTION = [24, 20, 20] + def get_chunked_prefill_walks(self) -> list[str]: + return ["prefill_text"] + def __init__( self, thinker_model: nn.Module, @@ -341,6 +344,167 @@ def _wrap_vision_input(self, vision_embeds: torch.Tensor): self._vision_eos_embed ], dim=0) + def _prepare_decode_input( + self, inputs: NameToTensorList, start_pos: float, device: torch.device, + ) -> ARNodeInputs: + # Get previous token ID from text_inputs + token_id = inputs["text_inputs"][0].to(device) # (1,) or scalar + if token_id.dim() == 0: + token_id = token_id.unsqueeze(0) + embeds = self.model.model.embed_tokens(token_id) + + # Next MRoPE position for all 3 components: read from the + # per-request cache-manager state (kept in sync by the + # post-forward ``advance_seq_lens`` call in ``thinker.py``). + pos_ids = torch.tensor( + [[start_pos], [start_pos], [start_pos]], + dtype=torch.float, + device=device, + ) # (3, 1) + + return ARNodeInputs( + input_seq_len=1, + input_embeds=embeds, + custom_pos_ids=pos_ids, + tensor_inputs={ + "masks_for_talker": self._get_decode_thinker_mask(device) + } # no additional tensors for decode step + ) + + def _prepare_text_input( + self, inputs: NameToTensorList, start_pos: float, device: torch.device, + ) -> ARNodeInputs: + # Embed a text-only token span (prefill chunk or single decode token) + # and compute 3D MRoPE position IDs starting at start_pos. + text_ids = inputs["text_inputs"][0].to(device) # (seq_len,) + embeds = self.model.model.embed_tokens(text_ids) + seq_len = text_ids.shape[0] + + # Compute 3D MRoPE position IDs for a pure-text span. Each + # prefill graph walk is single-modality so we use the simple + # per-modality helper instead of the full HF parser. + # + # ``start_pos`` is the next MRoPE position for this request, + # carried forward across walks by ``state.position_id_start`` + # (advanced post-forward by ``advance_seq_lens``). + pos_ids = get_rope_index_text(seq_len, start_pos, device) + masks_for_talker = torch.stack([ + torch.zeros(text_ids.shape, dtype=torch.bool, device=device), # multimodal + self._get_talker_text_mask(text_ids) # text inclusion + ]) + return ARNodeInputs( + input_seq_len=seq_len, + input_embeds=embeds, + custom_pos_ids=pos_ids, + tensor_inputs={ + "masks_for_talker": masks_for_talker + } + ) + + def _prepare_audio_input( + self, inputs: NameToTensorList, start_pos: float, device: torch.device, + ) -> ARNodeInputs: + audio_embeds = inputs["audio_embeds"][0].to(device) # (audio_tokens, hidden) + audio_len = audio_embeds.shape[0] + + mm_mask = torch.ones(audio_len + 2, dtype=torch.bool, device=device) + mm_mask[[0, -1]] = 0 + masks_for_talker = torch.stack([ + mm_mask, + ~mm_mask + ]) + + wrapped_embeds = self._wrap_audio_input(audio_embeds) + seq_len = audio_len + 2 + # Position IDs: + # - audio_start_token: text-like position at start_pos + # - audio tokens: temporal increments per frame, + # h/w = start_pos (handled by helper) + # - audio_end_token: text-like position right after + start_pos_ids = get_rope_index_text(1, start_pos, device) + audio_pos_ids = get_rope_index_audio( + audio_len, + start_pos + 1, + device, + self.config.thinker.position_id_per_seconds, + ) + end_pos_ids = get_rope_index_text( + 1, start_pos + 1 + audio_len, device + ) + pos_ids = torch.cat( + [start_pos_ids, audio_pos_ids, end_pos_ids], dim=1 + ) + return ARNodeInputs( + input_seq_len=seq_len, + input_embeds=wrapped_embeds, + custom_pos_ids=pos_ids, + tensor_inputs={ + "masks_for_talker": masks_for_talker + } + ) + + def _prepare_vision_input( + self, inputs: NameToTensorList, start_pos: float, device: torch.device, + ) -> ARNodeInputs: + vision_embeds = inputs["vision_embeds"][0].to(device) + vision_len = vision_embeds.shape[0] + + mm_mask = torch.ones(vision_len + 2, dtype=torch.bool, device=device) + mm_mask[[0, -1]] = 0 + masks_for_talker = torch.stack([ + mm_mask, + ~mm_mask + ]) + + wrapped_embeds = self._wrap_vision_input(vision_embeds) + total_len = vision_len + 2 + # Vision tokens use spatial 3D positions (temporal constant, + # h/w from the spatial grid after merging). If a proper + # ``image_grid_thw`` is available, use ``get_rope_index_vision``; + # otherwise fall back to a 1-D sequence (test path without + # AutoImageProcessor). + grid_thw = inputs.get("image_grid_thw", [None])[0] + seconds_per_grid = inputs.get("video_second_per_grid", []) + seconds_per_grid = seconds_per_grid[0].item() if seconds_per_grid else None + vision_pos_ids = get_rope_index_vision( + grid_thw.to(device), + start_pos + 1, # leave room for the BOS token + position_id_per_seconds=self.config.thinker.position_id_per_seconds, + device=device, + spatial_merge_size=self.config.vision.spatial_merge_size, + seconds_per_grid=seconds_per_grid + ) + + # Sentinel token positions (text-like). + start_pos_ids = get_rope_index_text(1, start_pos, device) + end_pos_base = float(vision_pos_ids.max().item()) + 1 + end_pos_ids = get_rope_index_text(1, end_pos_base, device) + + pos_ids = torch.cat( + [start_pos_ids, vision_pos_ids, end_pos_ids], dim=1 + ) + + # Next MRoPE position after this vision block is ``end_pos_base + # + 1`` (one past the EOS token). ``advance_seq_lens`` by + # default advances ``position_id_start`` by ``seq_len``, which + # for vision (= vision_len + 2) is typically smaller than the + # 3D-grid span. Emit the correct per-request advance so the + # Thinker forward can pass ``pos_id_ns`` through. + mrope_pos_advance = int(end_pos_base + 1 - start_pos) + deepstack = inputs["deepstack"] + + return ARNodeInputs( + input_seq_len=total_len, + input_embeds=wrapped_embeds, + custom_pos_ids=pos_ids, + tensor_inputs={ + "masks_for_talker": masks_for_talker, + "mrope_pos_advance": mrope_pos_advance, + "deepstack": deepstack, + "visual_pos_masks": mm_mask + } + ) + def prepare_inputs( self, graph_walk: str, @@ -351,155 +515,33 @@ def prepare_inputs( device = self.get_device() start_pos = pos_info.get("main", PositionInfo()).position_id_start if graph_walk == "thinker_decode": - # Get previous token ID from text_inputs - token_id = inputs["text_inputs"][0].to(device) # (1,) or scalar - if token_id.dim() == 0: - token_id = token_id.unsqueeze(0) - embeds = self.model.model.embed_tokens(token_id) - - # Next MRoPE position for all 3 components: read from the - # per-request cache-manager state (kept in sync by the - # post-forward ``advance_seq_lens`` call in ``thinker.py``). - pos_ids = torch.tensor( - [[start_pos], [start_pos], [start_pos]], - dtype=torch.float, - device=device, - ) # (3, 1) - - return ARNodeInputs( - input_seq_len=1, - input_embeds=embeds, - custom_pos_ids=pos_ids, - tensor_inputs={ - "masks_for_talker": self._get_decode_thinker_mask(device) - } # no additional tensors for decode step - ) + return self._prepare_decode_input(inputs, start_pos, device) + + if graph_walk == "thinker_step": + # ``thinker_step`` is the mixed-batch walk where + # each rid contributes a slice of its own modality + # (text-prefill chunk, decode token, or atomic audio/vision + # prefill). Dispatch by per-rid input keys to the right + # modality prep helper. ``forward_batched`` still routes the + # whole batch through its ``is_thinker_step`` branch — only + # the per-rid input embedding/position-id construction differs + # by modality. Audio/vision prefills cannot be chunked (their + # start/end sentinel wrappers are atomic), so they appear as + # a single non-chunked rid in the batch. + if "audio_embeds" in inputs: + return self._prepare_audio_input(inputs, start_pos, device) + if "vision_embeds" in inputs: + return self._prepare_vision_input(inputs, start_pos, device) + return self._prepare_text_input(inputs, start_pos, device) if graph_walk == "prefill_text": - text_ids = inputs["text_inputs"][0].to(device) # (seq_len,) - embeds = self.model.model.embed_tokens(text_ids) - seq_len = text_ids.shape[0] - - # Compute 3D MRoPE position IDs for a pure-text span. Each - # prefill graph walk is single-modality so we use the simple - # per-modality helper instead of the full HF parser. - # - # ``start_pos`` is the next MRoPE position for this request, - # carried forward across walks by ``state.position_id_start`` - # (advanced post-forward by ``advance_seq_lens``). - pos_ids = get_rope_index_text(seq_len, start_pos, device) - masks_for_talker = torch.stack([ - torch.zeros(text_ids.shape, dtype=torch.bool, device=device), # multimodal - self._get_talker_text_mask(text_ids) # text inclusion - ]) - return ARNodeInputs( - input_seq_len=seq_len, - input_embeds=embeds, - custom_pos_ids=pos_ids, - tensor_inputs={ - "masks_for_talker": masks_for_talker - } - ) + return self._prepare_text_input(inputs, start_pos, device) if graph_walk == "prefill_audio": - audio_embeds = inputs["audio_embeds"][0].to(device) # (audio_tokens, hidden) - audio_len = audio_embeds.shape[0] - - mm_mask = torch.ones(audio_len + 2, dtype=torch.bool, device=device) - mm_mask[[0, -1]] = 0 - masks_for_talker = torch.stack([ - mm_mask, - ~mm_mask - ]) - - wrapped_embeds = self._wrap_audio_input(audio_embeds) - seq_len = audio_len + 2 - # Position IDs: - # - audio_start_token: text-like position at start_pos - # - audio tokens: temporal increments per frame, - # h/w = start_pos (handled by helper) - # - audio_end_token: text-like position right after - start_pos_ids = get_rope_index_text(1, start_pos, device) - audio_pos_ids = get_rope_index_audio( - audio_len, - start_pos + 1, - device, - self.config.thinker.position_id_per_seconds, - ) - end_pos_ids = get_rope_index_text( - 1, start_pos + 1 + audio_len, device - ) - pos_ids = torch.cat( - [start_pos_ids, audio_pos_ids, end_pos_ids], dim=1 - ) - return ARNodeInputs( - input_seq_len=seq_len, - input_embeds=wrapped_embeds, - custom_pos_ids=pos_ids, - tensor_inputs={ - "masks_for_talker": masks_for_talker - } - ) + return self._prepare_audio_input(inputs, start_pos, device) if graph_walk == "prefill_vision": - vision_embeds = inputs["vision_embeds"][0].to(device) - vision_len = vision_embeds.shape[0] - - mm_mask = torch.ones(vision_len + 2, dtype=torch.bool, device=device) - mm_mask[[0, -1]] = 0 - masks_for_talker = torch.stack([ - mm_mask, - ~mm_mask - ]) - - wrapped_embeds = self._wrap_vision_input(vision_embeds) - total_len = vision_len + 2 - # Vision tokens use spatial 3D positions (temporal constant, - # h/w from the spatial grid after merging). If a proper - # ``image_grid_thw`` is available, use ``get_rope_index_vision``; - # otherwise fall back to a 1-D sequence (test path without - # AutoImageProcessor). - grid_thw = inputs.get("image_grid_thw", [None])[0] - seconds_per_grid = inputs.get("video_second_per_grid", []) - seconds_per_grid = seconds_per_grid[0].item() if seconds_per_grid else None - vision_pos_ids = get_rope_index_vision( - grid_thw.to(device), - start_pos + 1, # leave room for the BOS token - position_id_per_seconds=self.config.thinker.position_id_per_seconds, - device=device, - spatial_merge_size=self.config.vision.spatial_merge_size, - seconds_per_grid=seconds_per_grid - ) - - # Sentinel token positions (text-like). - start_pos_ids = get_rope_index_text(1, start_pos, device) - end_pos_base = float(vision_pos_ids.max().item()) + 1 - end_pos_ids = get_rope_index_text(1, end_pos_base, device) - - pos_ids = torch.cat( - [start_pos_ids, vision_pos_ids, end_pos_ids], dim=1 - ) - - # Next MRoPE position after this vision block is ``end_pos_base - # + 1`` (one past the EOS token). ``advance_seq_lens`` by - # default advances ``position_id_start`` by ``seq_len``, which - # for vision (= vision_len + 2) is typically smaller than the - # 3D-grid span. Emit the correct per-request advance so the - # Thinker forward can pass ``pos_id_ns`` through. - mrope_pos_advance = int(end_pos_base + 1 - start_pos) - deepstack = inputs["deepstack"] - - return ARNodeInputs( - input_seq_len=total_len, - input_embeds=wrapped_embeds, - custom_pos_ids=pos_ids, - tensor_inputs={ - "masks_for_talker": masks_for_talker, - "mrope_pos_advance": mrope_pos_advance, - "deepstack": deepstack, - "visual_pos_masks": mm_mask - } - ) + return self._prepare_vision_input(inputs, start_pos, device) def preprocess( self, @@ -529,12 +571,19 @@ def preprocess( target_dtype=input_embeds.dtype, ) - # Plan FlashInfer attention and rope for the main cache label + # Plan FlashInfer attention and rope for the main cache label. + # Explicit mode prevents the chunked-prefill last chunk (seq_len=1 + # per request) from being misclassified as decode by the seq_lens + # heuristic. ``thinker_step`` mixes decode (seq_len=1) and prefill + # (seq_len>=1) rids in one batch; routing to mode="prefill" picks + # FlashInfer's prefill wrapper, which handles arbitrary per-request + # seq_lens including seq_len=1. cache_manager = engine_inputs.cache_manager cache_manager.set_active_label("main") assert cache_manager is not None + mode = "decode" if graph_walk == "thinker_decode" else "prefill" cache_manager.plan_attention( - seq_lens=seq_lens, is_causal=True, label="main" + seq_lens=seq_lens, is_causal=True, label="main", mode=mode ) cache_manager.plan_rope(seq_lens=seq_lens, pos_ids=None, label="main") @@ -587,6 +636,12 @@ def forward( ``True`` for backwards compatibility with callers that do not set the flag (e.g. unit tests). """ + assert graph_walk != "thinker_step", ( + "thinker_step walk should always go through forward_batched, never " + "the eager path. If can_batch returns False for thinker_step in the " + "future, extend forward to mirror forward_batched's per-rid lm_head " + "gating logic." + ) request_info = engine_inputs.single_request_info audio_output = request_info.step_metadata.get( "audio_output", True, @@ -634,10 +689,12 @@ def forward( # ---- batching ---- def can_batch(self, batch: NodeBatch, model_inputs: list[NodeInputs]) -> bool: - return batch.graph_walk == "thinker_decode" + return batch.graph_walk in ("thinker_decode", "thinker_step") PREFILL_TOKEN_BUCKETS = [128, 256, 512, 1024, 2048] - PREFILL_CAPTURE_BATCH_SIZES = [1, 2, 4] + # bs=8 covers the typical thinker_step mixed-batch shape (4-7 decodes + # + 1 prefill chunk); below it batches fall through to eager. + PREFILL_CAPTURE_BATCH_SIZES = [1, 2, 4, 8] def _build_prefill_text_packed( self, num_tokens: int, device: torch.device, @@ -724,7 +781,7 @@ def get_cuda_graph_configs(self, device: torch.device): ), FlashInferPackedCudaGraphConfig( capture_graph_walk="prefill_text", - replay_graph_walks=["prefill_text", "prefill_audio"], + replay_graph_walks=["prefill_text", "prefill_audio", "thinker_step"], packed_seq_len_to_inputs=prefill_text_packed, requires_cfg=False, labels=["main"], @@ -763,6 +820,7 @@ def forward_batched( mrope_section: list[int] | None = None, mrope_pos_advance: list[int] | None = None, masks_for_talker: dict[str, torch.Tensor] | None = None, + seq_lens: list[int] | None = None, **kwargs, ) -> dict[str, NameToTensorList]: """Batched Thinker forward shared between ``thinker_decode`` and the prefill walks. @@ -796,8 +854,22 @@ def forward_batched( NOT included — its preprocess emits ``deepstack`` / ``visual_pos_masks`` / ``mrope_pos_advance`` extras that the model forward also consumes; it is kept on the eager path. + + ``thinker_step`` (mixed-batch walk, eager-only): + The batch carries a mix of decode tokens (seq_len=1) and prefill + chunks (seq_len>=1). Emits ``__batched_logits__`` (single + ``(bs, V)`` tensor) at the top level regardless of terminal-flag + distribution so the output dict shape is fixed across batches — + a precondition for CUDA graph capture. Per-rid dicts contain + ONLY ``thinker_states`` (and optionally ``thinker_mask``); + per-rid ``new_token`` assignment + non-terminal filtering moved + to ``AREngine._execute_batched``'s batched-logits sampling fast + path, which consults ``is_terminal_per_request`` to skip + sampling for non-terminal prefill chunks. """ - assert graph_walk in ("thinker_decode", "prefill_text", "prefill_audio") + assert graph_walk in ( + "thinker_decode", "prefill_text", "prefill_audio", "thinker_step", + ) # Packed dict from FlashInferPackedCudaGraphConfig is tensor-only by # design (the runner's static-buffer interning skips non-tensor @@ -805,7 +877,8 @@ def forward_batched( # class constant when the kwarg is missing. Decode goes through # preprocess which does pass it explicitly. is_prefill = graph_walk in ("prefill_text", "prefill_audio") - if mrope_section is None and is_prefill: + is_thinker_step = graph_walk == "thinker_step" + if mrope_section is None and (is_prefill or is_thinker_step): mrope_section = self.MROPE_SECTION cos_sin_3d = (cos_3d, sin_3d) if cos_3d is not None else None @@ -840,6 +913,74 @@ def forward_batched( "__batched_thinker_states__": thinker_states, } + if is_thinker_step: + # Mixed prefill + decode batch. Emit __batched_logits__ at the + # top level regardless of terminal-flag distribution so the + # output shape is fixed (CUDA graph capture precondition). + # Per-request gating of new_token assignment moves to the + # engine's batched-logits sampling fast path. + # + # seq_lens comes from preprocess (one entry per request, each + # request's contiguous slice in `hidden`). + assert seq_lens is not None, ( + "thinker_step requires seq_lens from preprocess to compute " + "per-request last-token indices." + ) + request_ids = cache_manager.request_ids + assert len(request_ids) == len(seq_lens), ( + f"thinker_step: request_ids ({len(request_ids)}) and " + f"seq_lens ({len(seq_lens)}) length mismatch" + ) + + # Compute last-token-per-request indices from cumulative seq_lens + # and run lm_head on the gathered last-token hidden states. This + # mirrors the prefill branch's qo_indptr-based gather pattern but + # uses the engine-provided seq_lens (thinker_step is eager-only; + # no qo_indptr_buf static buffer is required here). + seq_lens_t = torch.as_tensor( + seq_lens, dtype=torch.long, device=hidden.device, + ) + last_token_indices = torch.cumsum(seq_lens_t, dim=0) - 1 + last_hidden = hidden.index_select(0, last_token_indices) + batched_logits = self.model.lm_head(last_hidden) # (bs, vocab) + + # Pack thinker_states once for the whole batch (per-request + # slicing happens outside this function; non-audio rids are + # filtered out there as well). + if layer_n_hidden is not None: + thinker_states = torch.cat( + [layer_0_embed, layer_n_hidden], dim=-1, + ) + else: + thinker_states = torch.cat( + [layer_0_embed, layer_0_embed], dim=-1, + ) + + outputs: dict[str, NameToTensorList] = {} + cum = 0 + for rid, sl in zip(request_ids, seq_lens, strict=True): + slice_start, slice_end = cum, cum + sl + cum = slice_end + + req_out: NameToTensorList = {} + # Always emit thinker_states per-rid (Talker conditioning is + # independent of sampling — it consumes the full slice for + # every request, terminal or not). NEVER emit per-rid + # logits or new_token here — the engine's batched-logits + # sampling fast path owns that, gated on + # is_terminal_per_request. + req_out["thinker_states"] = [ + thinker_states[slice_start:slice_end] + ] + if masks_for_talker is not None and rid in masks_for_talker: + mask = masks_for_talker[rid] + if mask is not None: + req_out["thinker_mask"] = [mask] + + outputs[rid] = req_out + outputs["__batched_logits__"] = batched_logits + return outputs + # thinker_decode (existing behavior) logits = self.model.lm_head(hidden) # (batch, vocab) diff --git a/mminf/model/submodule_base.py b/mminf/model/submodule_base.py index 5bdeda22..a30eea59 100644 --- a/mminf/model/submodule_base.py +++ b/mminf/model/submodule_base.py @@ -134,6 +134,14 @@ class ModelInputsFromEngine: per_request_info: dict[str, CurrentForwardPassInfo] cache_manager: BatchedCacheManager | None = None + # Chunked-prefill: per-request terminal flag carried over from + # ``NodeBatch.is_terminal_per_request``. True means this request's slice + # should produce sampled output this step (decode token OR final prefill + # chunk that transitions to decode); False means it's a non-terminal + # prefill chunk and lm_head/sampling should be skipped. Default empty + # dict means "all terminal" — backwards compat with non-mixed batches. + is_terminal_per_request: dict[str, bool] = field(default_factory=dict) + @property def single_request_info(self): """ @@ -141,7 +149,7 @@ def single_request_info(self): """ assert len(self.per_request_info) == 1 return self.per_request_info[self.request_ids[0]] - + @property def first_request_info(self): """ @@ -230,14 +238,23 @@ def can_use_cuda_graphs( """Return True if this submodule supports CUDA graphs for ``batch``. Default: derives from ``get_cuda_graph_configs`` — if the submodule - declared a capture for this batch's graph_walk, CUDA graphs are - supported. Subclasses can override to reject on batch shape / - metadata (e.g. codec submodules that need homogeneous frame counts). + declared a capture (or replay alias) for this batch's graph_walk, + CUDA graphs are supported. Subclasses can override to reject on + batch shape / metadata (e.g. codec submodules that need + homogeneous frame counts). + + Walk eligibility: a walk is eligible if it appears in EITHER + ``capture_graph_walk`` (the walk a graph was captured under) OR + ``replay_graph_walks`` (additional walks that share the same + captured graph — e.g. ``prefill_audio`` and ``thinker_step`` + replay the ``prefill_text`` capture). """ if not hasattr(self, "_cached_cuda_graph_walks"): - self._cached_cuda_graph_walks = { - cfg.capture_graph_walk for cfg in self.get_cuda_graph_configs(device=torch.device("cpu")) - } + walks: set[str] = set() + for cfg in self.get_cuda_graph_configs(device=torch.device("cpu")): + walks.add(cfg.capture_graph_walk) + walks.update(cfg.replay_graph_walks) + self._cached_cuda_graph_walks = walks return batch.graph_walk in self._cached_cuda_graph_walks def postprocess( @@ -293,6 +310,22 @@ def cleanup_request(self, request_id: str): """Remove per-request state when a request completes.""" return + def get_chunked_prefill_walks(self) -> list[str]: + """Return the graph walks for which this submodule's forward tolerates chunking. + + For each walk in the returned list, AREngine may split a + single-request prefill into multiple forward passes of + ``max_prefill_chunk_size`` tokens each, with KV cache state carried + across via the existing paged cache manager. + + Default empty list — submodules must opt in per walk. Walks whose + inputs aren't sliceable along the token axis (e.g. fixed image-token + blocks emitted by an encoder, sentinel-wrapped audio/vision embeds) + must be omitted. Mirrors the per-walk eligibility pattern used by + ``can_use_cuda_graphs`` / ``get_cuda_graph_configs``. + """ + return [] + class ARNodeSubmodule(NodeSubmodule): @abstractmethod diff --git a/mminf/worker/engine_manager.py b/mminf/worker/engine_manager.py index ebe29618..b4ed0b31 100644 --- a/mminf/worker/engine_manager.py +++ b/mminf/worker/engine_manager.py @@ -65,10 +65,18 @@ def build( for engine_type_str, engine_node_names in type_to_nodes.items(): engine_cls = ENGINE_TYPE_TO_CLASS[engine_type_str] - engine = engine_cls( + engine_kwargs = dict( autocast_dtype=autocast_dtype, enable_nvtx=enable_nvtx, ) + if engine_cls is AREngine: + engine_kwargs["max_prefill_chunk_size"] = model_config.get( + "max_prefill_chunk_size" + ) + engine_kwargs["scheduler_owns_chunking"] = model_config.get( + "scheduler_owns_chunking", False + ) + engine = engine_cls(**engine_kwargs) # Extract submodules from the Model for this engine's nodes submodules: dict[str, torch.nn.Module] = {} diff --git a/mminf/worker/micro_scheduler.py b/mminf/worker/micro_scheduler.py index 01846946..57326a2d 100644 --- a/mminf/worker/micro_scheduler.py +++ b/mminf/worker/micro_scheduler.py @@ -1,8 +1,9 @@ import logging import time -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum +from mminf.conductor.request_info import CurrentForwardPassInfo from mminf.engine.base import EngineType from mminf.graph.base import GraphNode from mminf.worker.engine_manager import EngineManager @@ -28,6 +29,109 @@ class ScheduledBatch: # request_id -> worker_graph_id (for push-back on OOM) request_to_worker_graph: dict[str, str] = None + # Per-rid: should this request's slice produce sampled output this + # step? Empty dict means "all terminal" (no mid-prefill rids in batch). + is_terminal_per_request: dict[str, bool] = field(default_factory=dict) + + # Per-rid chunk size for in-flight prefill chunks. Empty dict means + # "no chunked prefill in this batch" — slicing and consumed-token + # advancement are skipped on the worker side. + prefill_chunk_sizes: dict[str, int] = field(default_factory=dict) + + +# ---------------------------------------------------------------------- +# Chunked-prefill mixed-batch packing. +# +# Decode-first packing under a per-step token budget. Each decode is 1 +# token; prefill chunks fill remaining budget. If a prefill's remaining +# tokens fit in budget, that chunk is "terminal" — the request transitions +# to decode after this step, so we sample its output. Non-terminal chunks +# skip lm_head + sampling. +# ---------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DecodeReadyRequest: + """A request that has 1 token to decode this step.""" + + rid: str + + +@dataclass(frozen=True) +class PrefillReadyRequest: + """A request with chunked prefill in progress.""" + + rid: str + tokens_remaining: int + # If True, must be packed in full this step or deferred to the next step. + # Audio and vision prefills are atomic — sentinel wrappers prevent slicing + # through the embedding block. False (default) for chunkable text prefills. + atomic: bool = False + + +@dataclass +class ChunkedStepPlan: + """The scheduler's verdict for one mixed-batch step. + + decode_rids: requests that should each contribute 1 token (decode). + prefill_allocations: rid → number of tokens to feed this step. + terminal_prefills: rids whose prefill completes this step (last chunk). + These need lm_head + sampling to produce the first decode token. + """ + + decode_rids: list[str] = field(default_factory=list) + prefill_allocations: dict[str, int] = field(default_factory=dict) + terminal_prefills: set[str] = field(default_factory=set) + + @property + def total_tokens(self) -> int: + return len(self.decode_rids) + sum(self.prefill_allocations.values()) + + +def plan_chunked_step( + ready_decodes: list[DecodeReadyRequest], + ready_prefills: list[PrefillReadyRequest], + max_step_tokens: int, +) -> ChunkedStepPlan: + """Pack one step under the token budget. + + Decode-first because each decode is 1 token; running them keeps tail + latency stable. Prefill fills remaining budget. If a prefill request's + remaining tokens fit in the budget, the chunk is terminal (transitions + the request to decode after this step). + """ + if max_step_tokens <= 0: + raise ValueError(f"max_step_tokens must be positive, got {max_step_tokens}") + + plan = ChunkedStepPlan() + budget = max_step_tokens + + # Decodes first. + for req in ready_decodes: + if budget <= 0: + break + plan.decode_rids.append(req.rid) + budget -= 1 + + # Prefill fills remaining budget. + for req in ready_prefills: + if budget <= 0: + break + if req.tokens_remaining <= 0: + continue + if req.atomic and req.tokens_remaining > budget: + # Atomic prefill doesn't fit this step's remaining budget; + # defer to a later step. Don't partial-chunk (would break + # multimodal sentinel wrappers). + continue + chunk = min(req.tokens_remaining, budget) + plan.prefill_allocations[req.rid] = chunk + if chunk == req.tokens_remaining: + plan.terminal_prefills.add(req.rid) + budget -= chunk + + return plan + # Priority: lower value = higher priority # AR decode is most latency-sensitive @@ -55,7 +159,8 @@ class MicroScheduler: def __init__( self, engine_manager: EngineManager, - sched_type=SchedulingType.ROUND_ROBIN + sched_type=SchedulingType.ROUND_ROBIN, + max_step_tokens: int = 2048, ): self.engine_manager = engine_manager self.batch_number = 0 @@ -64,6 +169,10 @@ def __init__( # request_id -> monotonic time until which the request is held self.held_until: dict[str, float] = {} + # Only consulted when an AR engine has scheduler_owns_chunking=True; + # otherwise the existing single-walk batching path is used. + self.max_step_tokens = max_step_tokens + def _select_node_priority( self, node_name_to_requests: dict[str, list[ReadyNodeEntry]] ): @@ -117,6 +226,173 @@ def hold_requests(self, request_ids: list[str]) -> None: for rid in request_ids: self.held_until[rid] = deadline + # ------------------------------------------------------------------ + # Chunked-prefill mixed-batch packing. + # ------------------------------------------------------------------ + + def _ar_engine_owns_chunking(self) -> bool: + """True iff this scheduler should pack mixed thinker_step batches. + + The flag lives on the AREngine. Non-AR-only workers (e.g. Talker / + Code2Wav) return False and use the single-walk batching path. + """ + ar_engine = self.engine_manager.get_ar_engine() + if ar_engine is None: + return False + return getattr(ar_engine, "scheduler_owns_chunking", False) + + def _get_chunked_step_batch( + self, + worker_graphs_manager: WorkerGraphsManager, + target_node_name: str | None = None, + exclude_target: tuple[str, str] | None = None, + ) -> ScheduledBatch | None: + """Pack a single ``thinker_step`` batch from ready AR-engine requests. + + Walks every ready AR node, classifying each request as decode-ready + (``is_prefill_complete=True``) or prefill-ready (mid-chunked-prefill). + Calls ``plan_chunked_step`` with the worker's max-step budget, then + pops the popped nodes' GraphNodes and returns a single ``ScheduledBatch`` + whose ``graph_walk`` is ``thinker_step`` and whose + ``is_terminal_per_request`` map encodes the plan. + + Returns None when no AR requests are ready (caller falls back to the + non-chunked scheduling path). + + The per-request prompt-token slicing and post-step + ``prefill_tokens_consumed`` advance are handled separately on the + worker side; this method only produces the batch + metadata. + """ + now = time.monotonic() + # Expire stale hold entries (mirrors get_next_batch). + self.held_until = { + rid: t for rid, t in self.held_until.items() if t > now + } + + # rid -> (worker_graph_id, node_name, graph_walk, fwd_info) + ready: dict[str, tuple[str, str, str, CurrentForwardPassInfo]] = {} + + for worker_graph_id, queue in worker_graphs_manager.queues.items(): + ready_map = queue.get_ready_node_names() + for request_id, node_names in ready_map.items(): + if request_id not in worker_graphs_manager.per_request_info: + continue + if request_id in self.held_until: + continue + for sname in node_names: + if target_node_name is not None and sname != target_node_name: + continue + if sname not in self.engine_manager.node_to_engine: + continue + engine = self.engine_manager.get_engine(sname) + if engine.engine_type() != EngineType.AR: + continue + node_partition = worker_graphs_manager.get_partition_for_node(sname) + graph_walk = worker_graphs_manager.get_graph_walk( + request_id, node_partition + ) + if exclude_target is not None and (sname, graph_walk) == exclude_target: + continue + fwd_info = worker_graphs_manager.get_fwd_info(request_id, node_partition) + if not engine.check_ready(sname, request_id, fwd_info): + continue + # Take the first eligible (rid, node_name) pair per request. + if request_id not in ready: + ready[request_id] = (worker_graph_id, sname, graph_walk, fwd_info) + + if not ready: + return None + + # Classify each ready request. + decode_ready: list[DecodeReadyRequest] = [] + prefill_ready: list[PrefillReadyRequest] = [] + for rid, (_wg_id, _sname, walk, fwd_info) in ready.items(): + if fwd_info.is_prefill_complete: + decode_ready.append(DecodeReadyRequest(rid=rid)) + else: + tokens_remaining = max( + 0, + fwd_info.prefill_tokens_total - fwd_info.prefill_tokens_consumed, + ) + # Audio/vision prefills can't be chunked safely (sentinel-wrapped + # blocks). Mark them atomic so the planner skips them when budget + # is too small instead of partial-chunking. + atomic = walk in ("prefill_audio", "prefill_vision") + prefill_ready.append( + PrefillReadyRequest( + rid=rid, tokens_remaining=tokens_remaining, atomic=atomic + ) + ) + + plan = plan_chunked_step(decode_ready, prefill_ready, self.max_step_tokens) + if plan.total_tokens == 0: + return None + + # Build the unified batch. Order: decodes first, then prefills. + batch_rids = list(plan.decode_rids) + list(plan.prefill_allocations.keys()) + node_objects: dict[str, GraphNode] = {} + request_to_worker_graph: dict[str, str] = {} + is_terminal_per_request: dict[str, bool] = {} + prefill_chunk_sizes: dict[str, int] = {} + + # Pop ready nodes for each rid; choose the same node name across rids + # (the scheduler's _select_node helpers normally enforce this; here we + # accept whatever node was ready since all are AR. In practice on a + # qwen3-omni-style worker the AR node is "Thinker" for all rids.) + node_name_for_batch: str | None = None + for rid in batch_rids: + wg_id, sname, _walk, _fwd = ready[rid] + queue = worker_graphs_manager.queues[wg_id] + popped = queue.pop_ready_nodes(rid, [sname]) + if not popped: + continue + assert len(popped) == 1 + node_objects[rid] = popped[0] + request_to_worker_graph[rid] = wg_id + if node_name_for_batch is None: + node_name_for_batch = sname + + if rid in plan.decode_rids: + is_terminal_per_request[rid] = True + else: + # prefill chunk: terminal iff this is the last chunk + is_terminal_per_request[rid] = rid in plan.terminal_prefills + prefill_chunk_sizes[rid] = plan.prefill_allocations[rid] + + if not node_objects or node_name_for_batch is None: + return None + + logger.debug( + "MicroScheduler chunked-step: node=%s rids=%d decodes=%d prefills=%d budget=%d", + node_name_for_batch, len(node_objects), + len(plan.decode_rids), len(plan.prefill_allocations), + self.max_step_tokens, + ) + # Pure-decode batches use the dedicated ``thinker_decode`` walk so + # the existing ``(bs, num_tokens=bs)`` decode CUDA-graph captures + # fire. ``thinker_step`` captures are prefill-shaped + # (num_tokens >= 128) and don't match a pure-decode batch's + # num_tokens=bs*1 — falling back to eager would cost ~2x per-token + # latency vs the decode captures. Mixed batches (decodes + + # prefill chunks) keep the ``thinker_step`` walk, which is where + # Phase 2's mixed-batch packing actually pays off. + is_pure_decode = bool(plan.decode_rids) and not plan.prefill_allocations + batch_graph_walk = "thinker_decode" if is_pure_decode else "thinker_step" + + self.batch_number += 1 + self.node_and_walk_to_last_batch_num[( + node_name_for_batch, batch_graph_walk + )] = self.batch_number + + return ScheduledBatch( + node_name=node_name_for_batch, + graph_walk=batch_graph_walk, + node_objects=node_objects, + request_to_worker_graph=request_to_worker_graph, + is_terminal_per_request=is_terminal_per_request, + prefill_chunk_sizes=prefill_chunk_sizes, + ) + def get_next_batch( self, worker_graphs_manager: WorkerGraphsManager, @@ -136,6 +412,25 @@ def get_next_batch( target_graph_walk: If set, only schedule this graph walk. exclude_target: If set, skip this (node_name, graph_walk) pair. """ + # When the AR engine has opted into scheduler-driven chunking, + # dispatch through the mixed-batch packer first. None ⇒ AR queue + # empty this tick — fall through so non-AR engines still schedule. + # ``target_graph_walk`` skips this path so callers explicitly + # asking for a specific walk get the single-walk batching semantics. + if ( + target_graph_walk is None + and self._ar_engine_owns_chunking() + ): + chunked = self._get_chunked_step_batch( + worker_graphs_manager, + target_node_name=target_node_name, + exclude_target=exclude_target, + ) + if chunked is not None: + return chunked + # Fall through: AR queue empty this tick, but other engines + # (e.g., Talker) may still have ready work. + # Collect all ready (node_name, request_id, graph_walk) tuples # grouped by node name node_name_to_requests: dict[str, list[ReadyNodeEntry]] = {} diff --git a/mminf/worker/node_manager_utils.py b/mminf/worker/node_manager_utils.py index 16a42d41..d0ff7740 100644 --- a/mminf/worker/node_manager_utils.py +++ b/mminf/worker/node_manager_utils.py @@ -352,6 +352,7 @@ def process_node_outputs( self, request_id: str, outputs: list[GraphEdge], graph_walk: str, + worker_graph_id_hint: str | None = None, ) -> NodeOutputRouting: """ After a node has finished processing, use its outputs to update @@ -361,6 +362,13 @@ def process_node_outputs( I.e., it updates ready/waiting queues for worker graphs on this current worker, and directs external outputs to worker graphs on the appropriate (different) worker. + + ``worker_graph_id_hint``: when provided, the caller knows exactly which + worker_graph the popped GraphNode came from (e.g., the chunked-prefill + scheduler relabels ``batch.graph_walk`` to ``thinker_step`` but pops + the GraphNode from a different walk's worker_graph). Use the hint + directly instead of filtering by ``graph_walk``, which would route to + the wrong queue. """ # (0) separate streaming edges — they bypass the queue system streaming_edges = [edge for edge in outputs if edge.is_streaming] @@ -371,11 +379,14 @@ def process_node_outputs( new_token_outputs = [edge for edge in non_streaming_outputs if edge.conductor_new_token] # (2) process all internal-facing outputs - worker_graph_ids = [ - gid - for gid in self.per_request_info[request_id].worker_graph_ids - if graph_walk in self.all_worker_graph_ids_to_graph_walks[gid] - ] + if worker_graph_id_hint is not None: + worker_graph_ids = [worker_graph_id_hint] + else: + worker_graph_ids = [ + gid + for gid in self.per_request_info[request_id].worker_graph_ids + if graph_walk in self.all_worker_graph_ids_to_graph_walks[gid] + ] completed_worker_graph_ids = [] routed_to_this_worker: list[GraphEdge] = [] # list of graph edges diff --git a/mminf/worker/worker.py b/mminf/worker/worker.py index 0a743119..2b0b1407 100644 --- a/mminf/worker/worker.py +++ b/mminf/worker/worker.py @@ -148,7 +148,11 @@ def __init__( node_to_partition=node_to_partition, ) - self.scheduler = MicroScheduler(self.engine_manager) + # Only consulted when an AR engine has scheduler_owns_chunking=True. + max_step_tokens = model_config.get("max_step_tokens", 2048) if model_config else 2048 + self.scheduler = MicroScheduler( + self.engine_manager, max_step_tokens=max_step_tokens + ) # Determine store write policy based on worker graph topology node_engine_types = model.get_node_engine_types() if model is not None else {} @@ -303,6 +307,34 @@ def _add_new_request(self, body: NewRequest) -> None: for node_name in ar_engine.submodule_management.keys(): self._last_active[(body.request_id, node_name)] = _time.monotonic() + # When scheduler-driven chunking is on, prime ``prefill_tokens_total`` + # from the prompt tensor's leading dimension so the MicroScheduler's + # mixed-batch packer can classify this request as prefill-ready. + # Audio/vision use embed_len + 2 to account for the start/end + # sentinels added by the Thinker's _wrap_audio_input / _wrap_vision_input + # helpers. When chunking is off, total stays 0 and + # ``is_prefill_complete`` is trivially True. + if ( + ar_engine is not None + and getattr(ar_engine, "scheduler_owns_chunking", False) + ): + for edge in body.initial_inputs: + total: int | None = None + if edge.name == "text_inputs" and edge.tensor_info: + prompt_len = edge.tensor_info[0].dims[0] if edge.tensor_info[0].dims else 0 + total = int(prompt_len) if prompt_len > 0 else None + elif edge.name == "audio_embeds" and edge.tensor_info: + audio_len = edge.tensor_info[0].dims[0] if edge.tensor_info[0].dims else 0 + # +2 for the start/end sentinel tokens added at Thinker prefill time. + total = int(audio_len) + 2 if audio_len > 0 else None + elif edge.name == "vision_embeds" and edge.tensor_info: + vision_len = edge.tensor_info[0].dims[0] if edge.tensor_info[0].dims else 0 + total = int(vision_len) + 2 if vision_len > 0 else None + if total is not None: + body.request_info.prefill_tokens_total = total + body.request_info.prefill_tokens_consumed = 0 + break + self.worker_graphs_manager.add_request( request_id=body.request_id, partition_worker_graph_ids=body.partition_worker_graph_ids, @@ -667,29 +699,88 @@ def _try_reload_request(self, node_name: str, request_id: str) -> bool: # Batch building # ------------------------------------------------------------------ + @staticmethod + def _slice_prompt_chunk( + tensors: NameToTensorList, + prefill_total: int, + start: int, + end: int, + ) -> NameToTensorList: + """Return a new ``NameToTensorList`` with token-axis tensors sliced to ``[start, end)``. + + Per-key token-axis rules (explicit, not dynamic): + - ``text_inputs``: 1D ``(seq_len,)`` — slice dim 0. + - All other keys: pass through unchanged. Worker-side non-token + tensors (e.g. fixed-size image or audio embeddings) are already + sized by modality length, not prompt_total; the engine-side + ``_slice_ar_inputs`` in ``ar_engine.py`` handles their sequence + axis after ``prepare_inputs`` constructs ARNodeInputs. + """ + chunk_len = end - start + sliced: NameToTensorList = {} + for name, tensor_list in tensors.items(): + new_list: list[torch.Tensor] = [] + for t in tensor_list: + if not isinstance(t, torch.Tensor): + new_list.append(t) + continue + if name == "text_inputs": + # text_inputs: (seq_len,) — matches _prepare_text_input expectation. + new_list.append(t[start:end]) + else: + # Non-token-axis tensors propagate unchanged; the engine-side + # _slice_ar_inputs handles any sequence-axis slicing post-prepare_inputs. + new_list.append(t) + sliced[name] = new_list + return sliced + def _build_node_batch(self, batch: ScheduledBatch) -> NodeBatch: """Gather input tensors from tensor_manager for all requests in the batch.""" per_request_inputs: dict[str, NameToTensorList] = {} per_request_info: dict[CurrentForwardPassInfo] = {} batch_partition = self.worker_graphs_manager.get_partition_for_node(batch.node_name) + # When ``prefill_chunk_sizes`` is populated, slice each prefill + # rid's token-axis tensors to ``[consumed : consumed + chunk_size]`` + # so the engine only sees this step's slice. Decode rids (absent + # from the dict) and empty-dict batches pass through unchanged. + chunk_sizes = batch.prefill_chunk_sizes or {} + for request_id, node in batch.node_objects.items(): - tensors = {} + tensors: NameToTensorList = {} for input_name in node.ready_inputs: tensors[input_name] = [ self.tensor_manager.get_tensor( request_id=request_id, uuid=info.uuid ) for info in node.ready_inputs[input_name].tensor_info ] + + if request_id in chunk_sizes: + fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, batch_partition) + consumed = fwd_info.prefill_tokens_consumed + total = fwd_info.prefill_tokens_total + chunk = int(chunk_sizes[request_id]) + # Defensive: clamp end to total so the last chunk's narrow() + # never overruns the prompt tensor. + end = min(consumed + chunk, total) + if total > 0 and end > consumed: + tensors = self._slice_prompt_chunk( + tensors, prefill_total=total, start=consumed, end=end, + ) + per_request_inputs[request_id] = tensors per_request_info[request_id] = self.worker_graphs_manager.get_fwd_info(request_id, batch_partition) + # Empty dict ⇒ "all terminal" — preserves single-walk batch behavior. + is_terminal_per_request = batch.is_terminal_per_request or {} + return NodeBatch( node_name=batch.node_name, graph_walk=batch.graph_walk, request_ids=list(batch.node_objects.keys()), per_request_input_tensors=per_request_inputs, - per_request_info=per_request_info + per_request_info=per_request_info, + is_terminal_per_request=is_terminal_per_request, ) # ------------------------------------------------------------------ @@ -760,7 +851,21 @@ def _store_outputs_and_finish_loops( ) if not request_output_tensors: - continue # Node produced no outputs (e.g., KV-cache-only prefill step) + # Node produced no outputs (e.g., KV-cache-only prefill step, + # Talker non-last prefill). For non-terminal chunked-prefill + # rids, the popped GraphNode must be re-queued so the next + # chunk can run on it; otherwise the rid's ready queue stays + # empty and the scheduler can't pick it up next step, + # hanging the request. Empty is_terminal_per_request dict + # (legacy path) ⇒ treat all rids as terminal, preserving + # the prior skip-only behavior for Talker etc. + if not batch.is_terminal_per_request.get(request_id, True): + worker_graph_id = batch.request_to_worker_graph.get(request_id) + if worker_graph_id is not None: + self.worker_graphs_manager.queues[worker_graph_id].push_back_node( + request_id, node, + ) + continue output_tensor_info = self.tensor_manager.store_and_populate_graph_edges( request_id=request_id, @@ -1361,6 +1466,20 @@ def _fast_postprocess( per_label_seq_info=req_info.per_label_seq_info, partition_name=batch_partition, ) + + # Advance prefill_tokens_consumed for each prefill chunk that just + # completed. Only fires when the scheduler populated + # ``prefill_chunk_sizes`` on the batch; non-chunked batches skip + # this entirely. + if batch.prefill_chunk_sizes: + for rid, chunk in batch.prefill_chunk_sizes.items(): + if rid not in node_batch.per_request_info: + continue + fwd_info = self.worker_graphs_manager.get_fwd_info(rid, batch_partition) + fwd_info.prefill_tokens_consumed = min( + fwd_info.prefill_tokens_total, + fwd_info.prefill_tokens_consumed + int(chunk), + ) if self.enable_nvtx: range_pop(synchronize=False) @@ -1410,8 +1529,19 @@ def _fast_postprocess( ] else: kept_for_routing = kept + # When the chunked-prefill scheduler relabels the batch's + # graph_walk (e.g. ``thinker_step``), filtering by graph_walk + # would route outputs to the wrong worker_graph. The scheduler + # populates ``request_to_worker_graph`` with the actual id the + # GraphNode was popped from — pass that as a hint. + wg_id_hint = ( + batch.request_to_worker_graph.get(request_id) + if batch.request_to_worker_graph else None + ) routing = self.worker_graphs_manager.process_node_outputs( - request_id, kept_for_routing, graph_walk=batch.graph_walk + request_id, kept_for_routing, + graph_walk=batch.graph_walk, + worker_graph_id_hint=wg_id_hint, ) routing_per_request[request_id] = routing if self.enable_nvtx: