Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/sparkrun/cli/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
@click.option("--foreground", is_flag=True, help="Run in foreground (don't detach)")
@click.option("--no-follow", is_flag=True, help="Don't follow container logs after launch")
@click.option("--no-sync-tuning", is_flag=True, help="Skip syncing tuning configs from registries")
@click.option("--no-rm", is_flag=True, help="Don't auto-remove containers on exit (keeps containers after stop)")
@click.option("--restart", "restart_policy", default=None,
help="Docker restart policy (no, always, unless-stopped, on-failure[:N])")
@click.option("--transfer-mode", default=None,
type=click.Choice(["auto", "local", "push", "delegated"], case_sensitive=False),
help="Resource transfer mode (overrides cluster setting)")
Expand All @@ -52,7 +55,7 @@ def run(
ctx, recipe_name, hosts, hosts_file, cluster_name, solo, port, tensor_parallel,
pipeline_parallel, gpu_mem, served_model_name, max_model_len, image, cache_dir,
ray_port, init_port, dashboard, dashboard_port,
dry_run, foreground, no_follow, no_sync_tuning, transfer_mode,
dry_run, foreground, no_follow, no_sync_tuning, no_rm, restart_policy, transfer_mode,
options, extra_args, config_path=None, setup=True,
):
"""Run an inference recipe.
Expand Down Expand Up @@ -222,6 +225,11 @@ def run(
click.echo(" Workers: %s" % ", ".join(host_list[1:]))
click.echo()

# Resolve container lifecycle options
auto_remove = not no_rm
if restart_policy:
auto_remove = False

# Launch via shared pipeline
result = launch_inference(
recipe=recipe,
Expand All @@ -242,6 +250,8 @@ def run(
dashboard_port=dashboard_port,
dashboard=dashboard,
init_port=init_port,
auto_remove=auto_remove,
restart_policy=restart_policy,
)

click.echo("Cluster: %s" % result.cluster_id)
Expand Down
5 changes: 5 additions & 0 deletions src/sparkrun/core/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def launch_inference(
dashboard_port: int | None = None,
dashboard: bool = False,
init_port: int | None = None,
# Container lifecycle options
auto_remove: bool = True,
restart_policy: str | None = None,
) -> LaunchResult:
"""Launch an inference workload.
Expand Down Expand Up @@ -285,6 +288,8 @@ def launch_inference(
nccl_env=nccl_env,
ib_ip_map=ib_ip_map,
skip_keys=skip_keys,
auto_remove=auto_remove,
restart_policy=restart_policy,
**run_kwargs,
)

