From ea49c3d28b0b2d05807719671a97cb7c8eb0a82c Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 31 May 2026 20:32:37 +0000 Subject: [PATCH 01/16] first cut of producer-driven state transitions --- mstar/conductor/conductor.py | 73 +++++++++- mstar/conductor/request_info.py | 25 ++++ mstar/model/qwen3_omni/qwen3_omni_model.py | 108 ++++++--------- mstar/streaming/chunk_policy.py | 35 ++--- mstar/streaming/stream_buffer.py | 151 ++++++++++++++++----- mstar/streaming/topology.py | 16 +++ mstar/utils/ipc_format.py | 3 + mstar/worker/node_manager_utils.py | 9 ++ mstar/worker/worker.py | 99 +++++++++++--- 9 files changed, 382 insertions(+), 137 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 6a6471a4..17c62ca9 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -20,6 +20,7 @@ PartitionDefinition, PartitionState, StreamingConnectionState, + TransitionSource, ) from mstar.distributed.base import ShardingConfig from mstar.distributed.communication import GlobalTPConfig, WorkerTPGroups @@ -189,6 +190,7 @@ def __init__( ): self.requests: dict[str, RequestData] = {} self.model = model + self._validate_transition_sources() self.hostname = hostname self.socket_path_prefix = socket_path_prefix self.log_level = log_level @@ -564,6 +566,52 @@ def _try_admit_waiting(self): ) self._do_ingest_request(body) + def _validate_transition_sources(self): + """Validate that graph-walk transition authority is unambiguous. + + Each partition's walk is moved by exactly one source. A + ``PRODUCER_TRIGGERED`` partition must be driven by at least one streaming + connection carrying a ``consumer_walk_transition``; a ``STATE_MACHINE`` + partition must not be driven by any. Fails fast at conductor startup. + + (A producer-triggered partition may have several driving connections — + e.g. qwen3omni's thinker_states + thinker_mask both carry the marker so + each lands in the right walk.) + """ + pdefs = {p.name: p for p in self.model.get_partitions()} + topology = self.model.get_partition_topology() + + # Collect transition-carrying connections per consumer partition. + driving_conns: dict[str, list[str]] = {} + for conn in topology.connections: + if conn.consumer_walk_transition is None: + continue + consumer = pdefs.get(conn.to_partition) + if consumer is None: + raise ValueError( + f"Connection edge {conn.edge_name!r} targets unknown " + f"partition {conn.to_partition!r}." + ) + if consumer.transition_source != TransitionSource.PRODUCER_TRIGGERED: + raise ValueError( + f"Connection edge {conn.edge_name!r} defines a " + f"consumer_walk_transition for partition " + f"{conn.to_partition!r}, but that partition's " + f"transition_source is {consumer.transition_source.value!r}, " + f"not 'producer_triggered'. The two transition mechanisms " + f"must not be mixed." + ) + driving_conns.setdefault(conn.to_partition, []).append(conn.edge_name) + + for name, pdef in pdefs.items(): + if (pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED + and not driving_conns.get(name)): + raise ValueError( + f"Producer-triggered partition {name!r} has no incoming " + f"connection with a consumer_walk_transition to drive its " + f"graph-walk transitions." + ) + def _ingest_request( self, body: NewRequestConductor ): @@ -672,6 +720,14 @@ def _do_ingest_request( body.request_id, p.name, fwd_args.full_metadata.graph_walk, ) partition_fwd_args[p.name] = fwd_args + + # set up tracked_consumer_graph_walks + for conn in topology.connections: + producer = partition_states[conn.from_partition] + consumer = partition_states[conn.to_partition] + if conn.consumer_walk_transition is not None: + producer.tracked_consumer_graph_walks[conn.to_partition] \ + = consumer.metadata.graph_walk # Send NewRequest to each worker with the appropriate partition's inputs for worker_id, worker_graph_ids in worker_to_worker_graph_ids.items(): @@ -706,7 +762,10 @@ def _do_ingest_request( requires_cfg=fwd_args.full_metadata.requires_cfg, partition_name=partition_name, max_tokens=request_data.max_output_tokens, - sampling_config=request_data.sampling_config + sampling_config=request_data.sampling_config, + produced_edge_idx=pstate.produced_edge_idx, + consumed_edge_idx=pstate.consumed_edge_idx, + tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), ) self.communicator.send( @@ -792,6 +851,10 @@ def _process_worker_graphs_done( partition_name, body.request_id, ) return [] + + pstate.produced_edge_idx.update(body.new_produced_edge_idx) + pstate.consumed_edge_idx.update(body.new_consumed_edge_idx) + pstate.tracked_consumer_graph_walks.update(body.consumer_graph_walk_transitions) # Persist signals: every rank contributes its shard (different uuid + # source_tp_rank); accumulate across ranks, do not dedup. @@ -947,6 +1010,9 @@ def _send_partition_inputs( partition_name=partition_name, max_tokens=request_data.max_output_tokens, sampling_config=request_data.sampling_config, + produced_edge_idx=pstate.produced_edge_idx, + consumed_edge_idx=pstate.consumed_edge_idx, + tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), partition_name=partition_name ), @@ -983,7 +1049,10 @@ def _send_producer_done( requires_cfg=False, partition_name=consumer_partition_name, max_tokens=request_data.max_output_tokens, - sampling_config=request_data.sampling_config + sampling_config=request_data.sampling_config, + produced_edge_idx=pstate.produced_edge_idx, + consumed_edge_idx=pstate.consumed_edge_idx, + tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), partition_name=consumer_partition_name, producer_done=set([producer_partition]), diff --git a/mstar/conductor/request_info.py b/mstar/conductor/request_info.py index ae7cc696..3a0023bb 100644 --- a/mstar/conductor/request_info.py +++ b/mstar/conductor/request_info.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Any from mstar.graph.loop_indices import NestedLoopIndices @@ -75,6 +76,9 @@ class CurrentForwardPassInfo: step_metadata: dict = field(default_factory=dict) per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") + produced_edge_idx: dict[str, int] = field(default_factory=dict) + consumed_edge_idx: dict[str, int] = field(default_factory=dict) + tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) # Per-loop stop indices; stop decisions come from each submodule's check_stop. loop_stop_times: dict[str, NestedLoopIndices] = field(default_factory=dict) @@ -88,6 +92,20 @@ def clear_loop_stop_info(self): # Partition types for async graph partitions # --------------------------------------------------------------------------- +class TransitionSource(str, Enum): + """Who owns a partition's graph-walk transitions. + + A partition's walk is moved by exactly one authority: + - STATE_MACHINE: the conductor advances the walk via the model's + ``get_partition_forward_pass_args``. + - PRODUCER_TRIGGERED: a producer partition drives the walk via streaming + transition markers (``Connection.consumer_walk_transition``); the + conductor does not advance it. + """ + STATE_MACHINE = "state_machine" + PRODUCER_TRIGGERED = "producer_triggered" + + @dataclass class PartitionDefinition: """Defines a partition within a model's computation graph. @@ -99,6 +117,7 @@ class PartitionDefinition: graph_walks: set[str] # walks this partition uses initial_walk: str | None = None # first walk, or None = triggered later producer_partitions: list[str] = field(default_factory=list) # partitions feeding tokens to this one + transition_source: TransitionSource = TransitionSource.STATE_MACHINE @dataclass @@ -120,6 +139,12 @@ class PartitionState: fwd_pass_number: int = 0 random_seed: int = 0 is_done: bool = False + # streaming edge name -> number of times this edge has been emitted + produced_edge_idx: dict[str, int] = field(default_factory=dict) + consumed_edge_idx: dict[str, int] = field(default_factory=dict) + # only for producer-driven graph walk transitions + tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) + new_tokens: dict[str, list[int]] = field(default_factory=dict) completed_worker_graph_ids: set[str] = field(default_factory=set) current_worker_graph_ids: set[str] = field(default_factory=set) # wg_id -> count of distinct TP ranks that have reported completion diff --git a/mstar/model/qwen3_omni/qwen3_omni_model.py b/mstar/model/qwen3_omni/qwen3_omni_model.py index c90444bc..e3a45e98 100644 --- a/mstar/model/qwen3_omni/qwen3_omni_model.py +++ b/mstar/model/qwen3_omni/qwen3_omni_model.py @@ -39,17 +39,18 @@ CurrentForwardConductorMetadata, PartitionDefinition, StreamingConnectionState, + TransitionSource, ) from mstar.engine.base import EngineType from mstar.engine.kv_store import KVCacheConfig from mstar.graph.base import GraphEdge, GraphNode, Loop, Sequential, TensorPointerInfo from mstar.graph.special_destinations import EMIT_TO_CLIENT, EMPTY_DESTINATION -from mstar.model.base import MAX_OUTPUT_TOKENS, ForwardPassArgs, Model, TensorAndMetadata +from mstar.model.base import ForwardPassArgs, MAX_OUTPUT_TOKENS, Model, TensorAndMetadata from mstar.model.qwen3_omni.components.talker import Qwen3OmniCodePredictor from mstar.model.submodule_base import NodeSubmodule from mstar.model.utils import Operation, WeightConverter from mstar.streaming.chunk_policy import FixedChunkPolicy, LeftContextChunkPolicy -from mstar.streaming.topology import Connection, PartitionTopology, StreamingGraphEdge +from mstar.streaming.topology import Connection, ConsumerTransitionCtx, PartitionTopology, StreamingGraphEdge, WalkTransition from mstar.utils.sampling import SamplingConfig logger = logging.getLogger(__name__) @@ -334,13 +335,10 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: outputs=[], ) - # -- Talker prefill: receives thinker_states + talker_trigger -- - # Dual-input gating: both thinker_states from streaming and - # talker_trigger from conductor cross-partition trigger must be - # present for a prefill step. + # -- Talker prefill: receives thinker_states talker_prefill = GraphNode( name="Talker", - input_names=["thinker_states", "thinker_mask", "talker_trigger"], + input_names=["thinker_states", "thinker_mask"], outputs=[], ) @@ -348,7 +346,7 @@ def get_graph_walk_graphs(self) -> dict[str, GraphNode | Sequential]: sections=[ GraphNode( name="Talker", - input_names=["thinker_states", "thinker_mask", "talker_trigger"], + input_names=["thinker_states", "thinker_mask"], outputs=[ GraphEdge( next_node=EMPTY_DESTINATION, @@ -435,6 +433,7 @@ def get_partitions(self) -> list[PartitionDefinition]: graph_walks={"talker_prefill", "talker_last_prefill", "talker_decode"}, initial_walk="talker_prefill", producer_partitions=["Thinker"], + transition_source=TransitionSource.PRODUCER_TRIGGERED, ), PartitionDefinition( name="Code2Wav", @@ -445,6 +444,13 @@ def get_partitions(self) -> list[PartitionDefinition]: ] def get_partition_topology(self) -> PartitionTopology: + def talker_state_transition(ctx: ConsumerTransitionCtx) -> WalkTransition: + if ctx.producer_walk != "thinker_decode": + return WalkTransition("talker_prefill") + if ctx.consumer_walk == "talker_prefill": + return WalkTransition("talker_last_prefill") + return WalkTransition("talker_decode") + return PartitionTopology( partitions=["Thinker", "Talker", "Code2Wav"], connections=[ @@ -452,13 +458,19 @@ def get_partition_topology(self) -> PartitionTopology: from_partition="Thinker", to_partition="Talker", edge_name="thinker_states", - chunk_policy_factory=lambda: FixedChunkPolicy(chunk_size=1, continue_after_done=True), + chunk_policy_factory=lambda: FixedChunkPolicy( + chunk_size=1, continue_after_done={"talker_decode"} + ), + consumer_walk_transition=talker_state_transition ), Connection( from_partition="Thinker", to_partition="Talker", edge_name="thinker_mask", - chunk_policy_factory=lambda: FixedChunkPolicy(chunk_size=1, continue_after_done=True), + chunk_policy_factory=lambda: FixedChunkPolicy( + chunk_size=1, continue_after_done={"talker_decode"} + ), + consumer_walk_transition=talker_state_transition ), Connection( from_partition="Talker", @@ -549,7 +561,7 @@ def get_initial_forward_pass_args( ) return ForwardPassArgs( full_metadata=full_metadata, - inputs=[GraphEdge(next_node="Talker", name="talker_trigger")] if audio_output else [], + inputs=[], unpersist_tensors=[], request_done="audio" not in output_modalities, step_metadata={ @@ -812,11 +824,9 @@ def _get_thinker_forward( # alongside the primary feature tensor). schedule = metadata.kwargs["prefill_schedule"] step = metadata.kwargs["prefill_step"] - is_last_prefill = (step == len(schedule) - 1) inputs = self._get_thinker_prefill_inputs(metadata, persist_signals) else: # Decode: previous token feeds back as text_inputs - is_last_prefill = False edge = GraphEdge(next_node="Thinker", name="text_inputs") edge.tensor_info = persist_signals.get("new_token", []) inputs = [edge] @@ -827,7 +837,6 @@ def _get_thinker_forward( step_metadata = { "is_prefill": metadata.is_prefill, - "is_last_prefill": is_last_prefill, # Persist the audio_output flag across every Thinker step so # the submodule can gate thinker_states emission. Default True # for backwards compatibility with callers that never set it. @@ -849,41 +858,24 @@ def _get_talker_forward( persist_signals: dict[str, list[TensorPointerInfo]], incoming_connections: list[StreamingConnectionState] | None = None, ) -> ForwardPassArgs: - """Talker partition state machine. - - 1. While prefill: return empty inputs (wait for cross-partition trigger) - - When trigger arrives with is_last_prefill=False: - extend KV cache only, no outputs - - When trigger arrives with is_last_prefill=True: - sample first codec token, produce all_codes - 2. After last prefill produces all_codes: transition to talker_decode - - Set graph_walk="talker_decode", is_prefill=False - - Return all_codes as input edge (conductor-driven) - 3. Each decode step: check all_codes for codec_eos - - If codec_eos: request_done=True for Talker - - Else: return all_codes as input again (loop) - """ - if metadata.graph_walk == "talker_prefill": - metadata.kwargs["prefill_chunks_processed"] += 1 - is_last_prefill = metadata.kwargs["num_thinker_prefill_steps"] == \ - metadata.kwargs["prefill_chunks_processed"] - metadata.graph_walk = "talker_last_prefill" if is_last_prefill else "talker_prefill" + if metadata.graph_walk == "talker_decode": + # If the decode dynamic loop reaches the conductor, we can end the request. return ForwardPassArgs( full_metadata=metadata, - inputs=[GraphEdge(next_node="Talker", name="talker_trigger")], + inputs=[], unpersist_tensors=[], - step_metadata={ - "is_prefill": True, - # voice is used for the last prefill - "voice": metadata.kwargs.get("voice", "Ethan"), - "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") - }, + request_done=True, ) - elif metadata.graph_walk == "talker_last_prefill": - metadata.is_prefill = False - metadata.graph_walk = "talker_decode" - metadata.kwargs["talker_prefill_done"] = True + step_metadata = { + # voice is used for the last prefill + "voice": metadata.kwargs.get("voice", "Ethan"), + "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") + } + inputs = [] + unpersist_tensors = [] + + if metadata.graph_walk == "talker_last_prefill": # Feed talker_input_embeds back as input for first decode step edge = GraphEdge(next_node="Talker", name="talker_input_embeds") edge.tensor_info = persist_signals["talker_input_embeds"] @@ -891,29 +883,13 @@ def _get_talker_forward( unpersist_tensors = sum( [inp.tensor_info for inp in inputs], start=[] ) + metadata.graph_walk = "talker_decode" - return ForwardPassArgs( - full_metadata=metadata, - inputs=inputs, - unpersist_tensors=unpersist_tensors, - step_metadata={ - "is_prefill": False, - "talker_max_tokens": metadata.kwargs.get("talker_max_tokens") - }, - ) - - elif metadata.graph_walk == "talker_decode": - # If the decode dynamic loop reaches the conductor, we can end the request. - return ForwardPassArgs( - full_metadata=metadata, - inputs=[], - unpersist_tensors=[], - request_done=True, - ) - - raise ValueError( - f"Talker in unexpected state: walk={metadata.graph_walk!r}, " - f"is_prefill={metadata.is_prefill}" + return ForwardPassArgs( + full_metadata=metadata, + inputs=inputs, + unpersist_tensors=unpersist_tensors, + step_metadata=step_metadata, ) # -- Code2Wav state machine -------------------------------------------- diff --git a/mstar/streaming/chunk_policy.py b/mstar/streaming/chunk_policy.py index 6c527fe4..1964b46a 100644 --- a/mstar/streaming/chunk_policy.py +++ b/mstar/streaming/chunk_policy.py @@ -11,13 +11,14 @@ def register_chunk(self, chunk_size: int): self.first_chunk_read = True self.items_consumed += chunk_size + # TODO: add graph walks to the methods in the implementors... @abstractmethod - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None=None) -> bool: """Return True if the buffer has enough items for a chunk.""" ... @abstractmethod - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None=None) -> int: """Return the number of items to consume for the next chunk. Only called when is_ready() returns True. @@ -26,7 +27,7 @@ def next_chunk_size(self, buffer_len: int) -> int: ... @abstractmethod - def window_size(self) -> int: + def window_size(self, graph_walk: str | None=None) -> int: """Return the full window of items to include in the chunk. For non-overlapping policies, equals next_chunk_size. @@ -35,7 +36,7 @@ def window_size(self) -> int: """ ... - def continue_after_producer_done(self) -> bool: + def continue_after_producer_done(self, graph_walk: str | None=None) -> bool: """Whether the buffer should keep producing (empty) chunks after the producer signals done and all buffered items have been consumed. @@ -66,13 +67,13 @@ def __init__(self, window: int, stride: int): self._window = window self._stride = stride - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None = None) -> bool: return buffer_len >= self._window - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: return self._stride - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: return self._window @@ -104,19 +105,19 @@ def __init__(self, chunk: int, left_context: int): self._left_context = left_context self._window = chunk + left_context - def is_ready(self, buffer_len: int) -> bool: + def is_ready(self, buffer_len: int, graph_walk: str | None = None) -> bool: if not self.first_chunk_read: return buffer_len >= self._chunk return buffer_len >= self._window - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: # First pop: advance by (chunk - left_context) so the tail of the # first chunk stays in the buffer as overlap for the next pop. if not self.first_chunk_read: return self._chunk - self._left_context return self._chunk - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: if not self.first_chunk_read: return self._chunk return self._window @@ -134,19 +135,21 @@ class FixedChunkPolicy(ChunkPolicy): the producer finishes and all buffered items are consumed. """ - def __init__(self, chunk_size: int, continue_after_done: bool = False): + # TODO: make continue_after_done a set of graph walks + def __init__(self, chunk_size: int, continue_after_done: set[str] | None=None): super().__init__() self._chunk_size = chunk_size self._continue_after_done = continue_after_done - def is_ready(self, buffer_len) -> bool: + def is_ready(self, buffer_len, graph_walk: str | None = None) -> bool: return buffer_len >= self._chunk_size - def next_chunk_size(self, buffer_len: int) -> int: + def next_chunk_size(self, buffer_len: int, graph_walk: str | None = None) -> int: return self._chunk_size - def window_size(self) -> int: + def window_size(self, graph_walk: str | None = None) -> int: return self._chunk_size - def continue_after_producer_done(self) -> bool: - return self._continue_after_done + def continue_after_producer_done(self, graph_walk: str | None = None) -> bool: + return self._continue_after_done is not None \ + and graph_walk in self._continue_after_done diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 1f8bed9f..c905c139 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -14,6 +14,14 @@ class StreamChunk: chunk_index: int start_offset: int = 0 # global position of the first item in this chunk is_final: bool = False + graph_walk_transition: str | None = None + + +@dataclass +class StreamingTensor: + index: int + tensor: torch.Tensor + graph_walk: str | None = None @dataclass @@ -32,11 +40,13 @@ class StreamBuffer: from_partition: str policy: ChunkPolicy + # graph edges of chunks that have been popped but not ingested _waiting_graph_edges: deque = field(default_factory=deque) - _buffer: list = field(default_factory=list) - _tensor_ids_in_order: deque = field(default_factory=deque) - _id_to_tensor: dict = field(default_factory=dict) + # edge index -> tensor and metadata + _tensors: dict[int, StreamingTensor] = field(default_factory=dict) + _buffer: list[StreamingTensor] = field(default_factory=list) + _current_index: int = 0 _consumed: int = 0 _chunks_popped: int = 0 producer_done: bool = False @@ -44,36 +54,59 @@ class StreamBuffer: _num_tensors_registered = 0 _num_buffer_writes = 0 - def pre_read_register(self, tensor_id: str): + def pre_read_register(self): + """ + Register that we are reading a tensor so we don't prematurely declare + the producer as done. + """ self._num_tensors_registered += 1 - self._tensor_ids_in_order.append(tensor_id) - - def put(self, tensor_id: str, item: torch.Tensor) -> None: - """Called when a tensor arrives via normal RDMA routing.""" - self._id_to_tensor[tensor_id] = item def _update_buffer(self): - while len(self._tensor_ids_in_order) > 0: - tensor_id = self._tensor_ids_in_order[0] - if tensor_id not in self._id_to_tensor: - return - self._tensor_ids_in_order.popleft() - self._buffer.append(self._id_to_tensor[tensor_id]) - self._num_buffer_writes += 1 - del self._id_to_tensor[tensor_id] + while self._current_index in self._tensors: + self._buffer.append(self._tensors.pop(self._current_index)) + self._current_index += 1 + + def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> None: + """Called when a tensor arrives via normal RDMA routing. + + Idempotent by index: if this index has already been buffered or has + already been drained into ``_buffer`` (``index < _current_index``), + the duplicate is dropped (first-arrival-wins). This handles the case + where multiple colocated producer ranks emit the same streaming item. + """ + if index < self._current_index or index in self._tensors: + return + self._tensors[index] = StreamingTensor( + index=index, + tensor=item, + graph_walk=graph_walk, + ) + self._num_buffer_writes += 1 def signal_done(self) -> None: """Producer signals no more items will arrive.""" self.producer_done = True + def set_index(self, index: int): + """Seed the next index to drain (e.g. when a new consumer worker takes + over a partition after a prefill->decode handoff). + + Only ever advances: indices are monotonic, and the conductor's tracked + value lags (it is refreshed only at WorkerGraphsDone). Rewinding a live + buffer would point ``_current_index`` at items already popped out of + ``_tensors``, deadlocking the drain. + """ + self._current_index = max(self._current_index, index) + def _producer_done_and_all_read(self) -> bool: - return self.producer_done and self._num_buffer_writes >= self._num_tensors_registered + return self.producer_done and \ + self._num_buffer_writes >= self._num_tensors_registered def pop_waiting_edge(self) -> GraphEdge | None: if len(self._waiting_graph_edges) > 0: return self._waiting_graph_edges.popleft() - def has_chunk_ready(self) -> bool: + def has_chunk_ready(self, graph_walk: str) -> bool: self._update_buffer() buf_len = len(self._buffer) if self._producer_done_and_all_read() and buf_len > 0: @@ -84,32 +117,79 @@ def has_chunk_ready(self) -> bool: # generating codec tokens after the Thinker hits text EOS). if (self._producer_done_and_all_read() and buf_len == 0 - and self.policy.continue_after_producer_done()): + and self.policy.continue_after_producer_done(graph_walk)): return True - return self.policy.is_ready(buf_len) - - def pop_chunk(self) -> StreamChunk: + return self.policy.is_ready(buf_len, graph_walk) + + def _chunk_boundary( + self, current_walk: str, max_len: int + ) -> tuple[int, str | None]: + """Bound a chunk so a producer-triggered walk transition starts fresh. + + A producer-triggered graph-walk transition must mark a chunk boundary: + the consumer runs one forward pass per walk, so a chunk cannot straddle + two walks. + + Returns ``(boundary, transition)`` where: + - ``transition`` is the walk this chunk runs under, taken from the + *first* buffered item if it carries a transition to a walk other + than ``current_walk`` (else ``None`` — walk unchanged). + - ``boundary`` is the number of leading items that share this chunk's + walk, clamped to ``max_len``. Any later item carrying a transition + to a *different* walk forces the boundary before it, so it becomes + the leading item of the next chunk. + """ + if not self._buffer or max_len <= 0: + return max(max_len, 0), None + + # (1) Leading transition: the walk this chunk runs under. + transition = None + chunk_walk = current_walk + first = self._buffer[0].graph_walk + if first is not None and first != current_walk: + transition = first + chunk_walk = first + + # (2) Cut before any later item that transitions to a different walk. + boundary = min(max_len, len(self._buffer)) + for j in range(1, boundary): + gw = self._buffer[j].graph_walk + if gw is not None and gw != chunk_walk: + boundary = j + break + return boundary, transition + + def pop_chunk(self, graph_walk: str) -> StreamChunk: """Pop the next chunk. Only call when has_chunk_ready() is True. For sliding-window: returns `window_size` items, advances by `stride` items, discards items that have fallen out of the window. start_offset is the global position of the first item in the chunk. + + A producer-triggered walk transition forces a chunk boundary (see + ``_chunk_boundary``): the returned chunk never straddles two walks, and + ``graph_walk_transition`` carries the walk this chunk runs under. """ self._update_buffer() buf_len = len(self._buffer) - window = self.policy.window_size() offset = self._consumed # global position of buffer[0] - if self._producer_done_and_all_read() and not self.policy.is_ready(buf_len): - # Flush remainder — return whatever is left (may be empty) - items = list(self._buffer) - self._buffer.clear() - self._consumed += len(items) - stride = len(items) + if self._producer_done_and_all_read() and not self.policy.is_ready(buf_len, graph_walk): + # Flush remainder — return whatever is left (may be empty), still + # cut at the first walk transition so each walk gets its own pass. + boundary, transition = self._chunk_boundary(graph_walk, len(self._buffer)) + items = self._buffer[:boundary] + self._buffer = self._buffer[boundary:] + self._consumed += boundary + stride = boundary else: - stride = self.policy.next_chunk_size(buf_len) - # Return the first `window` items (overlapping sliding window) - items = self._buffer[:window] + stride = self.policy.next_chunk_size(buf_len, graph_walk) + window = self.policy.window_size(graph_walk) + # Bound the window/stride so a transition starts a fresh chunk. + boundary, transition = self._chunk_boundary(graph_walk, window) + stride = min(stride, boundary) + # Return the first `boundary` items (overlapping sliding window). + items = self._buffer[:boundary] # Advance by stride — discard items that fell out of the window self._buffer = self._buffer[stride:] self._consumed += stride @@ -118,14 +198,15 @@ def pop_chunk(self) -> StreamChunk: is_final = self._producer_done_and_all_read() and len(self._buffer) == 0 # When continue_after_producer_done, never mark as final — the # consumer decides when it's done via its own model logic. - if self.policy.continue_after_producer_done(): + if self.policy.continue_after_producer_done(graph_walk): is_final = False chunk = StreamChunk( - data=self._collate(items), + data=self._collate([it.tensor for it in items]), chunk_index=self._chunks_popped, start_offset=offset, is_final=is_final, + graph_walk_transition=transition, ) self._chunks_popped += 1 return chunk diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index d227679d..d32b3153 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Callable +from mstar.conductor.request_info import CurrentForwardPassInfo from mstar.graph.base import GraphEdge from mstar.streaming.chunk_policy import ChunkPolicy @@ -15,11 +16,25 @@ class StreamingGraphEdge(GraphEdge): consuming node's input. """ target_partition: str = "" + _index: int = 0 + _graph_walk_transition: str | None = None def __post_init__(self): self.is_streaming = True +@dataclass(frozen=True) +class ConsumerTransitionCtx: + producer_walk: str + consumer_walk: str | None # None on the very first trigger + producer_fwd: CurrentForwardPassInfo + + +@dataclass +class WalkTransition: + graph_walk: str | None = None + # TODO: hook up metadata if needed + @dataclass class Connection: """Defines a streaming connection between two partitions.""" @@ -27,6 +42,7 @@ class Connection: to_partition: str edge_name: str chunk_policy_factory: Callable[[], ChunkPolicy] + consumer_walk_transition: Callable[[ConsumerTransitionCtx], WalkTransition] | None = None @dataclass diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index 1eb6ae77..b52064b7 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -119,6 +119,9 @@ class WorkerGraphsDone(MessageBody): persist_signals: dict[str, list[TensorPointerInfo]] = field(default_factory=dict) new_tokens: dict[str, list[int]] = field(default_factory=dict) # name to tokens output_signal_names: int = field(default=0) + new_produced_edge_idx: dict[str, int] = field(default_factory=dict) + new_consumed_edge_idx: dict[str, int] = field(default_factory=dict) + consumer_graph_walk_transitions: dict[str, str] = field(default_factory=dict) per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") partition_done: bool = field(default=False) diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index bb8ef000..9389bd87 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -289,6 +289,15 @@ def update_request_info( if per_label_seq_info is not None: fwd_info = self.get_fwd_info(request_id, partition_name) fwd_info.per_label_seq_info.update(per_label_seq_info) + + def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: str): + part_info = self.per_request_info[request_id].per_partition_info[partition_name] + if self.get_graph_walk(request_id, partition_name) != graph_walk: + part_info.graph_walk_worker_graph_ids = [ + graph_id for graph_id in self.per_request_info[request_id].worker_graph_ids \ + if graph_walk in self.all_worker_graph_ids_to_graph_walks[graph_id] + ] + part_info.current_fwd_info.graph_walk = graph_walk def get_graph_walk(self, request_id: str, partition_name: str): return self.get_fwd_info(request_id, partition_name).graph_walk diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 2e85ec53..1195220c 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -1,3 +1,5 @@ +from copy import deepcopy +from itertools import chain import logging import os import sys @@ -25,6 +27,7 @@ from mstar.graph.loop_indices import NestedLoopIndices from mstar.model.base import Model, WorkerGraph from mstar.streaming.stream_buffer import StreamBuffer +from mstar.streaming.topology import Connection, ConsumerTransitionCtx, WalkTransition from mstar.utils.ipc_format import ( ConductorMessage, ConductorMessageType, @@ -279,16 +282,13 @@ def __init__( tuple[str, torch.dtype, tuple[int, ...]], list[torch.Tensor] ] = defaultdict(list) - # Streaming buffers: request_id -> edge_name -> list of tensors - # (Legacy path — kept for models without PartitionTopology) - self.streaming_buffers: dict[str, dict[str, list[torch.Tensor]]] = {} - # New streaming path: PartitionTopology + StreamBuffer on consumer worker self.partition_topology = model.get_partition_topology() if model else None # Determine which partition this worker serves (by checking which node names # appear in my_worker_graphs vs the topology connections) - self._my_consumer_connections = [] + self._my_consumer_connections: list[Connection] = [] + self.edge_to_connection: dict[str, Connection] = {} if self.partition_topology: my_node_names = set() for wg in my_worker_graphs: @@ -298,6 +298,7 @@ def __init__( # by checking if the streaming edge's next_node is in my nodes if any(n in my_node_names for n in self._get_node_names_for_partition(conn.to_partition, model)): self._my_consumer_connections.append(conn) + self.edge_to_connection[conn.edge_name] = conn # Set of edge names that arrive via streaming (used to distinguish # streaming inputs from conductor-triggered non-streaming inputs @@ -446,7 +447,6 @@ def _remove_request(self, body: RemoveRequest) -> None: self.engine_manager.remove_request(body.request_id) self.worker_graphs_manager.remove_request(body.request_id) self.tensor_manager.cleanup_request(body.request_id) - self.streaming_buffers.pop(body.request_id, None) for node_name in self.engine_manager.lru_tracked_nodes(): self._last_active.pop((body.request_id, node_name), None) @@ -490,6 +490,11 @@ def _process_new_inputs(self, body: InputSignals) -> None: body.request_id, current_fwd_info=body.request_info, partition_name=body.partition_name ) + for edge, idx in body.request_info.consumed_edge_idx.items(): + req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) + if edge in req_info.stream_buffers: + req_info.stream_buffers[edge].set_index(idx) + if self.enable_nvtx: range_pop(synchronize=False) @@ -508,8 +513,8 @@ def _process_new_inputs(self, body: InputSignals) -> None: self.wakeup_event.register_futures(futures) for edge in streaming_with_tensors: stream_buf = req_info.stream_buffers[edge.name] - for info in edge.tensor_info: - stream_buf.pre_read_register(info.uuid) + for _ in edge.tensor_info: + stream_buf.pre_read_register() if self.enable_nvtx: range_pop(synchronize=False) range_push("process_new_inputs.process_inputs") @@ -604,16 +609,26 @@ def _route_streaming_tensor(self, request_id: str, edge: GraphEdge) -> None: request_id=request_id, uuid=info.uuid, ) - stream_buf.put(info.uuid, tensor.clone()) + stream_buf.put( + tensor.clone(), + index=edge._index, + graph_walk=edge._graph_walk_transition, + ) self.tensor_manager.dereference(request_id, info.uuid) def _pop_streaming_edge( - self, sbuf: StreamBuffer, edge_name: str, request_id: str + self, sbuf: StreamBuffer, edge_name: str, request_id: str, + # for speculation, it is messy to allow graph walk transitions here + allow_graph_walk_transition: bool=True ) -> GraphEdge | None: consumer_node = self._consumer_node_cache.get(edge_name, "") synthetic_edge = sbuf.pop_waiting_edge() - if synthetic_edge is None and sbuf.has_chunk_ready(): - chunk = sbuf.pop_chunk() + consumer_partition = self.worker_graphs_manager.get_partition_for_node(consumer_node) + graph_walk = self.worker_graphs_manager.get_graph_walk( + request_id, consumer_partition + ) + if synthetic_edge is None and sbuf.has_chunk_ready(graph_walk): + chunk = sbuf.pop_chunk(graph_walk) chunk_tensor = chunk.data.get("data") if chunk_tensor is None: # Empty chunk — producer done, no more data. @@ -642,6 +657,17 @@ def _pop_streaming_edge( tensor_info=tensor_infos.get(edge_name, []), _final_stream_chunk=chunk.is_final, ) + if chunk.graph_walk_transition is not None and chunk.graph_walk_transition != graph_walk: + if not allow_graph_walk_transition: + sbuf.store_uningested_edge(synthetic_edge) + return + self.worker_graphs_manager.update_graph_walk( + request_id, consumer_partition, + chunk.graph_walk_transition + ) + self.worker_graphs_manager.get_fwd_info( + request_id, consumer_partition + ).consumed_edge_idx[edge_name] = sbuf._current_index return synthetic_edge def _poll_stream_buffers_for_speculation( @@ -655,7 +681,10 @@ def _poll_stream_buffers_for_speculation( consumer_node = self._consumer_node_cache.get(edge_name, "") if consumer_node != node_name: continue - edge = self._pop_streaming_edge(sbuf, edge_name, request_id) + edge = self._pop_streaming_edge( + sbuf, edge_name, request_id, + allow_graph_walk_transition=False + ) if edge is not None: result.append(edge) return result @@ -983,13 +1012,39 @@ def _send_outputs( ) self.communicator.send("api_server", message) + produced_streaming_edges: set[str] = set() + consumer_graph_walk_transitions: dict[str, WalkTransition] = {} + # Handle streaming edges # Local streaming: route to StreamBuffer req_info = self.worker_graphs_manager.per_request_info[request_id] + fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) + + for edge in chain(outputs.streaming_local, *outputs.streaming_to_workers.values()): + edge._index = fwd_info.produced_edge_idx.get(edge.name, 0) + if edge.name not in produced_streaming_edges: + produced_streaming_edges.add(edge.name) + fwd_info.produced_edge_idx[edge.name] = \ + fwd_info.produced_edge_idx.get(edge.name, 0) + 1 + conn = self.edge_to_connection.get(edge.name) + if conn and conn.consumer_walk_transition: + consumer_graph_walk_transitions[edge.name] = conn.consumer_walk_transition( + ConsumerTransitionCtx( + producer_walk=graph_walk, + consumer_walk=fwd_info.tracked_consumer_graph_walks.get( + conn.to_partition + ), + producer_fwd=fwd_info + ) + ) + + walk_transition = consumer_graph_walk_transitions.get(edge.name) + edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else None + for edge in outputs.streaming_local: stream_buf = req_info.stream_buffers[edge.name] - for info in edge.tensor_info: - stream_buf.pre_read_register(info.uuid) + for _ in edge.tensor_info: + stream_buf.pre_read_register() self._route_streaming_tensor(request_id, edge) # Remote streaming: send to destination workers @@ -999,13 +1054,12 @@ def _send_outputs( body=InputSignals( request_id=request_id, inputs=edges, - request_info=self.worker_graphs_manager.get_fwd_info(request_id, partition_name), + request_info=fwd_info, partition_name=partition_name ), ) self.communicator.send(worker_id, message) if outputs.completed_worker_graph_ids: - fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) if partition_name is None: partition_name = getattr(fwd_info, 'partition_name', 'default') req_info = self.worker_graphs_manager.per_request_info.get(request_id) @@ -1029,7 +1083,16 @@ def _send_outputs( persist_signals=self.worker_graphs_manager.flush_persist_signals(request_id), new_tokens=self.worker_graphs_manager.flush_new_tokens(request_id), output_signal_names=self.worker_graphs_manager.flush_output_signals(request_id), - per_label_seq_info=self.worker_graphs_manager.get_seq_info(request_id, partition_name), + per_label_seq_info=fwd_info.per_label_seq_info, + new_produced_edge_idx=fwd_info.produced_edge_idx, + new_consumed_edge_idx=fwd_info.consumed_edge_idx, + # Key by consumer partition (not edge name) so it merges + # into the producer pstate's partition-keyed + # tracked_consumer_graph_walks on the conductor. + consumer_graph_walk_transitions={ + self.edge_to_connection[edge_name].to_partition: wt.graph_walk + for edge_name, wt in consumer_graph_walk_transitions.items() + }, partition_name=partition_name, partition_done=p_done, stream_tokens_consumed=stream_consumed, From 41b4d9717b1300c56a8e83d7966d1a6c20232379 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Mon, 1 Jun 2026 05:15:32 +0000 Subject: [PATCH 02/16] finish up implementation, ready to test --- mstar/conductor/conductor.py | 16 ++++++++++ mstar/graph/base.py | 4 +++ mstar/utils/ipc_format.py | 3 ++ mstar/worker/node_manager_utils.py | 47 ++++++++++++++++++++++++++---- mstar/worker/worker.py | 22 ++++++++++++-- 5 files changed, 83 insertions(+), 9 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 17c62ca9..230cd1cf 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -887,6 +887,22 @@ def _process_worker_graphs_done( body.output_signal_names, list ) else [] + # Producer-triggered partitions self-apply their walk from the stream. + # Adopt the walk the consumer actually ran (reported back) before the + # completion check, and realign the completion-tracking worker-graph set + # to it — otherwise current_worker_graph_ids (a state-machine-derived + # walk) diverges and the subset check below breaks. The conductor does + # not drive this partition's transitions; this only mirrors them. + pdef = request_data.partition_definitions.get(partition_name) + if (pdef is not None + and pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED + and body.partition_graph_walk + and body.partition_graph_walk != pstate.metadata.graph_walk): + pstate.metadata.graph_walk = body.partition_graph_walk + self._set_partition_worker_graph_ids( + body.request_id, partition_name, body.partition_graph_walk, + ) + # Each wg is only marked complete when all its TP ranks have reported. for wg_id in body.worker_graph_ids: count = pstate.wg_rank_completions.get(wg_id, 0) + 1 diff --git a/mstar/graph/base.py b/mstar/graph/base.py index 93b85e08..619c9e8d 100644 --- a/mstar/graph/base.py +++ b/mstar/graph/base.py @@ -71,6 +71,9 @@ class GraphEdge: _total_fanin: int = 1 _shard_dim: int | None = None + # set for non-streaming inputs + _target_graph_walk: str | None = None + def clone(self): return GraphEdge( next_node=self.next_node, @@ -82,6 +85,7 @@ def clone(self): output_modality=self.output_modality, _persist_for_loop=self._persist_for_loop, _final_stream_chunk=self._final_stream_chunk, + _target_graph_walk=self._target_graph_walk, ) diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index b52064b7..837b81b6 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -125,6 +125,9 @@ class WorkerGraphsDone(MessageBody): per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") partition_done: bool = field(default=False) + # the graph walk this partition's just-completed forward pass ran under; + # used by the conductor to track a producer-triggered partition's walk + partition_graph_walk: str | None = field(default=None) stream_tokens_consumed: dict[str, int] = field(default_factory=dict) # edge_name -> tokens consumed from stream output_loop_indices: dict[str, NestedLoopIndices] = field(default_factory=dict) diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index 9389bd87..aedebe42 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -192,6 +192,10 @@ class PerPartitionInfo: # graph_walk_worker_graph_ids = worker graphs for current graph walk graph_walk_worker_graph_ids: list[str] = field(default_factory=list) # for this worker stream_partition_done: bool = False # set True when last chunk pops with is_final + + # edges that are pending a producer-triggered graph walk transition + # {graph walk -> edges} + pending_edges: dict[str, list[GraphEdge]] = field(default_factory=dict) @dataclass @@ -268,22 +272,34 @@ def __post_init__(self): # ambiguity is moot for routing. self.walk_node_to_worker_graph_id[(walk, node)] = wg_id + def buffer_pending_edge( + self, request_id: str, graph_walk: str, edge: GraphEdge + ): + partition_name = self.get_partition_for_node(edge.next_node) + req_info = self.per_request_info[request_id] + part_info = req_info.per_partition_info[partition_name] + part_info.pending_edges.setdefault(graph_walk, []).append(edge) + def update_request_info( self, request_id: str, partition_name, current_fwd_info: CurrentForwardPassInfo | None=None, per_label_seq_info: PerLabelSeqInfo | None=None, + allow_graph_walk_transition: bool=True ): req_info = self.per_request_info[request_id] part_info = req_info.per_partition_info[partition_name] if current_fwd_info is not None: - graph_walk = current_fwd_info.graph_walk - if self.get_graph_walk(request_id, partition_name) != graph_walk: - part_info.graph_walk_worker_graph_ids = [ - graph_id for graph_id in self.per_request_info[request_id].worker_graph_ids \ - if graph_walk in self.all_worker_graph_ids_to_graph_walks[graph_id] - ] + if allow_graph_walk_transition: + self.update_graph_walk(request_id, partition_name, current_fwd_info.graph_walk) + else: + # Producer-triggered partition: the stream owns the walk. Absorb + # everything else from the conductor but keep the locally + # (stream-)applied walk so the wholesale replace below doesn't + # clobber it. (Inputs were already tagged with the conductor's + # intended walk before this call.) + current_fwd_info.graph_walk = self.get_graph_walk(request_id, partition_name) part_info.current_fwd_info = current_fwd_info if per_label_seq_info is not None: @@ -298,6 +314,11 @@ def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: st if graph_walk in self.all_worker_graph_ids_to_graph_walks[graph_id] ] part_info.current_fwd_info.graph_walk = graph_walk + if graph_walk in part_info.pending_edges: + edges = part_info.pending_edges.pop(graph_walk) + self.process_new_inputs( + request_id, inputs=edges + ) def get_graph_walk(self, request_id: str, partition_name: str): return self.get_fwd_info(request_id, partition_name).graph_walk @@ -332,6 +353,20 @@ def process_new_inputs( ``next_node`` lives on a different worker). Caller uses these for cross-worker routing. """ + wrong_graph_walk = [ + edge for edge in inputs if edge._target_graph_walk is not None \ + and edge._target_graph_walk != self.get_graph_walk( + request_id, partition_name=self.get_partition_for_node(edge.next_node) + ) + ] + for edge in wrong_graph_walk: + self.buffer_pending_edge(request_id, edge._target_graph_walk, edge) + inputs = [ + edge for edge in inputs if edge._target_graph_walk is None \ + or edge._target_graph_walk == self.get_graph_walk( + request_id, partition_name=self.get_partition_for_node(edge.next_node) + ) + ] for part_info in self.per_request_info[request_id].per_partition_info.values(): worker_graph_ids = part_info.graph_walk_worker_graph_ids for worker_graph_id in worker_graph_ids: diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 1195220c..375a604d 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -17,7 +17,7 @@ from mstar.communication.communicator import CommProtocol, ZMQCommunicator from mstar.communication.event import EventWakeup from mstar.communication.tensors import NameToTensorList, create_tensor_communication_manager -from mstar.conductor.request_info import CurrentForwardPassInfo +from mstar.conductor.request_info import CurrentForwardPassInfo, TransitionSource from mstar.distributed.base import ShardingConfig from mstar.distributed.communication import WorkerTPGroups from mstar.engine.base import EngineType, NodeBatch, NodeOutput @@ -284,6 +284,7 @@ def __init__( # New streaming path: PartitionTopology + StreamBuffer on consumer worker self.partition_topology = model.get_partition_topology() if model else None + self.partitions = model.get_partitions() if model else [] # Determine which partition this worker serves (by checking which node names # appear in my_worker_graphs vs the topology connections) @@ -307,6 +308,10 @@ def __init__( conn.edge_name for conn in self._my_consumer_connections } + self._producer_triggered_partitions: set[str] = set([ + part.name for part in self.partitions if part.transition_source == TransitionSource.PRODUCER_TRIGGERED + ]) + # Build consumer node cache: edge_name -> next_node name self._consumer_node_cache: dict[str, str] = {} if self._my_consumer_connections and model: @@ -486,15 +491,22 @@ def _process_new_inputs(self, body: InputSignals) -> None: # partition). Streaming-only InputSignals must not overwrite the current # partition's fwd_info. if non_streaming: + # Tag each input with the walk the conductor intends it for. Must be + # done BEFORE update_request_info: for producer-triggered partitions + # that call rewrites body.request_info.graph_walk to the stream walk. + # If the tag differs from the partition's current (stream-applied) + # walk, the ingest-time gate buffers it until the transition. + for edge in non_streaming: + edge._target_graph_walk = body.request_info.graph_walk self.worker_graphs_manager.update_request_info( body.request_id, current_fwd_info=body.request_info, - partition_name=body.partition_name + partition_name=body.partition_name, + allow_graph_walk_transition=body.partition_name not in self._producer_triggered_partitions ) for edge, idx in body.request_info.consumed_edge_idx.items(): req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) if edge in req_info.stream_buffers: req_info.stream_buffers[edge].set_index(idx) - if self.enable_nvtx: range_pop(synchronize=False) @@ -1095,6 +1107,10 @@ def _send_outputs( }, partition_name=partition_name, partition_done=p_done, + # the walk this just-completed forward pass ran under, so + # the conductor can track a producer-triggered partition's + # stream-applied walk + partition_graph_walk=fwd_info.graph_walk, stream_tokens_consumed=stream_consumed, output_loop_indices=self.worker_graphs_manager.get_output_loop_indices(request_id), ), From 41a206df06aa05a25eda0de9e07c245a33fdbef1 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Mon, 1 Jun 2026 08:59:24 +0000 Subject: [PATCH 03/16] bug fixes: only do graph walk transition when no ready nodes --- mstar/streaming/stream_buffer.py | 14 +++++++-- mstar/streaming/topology.py | 16 ++++++++++ mstar/worker/worker.py | 52 ++++++++++++++++++++++++++------ 3 files changed, 69 insertions(+), 13 deletions(-) diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index c905c139..09a052bb 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -23,6 +23,11 @@ class StreamingTensor: tensor: torch.Tensor graph_walk: str | None = None +@dataclass +class WaitingEdge: + edge: GraphEdge + walk_transition: str | None = None + @dataclass class StreamBuffer: @@ -102,7 +107,7 @@ def _producer_done_and_all_read(self) -> bool: return self.producer_done and \ self._num_buffer_writes >= self._num_tensors_registered - def pop_waiting_edge(self) -> GraphEdge | None: + def pop_waiting_edge(self) -> WaitingEdge | None: if len(self._waiting_graph_edges) > 0: return self._waiting_graph_edges.popleft() @@ -211,8 +216,11 @@ def pop_chunk(self, graph_walk: str) -> StreamChunk: self._chunks_popped += 1 return chunk - def store_uningested_edge(self, edge: GraphEdge): - self._waiting_graph_edges.append(edge) + def store_uningested_edge(self, edge: GraphEdge, walk_transition: str | None=None): + self._waiting_graph_edges.append(WaitingEdge( + edge=edge, + walk_transition=walk_transition + )) def _collate(self, items: list) -> dict[str, torch.Tensor | None]: if not items: diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index d32b3153..bc331e3e 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -21,6 +21,22 @@ class StreamingGraphEdge(GraphEdge): def __post_init__(self): self.is_streaming = True + + def clone(self): + return StreamingGraphEdge( + next_node=self.next_node, + name=self.name, + tensor_info=self.tensor_info[:], + persist=self.persist, + conductor_new_token=self.conductor_new_token, + is_streaming=self.is_streaming, + output_modality=self.output_modality, + _persist_for_loop=self._persist_for_loop, + _target_graph_walk=self._target_graph_walk, + target_partition=self.target_partition, + _index=self._index, + _graph_walk_transition=self._graph_walk_transition, + ) @dataclass(frozen=True) diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 375a604d..36789860 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -634,11 +634,23 @@ def _pop_streaming_edge( allow_graph_walk_transition: bool=True ) -> GraphEdge | None: consumer_node = self._consumer_node_cache.get(edge_name, "") - synthetic_edge = sbuf.pop_waiting_edge() consumer_partition = self.worker_graphs_manager.get_partition_for_node(consumer_node) graph_walk = self.worker_graphs_manager.get_graph_walk( request_id, consumer_partition ) + + waiting_edge = sbuf.pop_waiting_edge() + if waiting_edge is not None and waiting_edge.walk_transition is not None \ + and waiting_edge.walk_transition != graph_walk: + if not allow_graph_walk_transition: + sbuf.store_uningested_edge(waiting_edge.edge, waiting_edge.walk_transition) + return + self.worker_graphs_manager.update_graph_walk( + request_id, consumer_partition, + waiting_edge.walk_transition + ) + synthetic_edge = waiting_edge.edge if waiting_edge is not None else None + if synthetic_edge is None and sbuf.has_chunk_ready(graph_walk): chunk = sbuf.pop_chunk(graph_walk) chunk_tensor = chunk.data.get("data") @@ -671,7 +683,7 @@ def _pop_streaming_edge( ) if chunk.graph_walk_transition is not None and chunk.graph_walk_transition != graph_walk: if not allow_graph_walk_transition: - sbuf.store_uningested_edge(synthetic_edge) + sbuf.store_uningested_edge(synthetic_edge, chunk.graph_walk_transition) return self.worker_graphs_manager.update_graph_walk( request_id, consumer_partition, @@ -715,7 +727,18 @@ def _poll_stream_buffers(self) -> None: """Check all active StreamBuffers; when a chunk is ready, feed it as a normal input.""" for request_id, req_info in list(self.worker_graphs_manager.per_request_info.items()): for edge_name, sbuf in req_info.stream_buffers.items(): - synthetic_edge = self._pop_streaming_edge(sbuf, edge_name, request_id) + consumer_node = self._consumer_node_cache.get(edge_name, "") + partition_name = self.worker_graphs_manager.get_partition_for_node(consumer_node) + wgid = self.worker_graphs_manager.get_worker_graph_id_for_node(request_id, consumer_node) + + allow_graph_walk_transition = len( + self.worker_graphs_manager.queues[wgid].per_request_queues[request_id].ready_node_names + ) == 0 + + synthetic_edge = self._pop_streaming_edge( + sbuf, edge_name, request_id, + allow_graph_walk_transition=allow_graph_walk_transition + ) if synthetic_edge is not None: # Streaming edges go through the same path as regular ones — @@ -1053,6 +1076,14 @@ def _send_outputs( walk_transition = consumer_graph_walk_transitions.get(edge.name) edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else None + # Advance our local view of each consumer's walk by the transition just + # emitted: inside a dynamic loop there is no conductor round-trip to + # refresh tracked_consumer_graph_walks. After the edge loop so all edges + # of a pass see the same (pre-update) consumer walk. + for edge_name, wt in consumer_graph_walk_transitions.items(): + to_partition = self.edge_to_connection[edge_name].to_partition + fwd_info.tracked_consumer_graph_walks[to_partition] = wt.graph_walk + for edge in outputs.streaming_local: stream_buf = req_info.stream_buffers[edge.name] for _ in edge.tensor_info: @@ -1107,10 +1138,11 @@ def _send_outputs( }, partition_name=partition_name, partition_done=p_done, - # the walk this just-completed forward pass ran under, so - # the conductor can track a producer-triggered partition's - # stream-applied walk - partition_graph_walk=fwd_info.graph_walk, + # the walk this just-completed forward pass ran under (the + # batch's walk, NOT fwd_info.graph_walk which may have + # already advanced via a stream-buffer transition before + # this report is built — that race drops talker_input_embeds) + partition_graph_walk=graph_walk, stream_tokens_consumed=stream_consumed, output_loop_indices=self.worker_graphs_manager.get_output_loop_indices(request_id), ), @@ -1614,9 +1646,9 @@ def _thread_outputs_to_speculative( speculation.node_batch.per_request_info.pop(r, None) speculation.scheduled_batch.request_to_worker_graph.pop(r, None) speculation.scheduled_batch.node_objects.pop(r, None) - for edge in speculation.consumed_streaming_edges.get(rid, []): - self._return_speculative_streaming_edge(rid, edge) - speculation.consumed_streaming_edges.pop(rid, None) + for edge in speculation.consumed_streaming_edges.get(r, []): + self._return_speculative_streaming_edge(r, edge) + speculation.consumed_streaming_edges.pop(r, None) speculation.continuing_rids = threaded_continuing speculation.dropped = dropped From 392ba33693b464ba7fc06bd589f8e36ced515109 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 2 Jun 2026 00:21:48 +0000 Subject: [PATCH 04/16] bug fixes --- mstar/model/qwen3_omni/qwen3_omni_model.py | 23 ++++++++++-------- mstar/worker/worker.py | 28 +++++++++++++++++----- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/mstar/model/qwen3_omni/qwen3_omni_model.py b/mstar/model/qwen3_omni/qwen3_omni_model.py index e3a45e98..5ab7032f 100644 --- a/mstar/model/qwen3_omni/qwen3_omni_model.py +++ b/mstar/model/qwen3_omni/qwen3_omni_model.py @@ -15,13 +15,14 @@ Thinker --[thinker_states, FixedChunkPolicy(1)]--> Talker Talker --[codec_tokens, FixedChunkPolicy(25)]--> Code2Wav -Conductor-triggered pipelined prefill (Approach C): - After each Thinker walk completes (prefill_text, prefill_audio, - prefill_vision, thinker_decode), the conductor sends a - ``talker_trigger`` to the Talker partition. During prefill each - trigger extends the Talker KV cache with the new Thinker hidden - states. The final trigger (when thinker_decode starts) tells the - Talker to sample its first codec token and transition to decode. +Producer-triggered Talker prefill: + The Talker partition is PRODUCER_TRIGGERED: its graph walk is driven + by the Thinker's streamed thinker_states/thinker_mask rather than by + the conductor. Each streamed item carries a walk-transition marker + (see ``talker_state_transition``): items from a Thinker prefill walk + keep the Talker in talker_prefill (extend KV cache only); the first + item from thinker_decode flips it to talker_last_prefill (sample the + first codec token), after which it self-transitions to talker_decode. Text-only mode: When output_modalities does not include "audio", only the Thinker @@ -552,9 +553,6 @@ def get_initial_forward_pass_args( is_prefill=True, kwargs={ "audio_output": audio_output, - "talker_prefill_done": False, - "num_thinker_prefill_steps": len(input_modalities), - "prefill_chunks_processed": 0, "voice": model_kwargs.get("voice", "Ethan"), "talker_max_tokens": self.get_max_talker_output_tokens(**model_kwargs), }, @@ -824,9 +822,11 @@ def _get_thinker_forward( # alongside the primary feature tensor). schedule = metadata.kwargs["prefill_schedule"] step = metadata.kwargs["prefill_step"] + is_last_prefill = (step == len(schedule) - 1) inputs = self._get_thinker_prefill_inputs(metadata, persist_signals) else: # Decode: previous token feeds back as text_inputs + is_last_prefill = False edge = GraphEdge(next_node="Thinker", name="text_inputs") edge.tensor_info = persist_signals.get("new_token", []) inputs = [edge] @@ -837,6 +837,9 @@ def _get_thinker_forward( step_metadata = { "is_prefill": metadata.is_prefill, + # The last prefill step must emit logits so the first decode token + # can be sampled (read by the non-batched Thinker submodule). + "is_last_prefill": is_last_prefill, # Persist the audio_output flag across every Thinker step so # the submodule can gate thinker_states emission. Default True # for backwards compatibility with callers that never set it. diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 36789860..e17caef6 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -689,6 +689,12 @@ def _pop_streaming_edge( request_id, consumer_partition, chunk.graph_walk_transition ) + # "Consumed" here means drained into the buffer (high-water index), + # not popped out of it. These coincide for non-overlapping policies. + # For left-context/sliding-window policies the buffer retains items + # past this index as overlap, so a PD re-seed via set_index() would + # skip them — fine today since no sliding-window consumer spans + # multiple graph walks; revisit if that changes. self.worker_graphs_manager.get_fwd_info( request_id, consumer_partition ).consumed_edge_idx[edge_name] = sbuf._current_index @@ -731,9 +737,14 @@ def _poll_stream_buffers(self) -> None: partition_name = self.worker_graphs_manager.get_partition_for_node(consumer_node) wgid = self.worker_graphs_manager.get_worker_graph_id_for_node(request_id, consumer_node) - allow_graph_walk_transition = len( - self.worker_graphs_manager.queues[wgid].per_request_queues[request_id].ready_node_names - ) == 0 + # Defer a walk transition while a node is still ready (and so + # about to be scheduled) under the current walk — otherwise the + # scheduler would dispatch it under the new walk. A request with + # no queue entry yet has no ready nodes, so allow the transition. + per_req_queue = self.worker_graphs_manager.queues[wgid].per_request_queues.get(request_id) + allow_graph_walk_transition = ( + per_req_queue is None or len(per_req_queue.ready_node_names) == 0 + ) synthetic_edge = self._pop_streaming_edge( sbuf, edge_name, request_id, @@ -1055,12 +1066,16 @@ def _send_outputs( req_info = self.worker_graphs_manager.per_request_info[request_id] fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) + # One stream index per logical edge per pass. A single edge may fan out + # to several physical copies (local + one per consumer-shard worker); + # they all carry the same _index so the consumer's dedup/PD-reseed sees + # one item, not N. Assign the index once, reuse it for every copy. + edge_indices: dict[str, int] = {} for edge in chain(outputs.streaming_local, *outputs.streaming_to_workers.values()): - edge._index = fwd_info.produced_edge_idx.get(edge.name, 0) if edge.name not in produced_streaming_edges: produced_streaming_edges.add(edge.name) - fwd_info.produced_edge_idx[edge.name] = \ - fwd_info.produced_edge_idx.get(edge.name, 0) + 1 + edge_indices[edge.name] = fwd_info.produced_edge_idx.get(edge.name, 0) + fwd_info.produced_edge_idx[edge.name] = edge_indices[edge.name] + 1 conn = self.edge_to_connection.get(edge.name) if conn and conn.consumer_walk_transition: consumer_graph_walk_transitions[edge.name] = conn.consumer_walk_transition( @@ -1073,6 +1088,7 @@ def _send_outputs( ) ) + edge._index = edge_indices[edge.name] walk_transition = consumer_graph_walk_transitions.get(edge.name) edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else None From 203ca1fd3009bc67d2c630fbae16da31f3db6404 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 2 Jun 2026 05:30:28 +0000 Subject: [PATCH 05/16] guard against stale index in TP case --- mstar/conductor/conductor.py | 10 ++++++++-- mstar/worker/node_manager_utils.py | 14 ++++++++++++++ mstar/worker/worker.py | 7 +++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 230cd1cf..8b701a91 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -852,8 +852,14 @@ def _process_worker_graphs_done( ) return [] - pstate.produced_edge_idx.update(body.new_produced_edge_idx) - pstate.consumed_edge_idx.update(body.new_consumed_edge_idx) + # Stream indices are monotonic, but only the emitting TP rank advances + # them; sibling ranks report the (stale) seed they were given. With + # last-writer-wins .update(), a sibling landing after the emitter would + # rewind the index, re-emitting an already-used one. Merge with max. + for name, idx in body.new_produced_edge_idx.items(): + pstate.produced_edge_idx[name] = max(pstate.produced_edge_idx.get(name, 0), idx) + for name, idx in body.new_consumed_edge_idx.items(): + pstate.consumed_edge_idx[name] = max(pstate.consumed_edge_idx.get(name, 0), idx) pstate.tracked_consumer_graph_walks.update(body.consumer_graph_walk_transitions) # Persist signals: every rank contributes its shard (different uuid + diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index aedebe42..af8c8127 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -291,6 +291,19 @@ def update_request_info( part_info = req_info.per_partition_info[partition_name] if current_fwd_info is not None: + # produced_edge_idx is a monotonic local counter for streaming + # output. The conductor round-trips it but its copy lags (refreshed + # only at WorkerGraphsDone), so a stale value here would rewind the + # counter and re-emit an already-used stream index — which the + # consumer drops as a duplicate, taking its walk-transition tag with + # it. Never let the conductor rewind it (mirrors the set_index() + # max-guard on the consumer side). + old_fwd_info = part_info.current_fwd_info + if old_fwd_info is not None: + for name, idx in old_fwd_info.produced_edge_idx.items(): + current_fwd_info.produced_edge_idx[name] = max( + current_fwd_info.produced_edge_idx.get(name, 0), idx + ) if allow_graph_walk_transition: self.update_graph_walk(request_id, partition_name, current_fwd_info.graph_walk) else: @@ -307,6 +320,7 @@ def update_request_info( fwd_info.per_label_seq_info.update(per_label_seq_info) def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: str): + print(f"update graph walk {self.get_graph_walk(request_id, partition_name)} -> {graph_walk}") part_info = self.per_request_info[request_id].per_partition_info[partition_name] if self.get_graph_walk(request_id, partition_name) != graph_walk: part_info.graph_walk_worker_graph_ids = [ diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index e17caef6..b4de4618 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -735,6 +735,9 @@ def _poll_stream_buffers(self) -> None: for edge_name, sbuf in req_info.stream_buffers.items(): consumer_node = self._consumer_node_cache.get(edge_name, "") partition_name = self.worker_graphs_manager.get_partition_for_node(consumer_node) + + if not self.worker_graphs_manager.has_partition(request_id, partition_name): + continue wgid = self.worker_graphs_manager.get_worker_graph_id_for_node(request_id, consumer_node) # Defer a walk transition while a node is still ready (and so @@ -744,7 +747,7 @@ def _poll_stream_buffers(self) -> None: per_req_queue = self.worker_graphs_manager.queues[wgid].per_request_queues.get(request_id) allow_graph_walk_transition = ( per_req_queue is None or len(per_req_queue.ready_node_names) == 0 - ) + ) and request_id not in self._in_flight_rids synthetic_edge = self._pop_streaming_edge( sbuf, edge_name, request_id, @@ -1000,7 +1003,7 @@ def _send_outputs( message_type=WorkerMessageType.INPUT_SIGNALS, body=InputSignals( request_id=request_id, - inputs=edges, + inputs=[edge.clone() for edge in edges], request_info=self.worker_graphs_manager.get_fwd_info(request_id, partition_name), partition_name=partition_name ), From 0511495c67b7c7f7831cc1552fae0b75e2836638 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 2 Jun 2026 05:52:50 +0000 Subject: [PATCH 06/16] only make graph walk transitions on a freshly reset worker graph --- mstar/graph/base.py | 3 +++ mstar/worker/node_manager_utils.py | 1 - mstar/worker/worker.py | 12 ++++++------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mstar/graph/base.py b/mstar/graph/base.py index 619c9e8d..56198f41 100644 --- a/mstar/graph/base.py +++ b/mstar/graph/base.py @@ -721,8 +721,10 @@ def __init__(self, graph_section: GraphSection): self.ready_for_streaming = set(self.only_streaming_inputs) self.ready_streaming_next_iter = set(self.only_streaming_inputs) + self.clean = True def register_ingested_input(self, graph_edge: GraphEdge): + self.clean = False node = self.nodes[graph_edge.next_node] # If node._speculatively_scheduled, the node is already executing as # a spec batch. We don't want to double-queue it (either for the @@ -770,6 +772,7 @@ def reset_for_iter(self): ) def clear(self): + self.clean = True super().clear() self.ready_names.clear() self.ready_for_streaming = set(self.only_streaming_inputs) diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index af8c8127..9cf73e27 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -320,7 +320,6 @@ def update_request_info( fwd_info.per_label_seq_info.update(per_label_seq_info) def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: str): - print(f"update graph walk {self.get_graph_walk(request_id, partition_name)} -> {graph_walk}") part_info = self.per_request_info[request_id].per_partition_info[partition_name] if self.get_graph_walk(request_id, partition_name) != graph_walk: part_info.graph_walk_worker_graph_ids = [ diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index b4de4618..9759b7b0 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -740,14 +740,14 @@ def _poll_stream_buffers(self) -> None: continue wgid = self.worker_graphs_manager.get_worker_graph_id_for_node(request_id, consumer_node) - # Defer a walk transition while a node is still ready (and so - # about to be scheduled) under the current walk — otherwise the - # scheduler would dispatch it under the new walk. A request with - # no queue entry yet has no ready nodes, so allow the transition. + # Only transition the walk when the worker graph is in a fully + # reset (clean) state — no partial progress that the scheduler + # would otherwise dispatch under the new walk. A request with no + # queue entry yet has no progress, so allow the transition. per_req_queue = self.worker_graphs_manager.queues[wgid].per_request_queues.get(request_id) allow_graph_walk_transition = ( - per_req_queue is None or len(per_req_queue.ready_node_names) == 0 - ) and request_id not in self._in_flight_rids + per_req_queue is None or per_req_queue.wg_state_registry.clean + ) synthetic_edge = self._pop_streaming_edge( sbuf, edge_name, request_id, From 96d1f876813d7fa16d6142d0fbcdbb5c78c180a4 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 2 Jun 2026 06:04:35 +0000 Subject: [PATCH 07/16] only allow transition when full partition is done, not just worker graph --- mstar/worker/node_manager_utils.py | 12 ++++++++++++ mstar/worker/worker.py | 5 ++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index 9cf73e27..259b9e42 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -352,6 +352,18 @@ def get_fwd_info(self, request_id: str, partition_name: str): def get_partition_for_node(self, node_name: str) -> str | None: """Look up which partition a node belongs to.""" return self.node_to_partition.get(node_name) + + def partition_clean(self, request_id: str, partition_name: str) -> bool: + if request_id not in self.per_request_info: + return True + wgids = self.per_request_info[request_id].per_partition_info[partition_name].graph_walk_worker_graph_ids + for wgid in wgids: + queue = self.queues[wgid].per_request_queues.get(request_id) + if queue is None: + continue + if not queue.wg_state_registry.clean: + return False + return True def process_new_inputs( self, diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 9759b7b0..45527aa2 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -744,9 +744,8 @@ def _poll_stream_buffers(self) -> None: # reset (clean) state — no partial progress that the scheduler # would otherwise dispatch under the new walk. A request with no # queue entry yet has no progress, so allow the transition. - per_req_queue = self.worker_graphs_manager.queues[wgid].per_request_queues.get(request_id) - allow_graph_walk_transition = ( - per_req_queue is None or per_req_queue.wg_state_registry.clean + allow_graph_walk_transition = self.worker_graphs_manager.partition_clean( + request_id, partition_name ) synthetic_edge = self._pop_streaming_edge( From 716ad6c09a5bed6ae07b6f870e62fea618ba6858 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 7 Jun 2026 08:50:25 +0000 Subject: [PATCH 08/16] Support for streaming consumer PD disaggregation --- mstar/conductor/conductor.py | 52 ++++++++++++++--- mstar/streaming/stream_buffer.py | 1 + mstar/worker/node_manager_utils.py | 3 +- mstar/worker/worker.py | 93 ++++++++++++++++-------------- 4 files changed, 98 insertions(+), 51 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 8b701a91..0dd4edb4 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -217,9 +217,24 @@ def __init__( # (1) Set up worker graph TP ranks # (2) Assert that streaming consumers don't have graph-walk-specific sharding config - self.streaming_consumers = set() + self.state_machine_streaming_consumers = set() + state_machine_controlled_walks = set() + self.gw_to_part = {} + self.part_to_walks: dict[str, set[str]] = {} + for part_def in self.model.get_partitions(): + if part_def.transition_source == TransitionSource.STATE_MACHINE: + state_machine_controlled_walks.update(part_def.graph_walks) + self.gw_to_part.update({ + gw: part_def.name for gw in part_def.graph_walks + }) + self.part_to_walks[part_def.name] = set(part_def.graph_walks) + self.node_walk_to_wg: dict[tuple[str, str], WorkerGraph] = {} + + # partition name -> producer-triggered streaming consumer node names + self.producer_triggered_nodes: dict[str, set[str]] = dict() + # (worker idx) -> {tp_group_str: tp_rank} self.worker_tp_group_to_tp_rank: dict[int, dict[str, int]] = {} @@ -228,10 +243,20 @@ def __init__( for walk in wg.graph_walks: graph_walks.add(walk) for name, node in wg.section.get_nodes().items(): + state_machine_controlled = False for walk in wg.graph_walks: self.node_walk_to_wg[(name, walk)] = wg - if node.consumes_stream: - self.streaming_consumers.add(name) + if walk in state_machine_controlled_walks: + state_machine_controlled = True + if node.consumes_stream and state_machine_controlled: + self.state_machine_streaming_consumers.add(name) + elif node.consumes_stream and wg.graph_walks: + # Producer-triggered streaming consumer (e.g. PD-disaggregated + # Talker). All walks of a wg belong to one partition, so any + # of them resolves the partition. + self.producer_triggered_nodes.setdefault( + self.gw_to_part[next(iter(wg.graph_walks))], set() + ).add(name) # v1: one sharding group per worker graph. Track which group "owns" # each wg so we can assert single-group-per-wg. @@ -239,7 +264,7 @@ def __init__( for group in self.default_sharding_config.groups: if group.graph_walks is not None and any([ - node in self.streaming_consumers for node in group.nodes + node in self.state_machine_streaming_consumers for node in group.nodes ]): raise RuntimeError(( f"Sharding group with nodes {group.nodes} includes a streaming consumer but " @@ -461,7 +486,7 @@ def _build_request_sharding_config( for node_name in wg.section.get_nodes(): node_to_workers[NodeAndGraphWalk(node_name, walk)] = worker_ids cfg.setup(node_to_workers) - cfg.assert_stream_consumer_compatibility(self.streaming_consumers) + cfg.assert_stream_consumer_compatibility(self.state_machine_streaming_consumers) return cfg def _split_inputs_to_workers( @@ -478,6 +503,19 @@ def _split_inputs_to_workers( per dest, which the consumer's fan-in path consolidates. """ inputs_per_worker: dict[str, list[GraphEdge]] = defaultdict(list) + + # Seed empty inputs for producer-triggered streaming consumers (PD disag) + # across ALL the partition's walks, not just the conductor's current + # (lagging) walk — otherwise only the prefill worker is reached, never + # decode. set_index is monotonic, so re-seeding is a no-op. + part = self.gw_to_part.get(graph_walk) + for node in self.producer_triggered_nodes.get(part, set()): + for walk in self.part_to_walks.get(part, set()): + for worker in sharding_config.node_to_worker.get( + NodeAndGraphWalk(node, walk), [], + ): + inputs_per_worker.setdefault(worker, []) + for edge in inputs: if not edge.tensor_info: # Signal-only — broadcast to every dest worker. @@ -978,8 +1016,8 @@ def _process_done_forward( if conn.from_partition == partition_name: conn.producer_done = True self._send_producer_done(request_id, conn.from_partition, conn.to_partition) - elif fwd_args.inputs: - # Partition has inputs to send — conductor-driven + else: + # Always send partition inputs so that stream buffer counts are updated properly self._send_partition_inputs(request_id, partition_name, fwd_args) # else: no inputs — partition self-triggers via StreamBuffer diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 09a052bb..c6692d22 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -68,6 +68,7 @@ def pre_read_register(self): def _update_buffer(self): while self._current_index in self._tensors: + print(f"{self.edge_name} INGEST {self._current_index}") self._buffer.append(self._tensors.pop(self._current_index)) self._current_index += 1 diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index 259b9e42..a47890a9 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -18,6 +18,7 @@ from mstar.graph.special_destinations import EMIT_TO_CLIENT, SPECIAL_DESTINATIONS from mstar.model.base import WorkerGraph from mstar.streaming.stream_buffer import StreamBuffer +from mstar.streaming.topology import StreamingGraphEdge logger = logging.getLogger(__name__) @@ -565,7 +566,7 @@ def process_node_outputs( fanout = sharding_config.fanout_graph_edges( edge, source_node=node_name, source_graph_walk=graph_walk, - dest_graph_walk=None + dest_graph_walk=edge._graph_walk_transition ) this_worker_edge = fanout.pop(self.worker_id, None) if this_worker_edge: diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 45527aa2..f27e38c0 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -1,5 +1,3 @@ -from copy import deepcopy -from itertools import chain import logging import os import sys @@ -490,7 +488,7 @@ def _process_new_inputs(self, body: InputSignals) -> None: # a conductor-triggered forward pass, not just streaming data from another # partition). Streaming-only InputSignals must not overwrite the current # partition's fwd_info. - if non_streaming: + if non_streaming or not body.inputs: # Tag each input with the walk the conductor intends it for. Must be # done BEFORE update_request_info: for producer-triggered partitions # that call rewrites body.request_info.graph_walk to the stream walk. @@ -738,7 +736,6 @@ def _poll_stream_buffers(self) -> None: if not self.worker_graphs_manager.has_partition(request_id, partition_name): continue - wgid = self.worker_graphs_manager.get_worker_graph_id_for_node(request_id, consumer_node) # Only transition the walk when the worker graph is in a fully # reset (clean) state — no partial progress that the scheduler @@ -977,6 +974,7 @@ def _register_outputs( def _send_outputs( self, request_id: str, outputs: NodeOutputRouting, nested_loop_indices: NestedLoopIndices, + consumer_graph_walk_transitions, graph_walk: str | None = None, partition_name: str | None = None, prematerialized_new_tokens: dict[str, list[int]] | None = None, @@ -1060,48 +1058,9 @@ def _send_outputs( ) self.communicator.send("api_server", message) - produced_streaming_edges: set[str] = set() - consumer_graph_walk_transitions: dict[str, WalkTransition] = {} - - # Handle streaming edges - # Local streaming: route to StreamBuffer req_info = self.worker_graphs_manager.per_request_info[request_id] fwd_info = self.worker_graphs_manager.get_fwd_info(request_id, partition_name) - # One stream index per logical edge per pass. A single edge may fan out - # to several physical copies (local + one per consumer-shard worker); - # they all carry the same _index so the consumer's dedup/PD-reseed sees - # one item, not N. Assign the index once, reuse it for every copy. - edge_indices: dict[str, int] = {} - for edge in chain(outputs.streaming_local, *outputs.streaming_to_workers.values()): - if edge.name not in produced_streaming_edges: - produced_streaming_edges.add(edge.name) - edge_indices[edge.name] = fwd_info.produced_edge_idx.get(edge.name, 0) - fwd_info.produced_edge_idx[edge.name] = edge_indices[edge.name] + 1 - conn = self.edge_to_connection.get(edge.name) - if conn and conn.consumer_walk_transition: - consumer_graph_walk_transitions[edge.name] = conn.consumer_walk_transition( - ConsumerTransitionCtx( - producer_walk=graph_walk, - consumer_walk=fwd_info.tracked_consumer_graph_walks.get( - conn.to_partition - ), - producer_fwd=fwd_info - ) - ) - - edge._index = edge_indices[edge.name] - walk_transition = consumer_graph_walk_transitions.get(edge.name) - edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else None - - # Advance our local view of each consumer's walk by the transition just - # emitted: inside a dynamic loop there is no conductor round-trip to - # refresh tracked_consumer_graph_walks. After the edge loop so all edges - # of a pass see the same (pre-update) consumer walk. - for edge_name, wt in consumer_graph_walk_transitions.items(): - to_partition = self.edge_to_connection[edge_name].to_partition - fwd_info.tracked_consumer_graph_walks[to_partition] = wt.graph_walk - for edge in outputs.streaming_local: stream_buf = req_info.stream_buffers[edge.name] for _ in edge.tensor_info: @@ -1806,6 +1765,7 @@ def _postprocess_batch( # Mark nodes complete and route routing_per_request: dict[str, NodeOutputRouting] = {} per_request_uuids: dict[str, set[str]] = {} + consumer_walk_transitions_per_rid: dict[str, dict[str, WalkTransition]] = {} for rid, wg_id in batch_N.batch.request_to_worker_graph.items(): # Store output tensors before marking the node as complete so that # loop outputs can be buffered properly. @@ -1831,6 +1791,52 @@ def _postprocess_batch( ) real_outputs = [edge.clone() for edge in completion_output.output_edges] + # Handle streaming edges + # Local streaming: route to StreamBuffer + req_info = self.worker_graphs_manager.per_request_info[rid] + fwd_info = self.worker_graphs_manager.get_fwd_info(rid, batch_N.partition) + produced_streaming_edges: set[str] = set() + consumer_graph_walk_transitions: dict[str, WalkTransition] = {} + + # One stream index per logical edge per pass. A single edge may fan out + # to several physical copies (local + one per consumer-shard worker); + # they all carry the same _index so the consumer's dedup/PD-reseed sees + # one item, not N. Assign the index once, reuse it for every copy. + edge_indices: dict[str, int] = {} + for edge in real_outputs: + if not edge.is_streaming: + continue + if edge.name not in produced_streaming_edges: + produced_streaming_edges.add(edge.name) + edge_indices[edge.name] = fwd_info.produced_edge_idx.get(edge.name, 0) + fwd_info.produced_edge_idx[edge.name] = edge_indices[edge.name] + 1 + conn = self.edge_to_connection.get(edge.name) + if conn and conn.consumer_walk_transition: + consumer_graph_walk_transitions[edge.name] = conn.consumer_walk_transition( + ConsumerTransitionCtx( + producer_walk=batch_N.graph_walk, + consumer_walk=fwd_info.tracked_consumer_graph_walks.get( + conn.to_partition + ), + producer_fwd=fwd_info + ) + ) + + edge._index = edge_indices[edge.name] + walk_transition = consumer_graph_walk_transitions.get(edge.name) + edge._graph_walk_transition = walk_transition.graph_walk if walk_transition else \ + fwd_info.tracked_consumer_graph_walks.get( + self.edge_to_connection[edge.name].to_partition + ) + # Advance our local view of each consumer's walk by the transition just + # emitted: inside a dynamic loop there is no conductor round-trip to + # refresh tracked_consumer_graph_walks. After the edge loop so all edges + # of a pass see the same (pre-update) consumer walk. + for edge_name, wt in consumer_graph_walk_transitions.items(): + to_partition = self.edge_to_connection[edge_name].to_partition + fwd_info.tracked_consumer_graph_walks[to_partition] = wt.graph_walk + consumer_walk_transitions_per_rid[rid] = consumer_graph_walk_transitions + routing_per_request[rid] = self.worker_graphs_manager.process_node_outputs( rid, node_name=batch_N.node_name, outputs=real_outputs, graph_walk=batch_N.graph_walk @@ -1871,6 +1877,7 @@ def _postprocess_batch( self._send_outputs( rid, routing, nested_loop_indices=per_req_nested_idxs[rid], + consumer_graph_walk_transitions=consumer_walk_transitions_per_rid.get(rid, {}), graph_walk=batch_N.graph_walk, partition_name=batch_N.partition, node_speculatively_scheduled=batch_N.batch.node_objects[rid]._speculatively_scheduled From 7e1a7bb0be02bf5ebe8eeaf18a724e099341356b Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 7 Jun 2026 09:06:24 +0000 Subject: [PATCH 09/16] remove print --- mstar/streaming/stream_buffer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index c6692d22..09a052bb 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -68,7 +68,6 @@ def pre_read_register(self): def _update_buffer(self): while self._current_index in self._tensors: - print(f"{self.edge_name} INGEST {self._current_index}") self._buffer.append(self._tensors.pop(self._current_index)) self._current_index += 1 From 7812b6d4a68cb09c8dc0c088da966ceed91739c6 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 7 Jun 2026 21:46:27 +0000 Subject: [PATCH 10/16] add talker pd disag config --- configs/qwen3omni_talker_pd_disag.yaml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 configs/qwen3omni_talker_pd_disag.yaml diff --git a/configs/qwen3omni_talker_pd_disag.yaml b/configs/qwen3omni_talker_pd_disag.yaml new file mode 100644 index 00000000..5bc175bd --- /dev/null +++ b/configs/qwen3omni_talker_pd_disag.yaml @@ -0,0 +1,21 @@ +model: "qwen3_omni" +max_seq_len: 32768 +node_groups: + - node_names: [audio_encoder, vision_encoder, Code2Wav] + ranks: [0] + + - node_names: [Thinker] + ranks: [0] + graph_walks: [prefill_text, prefill_audio, prefill_vision] + + - node_names: [Thinker] + ranks: [1] + graph_walks: [thinker_decode] + + - node_names: [Talker] + ranks: [2] + graph_walks: [talker_prefill] + + - node_names: [Talker] + ranks: [3] + graph_walks: [talker_last_prefill, talker_decode] \ No newline at end of file From d663bbe2bd8e6c0e0fcf2d6736e238610373eb4e Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 7 Jun 2026 22:28:20 +0000 Subject: [PATCH 11/16] move producer done to its own message type (was riding on INPUT_SIGNALS and erroneously resetting the step metadata --- mstar/conductor/conductor.py | 25 ++++++------------------- mstar/streaming/chunk_policy.py | 1 - mstar/streaming/stream_buffer.py | 2 +- mstar/streaming/topology.py | 2 ++ mstar/utils/ipc_format.py | 7 +++++++ mstar/worker/worker.py | 24 +++++++++++++++--------- 6 files changed, 31 insertions(+), 30 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 0dd4edb4..a2901ce7 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -34,6 +34,7 @@ InputSignals, NewRequest, NewRequestConductor, + ProducerDone, RemoveRequest, UnpersistTensors, WorkerGraphsDone, @@ -1017,9 +1018,10 @@ def _process_done_forward( conn.producer_done = True self._send_producer_done(request_id, conn.from_partition, conn.to_partition) else: - # Always send partition inputs so that stream buffer counts are updated properly + # Call even with no inputs: a no-op for self-triggering partitions + # (empty inputs_per_worker → no messages), but lets producer-triggered + # consumers get seeded (PD) / their consumed_edge_idx propagated. self._send_partition_inputs(request_id, partition_name, fwd_args) - # else: no inputs — partition self-triggers via StreamBuffer self._un_persist_tensors(request_id, fwd_args.unpersist_tensors) @@ -1085,7 +1087,6 @@ def _send_producer_done( ): """Send producer_done signal to the consumer partition's worker(s).""" request_data = self.requests[request_id] - pstate = request_data.partition_states[consumer_partition_name] # Find which workers handle this consumer partition consumer_workers = set() @@ -1097,23 +1098,9 @@ def _send_producer_done( for worker_id in consumer_workers: message = WorkerMessage( - message_type=WorkerMessageType.INPUT_SIGNALS, - body=InputSignals( + message_type=WorkerMessageType.PRODUCER_DONE, + body=ProducerDone( request_id=request_id, - inputs=[], - request_info=CurrentForwardPassInfo( - request_id=request_id, - graph_walk=pstate.metadata.graph_walk or "", - fwd_index=pstate.fwd_pass_number, - random_seed=pstate.random_seed, - requires_cfg=False, - partition_name=consumer_partition_name, - max_tokens=request_data.max_output_tokens, - sampling_config=request_data.sampling_config, - produced_edge_idx=pstate.produced_edge_idx, - consumed_edge_idx=pstate.consumed_edge_idx, - tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, - ), partition_name=consumer_partition_name, producer_done=set([producer_partition]), ), diff --git a/mstar/streaming/chunk_policy.py b/mstar/streaming/chunk_policy.py index 1964b46a..b75d87ca 100644 --- a/mstar/streaming/chunk_policy.py +++ b/mstar/streaming/chunk_policy.py @@ -135,7 +135,6 @@ class FixedChunkPolicy(ChunkPolicy): the producer finishes and all buffered items are consumed. """ - # TODO: make continue_after_done a set of graph walks def __init__(self, chunk_size: int, continue_after_done: set[str] | None=None): super().__init__() self._chunk_size = chunk_size diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 09a052bb..15973d3f 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -79,6 +79,7 @@ def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> the duplicate is dropped (first-arrival-wins). This handles the case where multiple colocated producer ranks emit the same streaming item. """ + self._num_buffer_writes += 1 if index < self._current_index or index in self._tensors: return self._tensors[index] = StreamingTensor( @@ -86,7 +87,6 @@ def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> tensor=item, graph_walk=graph_walk, ) - self._num_buffer_writes += 1 def signal_done(self) -> None: """Producer signals no more items will arrive.""" diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index bc331e3e..f357880d 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -32,6 +32,8 @@ def clone(self): is_streaming=self.is_streaming, output_modality=self.output_modality, _persist_for_loop=self._persist_for_loop, + _total_fanin=self._total_fanin, + _shard_dim=self._shard_dim, _target_graph_walk=self._target_graph_walk, target_partition=self.target_partition, _index=self._index, diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index 837b81b6..ec8a154a 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -30,6 +30,7 @@ class WorkerMessageType(Enum): NEW_REQUEST = "new_request" REMOVE_REQUEST = "remove_request" INPUT_SIGNALS = "input_signals" + PRODUCER_DONE = "producer_done" UNPERSIST_TENSORS = "unpersist" TENSOR_RECEIVED = "tensor_received" SCHEDULE_TP = "schedule_tp" @@ -56,6 +57,12 @@ class InputSignals(MessageBody): inputs: list[GraphEdge] request_info: CurrentForwardPassInfo partition_name: str = "default" + + +@dataclass +class ProducerDone(MessageBody): + request_id: str + partition_name: str producer_done: set = field(default_factory=set) diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index f27e38c0..4a87d7d7 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -31,6 +31,7 @@ ConductorMessageType, InputSignals, NewRequest, + ProducerDone, RemoveRequest, ScheduleTPNode, SetupDone, @@ -460,6 +461,16 @@ def _handle_tensor_received(self, body: TensorReceived) -> None: self.tensor_manager.dereference( body.request_id, uuid, n=ref_cnt ) + + def _handle_producer_done(self, body: ProducerDone) -> None: + # Handle producer_done signal: mark all StreamBuffers for this request as done + req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) + if req_info: + for sbuf in req_info.stream_buffers.values(): + if sbuf.from_partition in body.producer_done: + # If we have multiple consumer partitions colocated, we need to signal + # the right one + sbuf.signal_done() def _process_new_inputs(self, body: InputSignals) -> None: logger.debug( @@ -470,14 +481,6 @@ def _process_new_inputs(self, body: InputSignals) -> None: if self.enable_nvtx: range_push("process_new_inputs.routing_update") - # Handle producer_done signal: mark all StreamBuffers for this request as done - if body.producer_done: - if req_info: - for sbuf in req_info.stream_buffers.values(): - if sbuf.from_partition in body.producer_done: - # If we have multiple consumer partitions colocated, we need to signal - # the right one - sbuf.signal_done() # Separate streaming edges — they'll be handled when tensors are ready # (streaming edges with tensor_info go through RDMA, handled in _check_ready_tensors) @@ -575,7 +578,8 @@ def _process_message_list(self, messages: list[WorkerMessage]): msg_types_needing_active_request = [ WorkerMessageType.REMOVE_REQUEST, WorkerMessageType.INPUT_SIGNALS, - WorkerMessageType.STOP_LOOPS + WorkerMessageType.STOP_LOOPS, + WorkerMessageType.PRODUCER_DONE ] for message in messages: if ( @@ -593,6 +597,8 @@ def _process_message_list(self, messages: list[WorkerMessage]): self._remove_request(message.body) elif message.message_type == WorkerMessageType.INPUT_SIGNALS: self._process_new_inputs(message.body) + elif message.message_type == WorkerMessageType.PRODUCER_DONE: + self._handle_producer_done(message.body) elif message.message_type == WorkerMessageType.TENSOR_RECEIVED: self._handle_tensor_received(message.body) elif message.message_type == WorkerMessageType.UNPERSIST_TENSORS: From 1577dc6436b0b55509277e66b10ecc37c903061b Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 7 Jun 2026 22:44:09 +0000 Subject: [PATCH 12/16] some more cleanup --- mstar/conductor/conductor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index a2901ce7..160fc731 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -899,6 +899,10 @@ def _process_worker_graphs_done( pstate.produced_edge_idx[name] = max(pstate.produced_edge_idx.get(name, 0), idx) for name, idx in body.new_consumed_edge_idx.items(): pstate.consumed_edge_idx[name] = max(pstate.consumed_edge_idx.get(name, 0), idx) + # Unlike the indices above, graph walks have no ordering to max-merge on, + # so this is a plain last-writer-wins update. Contract: a reporting rank + # must send either the correct (just-applied) consumer walk or nothing — + # a non-emitting TP sibling sends an empty dict, never a stale walk. pstate.tracked_consumer_graph_walks.update(body.consumer_graph_walk_transitions) # Persist signals: every rank contributes its shard (different uuid + From 1475c2869f5fd1d13876ab8340e40d621153dee7 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Fri, 12 Jun 2026 17:49:03 +0000 Subject: [PATCH 13/16] ruff --- mstar/conductor/conductor.py | 4 ++-- mstar/model/qwen3_omni/qwen3_omni_model.py | 12 +++++++++--- mstar/streaming/topology.py | 2 +- mstar/worker/node_manager_utils.py | 7 +++---- mstar/worker/worker.py | 4 ++-- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 160fc731..5f4ea641 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -759,7 +759,7 @@ def _do_ingest_request( body.request_id, p.name, fwd_args.full_metadata.graph_walk, ) partition_fwd_args[p.name] = fwd_args - + # set up tracked_consumer_graph_walks for conn in topology.connections: producer = partition_states[conn.from_partition] @@ -890,7 +890,7 @@ def _process_worker_graphs_done( partition_name, body.request_id, ) return [] - + # Stream indices are monotonic, but only the emitting TP rank advances # them; sibling ranks report the (stale) seed they were given. With # last-writer-wins .update(), a sibling landing after the emitter would diff --git a/mstar/model/qwen3_omni/qwen3_omni_model.py b/mstar/model/qwen3_omni/qwen3_omni_model.py index 5ab7032f..04e028c7 100644 --- a/mstar/model/qwen3_omni/qwen3_omni_model.py +++ b/mstar/model/qwen3_omni/qwen3_omni_model.py @@ -46,12 +46,18 @@ from mstar.engine.kv_store import KVCacheConfig from mstar.graph.base import GraphEdge, GraphNode, Loop, Sequential, TensorPointerInfo from mstar.graph.special_destinations import EMIT_TO_CLIENT, EMPTY_DESTINATION -from mstar.model.base import ForwardPassArgs, MAX_OUTPUT_TOKENS, Model, TensorAndMetadata +from mstar.model.base import MAX_OUTPUT_TOKENS, ForwardPassArgs, Model, TensorAndMetadata from mstar.model.qwen3_omni.components.talker import Qwen3OmniCodePredictor from mstar.model.submodule_base import NodeSubmodule from mstar.model.utils import Operation, WeightConverter from mstar.streaming.chunk_policy import FixedChunkPolicy, LeftContextChunkPolicy -from mstar.streaming.topology import Connection, ConsumerTransitionCtx, PartitionTopology, StreamingGraphEdge, WalkTransition +from mstar.streaming.topology import ( + Connection, + ConsumerTransitionCtx, + PartitionTopology, + StreamingGraphEdge, + WalkTransition, +) from mstar.utils.sampling import SamplingConfig logger = logging.getLogger(__name__) @@ -451,7 +457,7 @@ def talker_state_transition(ctx: ConsumerTransitionCtx) -> WalkTransition: if ctx.consumer_walk == "talker_prefill": return WalkTransition("talker_last_prefill") return WalkTransition("talker_decode") - + return PartitionTopology( partitions=["Thinker", "Talker", "Code2Wav"], connections=[ diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index f357880d..6bf640e7 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -21,7 +21,7 @@ class StreamingGraphEdge(GraphEdge): def __post_init__(self): self.is_streaming = True - + def clone(self): return StreamingGraphEdge( next_node=self.next_node, diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index a47890a9..aa6f2250 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -18,7 +18,6 @@ from mstar.graph.special_destinations import EMIT_TO_CLIENT, SPECIAL_DESTINATIONS from mstar.model.base import WorkerGraph from mstar.streaming.stream_buffer import StreamBuffer -from mstar.streaming.topology import StreamingGraphEdge logger = logging.getLogger(__name__) @@ -193,7 +192,7 @@ class PerPartitionInfo: # graph_walk_worker_graph_ids = worker graphs for current graph walk graph_walk_worker_graph_ids: list[str] = field(default_factory=list) # for this worker stream_partition_done: bool = False # set True when last chunk pops with is_final - + # edges that are pending a producer-triggered graph walk transition # {graph walk -> edges} pending_edges: dict[str, list[GraphEdge]] = field(default_factory=dict) @@ -319,7 +318,7 @@ def update_request_info( if per_label_seq_info is not None: fwd_info = self.get_fwd_info(request_id, partition_name) fwd_info.per_label_seq_info.update(per_label_seq_info) - + def update_graph_walk(self, request_id: str, partition_name: str, graph_walk: str): part_info = self.per_request_info[request_id].per_partition_info[partition_name] if self.get_graph_walk(request_id, partition_name) != graph_walk: @@ -353,7 +352,7 @@ def get_fwd_info(self, request_id: str, partition_name: str): def get_partition_for_node(self, node_name: str) -> str | None: """Look up which partition a node belongs to.""" return self.node_to_partition.get(node_name) - + def partition_clean(self, request_id: str, partition_name: str) -> bool: if request_id not in self.per_request_info: return True diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 4a87d7d7..17e30449 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -461,7 +461,7 @@ def _handle_tensor_received(self, body: TensorReceived) -> None: self.tensor_manager.dereference( body.request_id, uuid, n=ref_cnt ) - + def _handle_producer_done(self, body: ProducerDone) -> None: # Handle producer_done signal: mark all StreamBuffers for this request as done req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) @@ -642,7 +642,7 @@ def _pop_streaming_edge( graph_walk = self.worker_graphs_manager.get_graph_walk( request_id, consumer_partition ) - + waiting_edge = sbuf.pop_waiting_edge() if waiting_edge is not None and waiting_edge.walk_transition is not None \ and waiting_edge.walk_transition != graph_walk: From bd565ec63e93ea5854002b1ee40d73a1b9c15393 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Fri, 12 Jun 2026 19:00:16 +0000 Subject: [PATCH 14/16] Some cleanup --- configs/qwen3omni_talker_pd_disag.yaml | 9 ++++++--- mstar/streaming/chunk_policy.py | 1 - mstar/streaming/stream_buffer.py | 4 ++++ mstar/streaming/topology.py | 6 +++++- mstar/worker/worker.py | 20 +++++++++++--------- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/configs/qwen3omni_talker_pd_disag.yaml b/configs/qwen3omni_talker_pd_disag.yaml index 5bc175bd..7b6265de 100644 --- a/configs/qwen3omni_talker_pd_disag.yaml +++ b/configs/qwen3omni_talker_pd_disag.yaml @@ -1,7 +1,7 @@ model: "qwen3_omni" max_seq_len: 32768 node_groups: - - node_names: [audio_encoder, vision_encoder, Code2Wav] + - node_names: [audio_encoder, vision_encoder] ranks: [0] - node_names: [Thinker] @@ -14,8 +14,11 @@ node_groups: - node_names: [Talker] ranks: [2] - graph_walks: [talker_prefill] + graph_walks: [talker_prefill, talker_last_prefill] - node_names: [Talker] ranks: [3] - graph_walks: [talker_last_prefill, talker_decode] \ No newline at end of file + graph_walks: [talker_decode] + + - node_names: [Code2Wav] + ranks: [3] diff --git a/mstar/streaming/chunk_policy.py b/mstar/streaming/chunk_policy.py index b75d87ca..1ca1bee1 100644 --- a/mstar/streaming/chunk_policy.py +++ b/mstar/streaming/chunk_policy.py @@ -11,7 +11,6 @@ def register_chunk(self, chunk_size: int): self.first_chunk_read = True self.items_consumed += chunk_size - # TODO: add graph walks to the methods in the implementors... @abstractmethod def is_ready(self, buffer_len: int, graph_walk: str | None=None) -> bool: """Return True if the buffer has enough items for a chunk.""" diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 15973d3f..13b6379e 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -79,6 +79,10 @@ def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> the duplicate is dropped (first-arrival-wins). This handles the case where multiple colocated producer ranks emit the same streaming item. """ + # Counts put attempts, not unique items: incremented before the dedup + # return so it stays balanced with ``_num_tensors_registered`` (which + # ``pre_read_register`` also bumps once per registered tensor, including + # duplicates). ``_producer_done_and_all_read`` relies on that symmetry. self._num_buffer_writes += 1 if index < self._current_index or index in self._tensors: return diff --git a/mstar/streaming/topology.py b/mstar/streaming/topology.py index 6bf640e7..12b8939a 100644 --- a/mstar/streaming/topology.py +++ b/mstar/streaming/topology.py @@ -45,13 +45,17 @@ def clone(self): class ConsumerTransitionCtx: producer_walk: str consumer_walk: str | None # None on the very first trigger + # Unused by qwen3omni's transition fn, but exposed so a future model can + # base its transition on the producer's full forward-pass state. This is + # the sole reason streaming depends on conductor.request_info; revisit + # (e.g. a lighter type) if that dependency direction becomes a problem. producer_fwd: CurrentForwardPassInfo @dataclass class WalkTransition: graph_walk: str | None = None - # TODO: hook up metadata if needed + # TODO [FUTURE]: hook up metadata if needed @dataclass class Connection: diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 17e30449..8b1b603a 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -643,6 +643,17 @@ def _pop_streaming_edge( request_id, consumer_partition ) + sbuf._update_buffer() + # Note: "Consumed" here means drained into the buffer (high-water index), + # not popped out of it. These coincide for non-overlapping policies. + # For left-context/sliding-window policies the buffer retains items + # past this index as overlap, so a PD re-seed via set_index() would + # skip them. No sliding-window consumer spans multiple graph walks; + # revisit if that changes. + self.worker_graphs_manager.get_fwd_info( + request_id, consumer_partition + ).consumed_edge_idx[edge_name] = sbuf._current_index + waiting_edge = sbuf.pop_waiting_edge() if waiting_edge is not None and waiting_edge.walk_transition is not None \ and waiting_edge.walk_transition != graph_walk: @@ -693,15 +704,6 @@ def _pop_streaming_edge( request_id, consumer_partition, chunk.graph_walk_transition ) - # "Consumed" here means drained into the buffer (high-water index), - # not popped out of it. These coincide for non-overlapping policies. - # For left-context/sliding-window policies the buffer retains items - # past this index as overlap, so a PD re-seed via set_index() would - # skip them — fine today since no sliding-window consumer spans - # multiple graph walks; revisit if that changes. - self.worker_graphs_manager.get_fwd_info( - request_id, consumer_partition - ).consumed_edge_idx[edge_name] = sbuf._current_index return synthetic_edge def _poll_stream_buffers_for_speculation( From 22d7025f7fa33be9bb27d71f020902e78f95bd16 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Fri, 12 Jun 2026 21:08:45 +0000 Subject: [PATCH 15/16] set the consumed_edge_index as the index after the node is run --- mstar/conductor/conductor.py | 10 +++++----- mstar/conductor/request_info.py | 23 +++++++++++++++++++++-- mstar/graph/base.py | 6 ++++++ mstar/streaming/stream_buffer.py | 5 +++++ mstar/utils/ipc_format.py | 2 +- mstar/worker/node_manager_utils.py | 7 +++++++ mstar/worker/worker.py | 28 +++++++++++++++------------- 7 files changed, 60 insertions(+), 21 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 5f4ea641..90e702e6 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -803,7 +803,7 @@ def _do_ingest_request( max_tokens=request_data.max_output_tokens, sampling_config=request_data.sampling_config, produced_edge_idx=pstate.produced_edge_idx, - consumed_edge_idx=pstate.consumed_edge_idx, + next_stream_index=pstate.next_stream_index, tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), ) @@ -897,8 +897,8 @@ def _process_worker_graphs_done( # rewind the index, re-emitting an already-used one. Merge with max. for name, idx in body.new_produced_edge_idx.items(): pstate.produced_edge_idx[name] = max(pstate.produced_edge_idx.get(name, 0), idx) - for name, idx in body.new_consumed_edge_idx.items(): - pstate.consumed_edge_idx[name] = max(pstate.consumed_edge_idx.get(name, 0), idx) + for name, idx in body.new_next_stream_index.items(): + pstate.next_stream_index[name] = max(pstate.next_stream_index.get(name, 0), idx) # Unlike the indices above, graph walks have no ordering to max-merge on, # so this is a plain last-writer-wins update. Contract: a reporting rank # must send either the correct (just-applied) consumer walk or nothing — @@ -1024,7 +1024,7 @@ def _process_done_forward( else: # Call even with no inputs: a no-op for self-triggering partitions # (empty inputs_per_worker → no messages), but lets producer-triggered - # consumers get seeded (PD) / their consumed_edge_idx propagated. + # consumers get seeded (PD) / their next_stream_index propagated. self._send_partition_inputs(request_id, partition_name, fwd_args) self._un_persist_tensors(request_id, fwd_args.unpersist_tensors) @@ -1077,7 +1077,7 @@ def _send_partition_inputs( max_tokens=request_data.max_output_tokens, sampling_config=request_data.sampling_config, produced_edge_idx=pstate.produced_edge_idx, - consumed_edge_idx=pstate.consumed_edge_idx, + next_stream_index=pstate.next_stream_index, tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks, ), partition_name=partition_name diff --git a/mstar/conductor/request_info.py b/mstar/conductor/request_info.py index 3a0023bb..1434b550 100644 --- a/mstar/conductor/request_info.py +++ b/mstar/conductor/request_info.py @@ -48,6 +48,23 @@ def update(self, other: "PerLabelSeqInfo"): **val } + def merge_keep_longest(self, other: "PerLabelSeqInfo"): + """Merge ``other`` into self, keeping — per (kv_cache_string, rank, + label) — the ``SequenceInfo`` with the larger ``seq_len``. + + Used on the worker to stop a lagging conductor copy from rewinding a + request's locally-advanced KV sequence positions: a request's KV length + only ever grows, so the longer entry is always the more recent one. + """ + for key, label_info in other.info.items(): + dst = self.info.setdefault(key, {}) + for label, seq_info in label_info.items(): + cur = dst.get(label) + if cur is None or seq_info.seq_len > cur.seq_len: + dst[label] = seq_info + for kv_cache_str, ws in other.world_size.items(): + self.world_size.setdefault(kv_cache_str, ws) + def get(self, kv_cache_str: str, rank: int) -> dict: return self.info.get((kv_cache_str, rank), {}) @@ -77,7 +94,7 @@ class CurrentForwardPassInfo: per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") produced_edge_idx: dict[str, int] = field(default_factory=dict) - consumed_edge_idx: dict[str, int] = field(default_factory=dict) + next_stream_index: dict[str, int] = field(default_factory=dict) tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) # Per-loop stop indices; stop decisions come from each submodule's check_stop. @@ -141,7 +158,9 @@ class PartitionState: is_done: bool = False # streaming edge name -> number of times this edge has been emitted produced_edge_idx: dict[str, int] = field(default_factory=dict) - consumed_edge_idx: dict[str, int] = field(default_factory=dict) + # streaming edge name -> next stream index the consumer should drain + # (one past the last index a completed consumer pass has consumed) + next_stream_index: dict[str, int] = field(default_factory=dict) # only for producer-driven graph walk transitions tracked_consumer_graph_walks: dict[str, str] = field(default_factory=dict) new_tokens: dict[str, list[int]] = field(default_factory=dict) diff --git a/mstar/graph/base.py b/mstar/graph/base.py index 56198f41..f1e509ce 100644 --- a/mstar/graph/base.py +++ b/mstar/graph/base.py @@ -63,6 +63,11 @@ class GraphEdge: # only for EMIT_TO_CLIENT output_modality: str = field(default="") # text | image | video | audio _persist_for_loop: bool = field(default=False) + + # set on a synthetic streaming-input edge: the next stream index the + # consumer should drain after the pass that consumes this edge (None on + # non-streaming edges, which carry no stream position) + _next_stream_index: int | None = field(default=None) # set on a synthetic streaming-input edge carrying the final chunk, so the # consuming pass (not the earlier ingest) reports the partition done _final_stream_chunk: bool = field(default=False) @@ -84,6 +89,7 @@ def clone(self): is_streaming=self.is_streaming, output_modality=self.output_modality, _persist_for_loop=self._persist_for_loop, + _next_stream_index=self._next_stream_index, _final_stream_chunk=self._final_stream_chunk, _target_graph_walk=self._target_graph_walk, ) diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 13b6379e..96d24e93 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -12,6 +12,10 @@ class StreamChunk: """A chunk of data popped from a StreamBuffer.""" data: dict[str, torch.Tensor | None] chunk_index: int + # the next stream index to drain after this chunk is consumed (one past the + # last item in the chunk); used to re-seed a worker that takes over the + # partition (e.g. prefill->decode PD handoff) + next_stream_index: int start_offset: int = 0 # global position of the first item in this chunk is_final: bool = False graph_walk_transition: str | None = None @@ -213,6 +217,7 @@ def pop_chunk(self, graph_walk: str) -> StreamChunk: chunk = StreamChunk( data=self._collate([it.tensor for it in items]), chunk_index=self._chunks_popped, + next_stream_index=items[-1].index + 1 if items else self._current_index, start_offset=offset, is_final=is_final, graph_walk_transition=transition, diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index ec8a154a..09f57018 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -127,7 +127,7 @@ class WorkerGraphsDone(MessageBody): new_tokens: dict[str, list[int]] = field(default_factory=dict) # name to tokens output_signal_names: int = field(default=0) new_produced_edge_idx: dict[str, int] = field(default_factory=dict) - new_consumed_edge_idx: dict[str, int] = field(default_factory=dict) + new_next_stream_index: dict[str, int] = field(default_factory=dict) consumer_graph_walk_transitions: dict[str, str] = field(default_factory=dict) per_label_seq_info: PerLabelSeqInfo = field(default_factory=PerLabelSeqInfo) partition_name: str = field(default="default") diff --git a/mstar/worker/node_manager_utils.py b/mstar/worker/node_manager_utils.py index aa6f2250..4e4ca382 100644 --- a/mstar/worker/node_manager_utils.py +++ b/mstar/worker/node_manager_utils.py @@ -304,6 +304,13 @@ def update_request_info( current_fwd_info.produced_edge_idx[name] = max( current_fwd_info.produced_edge_idx.get(name, 0), idx ) + # The conductor's per_label_seq_info lags (refreshed only at + # WorkerGraphsDone), so the wholesale replace below could rewind + # the KV write position mid-decode and overwrite live KV. Keep + # the locally-advanced (longer) per-(cache, label) seq info. + current_fwd_info.per_label_seq_info.merge_keep_longest( + old_fwd_info.per_label_seq_info + ) if allow_graph_walk_transition: self.update_graph_walk(request_id, partition_name, current_fwd_info.graph_walk) else: diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 8b1b603a..63b54cac 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -504,7 +504,7 @@ def _process_new_inputs(self, body: InputSignals) -> None: partition_name=body.partition_name, allow_graph_walk_transition=body.partition_name not in self._producer_triggered_partitions ) - for edge, idx in body.request_info.consumed_edge_idx.items(): + for edge, idx in body.request_info.next_stream_index.items(): req_info = self.worker_graphs_manager.per_request_info.get(body.request_id) if edge in req_info.stream_buffers: req_info.stream_buffers[edge].set_index(idx) @@ -643,17 +643,6 @@ def _pop_streaming_edge( request_id, consumer_partition ) - sbuf._update_buffer() - # Note: "Consumed" here means drained into the buffer (high-water index), - # not popped out of it. These coincide for non-overlapping policies. - # For left-context/sliding-window policies the buffer retains items - # past this index as overlap, so a PD re-seed via set_index() would - # skip them. No sliding-window consumer spans multiple graph walks; - # revisit if that changes. - self.worker_graphs_manager.get_fwd_info( - request_id, consumer_partition - ).consumed_edge_idx[edge_name] = sbuf._current_index - waiting_edge = sbuf.pop_waiting_edge() if waiting_edge is not None and waiting_edge.walk_transition is not None \ and waiting_edge.walk_transition != graph_walk: @@ -676,6 +665,7 @@ def _pop_streaming_edge( next_node=consumer_node, name=edge_name, tensor_info=[], + _next_stream_index=chunk.next_stream_index, _final_stream_chunk=chunk.is_final, ) else: @@ -694,6 +684,7 @@ def _pop_streaming_edge( next_node=consumer_node, name=edge_name, tensor_info=tensor_infos.get(edge_name, []), + _next_stream_index=chunk.next_stream_index, _final_stream_chunk=chunk.is_final, ) if chunk.graph_walk_transition is not None and chunk.graph_walk_transition != graph_walk: @@ -1113,7 +1104,7 @@ def _send_outputs( output_signal_names=self.worker_graphs_manager.flush_output_signals(request_id), per_label_seq_info=fwd_info.per_label_seq_info, new_produced_edge_idx=fwd_info.produced_edge_idx, - new_consumed_edge_idx=fwd_info.consumed_edge_idx, + new_next_stream_index=fwd_info.next_stream_index, # Key by consumer partition (not edge name) so it merges # into the producer pstate's partition-keyed # tracked_consumer_graph_walks on the conductor. @@ -1650,6 +1641,17 @@ def _postprocess_batch( self, batch_N: PendingBatch, output: NodeOutput, ): + for rid, node in batch_N.batch.node_objects.items(): + for edge_name, edge in node.ready_signals.ready_inputs.items(): + # Only synthetic streaming-input edges carry a next stream index; + # skip non-streaming inputs (None) so they don't write spurious + # zero entries into the reported next_stream_index. + if edge._next_stream_index is None: + continue + self.worker_graphs_manager.get_fwd_info( + rid, batch_N.partition + ).next_stream_index[edge_name] = edge._next_stream_index + if self.enable_nvtx: range_push("worker.postprocess.cleanup_inputs", synchronize=False) self._cleanup_consumed_inputs(batch_N.batch) From fa3cf35a0608437b84701c20b397daaedbf03f59 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Fri, 12 Jun 2026 23:35:22 +0000 Subject: [PATCH 16/16] (1) gate partition done on reaching producer_edge_idx, (2) fix existing bug in the IPC kv transfer --- mstar/conductor/conductor.py | 2 ++ mstar/engine/kv_store.py | 4 ++-- mstar/streaming/stream_buffer.py | 10 +++++++--- mstar/utils/ipc_format.py | 1 + mstar/worker/worker.py | 4 +++- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mstar/conductor/conductor.py b/mstar/conductor/conductor.py index 90e702e6..c9d5b55d 100644 --- a/mstar/conductor/conductor.py +++ b/mstar/conductor/conductor.py @@ -1091,6 +1091,7 @@ def _send_producer_done( ): """Send producer_done signal to the consumer partition's worker(s).""" request_data = self.requests[request_id] + pstate = request_data.partition_states[producer_partition] # Find which workers handle this consumer partition consumer_workers = set() @@ -1106,6 +1107,7 @@ def _send_producer_done( body=ProducerDone( request_id=request_id, partition_name=consumer_partition_name, + last_produced_edge_idx=pstate.produced_edge_idx, producer_done=set([producer_partition]), ), ) diff --git a/mstar/engine/kv_store.py b/mstar/engine/kv_store.py index 380d5755..9cbfcbfc 100644 --- a/mstar/engine/kv_store.py +++ b/mstar/engine/kv_store.py @@ -378,11 +378,11 @@ def _do_read( for info in read_info: slice = tensor[ - info.layer_idx, info.remote_page_idx, + info.layer_idx, info.remote_page_idx, :, info.token_start:info.token_end ].to(self._device) self._kv_cache[ - info.layer_idx, info.local_page_idx, + info.layer_idx, info.local_page_idx, :, info.token_start:info.token_end ] = slice diff --git a/mstar/streaming/stream_buffer.py b/mstar/streaming/stream_buffer.py index 96d24e93..82385b91 100644 --- a/mstar/streaming/stream_buffer.py +++ b/mstar/streaming/stream_buffer.py @@ -59,6 +59,7 @@ class StreamBuffer: _consumed: int = 0 _chunks_popped: int = 0 producer_done: bool = False + _producer_edge_idx: int | None = None _num_tensors_registered = 0 _num_buffer_writes = 0 @@ -96,9 +97,10 @@ def put(self, item: torch.Tensor, index: int, graph_walk: str | None = None) -> graph_walk=graph_walk, ) - def signal_done(self) -> None: + def signal_done(self, producer_edge_idx: int | None=None) -> None: """Producer signals no more items will arrive.""" self.producer_done = True + self._producer_edge_idx = producer_edge_idx def set_index(self, index: int): """Seed the next index to drain (e.g. when a new consumer worker takes @@ -112,8 +114,10 @@ def set_index(self, index: int): self._current_index = max(self._current_index, index) def _producer_done_and_all_read(self) -> bool: - return self.producer_done and \ - self._num_buffer_writes >= self._num_tensors_registered + return self.producer_done and ( + self._producer_edge_idx is None or \ + self._current_index >= self._producer_edge_idx + ) and self._num_buffer_writes >= self._num_tensors_registered def pop_waiting_edge(self) -> WaitingEdge | None: if len(self._waiting_graph_edges) > 0: diff --git a/mstar/utils/ipc_format.py b/mstar/utils/ipc_format.py index 09f57018..984e273d 100644 --- a/mstar/utils/ipc_format.py +++ b/mstar/utils/ipc_format.py @@ -63,6 +63,7 @@ class InputSignals(MessageBody): class ProducerDone(MessageBody): request_id: str partition_name: str + last_produced_edge_idx: dict[str, int] producer_done: set = field(default_factory=set) diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 63b54cac..61318679 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -470,7 +470,9 @@ def _handle_producer_done(self, body: ProducerDone) -> None: if sbuf.from_partition in body.producer_done: # If we have multiple consumer partitions colocated, we need to signal # the right one - sbuf.signal_done() + sbuf.signal_done( + body.last_produced_edge_idx.get(sbuf.edge_name) + ) def _process_new_inputs(self, body: InputSignals) -> None: logger.debug(