Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/srtctl/cli/do_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from srtctl.core.schema import SrtConfig
from srtctl.core.slurm import get_slurm_job_id, start_srun_process
from srtctl.core.status import JobStage, JobStatus, StatusReporter
from srtctl.core.topology import Endpoint, NodePortAllocator, Process
from srtctl.core.topology import Endpoint, NodePortAllocator, Process, allocate_endpoints_het
from srtctl.logging_utils import setup_logging
from srtctl.ports import (
ETCD_CLIENT_PORT,
Expand Down Expand Up @@ -86,9 +86,22 @@ def backend(self):
def endpoints(self) -> list[Endpoint]:
"""Compute endpoint allocation topology (cached).

This is the single source of truth for endpoint assignments.
This is the single source of truth for endpoint assignments. Under
SLURM heterogeneous jobs, prefill and decode workers are allocated
from their own component nodelists so neither side bleeds into the
other's topology segment.
"""
r = self.config.resources
if self.runtime.nodes.het:
return allocate_endpoints_het(
num_prefill=r.num_prefill,
gpus_per_prefill=r.gpus_per_prefill,
prefill_nodes=self.runtime.nodes.prefill_group,
num_decode=r.num_decode,
gpus_per_decode=r.gpus_per_decode,
decode_nodes=self.runtime.nodes.decode_group,
gpus_per_node=r.gpus_per_node,
)
return self.backend.allocate_endpoints(
num_prefill=r.num_prefill,
num_decode=r.num_decode,
Expand Down Expand Up @@ -151,6 +164,7 @@ def start_head_infrastructure(self, registry: ProcessRegistry) -> ManagedProcess
output=str(infra_log),
container_image=str(self.runtime.container_image),
container_mounts=mounts,
het_group=self.runtime.nodes.het_group_for(infra_node),
)

managed = ManagedProcess(
Expand Down Expand Up @@ -219,6 +233,7 @@ def start_mooncake_master(self, registry: ProcessRegistry) -> ManagedProcess | N
output=str(mooncake_log),
container_image=container,
container_mounts=self.runtime.container_mounts,
het_group=self.runtime.nodes.het_group_for(infra_node),
)

managed = ManagedProcess(
Expand Down Expand Up @@ -420,6 +435,7 @@ def _ensure_model_cached(self) -> None:
container_mounts=self.runtime.container_mounts,
env_to_set=hf_env,
use_bash_wrapper=False, # command is already bash -c
het_group=self.runtime.nodes.het_group_for(download_node),
)

timeout_sec = 60 * 60 # 1 hour; large models can take a while
Expand Down Expand Up @@ -547,6 +563,7 @@ def _run_post_eval(self, stop_event: threading.Event) -> int:
container_image=str(self.runtime.container_image),
container_mounts=self.runtime.container_mounts,
env_to_set=env_to_set,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
)

while proc.poll() is None:
Expand Down
1 change: 1 addition & 0 deletions src/srtctl/cli/mixins/benchmark_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _run_benchmark_script(
container_mounts=container_mounts,
env_to_set=env_to_set,
srun_options=self.runtime.srun_options,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
)

try:
Expand Down
1 change: 1 addition & 0 deletions src/srtctl/cli/mixins/frontend_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _start_nginx(self, topology: FrontendTopology) -> ManagedProcess:
srun_options={
"container-remap-root": "",
},
het_group=self.runtime.nodes.het_group_for(topology.nginx_node),
)

return ManagedProcess(
Expand Down
2 changes: 2 additions & 0 deletions src/srtctl/cli/mixins/postprocess_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def _run_postprocess_container(self) -> tuple[Path | None, str | None]:
container_image="python:3.11",
container_mounts={self.runtime.log_dir: Path("/logs")},
env_to_set=env,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
)
proc.wait(timeout=600) # 10 min timeout for install + parse + full sync

Expand Down Expand Up @@ -588,6 +589,7 @@ def _run_ai_analysis(self, config: AIAnalysisConfig) -> None:
container_image="python:3.11",
container_mounts={self.runtime.log_dir: Path("/logs")},
env_to_set=env_to_set,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
)

