diff --git a/configs/qwen3omni_talker_pd_disag.yaml b/configs/qwen3omni_talker_pd_disag.yaml new file mode 100644 index 00000000..7b6265de --- /dev/null +++ b/configs/qwen3omni_talker_pd_disag.yaml @@ -0,0 +1,24 @@ +model: "qwen3_omni" +max_seq_len: 32768 +node_groups: + - node_names: [audio_encoder, vision_encoder] + ranks: [0] + + - node_names: [Thinker] + ranks: [0] + graph_walks: [prefill_text, prefill_audio, prefill_vision] + + - node_names: [Thinker] + ranks: [1] + graph_walks: [thinker_decode] + + - node_names: [Talker] + ranks: [2] + graph_walks: [talker_prefill, talker_last_prefill] + + - node_names: [Talker] + ranks: [3] + graph_walks: [talker_decode] + + - node_names: [Code2Wav] + ranks: [3] diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 6a6471a4..c9d5b55d 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -20,6 +20,7 @@ PartitionDefinition, PartitionState, StreamingConnectionState, + TransitionSource, ) from mstar.distributed.base import ShardingConfig from mstar.distributed.communication import GlobalTPConfig, WorkerTPGroups @@ -33,6 +34,7 @@ InputSignals, NewRequest, NewRequestConductor, + ProducerDone, RemoveRequest, UnpersistTensors, WorkerGraphsDone, @@ -189,6 +191,7 @@ def __init__( ): self.requests: dict[str, RequestData] = {} self.model = model + self._validate_transition_sources() self.hostname = hostname self.socket_path_prefix = socket_path_prefix self.log_level = log_level @@ -215,9 +218,24 @@ def __init__( # (1) Set up worker graph TP ranks # (2) Assert that streaming consumers don't have graph-walk-specific sharding config - self.streaming_consumers = set() + self.state_machine_streaming_consumers = set() + state_machine_controlled_walks = set() + self.gw_to_part = {} + self.part_to_walks: dict[str, set[str]] = {} + for part_def in self.model.get_partitions(): + if part_def.transition_source == TransitionSource.STATE_MACHINE: + state_machine_controlled_walks.update(part_def.graph_walks) + self.gw_to_part.update({ + gw: part_def.name for gw in part_def.graph_walks + }) + self.part_to_walks[part_def.name] = set(part_def.graph_walks) + self.node_walk_to_wg: dict[tuple[str, str], WorkerGraph] = {} + + # partition name -> producer-triggered streaming consumer node names + self.producer_triggered_nodes: dict[str, set[str]] = dict() + # (worker idx) -> {tp_group_str: tp_rank} self.worker_tp_group_to_tp_rank: dict[int, dict[str, int]] = {} @@ -226,10 +244,20 @@ def __init__( for walk in wg.graph_walks: graph_walks.add(walk) for name, node in wg.section.get_nodes().items(): + state_machine_controlled = False for walk in wg.graph_walks: self.node_walk_to_wg[(name, walk)] = wg - if node.consumes_stream: - self.streaming_consumers.add(name) + if walk in state_machine_controlled_walks: + state_machine_controlled = True + if node.consumes_stream and state_machine_controlled: + self.state_machine_streaming_consumers.add(name) + elif node.consumes_stream and wg.graph_walks: + # Producer-triggered streaming consumer (e.g. PD-disaggregated + # Talker). All walks of a wg belong to one partition, so any + # of them resolves the partition. + self.producer_triggered_nodes.setdefault( + self.gw_to_part[next(iter(wg.graph_walks))], set() + ).add(name) # v1: one sharding group per worker graph. Track which group "owns" # each wg so we can assert single-group-per-wg. @@ -237,7 +265,7 @@ def __init__( for group in self.default_sharding_config.groups: if group.graph_walks is not None and any([ - node in self.streaming_consumers for node in group.nodes + node in self.state_machine_streaming_consumers for node in group.nodes ]): raise RuntimeError(( f"Sharding group with nodes {group.nodes} includes a streaming consumer but " @@ -459,7 +487,7 @@ def _build_request_sharding_config( for node_name in wg.section.get_nodes(): node_to_workers[NodeAndGraphWalk(node_name, walk)] = worker_ids cfg.setup(node_to_workers) - cfg.assert_stream_consumer_compatibility(self.streaming_consumers) + cfg.assert_stream_consumer_compatibility(self.state_machine_streaming_consumers) return cfg def _split_inputs_to_workers( @@ -476,6 +504,19 @@ def _split_inputs_to_workers( per dest, which the consumer's fan-in path consolidates. """ inputs_per_worker: dict[str, list[GraphEdge]] = defaultdict(list) + + # Seed empty inputs for producer-triggered streaming consumers (PD disag) + # across ALL the partition's walks, not just the conductor's current + # (lagging) walk — otherwise only the prefill worker is reached, never + # decode. set_index is monotonic, so re-seeding is a no-op. + part = self.gw_to_part.get(graph_walk) + for node in self.producer_triggered_nodes.get(part, set()): + for walk in self.part_to_walks.get(part, set()): + for worker in sharding_config.node_to_worker.get( + NodeAndGraphWalk(node, walk), [], + ): + inputs_per_worker.setdefault(worker, []) + for edge in inputs: if not edge.tensor_info: # Signal-only — broadcast to every dest worker. @@ -564,6 +605,52 @@ def _try_admit_waiting(self): ) self._do_ingest_request(body) + def _validate_transition_sources(self): + """Validate that graph-walk transition authority is unambiguous. + + Each partition's walk is moved by exactly one source. A + ``PRODUCER_TRIGGERED`` partition must be driven by at least one streaming + connection carrying a ``consumer_walk_transition``; a ``STATE_MACHINE`` + partition must not be driven by any. Fails fast at conductor startup. + + (A producer-triggered partition may have several driving connections — + e.g. qwen3omni's thinker_states + thinker_mask both carry the marker so + each lands in the right walk.) + """ + pdefs = {p.name: p for p in self.model.get_partitions()} + topology = self.model.get_partition_topology() + + # Collect transition-carrying connections per consumer partition. + driving_conns: dict[str, list[str]] = {} + for conn in topology.connections: + if conn.consumer_walk_transition is None: + continue + consumer = pdefs.get(conn.to_partition) + if consumer is None: + raise ValueError( + f"Connection edge {conn.edge_name!r} targets unknown " + f"partition {conn.to_partition!r}." + ) + if consumer.transition_source != TransitionSource.PRODUCER_TRIGGERED: + raise ValueError( + f"Connection edge {conn.edge_name!r} defines a " + f"consumer_walk_transition for partition " + f"{conn.to_partition!r}, but that partition's " + f"transition_source is {consumer.transition_source.value!r}, " + f"not 'producer_triggered'. The two transition mechanisms " + f"must not be mixed." + ) + driving_conns.setdefault(conn.to_partition, []).append(conn.edge_name) + + for name, pdef in pdefs.items(): + if (pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED + and not driving_conns.get(name)): + raise ValueError( + f"Producer-triggered partition {name!r} has no incoming " + f"connection with a consumer_walk_transition to drive its " + f"graph-walk transitions." + ) + def _ingest_request( self, body: NewRequestConductor ): @@ -673,6 +760,14 @@ def _do_ingest_request( ) partition_fwd_args[p.name] = fwd_args + # set up tracked_consumer_graph_walks + for conn in topology.connections: + producer = partition_states[conn.from_partition] + consumer = partition_states[conn.to_partition] + if conn.consumer_walk_transition is not None: + producer.tracked_consumer_graph_walks[conn.to_partition] \ + = consumer.metadata.graph_walk + # Send NewRequest to each worker with the appropriate partition's inputs for worker_id, worker_graph_ids in worker_to_worker_graph_ids.items(): # Determine which partition this worker serves @@ -706,7 +801,10 @@ def _do_ingest_request( requires_cfg=fwd_args.full_metadata.requires_cfg, partition_name=partition_name, max_tokens=request_data.max_output_tokens, - sampling_config=request_data.sampling_config + sampling_config=request_data.sampling_config, + produced_edge_idx=pstate.produced_edge_idx, + next_stream_index=pstate.next_stream_index, + tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), ) self.communicator.send( @@ -793,6 +891,20 @@ def _process_worker_graphs_done( ) return [] + # Stream indices are monotonic, but only the emitting TP rank advances + # them; sibling ranks report the (stale) seed they were given. With + # last-writer-wins .update(), a sibling landing after the emitter would + # rewind the index, re-emitting an already-used one. Merge with max. + for name, idx in body.new_produced_edge_idx.items(): + pstate.produced_edge_idx[name] = max(pstate.produced_edge_idx.get(name, 0), idx) + for name, idx in body.new_next_stream_index.items(): + pstate.next_stream_index[name] = max(pstate.next_stream_index.get(name, 0), idx) + # Unlike the indices above, graph walks have no ordering to max-merge on, + # so this is a plain last-writer-wins update. Contract: a reporting rank + # must send either the correct (just-applied) consumer walk or nothing — + # a non-emitting TP sibling sends an empty dict, never a stale walk. + pstate.tracked_consumer_graph_walks.update(body.consumer_graph_walk_transitions) + # Persist signals: every rank contributes its shard (different uuid + # source_tp_rank); accumulate across ranks, do not dedup. if body.persist_signals: @@ -824,6 +936,22 @@ def _process_worker_graphs_done( body.output_signal_names, list ) else [] + # Producer-triggered partitions self-apply their walk from the stream. + # Adopt the walk the consumer actually ran (reported back) before the + # completion check, and realign the completion-tracking worker-graph set + # to it — otherwise current_worker_graph_ids (a state-machine-derived + # walk) diverges and the subset check below breaks. The conductor does + # not drive this partition's transitions; this only mirrors them. + pdef = request_data.partition_definitions.get(partition_name) + if (pdef is not None + and pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED + and body.partition_graph_walk + and body.partition_graph_walk != pstate.metadata.graph_walk): + pstate.metadata.graph_walk = body.partition_graph_walk + self._set_partition_worker_graph_ids( + body.request_id, partition_name, body.partition_graph_walk, + ) + # Each wg is only marked complete when all its TP ranks have reported. for wg_id in body.worker_graph_ids: count = pstate.wg_rank_completions.get(wg_id, 0) + 1 @@ -893,10 +1021,11 @@ def _process_done_forward( if conn.from_partition == partition_name: conn.producer_done = True self._send_producer_done(request_id, conn.from_partition, conn.to_partition) - elif fwd_args.inputs: - # Partition has inputs to send — conductor-driven + else: + # Call even with no inputs: a no-op for self-triggering partitions + # (empty inputs_per_worker → no messages), but lets producer-triggered + # consumers get seeded (PD) / their next_stream_index propagated. self._send_partition_inputs(request_id, partition_name, fwd_args) - # else: no inputs — partition self-triggers via StreamBuffer self._un_persist_tensors(request_id, fwd_args.unpersist_tensors) @@ -947,6 +1076,9 @@ def _send_partition_inputs( partition_name=partition_name, max_tokens=request_data.max_output_tokens, sampling_config=request_data.sampling_config, + produced_edge_idx=pstate.produced_edge_idx, + next_stream_index=pstate.next_stream_index, + tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), partition_name=partition_name ), @@ -959,7 +1091,7 @@ def _send_producer_done( ): """Send producer_done signal to the consumer partition's worker(s).""" request_data = self.requests[request_id] - pstate = request_data.partition_states[consumer_partition_name] + pstate = request_data.partition_states[producer_partition] # Find which workers handle this consumer partition consumer_workers = set() @@ -971,21 +1103,11 @@ def _send_producer_done( for worker_id in consumer_workers: message = WorkerMessage( - message_type=WorkerMessageType.INPUT_SIGNALS, - body=InputSignals( + message_type=WorkerMessageType.PRODUCER_DONE, + body=ProducerDone( request_id=request_id, - inputs=[], - request_info=CurrentForwardPassInfo( - request_id=request_id, - graph_walk=pstate.metadata.graph_walk or "", - fwd_index=pstate.fwd_pass_number, - random_seed=pstate.random_seed, - requires_cfg=False, - partition_name=consumer_partition_name, - max_tokens=request_data.max_output_tokens, - sampling_config=request_data.sampling_config - ), partition_name=consumer_partition_name, + last_produced_edge_idx=pstate.produced_edge_idx, producer_done=set([producer_partition]), ), ) diff --git a/mstar/conductor/request_info.py b/mstar/conductor/request_info.py index ae7cc696..1434b550 100644 --- a/mstar/conductor/request_info.py +++ b/mstar/conductor/request_info.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Any from mstar.graph.loop_indices import NestedLoopIndices @@ -47,6 +48,23 @@ def update(self, other: "PerLabelSeqInfo"): **val } + def merge_keep_longest(self, other: "PerLabelSeqInfo"): + """Merge ``other`` into self, keeping — per (kv_cache_string, rank, + label) — the ``SequenceInfo`` with the larger ``seq_len``. + + Used on the worker to stop a lagging conductor copy from rewinding a + request's locally-advanced KV sequence positions: a request's KV length + only ever grows, so the longer entry is always the more recent one. + """ + for key, label_info in other.info.items(): + dst = self.info.setdefault(key, {}) + for label, seq_info in label_info.items(): + cur = dst.get(label) + if cur is None or seq_info.seq_len > cur.seq_len: + dst[label] = seq_info + for kv_cache_str, ws in other.world_size.items(): + self.world_size.setdefault(kv_cache_str, ws) + def get(self, kv_cache_str: str, rank: int) -> dict: return self.info.get((kv_cache_str, rank), {}) @@ -75,6 +93,9 @@ class CurrentForwardPassInfo: step_metadata: dict = field(default_factory=dict) per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") + produced_edge_idx: dict[str, int] = field(default_factory=dict) + next_stream_index: dict[str, int] = field(default_factory=dict) + tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) # Per-loop stop indices; stop decisions come from each submodule's check_stop. loop_stop_times: dict[str, NestedLoopIndices] = field(default_factory=dict) @@ -88,6 +109,20 @@ def clear_loop_stop_info(self): # Partition types for async graph partitions # --------------------------------------------------------------------------- +class TransitionSource(str, Enum): + """Who owns a partition's graph-walk transitions. + + A partition's walk is moved by exactly one authority: + - STATE_MACHINE: the conductor advances the walk via the model's + ``get_partition_forward_pass_args``. + - PRODUCER_TRIGGERED: a producer partition drives the walk via streaming + transition markers (``Connection.consumer_walk_transition``); the + conductor does not advance it. + """ + STATE_MACHINE = "state_machine" + PRODUCER_TRIGGERED = "producer_triggered" + + @dataclass class PartitionDefinition: """Defines a partition within a model's computation graph. @@ -99,6 +134,7 @@ class PartitionDefinition: graph_walks: set[str] # walks this partition uses initial_walk: str | None = None # first walk, or None = triggered later producer_partitions: list[str] = field(default_factory=list) # partitions feeding tokens to this one + transition_source: TransitionSource = TransitionSource.STATE_MACHINE @dataclass @@ -120,6 +156,14 @@ class PartitionState: fwd_pass_number: int = 0 random_seed: int = 0 is_done: bool = False + # streaming edge name -> number of times this edge has been emitted + produced_edge_idx: dict[str, int] = field(default_factory=dict) + # streaming edge name -> next stream index the consumer should drain + # (one past the last index a completed consumer pass has consumed) + next_stream_index: dict[str, int] = field(default_factory=dict) + # only for producer-driven graph walk transitions + tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) + new_tokens: dict[str, list[int]] = field(default_factory=dict) completed_worker_graph_ids: set[str] = field(default_factory=set) current_worker_graph_ids: set[str] = field(default_factory=set) # wg_id -> count of distinct TP ranks that have reported completion diff --git a/mstar/engine/kv_store.py b/mstar/engine/kv_store.py index 380d5755..9cbfcbfc 100644 --- a/mstar/engine/kv_store.py +++ b/mstar/engine/kv_store.py @@ -378,11 +378,11 @@ def _do_read( for info in read_info: slice = tensor[ - info.layer_idx, info.remote_page_idx, + info.layer_idx, info.remote_page_idx, :, info.token_start:info.token_end ].to(self._device) self._kv_cache[ - info.layer_idx, info.local_page_idx, + info.layer_idx, info.local_page_idx, :, info.token_start:info.token_end ] = slice diff --git a/mstar/graph/base.py b/mstar/graph/base.py index 93b85e08..f1e509ce 100644 --- a/mstar/graph/base.py +++ b/mstar/graph/base.py @@ -63,6 +63,11 @@ class GraphEdge: # only for EMIT_TO_CLIENT output_modality: str = field(default="") # text | image | video | audio _persist_for_loop: bool = field(default=False) + + # set on a synthetic streaming-input edge: the next stream index the + # consumer should drain after the pass that consumes this edge (None on + # non-streaming edges, which carry no stream position) + _next_stream_index: int | None = field(default=None) # set on a synthetic streaming-input edge carrying the final chunk, so the # consuming pass (not the earlier ingest) reports the partition done _final_stream_chunk: bool = field(default=False) @@ -71,6 +76,9 @@ class GraphEdge: _total_fanin: int = 1 _shard_dim: int | None = None + # set for non-streaming inputs + _target_graph_walk: str | None = None + def clone(self): return GraphEdge( next_node=self.next_node, @@ -81,7 +89,9 @@ def clone(self): is_streaming=self.is_streaming, output_modality=self.output_modality, _persist_for_loop=self._persist_for_loop, + _next_stream_index=self._next_stream_index, _final_stream_chunk=self._final_stream_chunk, + _target_graph_walk=self._target_graph_walk, ) @@ -717,8 +727,10 @@ def __init__(self, graph_section: GraphSection): self.ready_for_streaming = set(self.only_streaming_inputs) self.ready_streaming_next_iter = set(self.only_streaming_inputs) + self.clean = True def register_ingested_input(self, graph_edge: GraphEdge): + self.clean = False node = self.nodes[graph_edge.next_node] # If node._speculatively_scheduled, the node is already executing as # a spec batch. We don't want to double-queue it (either for the @@ -766,6 +778,7 @@ def reset_for_iter(self): ) def clear(self): + self.clean = True super().clear() self.ready_names.clear() self.ready_for_streaming = set(self.only_streaming_inputs) diff --git a/mstar/model/qwen3_omni/qwen3_omni_model.py b/mstar/model/qwen3_omni/qwen3_omni_model.py index c90444bc..04e028c7 100644 --- a/mstar/model/qwen3_omni/qwen3_omni_model.py +++ b/mstar/model/qwen3_omni/qwen3_omni_model.py @@ -15,13 +15,14 @@ Thinker --[thinker_states, FixedChunkPolicy(1)]--> Talker Talker --[codec_tokens, FixedChunkPolicy(25)]--> Code2Wav -Conductor-triggered pipelined prefill (Approach C): - After each Thinker walk completes (prefill_text, prefill_audio, - prefill_vision, thinker_decode), the conductor sends a - ``talker_trigger`` to the Talker partition. During prefill each - trigger extends the Talker KV cache with the new Thinker hidden - states. The final trigger (when thinker_decode starts) tells the - Talker to sample its first codec token and transition to decode. +Producer-triggered Talker prefill: + The Talker partition is PRODUCER_TRIGGERED: its graph walk is driven + by the Thinker's streamed thinker_states/thinker_mask rather than by + the conductor. Each streamed item carries a walk-transition marker + (see ``talker_state_transition``): items from a Thinker prefill walk + keep the Talker in talker_prefill (extend KV cache only); the first + item from thinker_decode flips it to talker_last_prefill (sample the + first codec token), after which it self-transitions to talker_decode. Text-only mode: When output_modalities does not include "audio", only the Thinker @@ -39,6 +40,7 @@ CurrentForwardConductorMetadata, PartitionDefinition, StreamingConnectionState, + TransitionSource, ) from mstar.engine.base import EngineType from mstar.engine.kv_store import KVCacheConfig @@ -49,7 +51,13 @@ from mstar.model.submodule_base import NodeSubmodule from mstar.model.utils import Operation, WeightConverter from mstar.streaming.chunk_policy import FixedChunkPolicy, LeftContextChunkPolicy -from mstar.streaming.topology import Connection, PartitionTopology, StreamingGraphEdge +from mstar.streaming.topology import ( + Connection, + ConsumerTransitionCtx, + PartitionTopology, + StreamingGraphEdge, + WalkTransition, +) from mstar.utils.sampling import SamplingConfig logger = logging.getLogger(__name__) @@ -334,13 +342,10 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: outputs=[], ) - # -- Talker prefill: receives thinker_states + talker_trigger -- - # Dual-input gating: both thinker_states from streaming and - # talker_trigger from conductor cross-partition trigger must be - # present for a prefill step. + # -- Talker prefill: receives thinker_states talker_prefill = GraphNode( name="Talker", - input_names=["thinker_states", "thinker_mask", "talker_trigger"], + input_names=["thinker_states", "thinker_mask"], outputs=[], ) @@ -348,7 +353,7 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: sections=[ GraphNode( name="Talker", - input_names=["thinker_states", "thinker_mask", "talker_trigger"], + input_names=["thinker_states", "thinker_mask"], outputs=[ GraphEdge( next_node=EMPTY_DESTINATION, @@ -435,6 +440,7 @@ def get_partitions(self) -> list[PartitionDefinition]: graph_walks={"talker_prefill", "talker_last_prefill", "talker_decode"}, initial_walk="talker_prefill", producer_partitions=["Thinker"], + transition_source=TransitionSource.PRODUCER_TRIGGERED, ), PartitionDefinition( name="Code2Wav", @@ -445,6 +451,13 @@ def get_partitions(self) -> list[PartitionDefinition]: ] def get_partition_topology(self) -> PartitionTopology: + def talker_state_transition(ctx: ConsumerTransitionCtx) -> WalkTransition: + if ctx.producer_walk != "thinker_decode": + return WalkTransition("talker_prefill") + if ctx.consumer_walk == "talker_prefill": + return WalkTransition("talker_last_prefill") + return WalkTransition("talker_decode") + return PartitionTopology( partitions=["Thinker", "Talker", "Code2Wav"], connections=[ @@ -452,13 +465,19 @@ def get_partition_topology(self) -> PartitionTopology: from_partition="Thinker", to_partition="Talker", edge_name="thinker_states", - chunk_policy_factory=lambda: FixedChunkPolicy(chunk_size=1, continue_after_done=True), + chunk_policy_factory=lambda: FixedChunkPolicy( + chunk_size=1, continue_after_done={"talker_decode"} + ), + consumer_walk_transition=talker_state_transition ), Connection( from_partition="Thinker", to_partition="Talker", edge_name="thinker_mask", - chunk_policy_factory=lambda: FixedChunkPolicy(chunk_size=1, continue_after_done=True), + chunk_policy_factory=lambda: FixedChunkPolicy( + chunk_size=1, continue_after_done={"talker_decode"} + ), + consumer_walk_transition=talker_state_transition ), Connection( from_partition="Talker", @@ -540,16 +559,13 @@ def get_initial_forward_pass_args( is_prefill=True, kwargs={ "audio_output": audio_output, - "talker_prefill_done": False, - "num_thinker_prefill_steps": len(input_modalities), - "prefill_chunks_processed": 0, "voice": model_kwargs.get("voice", "Ethan"), "talker_max_tokens": self.get_max_talker_output_tokens(**model_kwargs), }, ) return ForwardPassArgs( full_metadata=full_metadata, - inputs=[GraphEdge(next_node="Talker", name="talker_trigger")] if audio_output else [], + inputs=[], unpersist_tensors=[], request_done="audio" not in output_modalities, step_metadata={ @@ -827,6 +843,8 @@ def _get_thinker_forward( step_metadata = { "is_prefill": metadata.is_prefill, + # The last prefill step must emit logits so the first decode token + # can be sampled (read by the non-batched Thinker submodule). "is_last_prefill": is_last_prefill, # Persist the audio_output flag across every Thinker step so # the submodule can gate thinker_states emission. Default True @@ -849,41 +867,24 @@ def _get_talker_forward( persist_signals: dict[str, list[TensorPointerInfo]], incoming_connections: list[StreamingConnectionState] | None = None, ) -> ForwardPassArgs: - """Talker partition state machine. - - 1. While prefill: return empty inputs (wait for cross-partition trigger) - - When trigger arrives with is_last_prefill=False: - extend KV cache only, no outputs - - When trigger arrives with is_last_prefill=True: - sample first codec token, produce all_codes - 2. After last prefill produces all_codes: transition to talker_decode - - Set graph_walk="talker_decode", is_prefill=False - - Return all_codes as input edge (conductor-driven) - 3. Each decode step: check all_codes for codec_eos - - If codec_eos: request_done=True for Talker - - Else: return all_codes as input again (loop) - """ - if metadata.graph_walk == "talker_prefill": - metadata.kwargs["prefill_chunks_processed"] += 1 - is_last_prefill = metadata.kwargs["num_thinker_prefill_steps"] == \ - metadata.kwargs["prefill_chunks_processed"] - metadata.graph_walk = "talker_last_prefill" if is_last_prefill else "talker_prefill" + if metadata.graph_walk == "talker_decode": + # If the decode dynamic loop reaches the conductor, we can end the request. return ForwardPassArgs( full_metadata=metadata, - inputs=[GraphEdge(next_node="Talker", name="talker_trigger")], + inputs=[], unpersist_tensors=[], - step_metadata={ - "is_prefill": True, - # voice is used for the last prefill - "voice": metadata.kwargs.get("voice", "Ethan"), - "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") - }, + request_done=True, ) - elif metadata.graph_walk == "talker_last_prefill": - metadata.is_prefill = False - metadata.graph_walk = "talker_decode" - metadata.kwargs["talker_prefill_done"] = True + step_metadata = { + # voice is used for the last prefill + "voice": metadata.kwargs.get("voice", "Ethan"), + "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") + } + inputs = [] + unpersist_tensors = [] + + if metadata.graph_walk == "talker_last_prefill": # Feed talker_input_embeds back as input for first decode step edge = GraphEdge(next_node="Talker", name="talker_input_embeds") edge.tensor_info = persist_signals["talker_input_embeds"] @@ -891,29 +892,13 @@ def _get_talker_forward( unpersist_tensors = sum( [inp.tensor_info for inp in inputs], start=[] ) + metadata.graph_walk = "talker_decode" - return ForwardPassArgs( - full_metadata=metadata, - inputs=inputs, - unpersist_tensors=unpersist_tensors, - step_metadata={ - "is_prefill": False, - "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") - }, - ) - - elif metadata.graph_walk == "talker_decode": - # If the decode dynamic loop reaches the conductor, we can end the request. - return ForwardPassArgs( - full_metadata=metadata, - inputs=[], - unpersist_tensors=[], - request_done=True, - ) - - raise ValueError( - f"Talker in unexpected state: walk={metadata.graph_walk!r}, " - f"is_prefill={metadata.is_prefill}" + return ForwardPassArgs( + full_metadata=metadata, + inputs=inputs, + unpersist_tensors=unpersist_tensors, + step_metadata=step_metadata, ) # -- Code2Wav state machine -------------------------------------------- diff --git a/mstar/streaming/chunk_policy.py b/mstar/streaming/chunk_policy.py index 6c527fe4..1ca1bee1 100644 --- a/mstar/streaming/chunk_policy.py +++ b/mstar/streaming/chunk_policy.py @@ -12,12 +12,12 @@ def register_chunk(self, chunk_size: int): self.items_consumed += chunk_size @abstractmethod - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None=None) -> bool: """Return True if the buffer has enough items for a chunk.""" ... @abstractmethod - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None=None) -> int: """Return the number of items to consume for the next chunk. Only called when is_ready() returns True. @@ -26,7 +26,7 @@ def next_chunk_size(self, buffer_len: int) -> int: ... @abstractmethod - def window_size(self) -> int: + def window_size(self, graph_walk: str | None=None) -> int: """Return the full window of items to include in the chunk. For non-overlapping policies, equals next_chunk_size. @@ -35,7 +35,7 @@ def window_size(self) -> int: """ ... - def continue_after_producer_done(self) -> bool: + def continue_after_producer_done(self, graph_walk: str | None=None) -> bool: """Whether the buffer should keep producing (empty) chunks after the producer signals done and all buffered items have been consumed. @@ -66,13 +66,13 @@ def __init__(self, window: int, stride: int): self._window = window self._stride = stride - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None = None) -> bool: return buffer_len >= self._window - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: return self._stride - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: return self._window @@ -104,19 +104,19 @@ def __init__(self, chunk: int, left_context: int): self._left_context = left_context self._window = chunk + left_context - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None = None) -> bool: if not self.first_chunk_read: return buffer_len >= self._chunk return buffer_len >= self._window - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: # First pop: advance by (chunk - left_context) so the tail of the # first chunk stays in the buffer as overlap for the next pop. if not self.first_chunk_read: return self._chunk - self._left_context return self._chunk - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: if not self.first_chunk_read: return self._chunk return self._window @@ -134,19 +134,20 @@ class FixedChunkPolicy(ChunkPolicy): the producer finishes and all buffered items are consumed. """ - def __init__(self, chunk_size: int, continue_after_done: bool = False): + def __init__(self, chunk_size: int, continue_after_done: set[str] | None=None): super().__init__() self._chunk_size = chunk_size self._continue_after_done = continue_after_done - def is_ready(self, buffer_len) -> bool: + def is_ready(self, buffer_len, graph_walk: str | None = None) -> bool: return buffer_len >= self._chunk_size - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: return self._chunk_size - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: return self._chunk_size - def continue_after_producer_done(self) -> bool: - return self._continue_after_done + def continue_after_producer_done(self, graph_walk: str | None = None) -> bool: + return self._continue_after_done is not None \ + and graph_walk in self._continue_after_done diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 1f8bed9f..82385b91 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -12,8 +12,25 @@ class StreamChunk: """A chunk of data popped from a StreamBuffer.""" data: dict[str, torch.Tensor | None] chunk_index: int + # the next stream index to drain after this chunk is consumed (one past the + # last item in the chunk); used to re-seed a worker that takes over the + # partition (e.g. prefill->decode PD handoff) + next_stream_index: int start_offset: int = 0 # global position of the first item in this chunk is_final: bool = False + graph_walk_transition: str | None = None + + +@dataclass +class StreamingTensor: + index: int + tensor: torch.Tensor + graph_walk: str | None = None + +@dataclass +class WaitingEdge: + edge: GraphEdge + walk_transition: str | None = None @dataclass @@ -32,48 +49,81 @@ class StreamBuffer: from_partition: str policy: ChunkPolicy + # graph edges of chunks that have been popped but not ingested _waiting_graph_edges: deque = field(default_factory=deque) - _buffer: list = field(default_factory=list) - _tensor_ids_in_order: deque = field(default_factory=deque) - _id_to_tensor: dict = field(default_factory=dict) + # edge index -> tensor and metadata + _tensors: dict[int, StreamingTensor] = field(default_factory=dict) + _buffer: list[StreamingTensor] = field(default_factory=list) + _current_index: int = 0 _consumed: int = 0 _chunks_popped: int = 0 producer_done: bool = False + _producer_edge_idx: int | None = None _num_tensors_registered = 0 _num_buffer_writes = 0 - def pre_read_register(self, tensor_id: str): + def pre_read_register(self): + """ + Register that we are reading a tensor so we don't prematurely declare + the producer as done. + """ self._num_tensors_registered += 1 - self._tensor_ids_in_order.append(tensor_id) - - def put(self, tensor_id: str, item: torch.Tensor) -> None: - """Called when a tensor arrives via normal RDMA routing.""" - self._id_to_tensor[tensor_id] = item def _update_buffer(self): - while len(self._tensor_ids_in_order) > 0: - tensor_id = self._tensor_ids_in_order[0] - if tensor_id not in self._id_to_tensor: - return - self._tensor_ids_in_order.popleft() - self._buffer.append(self._id_to_tensor[tensor_id]) - self._num_buffer_writes += 1 - del self._id_to_tensor[tensor_id] - - def signal_done(self) -> None: + while self._current_index in self._tensors: + self._buffer.append(self._tensors.pop(self._current_index)) + self._current_index += 1 + + def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> None: + """Called when a tensor arrives via normal RDMA routing. + + Idempotent by index: if this index has already been buffered or has + already been drained into ``_buffer`` (``index < _current_index``), + the duplicate is dropped (first-arrival-wins). This handles the case + where multiple colocated producer ranks emit the same streaming item. + """ + # Counts put attempts, not unique items: incremented before the dedup + # return so it stays balanced with ``_num_tensors_registered`` (which + # ``pre_read_register`` also bumps once per registered tensor, including + # duplicates). ``_producer_done_and_all_read`` relies on that symmetry. + self._num_buffer_writes += 1 + if index < self._current_index or index in self._tensors: + return + self._tensors[index] = StreamingTensor( + index=index, + tensor=item, + graph_walk=graph_walk, + ) + + def signal_done(self, producer_edge_idx: int | None=None) -> None: """Producer signals no more items will arrive.""" self.producer_done = True + self._producer_edge_idx = producer_edge_idx + + def set_index(self, index: int): + """Seed the next index to drain (e.g. when a new consumer worker takes + over a partition after a prefill->decode handoff). + + Only ever advances: indices are monotonic, and the conductor's tracked + value lags (it is refreshed only at WorkerGraphsDone). Rewinding a live + buffer would point ``_current_index`` at items already popped out of + ``_tensors``, deadlocking the drain. + """ + self._current_index = max(self._current_index, index) def _producer_done_and_all_read(self) -> bool: - return self.producer_done and self._num_buffer_writes >= self._num_tensors_registered + return self.producer_done and ( + self._producer_edge_idx is None or \ + self._current_index >= self._producer_edge_idx + ) and self._num_buffer_writes >= self._num_tensors_registered - def pop_waiting_edge(self) -> GraphEdge | None: + def pop_waiting_edge(self) -> WaitingEdge | None: if len(self._waiting_graph_edges) > 0: return self._waiting_graph_edges.popleft() - def has_chunk_ready(self) -> bool: + def has_chunk_ready(self, graph_walk: str) -> bool: self._update_buffer() buf_len = len(self._buffer) if self._producer_done_and_all_read() and buf_len > 0: @@ -84,32 +134,79 @@ def has_chunk_ready(self) -> bool: # generating codec tokens after the Thinker hits text EOS). if (self._producer_done_and_all_read() and buf_len == 0 - and self.policy.continue_after_producer_done()): + and self.policy.continue_after_producer_done(graph_walk)): return True - return self.policy.is_ready(buf_len) - - def pop_chunk(self) -> StreamChunk: + return self.policy.is_ready(buf_len, graph_walk) + + def _chunk_boundary( + self, current_walk: str, max_len: int + ) -> tuple[int, str | None]: + """Bound a chunk so a producer-triggered walk transition starts fresh. + + A producer-triggered graph-walk transition must mark a chunk boundary: + the consumer runs one forward pass per walk, so a chunk cannot straddle + two walks. + + Returns ``(boundary, transition)`` where: + - ``transition`` is the walk this chunk runs under, taken from the + *first* buffered item if it carries a transition to a walk other + than ``current_walk`` (else ``None`` — walk unchanged). + - ``boundary`` is the number of leading items that share this chunk's + walk, clamped to ``max_len``. Any later item carrying a transition + to a *different* walk forces the boundary before it, so it becomes + the leading item of the next chunk. + """ + if not self._buffer or max_len <= 0: + return max(max_len, 0), None + + # (1) Leading transition: the walk this chunk runs under. + transition = None + chunk_walk = current_walk + first = self._buffer[0].graph_walk + if first is not None and first != current_walk: + transition = first + chunk_walk = first + + # (2) Cut before any later item that transitions to a different walk. + boundary = min(max_len, len(self._buffer)) + for j in range(1, boundary): + gw = self._buffer[j].graph_walk + if gw is not None and gw != chunk_walk: + boundary = j + break + return boundary, transition + + def pop_chunk(self, graph_walk: str) -> StreamChunk: """Pop the next chunk. Only call when has_chunk_ready() is True. For sliding-window: returns `window_size` items, advances by `stride` items, discards items that have fallen out of the window. start_offset is the global position of the first item in the chunk. + + A producer-triggered walk transition forces a chunk boundary (see + ``_chunk_boundary``): the returned chunk never straddles two walks, and + ``graph_walk_transition`` carries the walk this chunk runs under. """ self._update_buffer() buf_len = len(self._buffer) - window = self.policy.window_size() offset = self._consumed # global position of buffer[0] - if self._producer_done_and_all_read() and not self.policy.is_ready(buf_len): - # Flush remainder — return whatever is left (may be empty) - items = list(self._buffer) - self._buffer.clear() - self._consumed += len(items) - stride = len(items) + if self._producer_done_and_all_read() and not self.policy.is_ready(buf_len, graph_walk): + # Flush remainder — return whatever is left (may be empty), still + # cut at the first walk transition so each walk gets its own pass. + boundary, transition = self._chunk_boundary(graph_walk, len(self._buffer)) + items = self._buffer[:boundary] + self._buffer = self._buffer[boundary:] + self._consumed += boundary + stride = boundary else: - stride = self.policy.next_chunk_size(buf_len) - # Return the first `window` items (overlapping sliding window) - items = self._buffer[:window] + stride = self.policy.next_chunk_size(buf_len, graph_walk) + window = self.policy.window_size(graph_walk) + # Bound the window/stride so a transition starts a fresh chunk. + boundary, transition = self._chunk_boundary(graph_walk, window) + stride = min(stride, boundary) + # Return the first `boundary` items (overlapping sliding window). + items = self._buffer[:boundary] # Advance by stride — discard items that fell out of the window self._buffer = self._buffer[stride:] self._consumed += stride @@ -118,20 +215,25 @@ def pop_chunk(self) -> StreamChunk: is_final = self._producer_done_and_all_read() and len(self._buffer) == 0 # When continue_after_producer_done, never mark as final — the # consumer decides when it's done via its own model logic. - if self.policy.continue_after_producer_done(): + if self.policy.continue_after_producer_done(graph_walk): is_final = False chunk = StreamChunk( - data=self._collate(items), + data=self._collate([it.tensor for it in items]), chunk_index=self._chunks_popped, + next_stream_index=items[-1].index + 1 if items else self._current_index, start_offset=offset, is_final=is_final, + graph_walk_transition=transition, ) self._chunks_popped += 1 return chunk - def store_uningested_edge(self, edge: GraphEdge): - self._waiting_graph_edges.append(edge) + def store_uningested_edge(self, edge: GraphEdge, walk_transition: str | None=None): + self._waiting_graph_edges.append(WaitingEdge( + edge=edge, + walk_transition=walk_transition + )) def _collate(self, items: list) -> dict[str, torch.Tensor | None]: if not items: diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index d227679d..12b8939a 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Callable +from mstar.conductor.request_info import CurrentForwardPassInfo from mstar.graph.base import GraphEdge from mstar.streaming.chunk_policy import ChunkPolicy @@ -15,10 +16,46 @@ class StreamingGraphEdge(GraphEdge): consuming node's input. """ target_partition: str = "" + _index: int = 0 + _graph_walk_transition: str | None = None def __post_init__(self): self.is_streaming = True + def clone(self): + return StreamingGraphEdge( + next_node=self.next_node, + name=self.name, + tensor_info=self.tensor_info[:], + persist=self.persist, + conductor_new_token=self.conductor_new_token, + is_streaming=self.is_streaming, + output_modality=self.output_modality, + _persist_for_loop=self._persist_for_loop, + _total_fanin=self._total_fanin, + _shard_dim=self._shard_dim, + _target_graph_walk=self._target_graph_walk, + target_partition=self.target_partition, + _index=self._index, + _graph_walk_transition=self._graph_walk_transition, + ) + + +@dataclass(frozen=True) +class ConsumerTransitionCtx: + producer_walk: str + consumer_walk: str | None # None on the very first trigger + # Unused by qwen3omni's transition fn, but exposed so a future model can + # base its transition on the producer's full forward-pass state. This is + # the sole reason streaming depends on conductor.request_info; revisit + # (e.g. a lighter type) if that dependency direction becomes a problem. + producer_fwd: CurrentForwardPassInfo + + +@dataclass +class WalkTransition: + graph_walk: str | None = None + # TODO [FUTURE]: hook up metadata if needed @dataclass class Connection: @@ -27,6 +64,7 @@ class Connection: to_partition: str edge_name: str chunk_policy_factory: Callable[[], ChunkPolicy] + consumer_walk_transition: Callable[[ConsumerTransitionCtx], WalkTransition] | None = None @dataclass diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index 1eb6ae77..984e273d 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -30,6 +30,7 @@ class WorkerMessageType(Enum): NEW_REQUEST = "new_request" REMOVE_REQUEST = "remove_request" INPUT_SIGNALS = "input_signals" + PRODUCER_DONE = "producer_done" UNPERSIST_TENSORS = "unpersist" TENSOR_RECEIVED = "tensor_received" SCHEDULE_TP = "schedule_tp" @@ -56,6 +57,13 @@ class InputSignals(MessageBody): inputs: list[GraphEdge] request_info: CurrentForwardPassInfo partition_name: str = "default" + + +@dataclass +class ProducerDone(MessageBody): + request_id: str + partition_name: str + last_produced_edge_idx: dict[str, int] producer_done: set = field(default_factory=set) @@ -119,9 +127,15 @@ class WorkerGraphsDone(MessageBody): persist_signals: dict[str, list[TensorPointerInfo]] = field(default_factory=dict) new_tokens: dict[str, list[int]] = field(default_factory=dict) # name to tokens output_signal_names: int = field(default=0) + new_produced_edge_idx: dict[str, int] = field(default_factory=dict) + new_next_stream_index: dict[str, int] = field(default_factory=dict) + consumer_graph_walk_transitions: dict[str, str] = field(default_factory=dict) per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") partition_done: bool = field(default=False) + # the graph walk this partition's just-completed forward pass ran under; + # used by the conductor to track a producer-triggered partition's walk + partition_graph_walk: str | None = field(default=None) stream_tokens_consumed: dict[str, int] = field(default_factory=dict) # edge_name -> tokens consumed from stream output_loop_indices: dict[str, NestedLoopIndices] = field(default_factory=dict) diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index bb8ef000..4e4ca382 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -193,6 +193,10 @@ class PerPartitionInfo: graph_walk_worker_graph_ids: list[str] = field(default_factory=list) # for this worker stream_partition_done: bool = False # set True when last chunk pops with is_final + # edges that are pending a producer-triggered graph walk transition + # {graph walk -> edges} + pending_edges: dict[str, list[GraphEdge]] = field(default_factory=dict) + @dataclass class PerRequestInfo: @@ -268,28 +272,74 @@ def __post_init__(self): # ambiguity is moot for routing. self.walk_node_to_worker_graph_id[(walk, node)] = wg_id + def buffer_pending_edge( + self, request_id: str, graph_walk: str, edge: GraphEdge + ): + partition_name = self.get_partition_for_node(edge.next_node) + req_info = self.per_request_info[request_id] + part_info = req_info.per_partition_info[partition_name] + part_info.pending_edges.setdefault(graph_walk, []).append(edge) + def update_request_info( self, request_id: str, partition_name, current_fwd_info: CurrentForwardPassInfo | None=None, per_label_seq_info: PerLabelSeqInfo | None=None, + allow_graph_walk_transition: bool=True ): req_info = self.per_request_info[request_id] part_info = req_info.per_partition_info[partition_name] if current_fwd_info is not None: - graph_walk = current_fwd_info.graph_walk - if self.get_graph_walk(request_id, partition_name) != graph_walk: - part_info.graph_walk_worker_graph_ids = [ - graph_id for graph_id in self.per_request_info[request_id].worker_graph_ids \ - if graph_walk in self.all_worker_graph_ids_to_graph_walks[graph_id] - ] + # produced_edge_idx is a monotonic local counter for streaming + # output. The conductor round-trips it but its copy lags (refreshed + # only at WorkerGraphsDone), so a stale value here would rewind the + # counter and re-emit an already-used stream index — which the + # consumer drops as a duplicate, taking its walk-transition tag with + # it. Never let the conductor rewind it (mirrors the set_index() + # max-guard on the consumer side). + old_fwd_info = part_info.current_fwd_info + if old_fwd_info is not None: + for name, idx in old_fwd_info.produced_edge_idx.items(): + current_fwd_info.produced_edge_idx[name] = max( + current_fwd_info.produced_edge_idx.get(name, 0), idx + ) + # The conductor's per_label_seq_info lags (refreshed only at + # WorkerGraphsDone), so the wholesale replace below could rewind + # the KV write position mid-decode and overwrite live KV. Keep + # the locally-advanced (longer) per-(cache, label) seq info. + current_fwd_info.per_label_seq_info.merge_keep_longest( + old_fwd_info.per_label_seq_info + ) + if allow_graph_walk_transition: + self.update_graph_walk(request_id, partition_name, current_fwd_info.graph_walk) + else: + # Producer-triggered partition: the stream owns the walk. Absorb + # everything else from the conductor but keep the locally + # (stream-)applied walk so the wholesale replace below doesn't + # clobber it. (Inputs were already tagged with the conductor's + # intended walk before this call.) + current_fwd_info.graph_walk = self.get_graph_walk(request_id, partition_name) part_info.current_fwd_info = current_fwd_info if per_label_seq_info is not None: fwd_info = self.get_fwd_info(request_id, partition_name) fwd_info.per_label_seq_info.update(per_label_seq_info) + def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: str): + part_info = self.per_request_info[request_id].per_partition_info[partition_name] + if self.get_graph_walk(request_id, partition_name) != graph_walk: + part_info.graph_walk_worker_graph_ids = [ + graph_id for graph_id in self.per_request_info[request_id].worker_graph_ids \ + if graph_walk in self.all_worker_graph_ids_to_graph_walks[graph_id] + ] + part_info.current_fwd_info.graph_walk = graph_walk + if graph_walk in part_info.pending_edges: + edges = part_info.pending_edges.pop(graph_walk) + self.process_new_inputs( + request_id, inputs=edges + ) + def get_graph_walk(self, request_id: str, partition_name: str): return self.get_fwd_info(request_id, partition_name).graph_walk @@ -310,6 +360,18 @@ def get_partition_for_node(self, node_name: str) -> str | None: """Look up which partition a node belongs to.""" return self.node_to_partition.get(node_name) + def partition_clean(self, request_id: str, partition_name: str) -> bool: + if request_id not in self.per_request_info: + return True + wgids = self.per_request_info[request_id].per_partition_info[partition_name].graph_walk_worker_graph_ids + for wgid in wgids: + queue = self.queues[wgid].per_request_queues.get(request_id) + if queue is None: + continue + if not queue.wg_state_registry.clean: + return False + return True + def process_new_inputs( self, request_id: str, @@ -323,6 +385,20 @@ def process_new_inputs( ``next_node`` lives on a different worker). Caller uses these for cross-worker routing. """ + wrong_graph_walk = [ + edge for edge in inputs if edge._target_graph_walk is not None \ + and edge._target_graph_walk != self.get_graph_walk( + request_id, partition_name=self.get_partition_for_node(edge.next_node) + ) + ] + for edge in wrong_graph_walk: + self.buffer_pending_edge(request_id, edge._target_graph_walk, edge) + inputs = [ + edge for edge in inputs if edge._target_graph_walk is None \ + or edge._target_graph_walk == self.get_graph_walk( + request_id, partition_name=self.get_partition_for_node(edge.next_node) + ) + ] for part_info in self.per_request_info[request_id].per_partition_info.values(): worker_graph_ids = part_info.graph_walk_worker_graph_ids for worker_graph_id in worker_graph_ids: @@ -496,7 +572,7 @@ def process_node_outputs( fanout = sharding_config.fanout_graph_edges( edge, source_node=node_name, source_graph_walk=graph_walk, - dest_graph_walk=None + dest_graph_walk=edge._graph_walk_transition ) this_worker_edge = fanout.pop(self.worker_id, None) if this_worker_edge: diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 2e85ec53..61318679 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -15,7 +15,7 @@ from mstar.communication.communicator import CommProtocol, ZMQCommunicator from mstar.communication.event import EventWakeup from mstar.communication.tensors import NameToTensorList, create_tensor_communication_manager -from mstar.conductor.request_info import CurrentForwardPassInfo +from mstar.conductor.request_info import CurrentForwardPassInfo, TransitionSource from mstar.distributed.base import ShardingConfig from mstar.distributed.communication import WorkerTPGroups from mstar.engine.base import EngineType, NodeBatch, NodeOutput @@ -25,11 +25,13 @@ from mstar.graph.loop_indices import NestedLoopIndices from mstar.model.base import Model, WorkerGraph from mstar.streaming.stream_buffer import StreamBuffer +from mstar.streaming.topology import Connection, ConsumerTransitionCtx, WalkTransition from mstar.utils.ipc_format import ( ConductorMessage, ConductorMessageType, InputSignals, NewRequest, + ProducerDone, RemoveRequest, ScheduleTPNode, SetupDone, @@ -279,16 +281,14 @@ def __init__( tuple[str, torch.dtype, tuple[int, ...]], list[torch.Tensor] ] = defaultdict(list) - # Streaming buffers: request_id -> edge_name -> list of tensors - # (Legacy path — kept for models without PartitionTopology) - self.streaming_buffers: dict[str, dict[str, list[torch.Tensor]]] = {} - # New streaming path: PartitionTopology + StreamBuffer on consumer worker self.partition_topology = model.get_partition_topology() if model else None + self.partitions = model.get_partitions() if model else [] # Determine which partition this worker serves (by checking which node names # appear in my_worker_graphs vs the topology connections) - self._my_consumer_connections = [] + self._my_consumer_connections: list[Connection] = [] + self.edge_to_connection: dict[str, Connection] = {} if self.partition_topology: my_node_names = set() for wg in my_worker_graphs: @@ -298,6 +298,7 @@ def __init__( # by checking if the streaming edge's next_node is in my nodes if any(n in my_node_names for n in self._get_node_names_for_partition(conn.to_partition, model)): self._my_consumer_connections.append(conn) + self.edge_to_connection[conn.edge_name] = conn # Set of edge names that arrive via streaming (used to distinguish # streaming inputs from conductor-triggered non-streaming inputs @@ -306,6 +307,10 @@ def __init__( conn.edge_name for conn in self._my_consumer_connections } + self._producer_triggered_partitions: set[str] = set([ + part.name for part in self.partitions if part.transition_source == TransitionSource.PRODUCER_TRIGGERED + ]) + # Build consumer node cache: edge_name -> next_node name self._consumer_node_cache: dict[str, str] = {} if self._my_consumer_connections and model: @@ -446,7 +451,6 @@ def _remove_request(self, body: RemoveRequest) -> None: self.engine_manager.remove_request(body.request_id) self.worker_graphs_manager.remove_request(body.request_id) self.tensor_manager.cleanup_request(body.request_id) - self.streaming_buffers.pop(body.request_id, None) for node_name in self.engine_manager.lru_tracked_nodes(): self._last_active.pop((body.request_id, node_name), None) @@ -458,6 +462,18 @@ def _handle_tensor_received(self, body: TensorReceived) -> None: body.request_id, uuid, n=ref_cnt ) + def _handle_producer_done(self, body: ProducerDone) -> None: + # Handle producer_done signal: mark all StreamBuffers for this request as done + req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) + if req_info: + for sbuf in req_info.stream_buffers.values(): + if sbuf.from_partition in body.producer_done: + # If we have multiple consumer partitions colocated, we need to signal + # the right one + sbuf.signal_done( + body.last_produced_edge_idx.get(sbuf.edge_name) + ) + def _process_new_inputs(self, body: InputSignals) -> None: logger.debug( "Received new signals %s at worker %s for request %s", @@ -467,14 +483,6 @@ def _process_new_inputs(self, body: InputSignals) -> None: if self.enable_nvtx: range_push("process_new_inputs.routing_update") - # Handle producer_done signal: mark all StreamBuffers for this request as done - if body.producer_done: - if req_info: - for sbuf in req_info.stream_buffers.values(): - if sbuf.from_partition in body.producer_done: - # If we have multiple consumer partitions colocated, we need to signal - # the right one - sbuf.signal_done() # Separate streaming edges — they'll be handled when tensors are ready # (streaming edges with tensor_info go through RDMA, handled in _check_ready_tensors) @@ -485,11 +493,23 @@ def _process_new_inputs(self, body: InputSignals) -> None: # a conductor-triggered forward pass, not just streaming data from another # partition). Streaming-only InputSignals must not overwrite the current # partition's fwd_info. - if non_streaming: + if non_streaming or not body.inputs: + # Tag each input with the walk the conductor intends it for. Must be + # done BEFORE update_request_info: for producer-triggered partitions + # that call rewrites body.request_info.graph_walk to the stream walk. + # If the tag differs from the partition's current (stream-applied) + # walk, the ingest-time gate buffers it until the transition. + for edge in non_streaming: + edge._target_graph_walk = body.request_info.graph_walk self.worker_graphs_manager.update_request_info( body.request_id, current_fwd_info=body.request_info, - partition_name=body.partition_name + partition_name=body.partition_name, + allow_graph_walk_transition=body.partition_name not in self._producer_triggered_partitions ) + for edge, idx in body.request_info.next_stream_index.items(): + req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) + if edge in req_info.stream_buffers: + req_info.stream_buffers[edge].set_index(idx) if self.enable_nvtx: range_pop(synchronize=False) @@ -508,8 +528,8 @@ def _process_new_inputs(self, body: InputSignals) -> None: self.wakeup_event.register_futures(futures) for edge in streaming_with_tensors: stream_buf = req_info.stream_buffers[edge.name] - for info in edge.tensor_info: - stream_buf.pre_read_register(info.uuid) + for _ in edge.tensor_info: + stream_buf.pre_read_register() if self.enable_nvtx: range_pop(synchronize=False) range_push("process_new_inputs.process_inputs") @@ -560,7 +580,8 @@ def _process_message_list(self, messages: list[WorkerMessage]): msg_types_needing_active_request = [ WorkerMessageType.REMOVE_REQUEST, WorkerMessageType.INPUT_SIGNALS, - WorkerMessageType.STOP_LOOPS + WorkerMessageType.STOP_LOOPS, + WorkerMessageType.PRODUCER_DONE ] for message in messages: if ( @@ -578,6 +599,8 @@ def _process_message_list(self, messages: list[WorkerMessage]): self._remove_request(message.body) elif message.message_type == WorkerMessageType.INPUT_SIGNALS: self._process_new_inputs(message.body) + elif message.message_type == WorkerMessageType.PRODUCER_DONE: + self._handle_producer_done(message.body) elif message.message_type == WorkerMessageType.TENSOR_RECEIVED: self._handle_tensor_received(message.body) elif message.message_type == WorkerMessageType.UNPERSIST_TENSORS: @@ -604,16 +627,38 @@ def _route_streaming_tensor(self, request_id: str, edge: GraphEdge) -> None: request_id=request_id, uuid=info.uuid, ) - stream_buf.put(info.uuid, tensor.clone()) + stream_buf.put( + tensor.clone(), + index=edge._index, + graph_walk=edge._graph_walk_transition, + ) self.tensor_manager.dereference(request_id, info.uuid) def _pop_streaming_edge( - self, sbuf: StreamBuffer, edge_name: str, request_id: str + self, sbuf: StreamBuffer, edge_name: str, request_id: str, + # for speculation, it is messy to allow graph walk transitions here + allow_graph_walk_transition: bool=True ) -> GraphEdge | None: consumer_node = self._consumer_node_cache.get(edge_name, "") - synthetic_edge = sbuf.pop_waiting_edge() - if synthetic_edge is None and sbuf.has_chunk_ready(): - chunk = sbuf.pop_chunk() + consumer_partition = self.worker_graphs_manager.get_partition_for_node(consumer_node) + graph_walk = self.worker_graphs_manager.get_graph_walk( + request_id, consumer_partition + ) + + waiting_edge = sbuf.pop_waiting_edge() + if waiting_edge is not None and waiting_edge.walk_transition is not None \ + and waiting_edge.walk_transition != graph_walk: + if not allow_graph_walk_transition: + sbuf.store_uningested_edge(waiting_edge.edge, waiting_edge.walk_transition) + return + self.worker_graphs_manager.update_graph_walk( + request_id, consumer_partition, + waiting_edge.walk_transition + ) + synthetic_edge = waiting_edge.edge if waiting_edge is not None else None + + if synthetic_edge is None and sbuf.has_chunk_ready(graph_walk): + chunk = sbuf.pop_chunk(graph_walk) chunk_tensor = chunk.data.get("data") if chunk_tensor is None: # Empty chunk — producer done, no more data. @@ -622,6 +667,7 @@ def _pop_streaming_edge( next_node=consumer_node, name=edge_name, tensor_info=[], + _next_stream_index=chunk.next_stream_index, _final_stream_chunk=chunk.is_final, ) else: @@ -640,8 +686,17 @@ def _pop_streaming_edge( next_node=consumer_node, name=edge_name, tensor_info=tensor_infos.get(edge_name, []), + _next_stream_index=chunk.next_stream_index, _final_stream_chunk=chunk.is_final, ) + if chunk.graph_walk_transition is not None and chunk.graph_walk_transition != graph_walk: + if not allow_graph_walk_transition: + sbuf.store_uningested_edge(synthetic_edge, chunk.graph_walk_transition) + return + self.worker_graphs_manager.update_graph_walk( + request_id, consumer_partition, + chunk.graph_walk_transition + ) return synthetic_edge def _poll_stream_buffers_for_speculation( @@ -655,7 +710,10 @@ def _poll_stream_buffers_for_speculation( consumer_node = self._consumer_node_cache.get(edge_name, "") if consumer_node != node_name: continue - edge = self._pop_streaming_edge(sbuf, edge_name, request_id) + edge = self._pop_streaming_edge( + sbuf, edge_name, request_id, + allow_graph_walk_transition=False + ) if edge is not None: result.append(edge) return result @@ -674,7 +732,24 @@ def _poll_stream_buffers(self) -> None: """Check all active StreamBuffers; when a chunk is ready, feed it as a normal input.""" for request_id, req_info in list(self.worker_graphs_manager.per_request_info.items()): for edge_name, sbuf in req_info.stream_buffers.items(): - synthetic_edge = self._pop_streaming_edge(sbuf, edge_name, request_id) + consumer_node = self._consumer_node_cache.get(edge_name, "") + partition_name = self.worker_graphs_manager.get_partition_for_node(consumer_node) + + if not self.worker_graphs_manager.has_partition(request_id, partition_name): + continue + + # Only transition the walk when the worker graph is in a fully + # reset (clean) state — no partial progress that the scheduler + # would otherwise dispatch under the new walk. A request with no + # queue entry yet has no progress, so allow the transition. + allow_graph_walk_transition = self.worker_graphs_manager.partition_clean( + request_id, partition_name + ) + + synthetic_edge = self._pop_streaming_edge( + sbuf, edge_name, request_id, + allow_graph_walk_transition=allow_graph_walk_transition + ) if synthetic_edge is not None: # Streaming edges go through the same path as regular ones — @@ -900,6 +975,7 @@ def _register_outputs( def _send_outputs( self, request_id: str, outputs: NodeOutputRouting, nested_loop_indices: NestedLoopIndices, + consumer_graph_walk_transitions, graph_walk: str | None = None, partition_name: str | None = None, prematerialized_new_tokens: dict[str, list[int]] | None = None, @@ -925,7 +1001,7 @@ def _send_outputs( message_type=WorkerMessageType.INPUT_SIGNALS, body=InputSignals( request_id=request_id, - inputs=edges, + inputs=[edge.clone() for edge in edges], request_info=self.worker_graphs_manager.get_fwd_info(request_id, partition_name), partition_name=partition_name ), @@ -983,13 +1059,13 @@ def _send_outputs( ) self.communicator.send("api_server", message) - # Handle streaming edges - # Local streaming: route to StreamBuffer req_info = self.worker_graphs_manager.per_request_info[request_id] + fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) + for edge in outputs.streaming_local: stream_buf = req_info.stream_buffers[edge.name] - for info in edge.tensor_info: - stream_buf.pre_read_register(info.uuid) + for _ in edge.tensor_info: + stream_buf.pre_read_register() self._route_streaming_tensor(request_id, edge) # Remote streaming: send to destination workers @@ -999,13 +1075,12 @@ def _send_outputs( body=InputSignals( request_id=request_id, inputs=edges, - request_info=self.worker_graphs_manager.get_fwd_info(request_id, partition_name), + request_info=fwd_info, partition_name=partition_name ), ) self.communicator.send(worker_id, message) if outputs.completed_worker_graph_ids: - fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) if partition_name is None: partition_name = getattr(fwd_info, 'partition_name', 'default') req_info = self.worker_graphs_manager.per_request_info.get(request_id) @@ -1029,9 +1104,23 @@ def _send_outputs( persist_signals=self.worker_graphs_manager.flush_persist_signals(request_id), new_tokens=self.worker_graphs_manager.flush_new_tokens(request_id), output_signal_names=self.worker_graphs_manager.flush_output_signals(request_id), - per_label_seq_info=self.worker_graphs_manager.get_seq_info(request_id, partition_name), + per_label_seq_info=fwd_info.per_label_seq_info, + new_produced_edge_idx=fwd_info.produced_edge_idx, + new_next_stream_index=fwd_info.next_stream_index, + # Key by consumer partition (not edge name) so it merges + # into the producer pstate's partition-keyed + # tracked_consumer_graph_walks on the conductor. + consumer_graph_walk_transitions={ + self.edge_to_connection[edge_name].to_partition: wt.graph_walk + for edge_name, wt in consumer_graph_walk_transitions.items() + }, partition_name=partition_name, partition_done=p_done, + # the walk this just-completed forward pass ran under (the + # batch's walk, NOT fwd_info.graph_walk which may have + # already advanced via a stream-buffer transition before + # this report is built — that race drops talker_input_embeds) + partition_graph_walk=graph_walk, stream_tokens_consumed=stream_consumed, output_loop_indices=self.worker_graphs_manager.get_output_loop_indices(request_id), ), @@ -1535,9 +1624,9 @@ def _thread_outputs_to_speculative( speculation.node_batch.per_request_info.pop(r, None) speculation.scheduled_batch.request_to_worker_graph.pop(r, None) speculation.scheduled_batch.node_objects.pop(r, None) - for edge in speculation.consumed_streaming_edges.get(rid, []): - self._return_speculative_streaming_edge(rid, edge) - speculation.consumed_streaming_edges.pop(rid, None) + for edge in speculation.consumed_streaming_edges.get(r, []): + self._return_speculative_streaming_edge(r, edge) + speculation.consumed_streaming_edges.pop(r, None) speculation.continuing_rids = threaded_continuing speculation.dropped = dropped @@ -1554,6 +1643,17 @@ def _postprocess_batch( self, batch_N: PendingBatch, output: NodeOutput, ): + for rid, node in batch_N.batch.node_objects.items(): + for edge_name, edge in node.ready_signals.ready_inputs.items(): + # Only synthetic streaming-input edges carry a next stream index; + # skip non-streaming inputs (None) so they don't write spurious + # zero entries into the reported next_stream_index. + if edge._next_stream_index is None: + continue + self.worker_graphs_manager.get_fwd_info( + rid, batch_N.partition + ).next_stream_index[edge_name] = edge._next_stream_index + if self.enable_nvtx: range_push("worker.postprocess.cleanup_inputs", synchronize=False) self._cleanup_consumed_inputs(batch_N.batch) @@ -1677,6 +1777,7 @@ def _postprocess_batch( # Mark nodes complete and route routing_per_request: dict[str, NodeOutputRouting] = {} per_request_uuids: dict[str, set[str]] = {} + consumer_walk_transitions_per_rid: dict[str, dict[str, WalkTransition]] = {} for rid, wg_id in batch_N.batch.request_to_worker_graph.items(): # Store output tensors before marking the node as complete so that # loop outputs can be buffered properly. @@ -1702,6 +1803,52 @@ def _postprocess_batch( ) real_outputs = [edge.clone() for edge in completion_output.output_edges] + # Handle streaming edges + # Local streaming: route to StreamBuffer + req_info = self.worker_graphs_manager.per_request_info[rid] + fwd_info = self.worker_graphs_manager.get_fwd_info(rid, batch_N.partition) + produced_streaming_edges: set[str] = set() + consumer_graph_walk_transitions: dict[str, WalkTransition] = {} + + # One stream index per logical edge per pass. A single edge may fan out + # to several physical copies (local + one per consumer-shard worker); + # they all carry the same _index so the consumer's dedup/PD-reseed sees + # one item, not N. Assign the index once, reuse it for every copy. + edge_indices: dict[str, int] = {} + for edge in real_outputs: + if not edge.is_streaming: + continue + if edge.name not in produced_streaming_edges: + produced_streaming_edges.add(edge.name) + edge_indices[edge.name] = fwd_info.produced_edge_idx.get(edge.name, 0) + fwd_info.produced_edge_idx[edge.name] = edge_indices[edge.name] + 1 + conn = self.edge_to_connection.get(edge.name) + if conn and conn.consumer_walk_transition: + consumer_graph_walk_transitions[edge.name] = conn.consumer_walk_transition( + ConsumerTransitionCtx( + producer_walk=batch_N.graph_walk, + consumer_walk=fwd_info.tracked_consumer_graph_walks.get( + conn.to_partition + ), + producer_fwd=fwd_info + ) + ) + + edge._index = edge_indices[edge.name] + walk_transition = consumer_graph_walk_transitions.get(edge.name) + edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else \ + fwd_info.tracked_consumer_graph_walks.get( + self.edge_to_connection[edge.name].to_partition + ) + # Advance our local view of each consumer's walk by the transition just + # emitted: inside a dynamic loop there is no conductor round-trip to + # refresh tracked_consumer_graph_walks. After the edge loop so all edges + # of a pass see the same (pre-update) consumer walk. + for edge_name, wt in consumer_graph_walk_transitions.items(): + to_partition = self.edge_to_connection[edge_name].to_partition + fwd_info.tracked_consumer_graph_walks[to_partition] = wt.graph_walk + consumer_walk_transitions_per_rid[rid] = consumer_graph_walk_transitions + routing_per_request[rid] = self.worker_graphs_manager.process_node_outputs( rid, node_name=batch_N.node_name, outputs=real_outputs, graph_walk=batch_N.graph_walk @@ -1742,6 +1889,7 @@ def _postprocess_batch( self._send_outputs( rid, routing, nested_loop_indices=per_req_nested_idxs[rid], + consumer_graph_walk_transitions=consumer_walk_transitions_per_rid.get(rid, {}), graph_walk=batch_N.graph_walk, partition_name=batch_N.partition, node_speculatively_scheduled=batch_N.batch.node_objects[rid]._speculatively_scheduled