Expand Down
17 changes: 16 additions & 1 deletion src/sparkrun/orchestration/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_DEFAULT_DOCKER_OPTS = [
"--privileged",
"--gpus all",
"--rm",
"--ipc=host",
"--shm-size=10.24gb",
"--network host",
Expand All @@ -30,6 +29,8 @@ def docker_run_cmd(
env: dict[str, str] | None = None,
volumes: dict[str, str] | None = None,
extra_opts: list[str] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> str:
"""Generate a ``docker run`` command string.

Expand All @@ -41,17 +42,31 @@ def docker_run_cmd(
env: Environment variables to set (``-e KEY=VALUE``).
volumes: Volume mounts (``-v host:container``).
extra_opts: Additional docker run options.
auto_remove: Add ``--rm`` flag (default True). Forced to False
when *restart_policy* is set (Docker does not allow both).
restart_policy: Docker restart policy (e.g. ``always``,
``unless-stopped``, ``on-failure:3``).

Returns:
Complete ``docker run`` command string.
"""
# Docker does not allow --rm with --restart
if restart_policy:
auto_remove = False

parts = ["docker", "run"]

if detach:
parts.append("-d")

parts.extend(_DEFAULT_DOCKER_OPTS)

if auto_remove:
parts.append("--rm")

if restart_policy:
parts.extend(["--restart", restart_policy])

if container_name:
parts.extend(["--name", container_name])

Expand Down
12 changes: 12 additions & 0 deletions src/sparkrun/orchestration/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def generate_container_launch_script(
nccl_env: dict[str, str] | None = None,
detach: bool = True,
extra_docker_opts: list[str] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> str:
"""Generate a script that launches a Docker container.

Expand Down Expand Up @@ -68,6 +70,8 @@ def generate_container_launch_script(
env=all_env,
volumes=volumes,
extra_opts=extra_docker_opts,
auto_remove=auto_remove,
restart_policy=restart_policy,
)

template = read_script("container_launch.sh")
Expand All @@ -88,6 +92,8 @@ def generate_ray_head_script(
env: dict[str, str] | None = None,
volumes: dict[str, str] | None = None,
nccl_env: dict[str, str] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> str:
"""Generate a script that starts a Ray head node in a container.

Expand Down Expand Up @@ -135,6 +141,8 @@ def generate_ray_head_script(
detach=True,
env=all_env,
volumes=volumes,
auto_remove=auto_remove,
restart_policy=restart_policy,
)

template = read_script("ray_head.sh")
Expand All @@ -152,6 +160,8 @@ def generate_ray_worker_script(
env: dict[str, str] | None = None,
volumes: dict[str, str] | None = None,
nccl_env: dict[str, str] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> str:
"""Generate a script that starts a Ray worker node.

Expand Down Expand Up @@ -185,6 +195,8 @@ def generate_ray_worker_script(
detach=True,
env=all_env,
volumes=volumes,
auto_remove=auto_remove,
restart_policy=restart_policy,
)

template = read_script("ray_worker.sh")
Expand Down
14 changes: 14 additions & 0 deletions src/sparkrun/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ def run(
nccl_env: dict[str, str] | None = None,
ib_ip_map: dict[str, str] | None = None,
skip_keys: set[str] | frozenset[str] = frozenset(),
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Launch a workload -- delegates to solo or cluster implementation.
Expand Down Expand Up @@ -569,6 +571,8 @@ def run(
nccl_env=nccl_env,
recipe=recipe,
overrides=overrides,
auto_remove=auto_remove,
restart_policy=restart_policy,
)
return self._run_cluster(
hosts=hosts,
Expand All @@ -585,6 +589,8 @@ def run(
nccl_env=nccl_env,
ib_ip_map=ib_ip_map,
skip_keys=skip_keys,
auto_remove=auto_remove,
restart_policy=restart_policy,
**kwargs,
)

Expand Down Expand Up @@ -670,6 +676,8 @@ def _run_solo(
nccl_env: dict[str, str] | None = None,
recipe: Recipe | None = None,
overrides: dict[str, Any] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> int:
"""Launch a single-node inference workload.

Expand Down Expand Up @@ -728,6 +736,8 @@ def _run_solo(
volumes=volumes,
nccl_env=nccl_env,
extra_docker_opts=self.get_extra_docker_opts() or None,
auto_remove=auto_remove,
restart_policy=restart_policy,
)
result = run_script_on_host(
host, launch_script, ssh_kwargs=ssh_kwargs, timeout=120, dry_run=dry_run,
Expand Down Expand Up @@ -809,6 +819,8 @@ def _generate_node_script(
volumes: dict[str, str] | None = None,
nccl_env: dict[str, str] | None = None,
extra_docker_opts: list[str] | None = None,
auto_remove: bool = True,
restart_policy: str | None = None,
) -> str:
"""Generate a script that launches a container with a direct entrypoint command.

Expand Down Expand Up @@ -842,6 +854,8 @@ def _generate_node_script(
env=all_env,
volumes=volumes,
extra_opts=extra_docker_opts,
auto_remove=auto_remove,
restart_policy=restart_policy,
)

return (
Expand Down
4 changes: 4 additions & 0 deletions src/sparkrun/runtimes/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def _run_cluster(
ib_ip_map: dict[str, str] | None = None,
rpc_port: int = _DEFAULT_RPC_PORT,
skip_keys: set[str] | frozenset[str] = frozenset(),
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Orchestrate a multi-node llama.cpp cluster using RPC.
Expand Down Expand Up @@ -454,6 +456,7 @@ def _run_cluster(
image=image, container_name=worker_container_name,
serve_command=rpc_worker_command, label="llama.cpp node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
future = executor.submit(
run_remote_script, host, script,
Expand Down Expand Up @@ -512,6 +515,7 @@ def _run_cluster(
image=image, container_name=head_container,
serve_command=head_command, label="llama.cpp node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
head_result = run_remote_script(
head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs,
Expand Down
4 changes: 4 additions & 0 deletions src/sparkrun/runtimes/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def _run_cluster(
nccl_env: dict[str, str] | None = None,
init_port: int = 25000,
skip_keys: set[str] | frozenset[str] = frozenset(),
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Orchestrate a multi-node SGLang cluster using native distribution.
Expand Down Expand Up @@ -374,6 +376,7 @@ def _run_cluster(
image=image, container_name=head_container,
serve_command=head_command, label="sglang node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
head_result = run_remote_script(
head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs,
Expand Down Expand Up @@ -439,6 +442,7 @@ def _run_cluster(
image=image, container_name=worker_container,
serve_command=worker_command, label="sglang node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
future = executor.submit(
run_remote_script, host, worker_script,
Expand Down
4 changes: 4 additions & 0 deletions src/sparkrun/runtimes/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ def _run_cluster(
detached: bool = True,
nccl_env: dict[str, str] | None = None,
skip_keys: set[str] | frozenset[str] = frozenset(),
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Orchestrate a multi-node TRT-LLM cluster using MPI.
Expand Down Expand Up @@ -484,6 +486,8 @@ def _run_cluster(
volumes=volumes,
nccl_env=nccl_env,
extra_docker_opts=extra_docker_opts or None,
auto_remove=auto_remove,
restart_policy=restart_policy,
)
future = executor.submit(
run_remote_script, host, launch_script,
Expand Down
4 changes: 4 additions & 0 deletions src/sparkrun/runtimes/vllm_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def _run_cluster(
nccl_env: dict[str, str] | None = None,
init_port: int = 25000,
skip_keys: set[str] | frozenset[str] = frozenset(),
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Orchestrate a multi-node vLLM cluster using native distribution.
Expand Down Expand Up @@ -307,6 +309,7 @@ def _run_cluster(
image=image, container_name=head_container,
serve_command=head_command, label="vllm node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
head_result = run_remote_script(
head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs,
Expand Down Expand Up @@ -372,6 +375,7 @@ def _run_cluster(
image=image, container_name=worker_container,
serve_command=worker_command, label="vllm node",
env=all_env, volumes=volumes, nccl_env=nccl_env,
auto_remove=auto_remove, restart_policy=restart_policy,
)
future = executor.submit(
run_remote_script, host, worker_script,
Expand Down
6 changes: 6 additions & 0 deletions src/sparkrun/runtimes/vllm_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _run_cluster(
ray_port: int = 46379,
dashboard_port: int = 8265,
dashboard: bool = False,
auto_remove: bool = True,
restart_policy: str | None = None,
**kwargs,
) -> int:
"""Orchestrate a multi-node Ray cluster for vLLM.
Expand Down Expand Up @@ -254,6 +256,8 @@ def _run_cluster(
env=all_env,
volumes=volumes,
nccl_env=nccl_env,
auto_remove=auto_remove,
restart_policy=restart_policy,
)
head_result = run_remote_script(
head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs,
Expand Down Expand Up @@ -303,6 +307,8 @@ def _run_cluster(
env=all_env,
volumes=volumes,
nccl_env=nccl_env,
auto_remove=auto_remove,
restart_policy=restart_policy,
)
worker_results = run_remote_scripts_parallel(
worker_hosts, worker_script, timeout=120, dry_run=dry_run, **ssh_kwargs,
Expand Down