# Wait for completion with timeout (15 minutes for install + analysis)
Expand Down
65 changes: 46 additions & 19 deletions src/srtctl/cli/mixins/telemetry_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,56 @@ def _start_exporter_container(
nodelist: list[str],
log_file: Path,
default_command_template: str,
) -> ManagedProcess:
"""Start one exporter container across the requested nodes."""
) -> list[ManagedProcess]:
"""Start one exporter container across the requested nodes.

Under SLURM heterogeneous jobs the nodelist may span both het
components (prefill on group 0, decode on group 1). A single srun
cannot target multiple het components, so we split the launch into
one srun per group when needed.
"""
if exporter_config.command is None:
cmd_str = default_command_template.format(port=exporter_config.port)
elif "{port}" in exporter_config.command:
cmd_str = exporter_config.command.format(port=exporter_config.port)
else:
cmd_str = exporter_config.command

proc = start_srun_process(
command=shlex.split(cmd_str),
ntasks=len(nodelist),
nodelist=nodelist,
output=str(log_file),
container_image=exporter_config.container_image,
container_mounts=self.runtime.container_mounts,
srun_options=self.runtime.srun_options,
)
return ManagedProcess(
name=name,
popen=proc,
log_file=log_file,
node=",".join(nodelist),
)
if self.runtime.nodes.het:
groups: dict[int, list[str]] = {}
for node in nodelist:
g = self.runtime.nodes.het_group_for(node)
if g is None:
raise RuntimeError(f"node {node!r} not in any het component")
groups.setdefault(g, []).append(node)
chunks = sorted(groups.items())
else:
chunks = [(-1, nodelist)] # sentinel: no --het-group

managed: list[ManagedProcess] = []
for group_id, nodes in chunks:
het_group = group_id if group_id >= 0 else None
chunk_log = log_file if len(chunks) == 1 else log_file.with_suffix(f".g{group_id}.out")
proc = start_srun_process(
command=shlex.split(cmd_str),
ntasks=len(nodes),
nodelist=nodes,
output=str(chunk_log),
container_image=exporter_config.container_image,
container_mounts=self.runtime.container_mounts,
srun_options=self.runtime.srun_options,
het_group=het_group,
)
chunk_name = name if len(chunks) == 1 else f"{name}_g{group_id}"
managed.append(
ManagedProcess(
name=chunk_name,
popen=proc,
log_file=chunk_log,
node=",".join(nodes),
)
)
return managed

def start_telemetry(self) -> list[ManagedProcess]:
"""Start the configured telemetry provider."""
Expand Down Expand Up @@ -99,7 +125,7 @@ def start_telemetry(self) -> list[ManagedProcess]:

