diff --git a/src/srtctl/cli/do_sweep.py b/src/srtctl/cli/do_sweep.py index c4041a7f..ea5ebfd6 100644 --- a/src/srtctl/cli/do_sweep.py +++ b/src/srtctl/cli/do_sweep.py @@ -44,7 +44,7 @@ from srtctl.core.schema import SrtConfig from srtctl.core.slurm import get_slurm_job_id, start_srun_process from srtctl.core.status import JobStage, JobStatus, StatusReporter -from srtctl.core.topology import Endpoint, NodePortAllocator, Process +from srtctl.core.topology import Endpoint, NodePortAllocator, Process, allocate_endpoints_het from srtctl.logging_utils import setup_logging from srtctl.ports import ( ETCD_CLIENT_PORT, @@ -86,9 +86,22 @@ def backend(self): def endpoints(self) -> list[Endpoint]: """Compute endpoint allocation topology (cached). - This is the single source of truth for endpoint assignments. + This is the single source of truth for endpoint assignments. Under + SLURM heterogeneous jobs, prefill and decode workers are allocated + from their own component nodelists so neither side bleeds into the + other's topology segment. """ r = self.config.resources + if self.runtime.nodes.het: + return allocate_endpoints_het( + num_prefill=r.num_prefill, + gpus_per_prefill=r.gpus_per_prefill, + prefill_nodes=self.runtime.nodes.prefill_group, + num_decode=r.num_decode, + gpus_per_decode=r.gpus_per_decode, + decode_nodes=self.runtime.nodes.decode_group, + gpus_per_node=r.gpus_per_node, + ) return self.backend.allocate_endpoints( num_prefill=r.num_prefill, num_decode=r.num_decode, @@ -151,6 +164,7 @@ def start_head_infrastructure(self, registry: ProcessRegistry) -> ManagedProcess output=str(infra_log), container_image=str(self.runtime.container_image), container_mounts=mounts, + het_group=self.runtime.nodes.het_group_for(infra_node), ) managed = ManagedProcess( @@ -219,6 +233,7 @@ def start_mooncake_master(self, registry: ProcessRegistry) -> ManagedProcess | N output=str(mooncake_log), container_image=container, container_mounts=self.runtime.container_mounts, + het_group=self.runtime.nodes.het_group_for(infra_node), ) managed = ManagedProcess( @@ -420,6 +435,7 @@ def _ensure_model_cached(self) -> None: container_mounts=self.runtime.container_mounts, env_to_set=hf_env, use_bash_wrapper=False, # command is already bash -c + het_group=self.runtime.nodes.het_group_for(download_node), ) timeout_sec = 60 * 60 # 1 hour; large models can take a while @@ -547,6 +563,7 @@ def _run_post_eval(self, stop_event: threading.Event) -> int: container_image=str(self.runtime.container_image), container_mounts=self.runtime.container_mounts, env_to_set=env_to_set, + het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head), ) while proc.poll() is None: diff --git a/src/srtctl/cli/mixins/benchmark_stage.py b/src/srtctl/cli/mixins/benchmark_stage.py index baa7029c..3c258a5c 100644 --- a/src/srtctl/cli/mixins/benchmark_stage.py +++ b/src/srtctl/cli/mixins/benchmark_stage.py @@ -201,6 +201,7 @@ def _run_benchmark_script( container_mounts=container_mounts, env_to_set=env_to_set, srun_options=self.runtime.srun_options, + het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head), ) try: diff --git a/src/srtctl/cli/mixins/frontend_stage.py b/src/srtctl/cli/mixins/frontend_stage.py index 9cd129d9..4d372f33 100644 --- a/src/srtctl/cli/mixins/frontend_stage.py +++ b/src/srtctl/cli/mixins/frontend_stage.py @@ -154,6 +154,7 @@ def _start_nginx(self, topology: FrontendTopology) -> ManagedProcess: srun_options={ "container-remap-root": "", }, + het_group=self.runtime.nodes.het_group_for(topology.nginx_node), ) return ManagedProcess( diff --git a/src/srtctl/cli/mixins/postprocess_stage.py b/src/srtctl/cli/mixins/postprocess_stage.py index be78b0ea..6293f5f1 100644 --- a/src/srtctl/cli/mixins/postprocess_stage.py +++ b/src/srtctl/cli/mixins/postprocess_stage.py @@ -436,6 +436,7 @@ def _run_postprocess_container(self) -> tuple[Path | None, str | None]: container_image="python:3.11", container_mounts={self.runtime.log_dir: Path("/logs")}, env_to_set=env, + het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head), ) proc.wait(timeout=600) # 10 min timeout for install + parse + full sync @@ -588,6 +589,7 @@ def _run_ai_analysis(self, config: AIAnalysisConfig) -> None: container_image="python:3.11", container_mounts={self.runtime.log_dir: Path("/logs")}, env_to_set=env_to_set, + het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head), ) # Wait for completion with timeout (15 minutes for install + analysis) diff --git a/src/srtctl/cli/mixins/telemetry_stage.py b/src/srtctl/cli/mixins/telemetry_stage.py index 0baa23a9..941f5a15 100644 --- a/src/srtctl/cli/mixins/telemetry_stage.py +++ b/src/srtctl/cli/mixins/telemetry_stage.py @@ -45,8 +45,14 @@ def _start_exporter_container( nodelist: list[str], log_file: Path, default_command_template: str, - ) -> ManagedProcess: - """Start one exporter container across the requested nodes.""" + ) -> list[ManagedProcess]: + """Start one exporter container across the requested nodes. + + Under SLURM heterogeneous jobs the nodelist may span both het + components (prefill on group 0, decode on group 1). A single srun + cannot target multiple het components, so we split the launch into + one srun per group when needed. + """ if exporter_config.command is None: cmd_str = default_command_template.format(port=exporter_config.port) elif "{port}" in exporter_config.command: @@ -54,21 +60,41 @@ def _start_exporter_container( else: cmd_str = exporter_config.command - proc = start_srun_process( - command=shlex.split(cmd_str), - ntasks=len(nodelist), - nodelist=nodelist, - output=str(log_file), - container_image=exporter_config.container_image, - container_mounts=self.runtime.container_mounts, - srun_options=self.runtime.srun_options, - ) - return ManagedProcess( - name=name, - popen=proc, - log_file=log_file, - node=",".join(nodelist), - ) + if self.runtime.nodes.het: + groups: dict[int, list[str]] = {} + for node in nodelist: + g = self.runtime.nodes.het_group_for(node) + if g is None: + raise RuntimeError(f"node {node!r} not in any het component") + groups.setdefault(g, []).append(node) + chunks = sorted(groups.items()) + else: + chunks = [(-1, nodelist)] # sentinel: no --het-group + + managed: list[ManagedProcess] = [] + for group_id, nodes in chunks: + het_group = group_id if group_id >= 0 else None + chunk_log = log_file if len(chunks) == 1 else log_file.with_suffix(f".g{group_id}.out") + proc = start_srun_process( + command=shlex.split(cmd_str), + ntasks=len(nodes), + nodelist=nodes, + output=str(chunk_log), + container_image=exporter_config.container_image, + container_mounts=self.runtime.container_mounts, + srun_options=self.runtime.srun_options, + het_group=het_group, + ) + chunk_name = name if len(chunks) == 1 else f"{name}_g{group_id}" + managed.append( + ManagedProcess( + name=chunk_name, + popen=proc, + log_file=chunk_log, + node=",".join(nodes), + ) + ) + return managed def start_telemetry(self) -> list[ManagedProcess]: """Start the configured telemetry provider.""" @@ -99,7 +125,7 @@ def start_telemetry(self) -> list[ManagedProcess]: worker_nodes = sorted({process.node for process in self.backend_processes}) processes: list[ManagedProcess] = [] - processes.append( + processes.extend( self._start_exporter_container( exporter_config=telemetry.dcgm_exporter, name="telemetry_dcgm_exporter", @@ -108,7 +134,7 @@ def start_telemetry(self) -> list[ManagedProcess]: default_command_template="dcgm-exporter --collect-interval=100 --address :{port}", ) ) - processes.append( + processes.extend( self._start_exporter_container( exporter_config=telemetry.node_exporter, name="telemetry_node_exporter", @@ -149,6 +175,7 @@ def start_telemetry(self) -> list[ManagedProcess]: container_mounts=scraper_mounts, env_to_set=env_to_set, srun_options=self.runtime.srun_options, + het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head), ), log_file=self.runtime.log_dir / "telemetry.out", node=self.runtime.nodes.head, diff --git a/src/srtctl/cli/mixins/worker_stage.py b/src/srtctl/cli/mixins/worker_stage.py index c9f1b0d1..c2c46bd1 100644 --- a/src/srtctl/cli/mixins/worker_stage.py +++ b/src/srtctl/cli/mixins/worker_stage.py @@ -233,6 +233,7 @@ def __missing__(self, key: str) -> str: env_to_set=env_to_set, bash_preamble=bash_preamble, srun_options=self.runtime.srun_options, + het_group=process.het_group, ) return ManagedProcess( @@ -364,6 +365,7 @@ def start_endpoint_worker(self, endpoint_processes: list["Process"]) -> ManagedP mpi=srun_config.mpi, oversubscribe=srun_config.oversubscribe, cpu_bind=srun_config.cpu_bind, + het_group=leader.het_group, ) return ManagedProcess( diff --git a/src/srtctl/cli/submit.py b/src/srtctl/cli/submit.py index 6937f6b6..2f560ec8 100644 --- a/src/srtctl/cli/submit.py +++ b/src/srtctl/cli/submit.py @@ -232,6 +232,31 @@ def show_config_details(config: SrtConfig) -> None: console.print(Panel(mounts_table, border_style="green")) + # --- SLURM heterogeneous job structure --- + het_components = config.resources.het_components( + infra_dedicated=config.infra.etcd_nats_dedicated_node, + cluster_default=get_srtslurm_setting("use_het_jobs", False), + ) + if het_components is not None: + het_table = Table(title="SLURM Heterogeneous Job", show_lines=False, pad_edge=False) + het_table.add_column("Group", style="dim", width=5) + het_table.add_column("Side", style="cyan", width=8) + het_table.add_column("Nodes", style="white", justify="right", width=6) + het_table.add_column("Segment", style="white", justify="right", width=8) + het_table.add_column("GPUs/node", style="white", justify="right", width=10) + het_table.add_column("Infra", style="dim") + for c in het_components: + infra_note = "first node" if c.name == "prefill" and config.infra.etcd_nats_dedicated_node else "" + het_table.add_row( + str(c.group), + c.name, + str(c.nodes), + str(c.segment), + str(c.gpus_per_node), + infra_note, + ) + console.print(Panel(het_table, border_style="magenta")) + # --- Environment Variables --- dynamo_environment = config.dynamo.get_wheel_environment() has_env = bool(config.environment or dynamo_environment) @@ -390,10 +415,19 @@ def generate_minimal_sbatch_script( env = Environment(loader=FileSystemLoader(str(template_dir))) template = env.get_template("job_script_minimal.j2") - total_nodes = config.total_nodes - # Add extra node for dedicated etcd/nats infrastructure - if config.infra.etcd_nats_dedicated_node: - total_nodes += 1 + het_components = config.resources.het_components( + infra_dedicated=config.infra.etcd_nats_dedicated_node, + cluster_default=get_srtslurm_setting("use_het_jobs", False), + ) + if het_components is None: + total_nodes = config.total_nodes + # Add extra node for dedicated etcd/nats infrastructure + if config.infra.etcd_nats_dedicated_node: + total_nodes += 1 + else: + # Sum is informational only — the template iterates het_components and + # ignores total_nodes when het_components is set. + total_nodes = sum(c.nodes for c in het_components) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Resolve container image path (expand aliases from srtslurm.yaml) @@ -406,6 +440,7 @@ def generate_minimal_sbatch_script( rendered = template.render( job_name=job_name, total_nodes=total_nodes, + het_components=het_components, gpus_per_node=config.resources.gpus_per_node, backend_type=config.backend_type, account=config.slurm.account or os.environ.get("SLURM_ACCOUNT", "default"), diff --git a/src/srtctl/core/config.py b/src/srtctl/core/config.py index f6104ec2..205dd9b4 100644 --- a/src/srtctl/core/config.py +++ b/src/srtctl/core/config.py @@ -121,6 +121,18 @@ def resolve_config_with_defaults(user_config: dict[str, Any], cluster_config: di sbatch_directives.setdefault(key, value) logger.debug("Applied default sbatch_directives: %s", default_sbatch_directives) + # Apply cluster-level het-job default. Without this, a recipe with + # `het_jobs: None` would defer the cluster default at render-time but skip + # SrtConfig validation (which only fires on `het_jobs is True`). Writing + # the cluster value into the resolved recipe ensures __post_init__ catches + # bad combinations (het + trtllm, het + agg, ...) at load time. + resources = config.get("resources") + if isinstance(resources, dict) and resources.get("het_jobs") is None: + cluster_het = cluster_config.get("use_het_jobs") + if cluster_het is not None: + resources["het_jobs"] = bool(cluster_het) + logger.debug("Applied cluster use_het_jobs default: %s", cluster_het) + # Resolve model path alias model = config.get("model", {}) model_path = model.get("path", "") diff --git a/src/srtctl/core/runtime.py b/src/srtctl/core/runtime.py index 69a1a914..7f5d48da 100644 --- a/src/srtctl/core/runtime.py +++ b/src/srtctl/core/runtime.py @@ -16,7 +16,7 @@ from srtctl.ports import FRONTEND_PUBLIC_PORT from .config import get_srtslurm_setting -from .slurm import get_hostname_ip, get_slurm_nodelist +from .slurm import get_hostname_ip, get_slurm_het_nodelists, get_slurm_nodelist if TYPE_CHECKING: from srtctl.core.schema import SrtConfig @@ -32,12 +32,40 @@ class Nodes: infra: Infrastructure node hostname (runs NATS, etcd). Same as head unless etcd_nats_dedicated_node is enabled. worker: Tuple of all worker node hostnames (prefill + decode) + het: True when the job was submitted as a SLURM heterogeneous job. In + this mode worker srun calls need ``--het-group=`` so SLURM + routes them to the right component. + prefill_group: Worker nodes that belong to het component 0 (prefill + + optionally the dedicated infra node). Empty tuple when het=False. + decode_group: Worker nodes that belong to het component 1 (decode). + Empty tuple when het=False. """ head: str bench: str infra: str worker: tuple[str, ...] + het: bool = False + prefill_group: tuple[str, ...] = () + decode_group: tuple[str, ...] = () + + def het_group_for(self, node: str) -> int | None: + """Return the het component (0 or 1) a node belongs to, or None. + + Returns None for non-het jobs so callers can pass the result directly + to ``start_srun_process(het_group=...)`` as a no-op fallback. + """ + if not self.het: + return None + if node in self.prefill_group: + return 0 + if node in self.decode_group: + return 1 + # Head and infra share group 0 under het (infra is folded into the + # prefill component, head sits on the prefill side). + if node == self.infra or node == self.head: + return 0 + return None @classmethod def from_slurm( @@ -53,6 +81,10 @@ def from_slurm( etcd_nats_dedicated_node: If True, dedicate first node for etcd/nats, second node is head, rest are workers. """ + het_lists = get_slurm_het_nodelists() + if het_lists is not None: + return cls._from_het_slurm(het_lists, etcd_nats_dedicated_node) + nodelist = get_slurm_nodelist() if not nodelist: raise RuntimeError("SLURM_NODELIST not set - are we running in SLURM?") @@ -79,6 +111,48 @@ def from_slurm( return cls(head=head, bench=bench, infra=infra, worker=worker) + @classmethod + def _from_het_slurm( + cls, + het_lists: list[list[str]], + etcd_nats_dedicated_node: bool, + ) -> "Nodes": + """Carve a Nodes from a SLURM heterogeneous-job allocation. + + Group 0 holds prefill (and the dedicated infra node when configured); + group 1 holds decode. Head/bench live on group 0. + """ + if len(het_lists) != 2: + raise ValueError( + f"het_jobs expects exactly 2 components (prefill, decode); SLURM_HET_SIZE reported {len(het_lists)}" + ) + group0, group1 = het_lists + if not group0 or not group1: + raise RuntimeError("Empty SLURM_JOB_NODELIST_HET_GROUP_* — are we inside a het job?") + + if etcd_nats_dedicated_node: + if len(group0) < 2: + raise ValueError("etcd_nats_dedicated_node requires >= 2 nodes in het group 0") + infra = group0[0] + head = group0[1] + prefill_group = tuple(group0[1:]) + else: + infra = group0[0] + head = group0[0] + prefill_group = tuple(group0) + bench = head + decode_group = tuple(group1) + worker = prefill_group + decode_group + return cls( + head=head, + bench=bench, + infra=infra, + worker=worker, + het=True, + prefill_group=prefill_group, + decode_group=decode_group, + ) + @dataclass(frozen=True) class RuntimeContext: diff --git a/src/srtctl/core/schema.py b/src/srtctl/core/schema.py index 193e20d7..ebd6729b 100644 --- a/src/srtctl/core/schema.py +++ b/src/srtctl/core/schema.py @@ -189,6 +189,12 @@ class ClusterConfig: use_gpus_per_node_directive: bool = True use_segment_sbatch_directive: bool = True use_exclusive_sbatch_directive: bool = False + # Default for ``ResourceConfig.het_jobs`` when the recipe doesn't set it. + # When True (and recipe doesn't override), the prefill side and decode side + # are submitted as two SLURM heterogeneous-job components, each with its + # own ``--segment``. Lets asymmetric layouts (e.g. prefill 12 + decode 10 + # nodes on GB200/GB300) preserve NVL72 affinity per side. + use_het_jobs: bool = False default_sbatch_directives: dict[str, str] | None = None default_health_check: dict[str, int] | None = None srtctl_root: str | None = None @@ -449,6 +455,26 @@ class IdentityConfig: Schema: ClassVar[type[Schema]] = Schema +@dataclass(frozen=True) +class HetComponent: + """One component of a SLURM heterogeneous job. + + A het job is submitted as multiple `#SBATCH` blocks separated by + `#SBATCH hetjob`. SLURM places each component within a single topology + segment, so we get per-side NVL72 affinity. At runtime each component + exposes its own `SLURM_JOB_NODELIST_HET_GROUP_`, and worker srun + calls target a component with `--het-group=`. + """ + + name: Literal["prefill", "decode"] + group: int + nodes: int + segment: int + gpus_per_node: int + + Schema: ClassVar[type[Schema]] = Schema + + @dataclass(frozen=True) class ResourceConfig: """Resource allocation configuration.""" @@ -471,6 +497,13 @@ class ResourceConfig: # (e.g. set decode_nodes=decode_workers when gpus_per_decode int: """Total GPUs used by all decode workers.""" return self.num_decode * self.gpus_per_decode + def het_components( + self, + *, + infra_dedicated: bool, + cluster_default: bool = False, + ) -> tuple[HetComponent, HetComponent] | None: + """Return the (prefill, decode) het components, or None when het is off. + + Het is enabled when ``self.het_jobs`` is True, or when it is None and + ``cluster_default`` is True. Only valid in disaggregated mode. Group 0 + is prefill (folds in the dedicated infra node when present); group 1 is + decode. Segment matches each component's node count, so each side lands + in its own topology segment (NVL72 domain on GB200/GB300). + + Pass ``cluster_default=get_srtslurm_setting("use_het_jobs", False)`` + from callers that have access to the cluster config; schema.py cannot + import from core.config without a cycle. + """ + enabled = self.het_jobs if self.het_jobs is not None else cluster_default + if not enabled or not self.is_disaggregated: + return None + prefill_nodes = (self.prefill_nodes or 0) + (1 if infra_dedicated else 0) + decode_nodes = self.decode_nodes or 0 + return ( + HetComponent( + name="prefill", + group=0, + nodes=prefill_nodes, + segment=prefill_nodes, + gpus_per_node=self.gpus_per_node, + ), + HetComponent( + name="decode", + group=1, + nodes=decode_nodes, + segment=decode_nodes, + gpus_per_node=self.gpus_per_node, + ), + ) + Schema: ClassVar[type[Schema]] = Schema @@ -1320,6 +1393,29 @@ def __post_init__(self): self._validate_profiling() self._validate_telemetry() self._validate_mooncake_kv_store() + self._validate_het_jobs() + + def _validate_het_jobs(self): + """When ``resources.het_jobs`` is set to True, enforce supported shape. + + Validation runs only when the per-recipe override is explicitly True; + a cluster-level default still in effect (recipe None) is permissive at + load-time and resolved later by callers that pass the cluster default + into ``het_components()``. This keeps a single recipe that disables het + via ``het_jobs: false`` from tripping on a cluster default. + """ + if self.resources.het_jobs is not True: + return + if not self.resources.is_disaggregated: + raise ValidationError( + "het_jobs=true requires a disaggregated layout (set resources.prefill_nodes and resources.decode_nodes)" + ) + if (self.resources.prefill_nodes or 0) < 1 or (self.resources.decode_nodes or 0) < 1: + raise ValidationError("het_jobs=true requires prefill_nodes >= 1 and decode_nodes >= 1") + if self.backend_type != "sglang": + raise ValidationError( + f"het_jobs=true is only supported on the sglang backend; got backend.type={self.backend_type!r}" + ) def _validate_mooncake_kv_store(self): """Catch the common misconfiguration: mooncake_kv_store set without a diff --git a/src/srtctl/core/slurm.py b/src/srtctl/core/slurm.py index ed60a369..af588f85 100644 --- a/src/srtctl/core/slurm.py +++ b/src/srtctl/core/slurm.py @@ -52,11 +52,14 @@ def get_slurm_nodelist() -> list[str]: Returns: List of node hostnames, or empty list if not in SLURM. """ - nodelist_raw = os.environ.get("SLURM_NODELIST", "") + return _expand_nodelist(os.environ.get("SLURM_NODELIST", "")) + + +def _expand_nodelist(nodelist_raw: str) -> list[str]: + """Expand a SLURM ranged nodelist via ``scontrol show hostnames``.""" if not nodelist_raw: return [] - # Use scontrol to expand the nodelist try: result = subprocess.run( ["scontrol", "show", "hostnames", nodelist_raw], @@ -70,6 +73,30 @@ def get_slurm_nodelist() -> list[str]: return [nodelist_raw] +def get_slurm_het_nodelists() -> list[list[str]] | None: + """Per-component nodelists for a SLURM heterogeneous job, else None. + + Returns one expanded nodelist per het component when ``SLURM_HET_SIZE`` is + set to a value greater than 1. Returns None for non-het jobs so callers can + fall back to ``get_slurm_nodelist()``. + """ + het_size_raw = os.environ.get("SLURM_HET_SIZE", "") + if not het_size_raw: + return None + try: + het_size = int(het_size_raw) + except ValueError: + return None + if het_size < 2: + return None + + groups: list[list[str]] = [] + for i in range(het_size): + nodelist_raw = os.environ.get(f"SLURM_JOB_NODELIST_HET_GROUP_{i}", "") + groups.append(_expand_nodelist(nodelist_raw)) + return groups + + # ============================================================================ # Network Resolution # ============================================================================ @@ -166,6 +193,7 @@ def start_srun_process( mpi: str | None = None, oversubscribe: bool = False, cpu_bind: str | None = None, + het_group: int | None = None, ) -> subprocess.Popen: """Start a process via srun with container support. @@ -231,6 +259,11 @@ def start_srun_process( if nodelist: srun_cmd.extend(["--nodelist", ",".join(nodelist)]) + # Route this srun to a specific component of a SLURM heterogeneous job. + # Omitted (None) for non-het jobs; safe to always pass-through from callers. + if het_group is not None: + srun_cmd.append(f"--het-group={het_group}") + if output: srun_cmd.extend(["--output", output]) diff --git a/src/srtctl/core/topology.py b/src/srtctl/core/topology.py index f0022e9f..610f90d9 100644 --- a/src/srtctl/core/topology.py +++ b/src/srtctl/core/topology.py @@ -154,6 +154,10 @@ class Endpoint: nodes: tuple[str, ...] gpu_indices: frozenset[int] = field(default_factory=lambda: frozenset(range(8))) gpus_per_node: int = 8 + # SLURM heterogeneous-job component index (0=prefill side, 1=decode side). + # None when the job is non-het — callers that pass this to srun treat None + # as "omit --het-group". + het_group: int | None = None @property def leader_node(self) -> str: @@ -207,6 +211,8 @@ class Process: kv_events_port: int | None = None nixl_port: int | None = None dp_rpc_port: int | None = None + # Inherited from the parent Endpoint when the job is heterogeneous. + het_group: int | None = None @property def is_leader(self) -> bool: @@ -410,6 +416,75 @@ def allocate_workers_simple(mode: WorkerMode, count: int, gpus_per_worker: int) return endpoints +def allocate_endpoints_het( + *, + num_prefill: int, + gpus_per_prefill: int, + prefill_nodes: Sequence[str], + num_decode: int, + gpus_per_decode: int, + decode_nodes: Sequence[str], + gpus_per_node: int, +) -> list[Endpoint]: + """Allocate endpoints for a SLURM heterogeneous job. + + Prefill workers come from ``prefill_nodes`` (het component 0); decode + workers come from ``decode_nodes`` (het component 1). Side pools are + independent — no gpu-offset bleed across sides — so SLURM places each side + inside its own topology segment. + + Each returned Endpoint is tagged with ``het_group`` (0 for prefill, 1 for + decode) for downstream srun ``--het-group=`` threading. + + Aggregated mode is unsupported under het and rejected at config validation. + """ + prefill_eps = allocate_endpoints( + num_prefill=num_prefill, + num_decode=0, + num_agg=0, + gpus_per_prefill=gpus_per_prefill, + gpus_per_decode=gpus_per_decode, + gpus_per_agg=0, + gpus_per_node=gpus_per_node, + available_nodes=prefill_nodes, + ) + decode_eps = allocate_endpoints( + num_prefill=0, + num_decode=num_decode, + num_agg=0, + gpus_per_prefill=gpus_per_prefill, + gpus_per_decode=gpus_per_decode, + gpus_per_agg=0, + gpus_per_node=gpus_per_node, + available_nodes=decode_nodes, + ) + # Endpoint is frozen; re-emit with het_group set. + tagged: list[Endpoint] = [] + for ep in prefill_eps: + tagged.append( + Endpoint( + mode=ep.mode, + index=ep.index, + nodes=ep.nodes, + gpu_indices=ep.gpu_indices, + gpus_per_node=ep.gpus_per_node, + het_group=0, + ) + ) + for ep in decode_eps: + tagged.append( + Endpoint( + mode=ep.mode, + index=ep.index, + nodes=ep.nodes, + gpu_indices=ep.gpu_indices, + gpus_per_node=ep.gpus_per_node, + het_group=1, + ) + ) + return tagged + + def endpoints_to_processes( endpoints: list[Endpoint], base_sys_port: int = DYN_SYSTEM_PORT_BASE, @@ -469,6 +544,7 @@ def endpoints_to_processes( bootstrap_port=endpoint_bootstrap_port, kv_events_port=node_kv_events_port, nixl_port=node_nixl_port, + het_group=endpoint.het_group, ) ) current_sys_port += 1 diff --git a/src/srtctl/frontends/dynamo.py b/src/srtctl/frontends/dynamo.py index f7241e6a..5fe47ed2 100644 --- a/src/srtctl/frontends/dynamo.py +++ b/src/srtctl/frontends/dynamo.py @@ -112,6 +112,7 @@ def start_frontends( # TODO(jthomson): I don't have the faintest clue of # why this is needed in later versions of Dynamo, but it is. mpi="pmix", + het_group=runtime.nodes.het_group_for(node), ) processes.append( diff --git a/src/srtctl/frontends/sglang.py b/src/srtctl/frontends/sglang.py index 6df7915f..060d2e31 100644 --- a/src/srtctl/frontends/sglang.py +++ b/src/srtctl/frontends/sglang.py @@ -147,6 +147,7 @@ def start_frontends( container_image=str(runtime.container_image), container_mounts=runtime.container_mounts, env_to_set=env_to_set if env_to_set else None, + het_group=runtime.nodes.het_group_for(node), ) processes.append( diff --git a/src/srtctl/templates/job_script_minimal.j2 b/src/srtctl/templates/job_script_minimal.j2 index 34a19dcd..2ff19eb0 100644 --- a/src/srtctl/templates/job_script_minimal.j2 +++ b/src/srtctl/templates/job_script_minimal.j2 @@ -2,6 +2,45 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 #SBATCH --job-name={{ job_name }} +#SBATCH --output={{ output_base }}/%j/logs/sweep_%j.log +{% if het_components %} +{# SLURM applies each `#SBATCH ...` to the component immediately preceding it. + Per-component required scheduling directives (--account/--partition/--time + and per-side --nodes/--ntasks/--segment) repeat inside the loop; --output + above is job-wide and applies once. #} +{% for c in het_components %} +{% if not loop.first %} +#SBATCH hetjob +{% endif %} +#SBATCH --nodes={{ c.nodes }} +{% if backend_type == "trtllm" %} +#SBATCH --ntasks={{ c.gpus_per_node * c.nodes }} +#SBATCH --ntasks-per-node={{ c.gpus_per_node }} +{% else %} +#SBATCH --ntasks={{ c.nodes }} +#SBATCH --ntasks-per-node=1 +{% endif %} +{% if use_gpus_per_node_directive %} +#SBATCH --gpus-per-node={{ c.gpus_per_node }} +{% endif %} +{% if use_segment_sbatch_directive %} +#SBATCH --segment={{ c.segment }} +{% endif %} +{% if use_exclusive_sbatch_directive %} +#SBATCH --exclusive +{% endif %} +#SBATCH --account={{ account }} +#SBATCH --time={{ time_limit }} +#SBATCH --partition={{ partition }} +{% for key, value in sbatch_directives.items() %} +{% if value %} +#SBATCH --{{ key }}={{ value }} +{% else %} +#SBATCH --{{ key }} +{% endif %} +{% endfor %} +{% endfor %} +{% else %} #SBATCH --nodes={{ total_nodes }} {% if backend_type == "trtllm" %} #SBATCH --ntasks={{ gpus_per_node * total_nodes }} @@ -21,7 +60,6 @@ {% endif %} #SBATCH --account={{ account }} #SBATCH --time={{ time_limit }} -#SBATCH --output={{ output_base }}/%j/logs/sweep_%j.log #SBATCH --partition={{ partition }} {% for key, value in sbatch_directives.items() %} {% if value %} @@ -30,6 +68,7 @@ #SBATCH --{{ key }} {% endif %} {% endfor %} +{% endif %} # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_configs.py b/tests/test_configs.py index afdfaa3c..74a7bade 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -227,7 +227,10 @@ def test_hash_install_command(self): assert "touch /configs/dynamo-wheels/abc123/.complete" in cmd # Final install from cache - assert "pip install --break-system-packages --force-reinstall /configs/dynamo-wheels/abc123/ai_dynamo_runtime-*.whl" in cmd + assert ( + "pip install --break-system-packages --force-reinstall /configs/dynamo-wheels/abc123/ai_dynamo_runtime-*.whl" + in cmd + ) assert "tar -xzf /configs/dynamo-wheels/abc123/dynamo-src.tar.gz" in cmd assert "pip install --break-system-packages -e /tmp/dynamo-src/dynamo" in cmd @@ -878,108 +881,111 @@ def mock_scontrol(cmd, **kwargs): return result raise subprocess.CalledProcessError(1, cmd) - with patch.dict(os.environ, slurm_env), patch("subprocess.run", mock_scontrol): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - # Create config with templated environment variables - config = SrtConfig( - name="test", - model=ModelConfig( - path=str(model_path), - container=str(container_path), - precision="fp8", - ), - resources=ResourceConfig( - gpu_type="h100", - gpus_per_node=8, - prefill_nodes=1, - decode_nodes=2, - ), - backend=SGLangProtocol( - prefill_environment={ - "SGLANG_DG_CACHE_DIR": "/configs/dg-{node_id}", - "WORKER_NODE": "{node}", - }, - decode_environment={ - "SGLANG_DG_CACHE_DIR": "/configs/dg-{node_id}", - }, - ), - ) + with ( + patch.dict(os.environ, slurm_env), + patch("subprocess.run", mock_scontrol), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + # Create config with templated environment variables + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=2, + ), + backend=SGLangProtocol( + prefill_environment={ + "SGLANG_DG_CACHE_DIR": "/configs/dg-{node_id}", + "WORKER_NODE": "{node}", + }, + decode_environment={ + "SGLANG_DG_CACHE_DIR": "/configs/dg-{node_id}", + }, + ), + ) - runtime = RuntimeContext.from_config(config, job_id="12345") + runtime = RuntimeContext.from_config(config, job_id="12345") - # Create a mock worker stage - class MockWorkerStage(WorkerStageMixin): - def __init__(self, config, runtime): - self.config = config - self.runtime = runtime - - worker_stage = MockWorkerStage(config, runtime) - - # Create test processes on different nodes - processes = [ - Process( - node="gpu-01", - gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), - sys_port=8081, - http_port=30000, - endpoint_mode="prefill", - endpoint_index=0, - node_rank=0, - ), - Process( - node="gpu-02", - gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), - sys_port=8082, - http_port=30001, - endpoint_mode="decode", - endpoint_index=0, - node_rank=0, - ), - Process( - node="gpu-03", - gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), - sys_port=8083, - http_port=30002, - endpoint_mode="decode", - endpoint_index=1, - node_rank=0, - ), - ] + # Create a mock worker stage + class MockWorkerStage(WorkerStageMixin): + def __init__(self, config, runtime): + self.config = config + self.runtime = runtime + + worker_stage = MockWorkerStage(config, runtime) + + # Create test processes on different nodes + processes = [ + Process( + node="gpu-01", + gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), + sys_port=8081, + http_port=30000, + endpoint_mode="prefill", + endpoint_index=0, + node_rank=0, + ), + Process( + node="gpu-02", + gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), + sys_port=8082, + http_port=30001, + endpoint_mode="decode", + endpoint_index=0, + node_rank=0, + ), + Process( + node="gpu-03", + gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), + sys_port=8083, + http_port=30002, + endpoint_mode="decode", + endpoint_index=1, + node_rank=0, + ), + ] - # Mock backend command builder and srun process to capture environment variables - mock_backend = MagicMock() - mock_backend.get_environment_for_mode.side_effect = config.backend.get_environment_for_mode - mock_backend.build_worker_command.return_value = ["echo", "test"] + # Mock backend command builder and srun process to capture environment variables + mock_backend = MagicMock() + mock_backend.get_environment_for_mode.side_effect = config.backend.get_environment_for_mode + mock_backend.build_worker_command.return_value = ["echo", "test"] - with patch.object(worker_stage, "config") as mock_config: - mock_config.backend = mock_backend - mock_config.profiling = config.profiling + with patch.object(worker_stage, "config") as mock_config: + mock_config.backend = mock_backend + mock_config.profiling = config.profiling - with patch("srtctl.cli.mixins.worker_stage.start_srun_process") as mock_srun: - mock_srun.return_value = MagicMock() + with patch("srtctl.cli.mixins.worker_stage.start_srun_process") as mock_srun: + mock_srun.return_value = MagicMock() - # Test prefill worker on gpu-01 (index 0) - worker_stage.start_worker(processes[0], []) - call_kwargs = mock_srun.call_args.kwargs - env_vars = call_kwargs.get("env_to_set", {}) + # Test prefill worker on gpu-01 (index 0) + worker_stage.start_worker(processes[0], []) + call_kwargs = mock_srun.call_args.kwargs + env_vars = call_kwargs.get("env_to_set", {}) - assert "SGLANG_DG_CACHE_DIR" in env_vars - assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-0" - assert env_vars["WORKER_NODE"] == "gpu-01" + assert "SGLANG_DG_CACHE_DIR" in env_vars + assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-0" + assert env_vars["WORKER_NODE"] == "gpu-01" - # Test decode worker on gpu-02 (index 1) - worker_stage.start_worker(processes[1], []) - call_kwargs = mock_srun.call_args.kwargs - env_vars = call_kwargs.get("env_to_set", {}) + # Test decode worker on gpu-02 (index 1) + worker_stage.start_worker(processes[1], []) + call_kwargs = mock_srun.call_args.kwargs + env_vars = call_kwargs.get("env_to_set", {}) - assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-1" + assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-1" - # Test decode worker on gpu-03 (index 2) - worker_stage.start_worker(processes[2], []) - call_kwargs = mock_srun.call_args.kwargs - env_vars = call_kwargs.get("env_to_set", {}) + # Test decode worker on gpu-03 (index 2) + worker_stage.start_worker(processes[2], []) + call_kwargs = mock_srun.call_args.kwargs + env_vars = call_kwargs.get("env_to_set", {}) - assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-2" + assert env_vars["SGLANG_DG_CACHE_DIR"] == "/configs/dg-2" def test_environment_variable_unsupported_placeholder(self, monkeypatch, tmp_path): """Test that unsupported placeholders like {foo} remain unchanged and don't throw errors.""" @@ -1016,76 +1022,79 @@ def mock_scontrol(cmd, **kwargs): return result raise subprocess.CalledProcessError(1, cmd) - with patch.dict(os.environ, slurm_env), patch("subprocess.run", mock_scontrol): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - # Create config with unsupported template placeholders - config = SrtConfig( - name="test", - model=ModelConfig( - path=str(model_path), - container=str(container_path), - precision="fp8", - ), - resources=ResourceConfig( - gpu_type="h100", - gpus_per_node=8, - prefill_nodes=1, - decode_nodes=1, - ), - backend=SGLangProtocol( - prefill_environment={ - # Mix of supported and unsupported placeholders - "CACHE_DIR": "/cache/{node_id}/data", - "UNSUPPORTED": "/path/{foo}/bar/{baz}", - "MIXED": "{node}-{unsupported_var}-cache", - }, - ), - ) + with ( + patch.dict(os.environ, slurm_env), + patch("subprocess.run", mock_scontrol), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + # Create config with unsupported template placeholders + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + backend=SGLangProtocol( + prefill_environment={ + # Mix of supported and unsupported placeholders + "CACHE_DIR": "/cache/{node_id}/data", + "UNSUPPORTED": "/path/{foo}/bar/{baz}", + "MIXED": "{node}-{unsupported_var}-cache", + }, + ), + ) - runtime = RuntimeContext.from_config(config, job_id="12345") + runtime = RuntimeContext.from_config(config, job_id="12345") - class MockWorkerStage(WorkerStageMixin): - def __init__(self, config, runtime): - self.config = config - self.runtime = runtime + class MockWorkerStage(WorkerStageMixin): + def __init__(self, config, runtime): + self.config = config + self.runtime = runtime - worker_stage = MockWorkerStage(config, runtime) + worker_stage = MockWorkerStage(config, runtime) - process = Process( - node="gpu-01", - gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), - sys_port=8081, - http_port=30000, - endpoint_mode="prefill", - endpoint_index=0, - node_rank=0, - ) + process = Process( + node="gpu-01", + gpu_indices=frozenset([0, 1, 2, 3, 4, 5, 6, 7]), + sys_port=8081, + http_port=30000, + endpoint_mode="prefill", + endpoint_index=0, + node_rank=0, + ) - # Mock backend command builder and srun process to capture environment variables - mock_backend = MagicMock() - mock_backend.get_environment_for_mode.side_effect = config.backend.get_environment_for_mode - mock_backend.build_worker_command.return_value = ["echo", "test"] + # Mock backend command builder and srun process to capture environment variables + mock_backend = MagicMock() + mock_backend.get_environment_for_mode.side_effect = config.backend.get_environment_for_mode + mock_backend.build_worker_command.return_value = ["echo", "test"] - with patch.object(worker_stage, "config") as mock_config: - mock_config.backend = mock_backend - mock_config.profiling = config.profiling + with patch.object(worker_stage, "config") as mock_config: + mock_config.backend = mock_backend + mock_config.profiling = config.profiling - with patch("srtctl.cli.mixins.worker_stage.start_srun_process") as mock_srun: - mock_srun.return_value = MagicMock() + with patch("srtctl.cli.mixins.worker_stage.start_srun_process") as mock_srun: + mock_srun.return_value = MagicMock() - # This should NOT throw an error - worker_stage.start_worker(process, []) - call_kwargs = mock_srun.call_args.kwargs - env_vars = call_kwargs.get("env_to_set", {}) + # This should NOT throw an error + worker_stage.start_worker(process, []) + call_kwargs = mock_srun.call_args.kwargs + env_vars = call_kwargs.get("env_to_set", {}) - # Supported placeholder should be replaced - assert env_vars["CACHE_DIR"] == "/cache/0/data" + # Supported placeholder should be replaced + assert env_vars["CACHE_DIR"] == "/cache/0/data" - # Unsupported placeholders should remain unchanged - assert env_vars["UNSUPPORTED"] == "/path/{foo}/bar/{baz}" + # Unsupported placeholders should remain unchanged + assert env_vars["UNSUPPORTED"] == "/path/{foo}/bar/{baz}" - # Mixed case: supported replaced, unsupported kept - assert env_vars["MIXED"] == "gpu-01-{unsupported_var}-cache" + # Mixed case: supported replaced, unsupported kept + assert env_vars["MIXED"] == "gpu-01-{unsupported_var}-cache" class TestInfraConfig: @@ -1156,9 +1165,11 @@ def test_nodes_dedicated_infra_requires_two_nodes(self): from srtctl.core.runtime import Nodes - with patch("srtctl.core.runtime.get_slurm_nodelist", return_value=["node0"]): - with pytest.raises(ValueError, match="at least 2 nodes"): - Nodes.from_slurm(etcd_nats_dedicated_node=True) + with ( + patch("srtctl.core.runtime.get_slurm_nodelist", return_value=["node0"]), + pytest.raises(ValueError, match="at least 2 nodes"), + ): + Nodes.from_slurm(etcd_nats_dedicated_node=True) class TestSbatchNodeCount: @@ -1443,6 +1454,302 @@ def test_enabled_does_not_pack_when_one_node_does_not_fit(self): assert endpoints[1].nodes == ("node1",) +class TestHetJobsValidation: + """SrtConfig.__post_init__ validation for `resources.het_jobs: true`.""" + + def _make(self, **resource_overrides): + from srtctl.core.schema import ModelConfig, ResourceConfig, SrtConfig + + resources_kwargs = dict( + gpu_type="gb200", + gpus_per_node=4, + prefill_nodes=12, + decode_nodes=10, + prefill_workers=12, + decode_workers=10, + het_jobs=True, + ) + backend = resource_overrides.pop("backend", None) + resources_kwargs.update(resource_overrides) + kwargs = dict( + name="t", + model=ModelConfig(path="/m", container="/c.sqsh", precision="fp8"), + resources=ResourceConfig(**resources_kwargs), + ) + if backend is not None: + kwargs["backend"] = backend + return SrtConfig, kwargs + + def test_het_jobs_passes_with_disagg_sglang(self): + SrtConfig, kwargs = self._make() + cfg = SrtConfig(**kwargs) + assert cfg.resources.het_jobs is True + + def test_het_jobs_rejected_in_agg_mode(self): + import pytest + from marshmallow import ValidationError + + SrtConfig, kwargs = self._make( + prefill_nodes=None, + decode_nodes=None, + prefill_workers=None, + decode_workers=None, + agg_nodes=2, + agg_workers=2, + ) + with pytest.raises(ValidationError, match="disaggregated layout"): + SrtConfig(**kwargs) + + def test_het_jobs_rejected_on_trtllm(self): + import pytest + from marshmallow import ValidationError + + from srtctl.backends import TRTLLMProtocol + + SrtConfig, kwargs = self._make(backend=TRTLLMProtocol()) + with pytest.raises(ValidationError, match="only supported on the sglang backend"): + SrtConfig(**kwargs) + + def test_het_jobs_rejected_with_zero_nodes(self): + import pytest + from marshmallow import ValidationError + + SrtConfig, kwargs = self._make(prefill_nodes=0) + with pytest.raises(ValidationError, match="prefill_nodes >= 1"): + SrtConfig(**kwargs) + + def test_het_jobs_off_is_unrestricted(self): + """Recipe with het_jobs=None or False should not trigger het validation.""" + from srtctl.backends import TRTLLMProtocol + from srtctl.core.schema import ModelConfig, ResourceConfig, SrtConfig + + # trtllm + agg is fine when het is off — would only fail if het_jobs were True. + cfg = SrtConfig( + name="t", + model=ModelConfig(path="/m", container="/c.sqsh", precision="fp8"), + resources=ResourceConfig( + gpu_type="gb200", + gpus_per_node=4, + agg_nodes=2, + agg_workers=2, + het_jobs=False, + ), + backend=TRTLLMProtocol(), + ) + assert cfg.resources.het_jobs is False + + +class TestHetComponents: + """ResourceConfig.het_components() shape.""" + + def _resources(self, **overrides): + from srtctl.core.schema import ResourceConfig + + base = dict( + gpu_type="gb200", + gpus_per_node=4, + prefill_nodes=12, + decode_nodes=10, + prefill_workers=12, + decode_workers=10, + het_jobs=True, + ) + base.update(overrides) + return ResourceConfig(**base) + + def test_het_components_returns_two_components(self): + r = self._resources() + components = r.het_components(infra_dedicated=False) + assert components is not None + prefill, decode = components + assert prefill.name == "prefill" + assert prefill.group == 0 + assert prefill.nodes == 12 + assert prefill.segment == 12 + assert decode.name == "decode" + assert decode.group == 1 + assert decode.nodes == 10 + assert decode.segment == 10 + + def test_het_components_folds_infra_into_prefill(self): + r = self._resources() + components = r.het_components(infra_dedicated=True) + assert components is not None + prefill, decode = components + # prefill_nodes (12) + 1 dedicated infra + assert prefill.nodes == 13 + assert prefill.segment == 13 + # decode unchanged + assert decode.nodes == 10 + assert decode.segment == 10 + + def test_het_components_none_when_off(self): + from srtctl.core.schema import ResourceConfig + + r = ResourceConfig( + gpu_type="gb200", + gpus_per_node=4, + prefill_nodes=12, + decode_nodes=10, + prefill_workers=12, + decode_workers=10, + het_jobs=False, + ) + assert r.het_components(infra_dedicated=False) is None + + def test_het_components_cluster_default_applies_when_recipe_none(self): + from srtctl.core.schema import ResourceConfig + + r = ResourceConfig( + gpu_type="gb200", + gpus_per_node=4, + prefill_nodes=12, + decode_nodes=10, + prefill_workers=12, + decode_workers=10, + het_jobs=None, + ) + # cluster_default=False -> off + assert r.het_components(infra_dedicated=False) is None + # cluster_default=True -> on + assert r.het_components(infra_dedicated=False, cluster_default=True) is not None + + +class TestHetJobsSbatchScript: + """generate_minimal_sbatch_script() emits het structure when het_jobs is True.""" + + def _config(self, *, het_jobs, infra_dedicated): + from srtctl.core.schema import InfraConfig, ModelConfig, ResourceConfig, SrtConfig + + return SrtConfig( + name="t", + model=ModelConfig(path="/m", container="/c.sqsh", precision="fp8"), + resources=ResourceConfig( + gpu_type="gb200", + gpus_per_node=4, + prefill_nodes=12, + decode_nodes=10, + prefill_workers=12, + decode_workers=10, + het_jobs=het_jobs, + ), + infra=InfraConfig(etcd_nats_dedicated_node=infra_dedicated), + ) + + def test_emits_hetjob_separator_and_two_segments(self): + from pathlib import Path + + from srtctl.cli.submit import generate_minimal_sbatch_script + + cfg = self._config(het_jobs=True, infra_dedicated=False) + script = generate_minimal_sbatch_script(cfg, Path("/tmp/test.yaml")) + + assert script.count("#SBATCH hetjob") == 1 + assert "#SBATCH --segment=12" in script + assert "#SBATCH --segment=10" in script + # SLURM het-jobs need --account/--time/--partition repeated per component + # (each #SBATCH directive applies to the component it follows, not the job). + assert script.count("#SBATCH --account=") == 2 + assert script.count("#SBATCH --partition=") == 2 + # --output is job-wide (only one log file), so it appears once at the top. + assert script.count("#SBATCH --output=") == 1 + # Per-component --nodes lines + assert "#SBATCH --nodes=12" in script + assert "#SBATCH --nodes=10" in script + + def test_infra_folds_into_prefill_component(self): + from pathlib import Path + + from srtctl.cli.submit import generate_minimal_sbatch_script + + cfg = self._config(het_jobs=True, infra_dedicated=True) + script = generate_minimal_sbatch_script(cfg, Path("/tmp/test.yaml")) + + # prefill component grows by 1 for the dedicated infra node + assert "#SBATCH --nodes=13" in script + assert "#SBATCH --segment=13" in script + assert "#SBATCH --nodes=10" in script + assert "#SBATCH --segment=10" in script + + def test_no_hetjob_block_when_off(self): + from pathlib import Path + + from srtctl.cli.submit import generate_minimal_sbatch_script + + cfg = self._config(het_jobs=False, infra_dedicated=False) + script = generate_minimal_sbatch_script(cfg, Path("/tmp/test.yaml")) + assert "#SBATCH hetjob" not in script + # Single --nodes line (12 prefill + 10 decode = 22) + assert "#SBATCH --nodes=22" in script + + +class TestNodesHetGroupParsing: + """Nodes.from_slurm reads SLURM_HET_SIZE/SLURM_JOB_NODELIST_HET_GROUP_*.""" + + def test_from_slurm_returns_het_layout(self): + from unittest.mock import patch + + from srtctl.core.runtime import Nodes + + het_lists = [ + ["gb200-01", "gb200-02", "gb200-03"], # group 0: prefill (+ infra) + ["gb200-04", "gb200-05"], # group 1: decode + ] + with patch("srtctl.core.runtime.get_slurm_het_nodelists", return_value=het_lists): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=False) + + assert nodes.het is True + assert nodes.prefill_group == ("gb200-01", "gb200-02", "gb200-03") + assert nodes.decode_group == ("gb200-04", "gb200-05") + assert nodes.worker == ("gb200-01", "gb200-02", "gb200-03", "gb200-04", "gb200-05") + + def test_from_slurm_het_with_dedicated_infra(self): + from unittest.mock import patch + + from srtctl.core.runtime import Nodes + + het_lists = [ + ["gb200-00", "gb200-01", "gb200-02"], # group 0: [infra, prefill...] + ["gb200-03", "gb200-04"], # group 1: decode + ] + with patch("srtctl.core.runtime.get_slurm_het_nodelists", return_value=het_lists): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=True) + + assert nodes.infra == "gb200-00" + assert nodes.head == "gb200-01" + assert nodes.prefill_group == ("gb200-01", "gb200-02") + assert nodes.decode_group == ("gb200-03", "gb200-04") + # Infra node carved out of worker pool + assert "gb200-00" not in nodes.worker + + def test_het_group_for_returns_correct_group(self): + from unittest.mock import patch + + from srtctl.core.runtime import Nodes + + het_lists = [["p0", "p1"], ["d0", "d1"]] + with patch("srtctl.core.runtime.get_slurm_het_nodelists", return_value=het_lists): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=False) + + assert nodes.het_group_for("p0") == 0 + assert nodes.het_group_for("d0") == 1 + assert nodes.het_group_for("unknown") is None + + def test_het_group_for_returns_none_on_non_het(self): + from unittest.mock import patch + + from srtctl.core.runtime import Nodes + + with ( + patch("srtctl.core.runtime.get_slurm_het_nodelists", return_value=None), + patch("srtctl.core.runtime.get_slurm_nodelist", return_value=["n0", "n1"]), + ): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=False) + + assert nodes.het is False + assert nodes.het_group_for("n0") is None + + class TestVLLMDataParallelMode: """Tests for vLLM DP+EP (Data Parallel + Expert Parallel) mode.""" @@ -2122,9 +2429,11 @@ def test_trtllm_hf_model_uses_model_id(self): runtime = self._make_runtime(is_hf=True) runtime.log_dir = Path("/tmp/test-logs") - with patch("pathlib.Path.write_text"): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - cmd = backend.build_worker_command(process=process, endpoint_processes=[process], runtime=runtime) + with ( + patch("pathlib.Path.write_text"), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + cmd = backend.build_worker_command(process=process, endpoint_processes=[process], runtime=runtime) idx = cmd.index("--model-path") assert cmd[idx + 1] == "facebook/opt-125m" @@ -2141,9 +2450,11 @@ def test_trtllm_local_model_uses_container_mount(self): runtime = self._make_runtime(is_hf=False) runtime.log_dir = Path("/tmp/test-logs") - with patch("pathlib.Path.write_text"): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - cmd = backend.build_worker_command(process=process, endpoint_processes=[process], runtime=runtime) + with ( + patch("pathlib.Path.write_text"), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + cmd = backend.build_worker_command(process=process, endpoint_processes=[process], runtime=runtime) idx = cmd.index("--model-path") assert cmd[idx + 1] == "/model" @@ -2184,26 +2495,28 @@ def mock_scontrol(cmd, **kwargs): return result raise subprocess.CalledProcessError(1, cmd) - with patch.dict(os.environ, slurm_env): - with patch("subprocess.run", mock_scontrol): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - config = SrtConfig( - name="test", - model=ModelConfig( - path=str(model_path), - container=str(container_path), - precision="fp8", - ), - resources=ResourceConfig( - gpu_type="h100", - gpus_per_node=8, - prefill_nodes=1, - decode_nodes=1, - ), - ) - runtime = RuntimeContext.from_config(config, job_id="12345") - - assert Path("/infmax-workspace") in runtime.container_mounts.values() + with ( + patch.dict(os.environ, slurm_env), + patch("subprocess.run", mock_scontrol), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + ) + runtime = RuntimeContext.from_config(config, job_id="12345") + + assert Path("/infmax-workspace") in runtime.container_mounts.values() def test_infmax_workspace_mount_not_added_without_env(self, tmp_path): """RuntimeContext does not include /infmax-workspace without env var.""" @@ -2238,25 +2551,27 @@ def mock_scontrol(cmd, **kwargs): with patch.dict(os.environ, slurm_env): os.environ.pop("INFMAX_WORKSPACE", None) - with patch("subprocess.run", mock_scontrol): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - config = SrtConfig( - name="test", - model=ModelConfig( - path=str(model_path), - container=str(container_path), - precision="fp8", - ), - resources=ResourceConfig( - gpu_type="h100", - gpus_per_node=8, - prefill_nodes=1, - decode_nodes=1, - ), - ) - runtime = RuntimeContext.from_config(config, job_id="12345") - - assert Path("/infmax-workspace") not in runtime.container_mounts.values() + with ( + patch("subprocess.run", mock_scontrol), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + ) + runtime = RuntimeContext.from_config(config, job_id="12345") + + assert Path("/infmax-workspace") not in runtime.container_mounts.values() class TestExtraMountExpansion: @@ -2295,25 +2610,27 @@ def mock_scontrol(cmd, **kwargs): return result raise subprocess.CalledProcessError(1, cmd) - with patch.dict(os.environ, slurm_env): - with patch("subprocess.run", mock_scontrol): - with patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"): - config = SrtConfig( - name="test", - model=ModelConfig( - path=str(model_path), - container=str(container_path), - precision="fp8", - ), - resources=ResourceConfig( - gpu_type="h100", - gpus_per_node=8, - prefill_nodes=1, - decode_nodes=1, - ), - extra_mount=("$SRT_EXTRA_ROOT:/extra",), - ) - runtime = RuntimeContext.from_config(config, job_id="12345") - - assert extra_root.resolve() in runtime.container_mounts - assert runtime.container_mounts[extra_root.resolve()] == Path("/extra") + with ( + patch.dict(os.environ, slurm_env), + patch("subprocess.run", mock_scontrol), + patch("srtctl.core.slurm.get_hostname_ip", return_value="10.0.0.1"), + ): + config = SrtConfig( + name="test", + model=ModelConfig( + path=str(model_path), + container=str(container_path), + precision="fp8", + ), + resources=ResourceConfig( + gpu_type="h100", + gpus_per_node=8, + prefill_nodes=1, + decode_nodes=1, + ), + extra_mount=("$SRT_EXTRA_ROOT:/extra",), + ) + runtime = RuntimeContext.from_config(config, job_id="12345") + + assert extra_root.resolve() in runtime.container_mounts + assert runtime.container_mounts[extra_root.resolve()] == Path("/extra") diff --git a/tests/test_dry_run.py b/tests/test_dry_run.py index e806390b..3c54305d 100644 --- a/tests/test_dry_run.py +++ b/tests/test_dry_run.py @@ -313,3 +313,54 @@ def test_mooncake_kv_store_no_container_shows_default(self, capsys): output = capsys.readouterr().out assert "" in output assert "MOONCAKE_PROTOCOL" in output + + +class TestDryRunHetJobs: + """Het structure panel appears only when het is enabled.""" + + def test_het_panel_rendered_when_enabled(self, capsys): + config = _make_config( + { + "resources": { + "gpu_type": "gb200", + "gpus_per_node": 4, + "prefill_nodes": 12, + "decode_nodes": 10, + "prefill_workers": 12, + "decode_workers": 10, + "het_jobs": True, + }, + } + ) + show_config_details(config) + output = capsys.readouterr().out + assert "Heterogeneous Job" in output + assert "prefill" in output + assert "decode" in output + + def test_het_panel_hidden_when_disabled(self, capsys): + """No het panel when het_jobs is unset (recipe default).""" + config = _make_config() + show_config_details(config) + output = capsys.readouterr().out + assert "Heterogeneous Job" not in output + + def test_het_panel_shows_infra_folded_into_prefill(self, capsys): + config = _make_config( + { + "resources": { + "gpu_type": "gb200", + "gpus_per_node": 4, + "prefill_nodes": 12, + "decode_nodes": 10, + "prefill_workers": 12, + "decode_workers": 10, + "het_jobs": True, + }, + "infra": {"etcd_nats_dedicated_node": True}, + } + ) + show_config_details(config) + output = capsys.readouterr().out + assert "Heterogeneous Job" in output + assert "first node" in output # infra note on the prefill row diff --git a/tests/test_e2e.py b/tests/test_e2e.py index a692a243..30f33e53 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -90,6 +90,64 @@ def mock_run(cmd, **kwargs): return mock_run +class GB200HetRack: + """GB200 het-job allocation: prefill component (12 nodes) + decode (10 nodes). + + Models the 48+40 asymmetric case the het-job feature was built for. Group 0 + holds prefill nodes (and the dedicated infra node when configured); group 1 + holds decode nodes. + """ + + PREFILL_NODES = 12 + DECODE_NODES = 10 + GPUS_PER_NODE = 4 + + @classmethod + def prefill_nodelist(cls) -> list[str]: + return [f"gb200-{i:02d}" for i in range(1, cls.PREFILL_NODES + 1)] + + @classmethod + def decode_nodelist(cls) -> list[str]: + return [f"gb200-{i:02d}" for i in range(cls.PREFILL_NODES + 1, cls.PREFILL_NODES + cls.DECODE_NODES + 1)] + + @classmethod + def slurm_env(cls) -> dict[str, str]: + prefill_raw = f"gb200-[01-{cls.PREFILL_NODES:02d}]" + decode_raw = f"gb200-[{cls.PREFILL_NODES + 1:02d}-{cls.PREFILL_NODES + cls.DECODE_NODES:02d}]" + return { + "SLURM_JOB_ID": "13579", + "SLURM_JOBID": "13579", + # SLURM_NODELIST is intentionally omitted — Nodes.from_slurm() should + # take the het branch off SLURM_HET_SIZE before reading it. + "SLURM_HET_SIZE": "2", + "SLURM_JOB_NODELIST_HET_GROUP_0": prefill_raw, + "SLURM_JOB_NODELIST_HET_GROUP_1": decode_raw, + "SLURM_JOB_NUM_NODES": str(cls.PREFILL_NODES + cls.DECODE_NODES), + "SRTCTL_SOURCE_DIR": str(Path(__file__).parent.parent), + } + + @classmethod + def mock_scontrol(cls): + prefill_raw = f"gb200-[01-{cls.PREFILL_NODES:02d}]" + decode_raw = f"gb200-[{cls.PREFILL_NODES + 1:02d}-{cls.PREFILL_NODES + cls.DECODE_NODES:02d}]" + + def mock_run(cmd, **kwargs): + if cmd[0] == "scontrol" and "hostnames" in cmd: + nodelist_raw = cmd[-1] + result = MagicMock() + if nodelist_raw == prefill_raw: + result.stdout = "\n".join(cls.prefill_nodelist()) + elif nodelist_raw == decode_raw: + result.stdout = "\n".join(cls.decode_nodelist()) + else: + raise AssertionError(f"unexpected nodelist {nodelist_raw}") + result.returncode = 0 + return result + raise subprocess.CalledProcessError(1, cmd) + + return mock_run + + # ============================================================================= # Tests # ============================================================================= @@ -108,60 +166,66 @@ class TestGB200FP4Cluster: @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_gpus_per_node_is_4(self, recipe_path): """All GB200 FP4 1k1k configs use 4 GPUs per node.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - assert config.resources.gpus_per_node == self.RACK.GPUS_PER_NODE, ( - f"{recipe_path.name}: expected gpus_per_node={self.RACK.GPUS_PER_NODE}, " - f"got {config.resources.gpus_per_node}" - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + assert config.resources.gpus_per_node == self.RACK.GPUS_PER_NODE, ( + f"{recipe_path.name}: expected gpus_per_node={self.RACK.GPUS_PER_NODE}, " + f"got {config.resources.gpus_per_node}" + ) @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_fits_in_rack(self, recipe_path): """Recipe fits within the GB200 NVL rack (18 nodes).""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - total_nodes_needed = (r.prefill_nodes or 0) + (r.decode_nodes or 0) + (r.agg_nodes or 0) - assert total_nodes_needed <= self.RACK.NUM_NODES, ( - f"{recipe_path.name}: needs {total_nodes_needed} nodes, rack has {self.RACK.NUM_NODES}" - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + total_nodes_needed = (r.prefill_nodes or 0) + (r.decode_nodes or 0) + (r.agg_nodes or 0) + assert ( + total_nodes_needed <= self.RACK.NUM_NODES + ), f"{recipe_path.name}: needs {total_nodes_needed} nodes, rack has {self.RACK.NUM_NODES}" @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_endpoint_allocation(self, recipe_path): """Endpoints are allocated correctly on GB200 NVL rack.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - endpoints = config.backend.allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=r.num_agg, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=r.gpus_per_agg, - gpus_per_node=r.gpus_per_node, - available_nodes=self.RACK.nodes(), - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + endpoints = config.backend.allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=r.num_agg, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=r.gpus_per_agg, + gpus_per_node=r.gpus_per_node, + available_nodes=self.RACK.nodes(), + ) - prefill_eps = [e for e in endpoints if e.mode == "prefill"] - decode_eps = [e for e in endpoints if e.mode == "decode"] + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] - assert len(prefill_eps) == r.num_prefill - assert len(decode_eps) == r.num_decode + assert len(prefill_eps) == r.num_prefill + assert len(decode_eps) == r.num_decode - for ep in prefill_eps: - assert ep.total_gpus == r.gpus_per_prefill, ( - f"prefill endpoint {ep.index} has {ep.total_gpus} GPUs, expected {r.gpus_per_prefill}" - ) + for ep in prefill_eps: + assert ( + ep.total_gpus == r.gpus_per_prefill + ), f"prefill endpoint {ep.index} has {ep.total_gpus} GPUs, expected {r.gpus_per_prefill}" - for ep in decode_eps: - assert ep.total_gpus == r.gpus_per_decode, ( - f"decode endpoint {ep.index} has {ep.total_gpus} GPUs, expected {r.gpus_per_decode}" - ) + for ep in decode_eps: + assert ( + ep.total_gpus == r.gpus_per_decode + ), f"decode endpoint {ep.index} has {ep.total_gpus} GPUs, expected {r.gpus_per_decode}" class TestH100Cluster: @@ -173,21 +237,60 @@ class TestH100Cluster: @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_gpus_per_node_is_8(self, recipe_path): """All H100 configs use 8 GPUs per node.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - assert config.resources.gpus_per_node == self.RACK.GPUS_PER_NODE, ( - f"{recipe_path.name}: expected gpus_per_node={self.RACK.GPUS_PER_NODE}, " - f"got {config.resources.gpus_per_node}" - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + assert config.resources.gpus_per_node == self.RACK.GPUS_PER_NODE, ( + f"{recipe_path.name}: expected gpus_per_node={self.RACK.GPUS_PER_NODE}, " + f"got {config.resources.gpus_per_node}" + ) @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_endpoint_allocation(self, recipe_path): """Endpoints are allocated correctly on H100 rack.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + endpoints = config.backend.allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=r.num_agg, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=r.gpus_per_agg, + gpus_per_node=r.gpus_per_node, + available_nodes=self.RACK.nodes(), + ) + + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] + + assert len(prefill_eps) == r.num_prefill + assert len(decode_eps) == r.num_decode + + for ep in prefill_eps: + assert ep.total_gpus == r.gpus_per_prefill + for ep in decode_eps: + assert ep.total_gpus == r.gpus_per_decode + + @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) + def test_multi_node_tp(self, recipe_path): + """H100 configs with TP > 8 span multiple nodes correctly.""" + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + if r.gpus_per_prefill > self.RACK.GPUS_PER_NODE: + expected_nodes = r.gpus_per_prefill // self.RACK.GPUS_PER_NODE endpoints = config.backend.allocate_endpoints( num_prefill=r.num_prefill, @@ -200,43 +303,10 @@ def test_endpoint_allocation(self, recipe_path): available_nodes=self.RACK.nodes(), ) - prefill_eps = [e for e in endpoints if e.mode == "prefill"] - decode_eps = [e for e in endpoints if e.mode == "decode"] - - assert len(prefill_eps) == r.num_prefill - assert len(decode_eps) == r.num_decode - - for ep in prefill_eps: - assert ep.total_gpus == r.gpus_per_prefill - for ep in decode_eps: - assert ep.total_gpus == r.gpus_per_decode - - @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) - def test_multi_node_tp(self, recipe_path): - """H100 configs with TP > 8 span multiple nodes correctly.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - if r.gpus_per_prefill > self.RACK.GPUS_PER_NODE: - expected_nodes = r.gpus_per_prefill // self.RACK.GPUS_PER_NODE - - endpoints = config.backend.allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=r.num_agg, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=r.gpus_per_agg, - gpus_per_node=r.gpus_per_node, - available_nodes=self.RACK.nodes(), - ) - - for ep in [e for e in endpoints if e.mode == "prefill"]: - assert ep.num_nodes == expected_nodes, ( - f"prefill endpoint should span {expected_nodes} nodes, got {ep.num_nodes}" - ) + for ep in [e for e in endpoints if e.mode == "prefill"]: + assert ( + ep.num_nodes == expected_nodes + ), f"prefill endpoint should span {expected_nodes} nodes, got {ep.num_nodes}" class TestCIConfigs: @@ -250,26 +320,28 @@ def test_agg_config(self): if not recipe_path.exists(): pytest.skip("agg.yaml not found") - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - endpoints = config.backend.allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=r.num_agg, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=r.gpus_per_agg, - gpus_per_node=r.gpus_per_node, - available_nodes=self.RACK.nodes(), - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + endpoints = config.backend.allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=r.num_agg, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=r.gpus_per_agg, + gpus_per_node=r.gpus_per_node, + available_nodes=self.RACK.nodes(), + ) - agg_eps = [e for e in endpoints if e.mode == "agg"] - assert len(agg_eps) == r.num_agg - for ep in agg_eps: - assert ep.total_gpus == r.gpus_per_agg + agg_eps = [e for e in endpoints if e.mode == "agg"] + assert len(agg_eps) == r.num_agg + for ep in agg_eps: + assert ep.total_gpus == r.gpus_per_agg def test_disagg_config(self): """Disaggregated CI config allocates correctly.""" @@ -277,32 +349,34 @@ def test_disagg_config(self): if not recipe_path.exists(): pytest.skip("disagg.yaml not found") - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - endpoints = config.backend.allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=r.num_agg, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=r.gpus_per_agg, - gpus_per_node=r.gpus_per_node, - available_nodes=self.RACK.nodes(), - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + endpoints = config.backend.allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=r.num_agg, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=r.gpus_per_agg, + gpus_per_node=r.gpus_per_node, + available_nodes=self.RACK.nodes(), + ) - prefill_eps = [e for e in endpoints if e.mode == "prefill"] - decode_eps = [e for e in endpoints if e.mode == "decode"] + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] - assert len(prefill_eps) == r.num_prefill - assert len(decode_eps) == r.num_decode + assert len(prefill_eps) == r.num_prefill + assert len(decode_eps) == r.num_decode - for ep in prefill_eps: - assert ep.total_gpus == r.gpus_per_prefill - for ep in decode_eps: - assert ep.total_gpus == r.gpus_per_decode + for ep in prefill_eps: + assert ep.total_gpus == r.gpus_per_prefill + for ep in decode_eps: + assert ep.total_gpus == r.gpus_per_decode class TestQwen32BCluster: @@ -314,11 +388,13 @@ class TestQwen32BCluster: @pytest.mark.parametrize("recipe_path", RECIPES, ids=lambda p: p.name) def test_config_loads(self, recipe_path): """Qwen3-32B configs load correctly.""" - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - assert config.name is not None - assert config.resources.gpus_per_node == 8 + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + assert config.name is not None + assert config.resources.gpus_per_node == 8 def test_disagg_kv_router_shared_node_allocation(self): """disagg-kv-sglang.yaml: 6P+2D on 2 nodes with decode_nodes=0.""" @@ -326,58 +402,60 @@ def test_disagg_kv_router_shared_node_allocation(self): if not recipe_path.exists(): pytest.skip("disagg-kv-sglang.yaml not found") - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - # Verify decode_nodes=0 triggers inheritance from prefill - assert r.decode_nodes == 0, "decode_nodes should be 0" - assert r.gpus_per_prefill == 2, "prefill TP should be 2" - assert r.gpus_per_decode == 2, "decode TP should inherit 2 from prefill" - - # Allocate endpoints - nodes = self.RACK.nodes()[:2] - endpoints = allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=0, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=8, - gpus_per_node=r.gpus_per_node, - available_nodes=nodes, - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + # Verify decode_nodes=0 triggers inheritance from prefill + assert r.decode_nodes == 0, "decode_nodes should be 0" + assert r.gpus_per_prefill == 2, "prefill TP should be 2" + assert r.gpus_per_decode == 2, "decode TP should inherit 2 from prefill" + + # Allocate endpoints + nodes = self.RACK.nodes()[:2] + endpoints = allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=0, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=8, + gpus_per_node=r.gpus_per_node, + available_nodes=nodes, + ) - prefill_eps = [e for e in endpoints if e.mode == "prefill"] - decode_eps = [e for e in endpoints if e.mode == "decode"] + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] - assert len(prefill_eps) == 6 - assert len(decode_eps) == 2 + assert len(prefill_eps) == 6 + assert len(decode_eps) == 2 - # Check prefill allocation: first 4 on node0, next 2 on node1 - for i, ep in enumerate(prefill_eps[:4]): - assert ep.nodes[0] == nodes[0], f"prefill {i} should be on node0" - for i, ep in enumerate(prefill_eps[4:]): - assert ep.nodes[0] == nodes[1], f"prefill {i + 4} should be on node1" + # Check prefill allocation: first 4 on node0, next 2 on node1 + for i, ep in enumerate(prefill_eps[:4]): + assert ep.nodes[0] == nodes[0], f"prefill {i} should be on node0" + for i, ep in enumerate(prefill_eps[4:]): + assert ep.nodes[0] == nodes[1], f"prefill {i + 4} should be on node1" - # Check decode allocation: on node1 (GPUs 4-5, 6-7) - for ep in decode_eps: - assert ep.nodes[0] == nodes[1], "decode should be on node1" + # Check decode allocation: on node1 (GPUs 4-5, 6-7) + for ep in decode_eps: + assert ep.nodes[0] == nodes[1], "decode should be on node1" - # Verify GPU indices don't overlap on shared node (node1) - node1_prefill_gpus = set() - for ep in prefill_eps: - if ep.nodes[0] == nodes[1]: - node1_prefill_gpus.update(ep.gpu_indices) + # Verify GPU indices don't overlap on shared node (node1) + node1_prefill_gpus = set() + for ep in prefill_eps: + if ep.nodes[0] == nodes[1]: + node1_prefill_gpus.update(ep.gpu_indices) - node1_decode_gpus = set() - for ep in decode_eps: - node1_decode_gpus.update(ep.gpu_indices) + node1_decode_gpus = set() + for ep in decode_eps: + node1_decode_gpus.update(ep.gpu_indices) - assert node1_prefill_gpus.isdisjoint(node1_decode_gpus), ( - f"GPU overlap on node1! prefill uses {node1_prefill_gpus}, decode uses {node1_decode_gpus}" - ) + assert node1_prefill_gpus.isdisjoint( + node1_decode_gpus + ), f"GPU overlap on node1! prefill uses {node1_prefill_gpus}, decode uses {node1_decode_gpus}" def test_disagg_kv_router_cuda_visible_devices(self): """Processes on shared node have non-overlapping CUDA_VISIBLE_DEVICES.""" @@ -385,49 +463,58 @@ def test_disagg_kv_router_cuda_visible_devices(self): if not recipe_path.exists(): pytest.skip("disagg-kv-sglang.yaml not found") - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources - - nodes = self.RACK.nodes()[:2] - endpoints = allocate_endpoints( - num_prefill=r.num_prefill, - num_decode=r.num_decode, - num_agg=0, - gpus_per_prefill=r.gpus_per_prefill, - gpus_per_decode=r.gpus_per_decode, - gpus_per_agg=8, - gpus_per_node=r.gpus_per_node, - available_nodes=nodes, - ) - - processes = endpoints_to_processes(endpoints) - - # Group processes by node - node1_processes = [p for p in processes if p.node == nodes[1]] - - # Should have 2 prefill + 2 decode = 4 processes on node1 - assert len(node1_processes) == 4, f"Expected 4 processes on node1, got {len(node1_processes)}" - - # Each process should have unique, non-overlapping GPU indices - all_gpus_on_node1 = set() - for proc in node1_processes: - for gpu in proc.gpu_indices: - assert gpu not in all_gpus_on_node1, f"GPU {gpu} assigned to multiple processes on {nodes[1]}!" - all_gpus_on_node1.add(gpu) - - # All 8 GPUs on node1 should be used - assert all_gpus_on_node1 == {0, 1, 2, 3, 4, 5, 6, 7}, ( - f"Expected all 8 GPUs used on node1, got {all_gpus_on_node1}" - ) + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources + + nodes = self.RACK.nodes()[:2] + endpoints = allocate_endpoints( + num_prefill=r.num_prefill, + num_decode=r.num_decode, + num_agg=0, + gpus_per_prefill=r.gpus_per_prefill, + gpus_per_decode=r.gpus_per_decode, + gpus_per_agg=8, + gpus_per_node=r.gpus_per_node, + available_nodes=nodes, + ) - # Verify CUDA_VISIBLE_DEVICES strings are correct - for proc in node1_processes: - cvd = proc.cuda_visible_devices - expected_gpus = sorted(proc.gpu_indices) - expected_cvd = ",".join(str(g) for g in expected_gpus) - assert cvd == expected_cvd, f"Expected CUDA_VISIBLE_DEVICES={expected_cvd}, got {cvd}" + processes = endpoints_to_processes(endpoints) + + # Group processes by node + node1_processes = [p for p in processes if p.node == nodes[1]] + + # Should have 2 prefill + 2 decode = 4 processes on node1 + assert len(node1_processes) == 4, f"Expected 4 processes on node1, got {len(node1_processes)}" + + # Each process should have unique, non-overlapping GPU indices + all_gpus_on_node1 = set() + for proc in node1_processes: + for gpu in proc.gpu_indices: + assert gpu not in all_gpus_on_node1, f"GPU {gpu} assigned to multiple processes on {nodes[1]}!" + all_gpus_on_node1.add(gpu) + + # All 8 GPUs on node1 should be used + assert all_gpus_on_node1 == { + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + }, f"Expected all 8 GPUs used on node1, got {all_gpus_on_node1}" + + # Verify CUDA_VISIBLE_DEVICES strings are correct + for proc in node1_processes: + cvd = proc.cuda_visible_devices + expected_gpus = sorted(proc.gpu_indices) + expected_cvd = ",".join(str(g) for g in expected_gpus) + assert cvd == expected_cvd, f"Expected CUDA_VISIBLE_DEVICES={expected_cvd}, got {cvd}" def test_disagg_kv_router_total_allocation_fits(self): """Total GPU allocation fits within declared nodes.""" @@ -435,18 +522,20 @@ def test_disagg_kv_router_total_allocation_fits(self): if not recipe_path.exists(): pytest.skip("disagg-kv-sglang.yaml not found") - with patch.dict(os.environ, self.RACK.slurm_env(), clear=False): - with patch("subprocess.run", side_effect=self.RACK.mock_scontrol()): - config = load_config(str(recipe_path)) - r = config.resources + with ( + patch.dict(os.environ, self.RACK.slurm_env(), clear=False), + patch("subprocess.run", side_effect=self.RACK.mock_scontrol()), + ): + config = load_config(str(recipe_path)) + r = config.resources - total_gpus_needed = r.num_prefill * r.gpus_per_prefill + r.num_decode * r.gpus_per_decode - total_gpus_available = r.total_nodes * r.gpus_per_node + total_gpus_needed = r.num_prefill * r.gpus_per_prefill + r.num_decode * r.gpus_per_decode + total_gpus_available = r.total_nodes * r.gpus_per_node - assert total_gpus_needed <= total_gpus_available, ( - f"Need {total_gpus_needed} GPUs but only have {total_gpus_available} " - f"({r.total_nodes} nodes × {r.gpus_per_node} GPUs)" - ) + assert total_gpus_needed <= total_gpus_available, ( + f"Need {total_gpus_needed} GPUs but only have {total_gpus_available} " + f"({r.total_nodes} nodes × {r.gpus_per_node} GPUs)" + ) class TestMooncakeKVStore: @@ -678,3 +767,48 @@ def test_mooncake_kv_store_no_container(self): assert config.backend.mooncake_kv_store is not None assert config.backend.mooncake_kv_store.container is None assert config.backend.mooncake_kv_store.env["MOONCAKE_PROTOCOL"] == "rdma" + + +class TestGB200HetAsymmetric: + """End-to-end test of het-job nodelist parsing + endpoint allocation.""" + + def test_nodes_carves_into_two_components(self): + from srtctl.core.runtime import Nodes + + with patch.dict(os.environ, GB200HetRack.slurm_env()), patch("subprocess.run", GB200HetRack.mock_scontrol()): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=False) + + assert nodes.het is True + assert len(nodes.prefill_group) == GB200HetRack.PREFILL_NODES + assert len(nodes.decode_group) == GB200HetRack.DECODE_NODES + # Worker pool is the concatenation + assert len(nodes.worker) == GB200HetRack.PREFILL_NODES + GB200HetRack.DECODE_NODES + + def test_endpoint_allocation_respects_group_isolation(self): + from srtctl.core.runtime import Nodes + from srtctl.core.topology import allocate_endpoints_het + + with patch.dict(os.environ, GB200HetRack.slurm_env()), patch("subprocess.run", GB200HetRack.mock_scontrol()): + nodes = Nodes.from_slurm(etcd_nats_dedicated_node=False) + + # 12 prefill workers at TP4 (1 node each) + 10 decode workers at TP4 + endpoints = allocate_endpoints_het( + num_prefill=12, + gpus_per_prefill=4, + prefill_nodes=nodes.prefill_group, + num_decode=10, + gpus_per_decode=4, + decode_nodes=nodes.decode_group, + gpus_per_node=GB200HetRack.GPUS_PER_NODE, + ) + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] + assert len(prefill_eps) == 12 + assert len(decode_eps) == 10 + # No prefill worker on a decode node + for ep in prefill_eps: + assert all(n in nodes.prefill_group for n in ep.nodes) + assert ep.het_group == 0 + for ep in decode_eps: + assert all(n in nodes.decode_group for n in ep.nodes) + assert ep.het_group == 1 diff --git a/tests/test_endpoint_allocation.py b/tests/test_endpoint_allocation.py index 0b8afd3c..5b362e40 100644 --- a/tests/test_endpoint_allocation.py +++ b/tests/test_endpoint_allocation.py @@ -8,6 +8,7 @@ from srtctl.core.topology import ( NodePortAllocator, allocate_endpoints, + allocate_endpoints_het, endpoints_to_processes, ) from srtctl.ports import ( @@ -462,3 +463,89 @@ def test_endpoints_to_processes_uses_default_sys_port(self): processes = endpoints_to_processes(endpoints) assert [p.sys_port for p in processes] == [DYN_SYSTEM_PORT_BASE, DYN_SYSTEM_PORT_BASE + 1] + + +class TestAllocateEndpointsHet: + """Per-side heterogeneous-job allocation.""" + + def test_prefill_and_decode_isolated_to_own_pools(self): + """Prefill workers only land on prefill_nodes; decode only on decode_nodes.""" + # Asymmetric: 12 prefill nodes + 10 decode nodes (the 48+40 case) + prefill_nodes = tuple(f"p-{i:02d}" for i in range(12)) + decode_nodes = tuple(f"d-{i:02d}" for i in range(10)) + + endpoints = allocate_endpoints_het( + num_prefill=12, + gpus_per_prefill=4, + prefill_nodes=prefill_nodes, + num_decode=10, + gpus_per_decode=4, + decode_nodes=decode_nodes, + gpus_per_node=4, + ) + + prefill_eps = [e for e in endpoints if e.mode == "prefill"] + decode_eps = [e for e in endpoints if e.mode == "decode"] + assert len(prefill_eps) == 12 + assert len(decode_eps) == 10 + + # Side isolation: no prefill endpoint lands on a decode node, and vice versa. + for ep in prefill_eps: + for node in ep.nodes: + assert node in prefill_nodes, f"prefill ep on decode node {node}" + for ep in decode_eps: + for node in ep.nodes: + assert node in decode_nodes, f"decode ep on prefill node {node}" + + def test_het_group_tagged_on_endpoints(self): + prefill_nodes = ("p0", "p1") + decode_nodes = ("d0", "d1") + endpoints = allocate_endpoints_het( + num_prefill=2, + gpus_per_prefill=4, + prefill_nodes=prefill_nodes, + num_decode=2, + gpus_per_decode=4, + decode_nodes=decode_nodes, + gpus_per_node=4, + ) + for ep in endpoints: + if ep.mode == "prefill": + assert ep.het_group == 0 + elif ep.mode == "decode": + assert ep.het_group == 1 + + def test_het_group_propagates_to_processes(self): + endpoints = allocate_endpoints_het( + num_prefill=1, + gpus_per_prefill=4, + prefill_nodes=("p0",), + num_decode=1, + gpus_per_decode=4, + decode_nodes=("d0",), + gpus_per_node=4, + ) + processes = endpoints_to_processes(endpoints) + for proc in processes: + if proc.endpoint_mode == "prefill": + assert proc.het_group == 0 + elif proc.endpoint_mode == "decode": + assert proc.het_group == 1 + + def test_multi_node_prefill_worker_stays_in_prefill_pool(self): + # Single prefill worker with TP8 (2 nodes) — confirm it pulls from prefill_nodes only. + prefill_nodes = ("p0", "p1", "p2", "p3") + decode_nodes = ("d0", "d1") + endpoints = allocate_endpoints_het( + num_prefill=1, + gpus_per_prefill=8, + prefill_nodes=prefill_nodes, + num_decode=2, + gpus_per_decode=4, + decode_nodes=decode_nodes, + gpus_per_node=4, + ) + prefill_ep = next(e for e in endpoints if e.mode == "prefill") + assert len(prefill_ep.nodes) == 2 + for node in prefill_ep.nodes: + assert node in prefill_nodes diff --git a/tests/test_slurm.py b/tests/test_slurm.py index 6dcc20cc..789c449b 100644 --- a/tests/test_slurm.py +++ b/tests/test_slurm.py @@ -10,7 +10,7 @@ from srtctl.cli.mixins.worker_stage import WorkerStageMixin from srtctl.core.schema import ObservabilityConfig -from srtctl.core.slurm import start_srun_process +from srtctl.core.slurm import get_slurm_het_nodelists, start_srun_process def _built_bash_command(mock_popen: MagicMock) -> str: @@ -157,6 +157,7 @@ def test_worker_stage_wraps_nonfatal_fingerprint_hook(tmp_path: Path) -> None: sys_port=5000, gpu_indices=list(range(8)), cuda_visible_devices="0,1,2,3,4,5,6,7", + het_group=None, ) with ( @@ -170,3 +171,72 @@ def test_worker_stage_wraps_nonfatal_fingerprint_hook(tmp_path: Path) -> None: assert "setup.sh" in bash_preamble assert "/configs/patches/${setup_script}" in bash_preamble assert bash_preamble.endswith("&& ( fingerprint || true )") + + +# ---- Heterogeneous-job nodelist parsing ---- + + +def test_get_slurm_het_nodelists_returns_none_without_het_size() -> None: + with patch.dict("os.environ", {}, clear=False): + # Make sure SLURM_HET_SIZE is unset + import os + + os.environ.pop("SLURM_HET_SIZE", None) + assert get_slurm_het_nodelists() is None + + +def test_get_slurm_het_nodelists_returns_none_for_size_one() -> None: + with patch.dict("os.environ", {"SLURM_HET_SIZE": "1"}): + assert get_slurm_het_nodelists() is None + + +def test_get_slurm_het_nodelists_expands_two_groups() -> None: + env = { + "SLURM_HET_SIZE": "2", + "SLURM_JOB_NODELIST_HET_GROUP_0": "gb200-[01-03]", + "SLURM_JOB_NODELIST_HET_GROUP_1": "gb200-[04-05]", + } + + def mock_run(cmd, **kwargs): + result = MagicMock() + # cmd[-1] is the raw nodelist passed to `scontrol show hostnames` + nodelist_raw = cmd[-1] + if nodelist_raw == "gb200-[01-03]": + result.stdout = "gb200-01\ngb200-02\ngb200-03\n" + elif nodelist_raw == "gb200-[04-05]": + result.stdout = "gb200-04\ngb200-05\n" + else: + raise AssertionError(f"unexpected nodelist {nodelist_raw}") + result.returncode = 0 + return result + + with patch.dict("os.environ", env), patch("subprocess.run", side_effect=mock_run): + groups = get_slurm_het_nodelists() + assert groups == [["gb200-01", "gb200-02", "gb200-03"], ["gb200-04", "gb200-05"]] + + +def test_start_srun_emits_het_group_flag() -> None: + with ( + patch("srtctl.core.slurm.get_slurm_job_id", return_value="12345"), + patch("srtctl.core.slurm._get_cluster_bash_preamble", return_value=None), + patch("subprocess.Popen") as mock_popen, + ): + mock_popen.return_value = MagicMock() + start_srun_process(["echo", "hi"], het_group=1) + + srun_cmd = mock_popen.call_args.args[0] + assert "--het-group=1" in srun_cmd + + +def test_start_srun_omits_het_group_when_none() -> None: + with ( + patch("srtctl.core.slurm.get_slurm_job_id", return_value="12345"), + patch("srtctl.core.slurm._get_cluster_bash_preamble", return_value=None), + patch("subprocess.Popen") as mock_popen, + ): + mock_popen.return_value = MagicMock() + start_srun_process(["echo", "hi"]) # default het_group=None + + srun_cmd = mock_popen.call_args.args[0] + for arg in srun_cmd: + assert not str(arg).startswith("--het-group") diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index 8ef7063b..84bed44c 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -127,6 +127,7 @@ def __init__(self): self.runtime = MagicMock() self.runtime.log_dir = tmp_path self.runtime.nodes.head = "node-a" + self.runtime.nodes.het = False self.runtime.srun_options = {} self.runtime.container_mounts = {Path(tmp_path): Path("/logs")} self._backend_processes = [