Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/srtctl/cli/do_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,22 @@ def _run_post_eval(self, stop_event: threading.Event) -> int:
env_to_set["EVAL_CONC"] = str(max(conc_list))
logger.info("Eval concurrency (max of %s): %s", conc_list, env_to_set["EVAL_CONC"])

bash_preamble = None
if self.config.setup_script:
script_path = f"/configs/{self.config.setup_script}"
bash_preamble = (
f"echo 'Running setup script: {script_path}' && "
f"if [ -f '{script_path}' ]; then bash '{script_path}'; else echo 'WARNING: {script_path} not found'; fi"
)

proc = start_srun_process(
command=cmd,
nodelist=[self.runtime.nodes.head],
output=str(eval_log),
container_image=str(self.runtime.container_image),
container_mounts=self.runtime.container_mounts,
env_to_set=env_to_set,
bash_preamble=bash_preamble,
het_group=self.runtime.nodes.het_group_for(self.runtime.nodes.head),
)

Expand Down
15 changes: 15 additions & 0 deletions src/srtctl/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,21 @@ def resolve_config_with_defaults(user_config: dict[str, Any], cluster_config: di
model["container"] = resolved_container
logger.debug(f"Resolved container alias '{container}' -> '{resolved_container}'")

# Resolve extra_mount host path aliases through model_paths. This lets
# recipes mount secondary model assets by alias rather than cluster path.
extra_mounts = config.get("extra_mount", [])
if model_paths and extra_mounts:
resolved_mounts = []
for mount_spec in extra_mounts:
host_path, container_path = mount_spec.split(":", 1)
if host_path in model_paths:
resolved_host = model_paths[host_path]
resolved_mounts.append(f"{resolved_host}:{container_path}")
logger.debug(f"Resolved extra_mount alias '{host_path}' -> '{resolved_host}'")
else:
resolved_mounts.append(mount_spec)
config["extra_mount"] = resolved_mounts

# Apply reporting defaults (if not specified in user config)
if "reporting" not in config and cluster_config.get("reporting"):
config["reporting"] = cluster_config["reporting"]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,36 @@ def test_successful_eval(self):
result = orch._run_post_eval(stop)
assert result == 0

def test_post_eval_setup_script_preamble(self):
"""setup_script runs before post-eval inside the eval container."""
import os
import threading
from dataclasses import replace
from unittest.mock import MagicMock, patch

orch = self._make_orchestrator()
orch.config = replace(orch.config, setup_script="install-trtllm-pip.sh")
stop = threading.Event()

mock_proc = MagicMock()
mock_proc.poll.return_value = 0
mock_proc.returncode = 0

captured_kwargs = {}

def capture_srun(**kwargs):
captured_kwargs.update(kwargs)
return mock_proc

with patch.dict(os.environ, {"EVAL_ONLY": "false"}, clear=False):
with patch("srtctl.cli.do_sweep.wait_for_port", return_value=True):
with patch("srtctl.cli.do_sweep.start_srun_process", side_effect=capture_srun):
result = orch._run_post_eval(stop)

assert result == 0
assert "Running setup script: /configs/install-trtllm-pip.sh" in captured_kwargs["bash_preamble"]
assert "bash '/configs/install-trtllm-pip.sh'" in captured_kwargs["bash_preamble"]

def test_eval_only_successful(self):
"""Returns 0 in eval-only mode when health check and eval succeed."""
import os
Expand Down
18 changes: 18 additions & 0 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,24 @@ def test_default_health_check_does_not_override_recipe(self):

assert resolved["health_check"] == {"max_attempts": 720, "interval_seconds": 10}

def test_extra_mount_model_path_alias_resolution(self):
"""extra_mount host paths can resolve through srtslurm.yaml model_paths."""
from srtctl.core.config import resolve_config_with_defaults

user_config = {
"name": "test",
"model": {"path": "/model", "container": "/c.sqsh", "precision": "fp8"},
"resources": {"gpu_type": "h100", "gpus_per_node": 8, "agg_nodes": 1},
"extra_mount": ["glm5-eagle:/eagle", "/literal/path:/literal"],
}

resolved = resolve_config_with_defaults(
user_config,
{"model_paths": {"glm5-eagle": "/models/glm5-eagle"}},
)

assert resolved["extra_mount"] == ["/models/glm5-eagle:/eagle", "/literal/path:/literal"]

def test_cluster_sbatch_directives_are_not_treated_as_defaults(self):
"""srtslurm.yaml defaults must use default_sbatch_directives explicitly."""
from srtctl.core.config import resolve_config_with_defaults
Expand Down
Loading