worker_nodes = sorted({process.node for process in self.backend_processes})
processes: list[ManagedProcess] = []
processes.append(
processes.extend(
self._start_exporter_container(
exporter_config=telemetry.dcgm_exporter,
name="telemetry_dcgm_exporter",
Expand All @@ -108,7 +134,7 @@ def start_telemetry(self) -> list[ManagedProcess]:
default_command_template="dcgm-exporter --collect-interval=100 --address :{port}",
)
)
processes.append(
processes.extend(
self._start_exporter_container(
exporter_config=telemetry.node_exporter,
name="telemetry_node_exporter",
Expand Down Expand Up @@ -149,6 +175,7 @@ def start_telemetry(self) -> list[ManagedProcess]:
container_mounts=scraper_mounts,
env_to_set=env_to_set,
srun_options=self.runtime.srun_options,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
),
log_file=self.runtime.log_dir / "telemetry.out",
node=self.runtime.nodes.head,
Expand Down
2 changes: 2 additions & 0 deletions src/srtctl/cli/mixins/worker_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __missing__(self, key: str) -> str:
env_to_set=env_to_set,
bash_preamble=bash_preamble,
srun_options=self.runtime.srun_options,
het_group=process.het_group,
)

return ManagedProcess(
Expand Down Expand Up @@ -364,6 +365,7 @@ def start_endpoint_worker(self, endpoint_processes: list["Process"]) -> ManagedP
mpi=srun_config.mpi,
oversubscribe=srun_config.oversubscribe,
cpu_bind=srun_config.cpu_bind,
het_group=leader.het_group,
)

return ManagedProcess(
Expand Down
43 changes: 39 additions & 4 deletions src/srtctl/cli/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,31 @@ def show_config_details(config: SrtConfig) -> None:

console.print(Panel(mounts_table, border_style="green"))

# --- SLURM heterogeneous job structure ---
het_components = config.resources.het_components(
infra_dedicated=config.infra.etcd_nats_dedicated_node,
cluster_default=get_srtslurm_setting("use_het_jobs", False),
)
if het_components is not None:
het_table = Table(title="SLURM Heterogeneous Job", show_lines=False, pad_edge=False)
het_table.add_column("Group", style="dim", width=5)
het_table.add_column("Side", style="cyan", width=8)
het_table.add_column("Nodes", style="white", justify="right", width=6)
het_table.add_column("Segment", style="white", justify="right", width=8)
het_table.add_column("GPUs/node", style="white", justify="right", width=10)
het_table.add_column("Infra", style="dim")
for c in het_components:
infra_note = "first node" if c.name == "prefill" and config.infra.etcd_nats_dedicated_node else ""
het_table.add_row(
str(c.group),
c.name,
str(c.nodes),
str(c.segment),
str(c.gpus_per_node),
infra_note,
)
console.print(Panel(het_table, border_style="magenta"))

# --- Environment Variables ---
dynamo_environment = config.dynamo.get_wheel_environment()
has_env = bool(config.environment or dynamo_environment)
Expand Down Expand Up @@ -390,10 +415,19 @@ def generate_minimal_sbatch_script(
env = Environment(loader=FileSystemLoader(str(template_dir)))
template = env.get_template("job_script_minimal.j2")

total_nodes = config.total_nodes
# Add extra node for dedicated etcd/nats infrastructure
if config.infra.etcd_nats_dedicated_node:
total_nodes += 1
het_components = config.resources.het_components(
infra_dedicated=config.infra.etcd_nats_dedicated_node,
cluster_default=get_srtslurm_setting("use_het_jobs", False),
)
if het_components is None:
total_nodes = config.total_nodes
# Add extra node for dedicated etcd/nats infrastructure
if config.infra.etcd_nats_dedicated_node:
total_nodes += 1
else:
# Sum is informational only — the template iterates het_components and
# ignores total_nodes when het_components is set.
total_nodes = sum(c.nodes for c in het_components)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Resolve container image path (expand aliases from srtslurm.yaml)
Expand All @@ -406,6 +440,7 @@ def generate_minimal_sbatch_script(
rendered = template.render(
job_name=job_name,
total_nodes=total_nodes,
het_components=het_components,
gpus_per_node=config.resources.gpus_per_node,
backend_type=config.backend_type,
account=config.slurm.account or os.environ.get("SLURM_ACCOUNT", "default"),
Expand Down
12 changes: 12 additions & 0 deletions src/srtctl/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def resolve_config_with_defaults(user_config: dict[str, Any], cluster_config: di
sbatch_directives.setdefault(key, value)
logger.debug("Applied default sbatch_directives: %s", default_sbatch_directives)

# Apply cluster-level het-job default. Without this, a recipe with
# `het_jobs: None` would defer the cluster default at render-time but skip
# SrtConfig validation (which only fires on `het_jobs is True`). Writing
# the cluster value into the resolved recipe ensures __post_init__ catches
# bad combinations (het + trtllm, het + agg, ...) at load time.
resources = config.get("resources")
if isinstance(resources, dict) and resources.get("het_jobs") is None:
cluster_het = cluster_config.get("use_het_jobs")
if cluster_het is not None:
resources["het_jobs"] = bool(cluster_het)
logger.debug("Applied cluster use_het_jobs default: %s", cluster_het)

# Resolve model path alias
model = config.get("model", {})
model_path = model.get("path", "")
Expand Down
Loading
Loading