Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions configs/qwen3omni_talker_pd_disag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
model: "qwen3_omni"
max_seq_len: 32768
node_groups:
- node_names: [audio_encoder, vision_encoder]
ranks: [0]

- node_names: [Thinker]
ranks: [0]
graph_walks: [prefill_text, prefill_audio, prefill_vision]

- node_names: [Thinker]
ranks: [1]
graph_walks: [thinker_decode]

- node_names: [Talker]
ranks: [2]
graph_walks: [talker_prefill, talker_last_prefill]

- node_names: [Talker]
ranks: [3]
graph_walks: [talker_decode]

- node_names: [Code2Wav]
ranks: [3]
168 changes: 145 additions & 23 deletions mstar/conductor/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PartitionDefinition,
PartitionState,
StreamingConnectionState,
TransitionSource,
)
from mstar.distributed.base import ShardingConfig
from mstar.distributed.communication import GlobalTPConfig, WorkerTPGroups
Expand All @@ -33,6 +34,7 @@
InputSignals,
NewRequest,
NewRequestConductor,
ProducerDone,
RemoveRequest,
UnpersistTensors,
WorkerGraphsDone,
Expand Down Expand Up @@ -189,6 +191,7 @@ def __init__(
):
self.requests: dict[str, RequestData] = {}
self.model = model
self._validate_transition_sources()
self.hostname = hostname
self.socket_path_prefix = socket_path_prefix
self.log_level = log_level
Expand All @@ -215,9 +218,24 @@ def __init__(

# (1) Set up worker graph TP ranks
# (2) Assert that streaming consumers don't have graph-walk-specific sharding config
self.streaming_consumers = set()
self.state_machine_streaming_consumers = set()
state_machine_controlled_walks = set()
self.gw_to_part = {}
self.part_to_walks: dict[str, set[str]] = {}
for part_def in self.model.get_partitions():
if part_def.transition_source == TransitionSource.STATE_MACHINE:
state_machine_controlled_walks.update(part_def.graph_walks)
self.gw_to_part.update({
gw: part_def.name for gw in part_def.graph_walks
})
self.part_to_walks[part_def.name] = set(part_def.graph_walks)

self.node_walk_to_wg: dict[tuple[str, str], WorkerGraph] = {}


# partition name -> producer-triggered streaming consumer node names
self.producer_triggered_nodes: dict[str, set[str]] = dict()

# (worker idx) -> {tp_group_str: tp_rank}
self.worker_tp_group_to_tp_rank: dict[int, dict[str, int]] = {}

Expand All @@ -226,18 +244,28 @@ def __init__(
for walk in wg.graph_walks:
graph_walks.add(walk)
for name, node in wg.section.get_nodes().items():
state_machine_controlled = False
for walk in wg.graph_walks:
self.node_walk_to_wg[(name, walk)] = wg
if node.consumes_stream:
self.streaming_consumers.add(name)
if walk in state_machine_controlled_walks:
state_machine_controlled = True
if node.consumes_stream and state_machine_controlled:
self.state_machine_streaming_consumers.add(name)
elif node.consumes_stream and wg.graph_walks:
# Producer-triggered streaming consumer (e.g. PD-disaggregated
# Talker). All walks of a wg belong to one partition, so any
# of them resolves the partition.
self.producer_triggered_nodes.setdefault(
self.gw_to_part[next(iter(wg.graph_walks))], set()
).add(name)

# v1: one sharding group per worker graph. Track which group "owns"
# each wg so we can assert single-group-per-wg.
wg_to_owning_group: dict[str, str] = {}

for group in self.default_sharding_config.groups:
if group.graph_walks is not None and any([
node in self.streaming_consumers for node in group.nodes
node in self.state_machine_streaming_consumers for node in group.nodes
]):
raise RuntimeError((
f"Sharding group with nodes {group.nodes} includes a streaming consumer but "
Expand Down Expand Up @@ -459,7 +487,7 @@ def _build_request_sharding_config(
for node_name in wg.section.get_nodes():
node_to_workers[NodeAndGraphWalk(node_name, walk)] = worker_ids
cfg.setup(node_to_workers)
cfg.assert_stream_consumer_compatibility(self.streaming_consumers)
cfg.assert_stream_consumer_compatibility(self.state_machine_streaming_consumers)
return cfg

def _split_inputs_to_workers(
Expand All @@ -476,6 +504,19 @@ def _split_inputs_to_workers(
per dest, which the consumer's fan-in path consolidates.
"""
inputs_per_worker: dict[str, list[GraphEdge]] = defaultdict(list)

# Seed empty inputs for producer-triggered streaming consumers (PD disag)
# across ALL the partition's walks, not just the conductor's current
# (lagging) walk — otherwise only the prefill worker is reached, never
# decode. set_index is monotonic, so re-seeding is a no-op.
part = self.gw_to_part.get(graph_walk)
for node in self.producer_triggered_nodes.get(part, set()):
for walk in self.part_to_walks.get(part, set()):
for worker in sharding_config.node_to_worker.get(
NodeAndGraphWalk(node, walk), [],
):
inputs_per_worker.setdefault(worker, [])

for edge in inputs:
if not edge.tensor_info:
# Signal-only — broadcast to every dest worker.
Expand Down Expand Up @@ -564,6 +605,52 @@ def _try_admit_waiting(self):
)
self._do_ingest_request(body)

def _validate_transition_sources(self):
"""Validate that graph-walk transition authority is unambiguous.

Each partition's walk is moved by exactly one source. A
``PRODUCER_TRIGGERED`` partition must be driven by at least one streaming
connection carrying a ``consumer_walk_transition``; a ``STATE_MACHINE``
partition must not be driven by any. Fails fast at conductor startup.

(A producer-triggered partition may have several driving connections —
e.g. qwen3omni's thinker_states + thinker_mask both carry the marker so
each lands in the right walk.)
"""
pdefs = {p.name: p for p in self.model.get_partitions()}
topology = self.model.get_partition_topology()

# Collect transition-carrying connections per consumer partition.
driving_conns: dict[str, list[str]] = {}
for conn in topology.connections:
if conn.consumer_walk_transition is None:
continue
consumer = pdefs.get(conn.to_partition)
if consumer is None:
raise ValueError(
f"Connection edge {conn.edge_name!r} targets unknown "
f"partition {conn.to_partition!r}."
)
if consumer.transition_source != TransitionSource.PRODUCER_TRIGGERED:
raise ValueError(
f"Connection edge {conn.edge_name!r} defines a "
f"consumer_walk_transition for partition "
f"{conn.to_partition!r}, but that partition's "
f"transition_source is {consumer.transition_source.value!r}, "
f"not 'producer_triggered'. The two transition mechanisms "
f"must not be mixed."
)
driving_conns.setdefault(conn.to_partition, []).append(conn.edge_name)

for name, pdef in pdefs.items():
if (pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED
and not driving_conns.get(name)):
raise ValueError(
f"Producer-triggered partition {name!r} has no incoming "
f"connection with a consumer_walk_transition to drive its "
f"graph-walk transitions."
)

def _ingest_request(
self, body: NewRequestConductor
):
Expand Down Expand Up @@ -673,6 +760,14 @@ def _do_ingest_request(
)
partition_fwd_args[p.name] = fwd_args

# set up tracked_consumer_graph_walks
for conn in topology.connections:
producer = partition_states[conn.from_partition]
consumer = partition_states[conn.to_partition]
if conn.consumer_walk_transition is not None:
producer.tracked_consumer_graph_walks[conn.to_partition] \
= consumer.metadata.graph_walk

# Send NewRequest to each worker with the appropriate partition's inputs
for worker_id, worker_graph_ids in worker_to_worker_graph_ids.items():
# Determine which partition this worker serves
Expand Down Expand Up @@ -706,7 +801,10 @@ def _do_ingest_request(
requires_cfg=fwd_args.full_metadata.requires_cfg,
partition_name=partition_name,
max_tokens=request_data.max_output_tokens,
sampling_config=request_data.sampling_config
sampling_config=request_data.sampling_config,
produced_edge_idx=pstate.produced_edge_idx,
next_stream_index=pstate.next_stream_index,
tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks,
),
)
self.communicator.send(
Expand Down Expand Up @@ -793,6 +891,20 @@ def _process_worker_graphs_done(
)
return []

# Stream indices are monotonic, but only the emitting TP rank advances
# them; sibling ranks report the (stale) seed they were given. With
# last-writer-wins .update(), a sibling landing after the emitter would
# rewind the index, re-emitting an already-used one. Merge with max.
for name, idx in body.new_produced_edge_idx.items():
pstate.produced_edge_idx[name] = max(pstate.produced_edge_idx.get(name, 0), idx)
for name, idx in body.new_next_stream_index.items():
pstate.next_stream_index[name] = max(pstate.next_stream_index.get(name, 0), idx)
# Unlike the indices above, graph walks have no ordering to max-merge on,
# so this is a plain last-writer-wins update. Contract: a reporting rank
# must send either the correct (just-applied) consumer walk or nothing —
# a non-emitting TP sibling sends an empty dict, never a stale walk.
pstate.tracked_consumer_graph_walks.update(body.consumer_graph_walk_transitions)

# Persist signals: every rank contributes its shard (different uuid +
# source_tp_rank); accumulate across ranks, do not dedup.
if body.persist_signals:
Expand Down Expand Up @@ -824,6 +936,22 @@ def _process_worker_graphs_done(
body.output_signal_names, list
) else []

# Producer-triggered partitions self-apply their walk from the stream.
# Adopt the walk the consumer actually ran (reported back) before the
# completion check, and realign the completion-tracking worker-graph set
# to it — otherwise current_worker_graph_ids (a state-machine-derived
# walk) diverges and the subset check below breaks. The conductor does
# not drive this partition's transitions; this only mirrors them.
pdef = request_data.partition_definitions.get(partition_name)
if (pdef is not None
and pdef.transition_source == TransitionSource.PRODUCER_TRIGGERED
and body.partition_graph_walk
and body.partition_graph_walk != pstate.metadata.graph_walk):
pstate.metadata.graph_walk = body.partition_graph_walk
self._set_partition_worker_graph_ids(
body.request_id, partition_name, body.partition_graph_walk,
)

# Each wg is only marked complete when all its TP ranks have reported.
for wg_id in body.worker_graph_ids:
count = pstate.wg_rank_completions.get(wg_id, 0) + 1
Expand Down Expand Up @@ -893,10 +1021,11 @@ def _process_done_forward(
if conn.from_partition == partition_name:
conn.producer_done = True
self._send_producer_done(request_id, conn.from_partition, conn.to_partition)
elif fwd_args.inputs:
# Partition has inputs to send — conductor-driven
else:
# Call even with no inputs: a no-op for self-triggering partitions
# (empty inputs_per_worker → no messages), but lets producer-triggered
# consumers get seeded (PD) / their next_stream_index propagated.
self._send_partition_inputs(request_id, partition_name, fwd_args)
# else: no inputs — partition self-triggers via StreamBuffer

self._un_persist_tensors(request_id, fwd_args.unpersist_tensors)

Expand Down Expand Up @@ -947,6 +1076,9 @@ def _send_partition_inputs(
partition_name=partition_name,
max_tokens=request_data.max_output_tokens,
sampling_config=request_data.sampling_config,
produced_edge_idx=pstate.produced_edge_idx,
next_stream_index=pstate.next_stream_index,
tracked_consumer_graph_walks=pstate.tracked_consumer_graph_walks,
),
partition_name=partition_name
),
Expand All @@ -959,7 +1091,7 @@ def _send_producer_done(
):
"""Send producer_done signal to the consumer partition's worker(s)."""
request_data = self.requests[request_id]
pstate = request_data.partition_states[consumer_partition_name]
pstate = request_data.partition_states[producer_partition]

# Find which workers handle this consumer partition
consumer_workers = set()
Expand All @@ -971,21 +1103,11 @@ def _send_producer_done(

for worker_id in consumer_workers:
message = WorkerMessage(
message_type=WorkerMessageType.INPUT_SIGNALS,
body=InputSignals(
message_type=WorkerMessageType.PRODUCER_DONE,
body=ProducerDone(
request_id=request_id,
inputs=[],
request_info=CurrentForwardPassInfo(
request_id=request_id,
graph_walk=pstate.metadata.graph_walk or "",
fwd_index=pstate.fwd_pass_number,
random_seed=pstate.random_seed,
requires_cfg=False,
partition_name=consumer_partition_name,
max_tokens=request_data.max_output_tokens,
sampling_config=request_data.sampling_config
),
partition_name=consumer_partition_name,
last_produced_edge_idx=pstate.produced_edge_idx,
producer_done=set([producer_partition]),
),
)
Expand Down
Loading
Loading