From f158c77bc3aa76c9b8cc0364ce3545612c22ecad Mon Sep 17 00:00:00 2001 From: Matt Date: Sat, 14 Mar 2026 15:36:17 -0400 Subject: [PATCH] feat: add --restart and --no-rm options Add Docker restart policy support (--restart) and the ability to keep containers after exit (--no-rm) to the `run` command. When a restart policy is set, --rm is automatically disabled since Docker does not allow both flags simultaneously. --- src/sparkrun/cli/_run.py | 12 +++++++++++- src/sparkrun/core/launcher.py | 5 +++++ src/sparkrun/orchestration/docker.py | 17 ++++++++++++++++- src/sparkrun/orchestration/scripts.py | 12 ++++++++++++ src/sparkrun/runtimes/base.py | 14 ++++++++++++++ src/sparkrun/runtimes/llama_cpp.py | 4 ++++ src/sparkrun/runtimes/sglang.py | 4 ++++ src/sparkrun/runtimes/trtllm.py | 4 ++++ src/sparkrun/runtimes/vllm_distributed.py | 4 ++++ src/sparkrun/runtimes/vllm_ray.py | 6 ++++++ 10 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/sparkrun/cli/_run.py b/src/sparkrun/cli/_run.py index a95fc08..37eb007 100644 --- a/src/sparkrun/cli/_run.py +++ b/src/sparkrun/cli/_run.py @@ -43,6 +43,9 @@ @click.option("--foreground", is_flag=True, help="Run in foreground (don't detach)") @click.option("--no-follow", is_flag=True, help="Don't follow container logs after launch") @click.option("--no-sync-tuning", is_flag=True, help="Skip syncing tuning configs from registries") +@click.option("--no-rm", is_flag=True, help="Don't auto-remove containers on exit (keeps containers after stop)") +@click.option("--restart", "restart_policy", default=None, + help="Docker restart policy (no, always, unless-stopped, on-failure[:N])") @click.option("--transfer-mode", default=None, type=click.Choice(["auto", "local", "push", "delegated"], case_sensitive=False), help="Resource transfer mode (overrides cluster setting)") @@ -52,7 +55,7 @@ def run( ctx, recipe_name, hosts, hosts_file, cluster_name, solo, port, tensor_parallel, pipeline_parallel, gpu_mem, served_model_name, max_model_len, image, cache_dir, ray_port, init_port, dashboard, dashboard_port, - dry_run, foreground, no_follow, no_sync_tuning, transfer_mode, + dry_run, foreground, no_follow, no_sync_tuning, no_rm, restart_policy, transfer_mode, options, extra_args, config_path=None, setup=True, ): """Run an inference recipe. @@ -222,6 +225,11 @@ def run( click.echo(" Workers: %s" % ", ".join(host_list[1:])) click.echo() + # Resolve container lifecycle options + auto_remove = not no_rm + if restart_policy: + auto_remove = False + # Launch via shared pipeline result = launch_inference( recipe=recipe, @@ -242,6 +250,8 @@ def run( dashboard_port=dashboard_port, dashboard=dashboard, init_port=init_port, + auto_remove=auto_remove, + restart_policy=restart_policy, ) click.echo("Cluster: %s" % result.cluster_id) diff --git a/src/sparkrun/core/launcher.py b/src/sparkrun/core/launcher.py index 9509454..45070b2 100644 --- a/src/sparkrun/core/launcher.py +++ b/src/sparkrun/core/launcher.py @@ -65,6 +65,9 @@ def launch_inference( dashboard_port: int | None = None, dashboard: bool = False, init_port: int | None = None, + # Container lifecycle options + auto_remove: bool = True, + restart_policy: str | None = None, ) -> LaunchResult: """Launch an inference workload. @@ -285,6 +288,8 @@ def launch_inference( nccl_env=nccl_env, ib_ip_map=ib_ip_map, skip_keys=skip_keys, + auto_remove=auto_remove, + restart_policy=restart_policy, **run_kwargs, ) diff --git a/src/sparkrun/orchestration/docker.py b/src/sparkrun/orchestration/docker.py index e1d14f0..d5630d0 100644 --- a/src/sparkrun/orchestration/docker.py +++ b/src/sparkrun/orchestration/docker.py @@ -15,7 +15,6 @@ _DEFAULT_DOCKER_OPTS = [ "--privileged", "--gpus all", - "--rm", "--ipc=host", "--shm-size=10.24gb", "--network host", @@ -30,6 +29,8 @@ def docker_run_cmd( env: dict[str, str] | None = None, volumes: dict[str, str] | None = None, extra_opts: list[str] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> str: """Generate a ``docker run`` command string. @@ -41,10 +42,18 @@ def docker_run_cmd( env: Environment variables to set (``-e KEY=VALUE``). volumes: Volume mounts (``-v host:container``). extra_opts: Additional docker run options. + auto_remove: Add ``--rm`` flag (default True). Forced to False + when *restart_policy* is set (Docker does not allow both). + restart_policy: Docker restart policy (e.g. ``always``, + ``unless-stopped``, ``on-failure:3``). Returns: Complete ``docker run`` command string. """ + # Docker does not allow --rm with --restart + if restart_policy: + auto_remove = False + parts = ["docker", "run"] if detach: @@ -52,6 +61,12 @@ def docker_run_cmd( parts.extend(_DEFAULT_DOCKER_OPTS) + if auto_remove: + parts.append("--rm") + + if restart_policy: + parts.extend(["--restart", restart_policy]) + if container_name: parts.extend(["--name", container_name]) diff --git a/src/sparkrun/orchestration/scripts.py b/src/sparkrun/orchestration/scripts.py index f8fddb7..ec1451b 100644 --- a/src/sparkrun/orchestration/scripts.py +++ b/src/sparkrun/orchestration/scripts.py @@ -38,6 +38,8 @@ def generate_container_launch_script( nccl_env: dict[str, str] | None = None, detach: bool = True, extra_docker_opts: list[str] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> str: """Generate a script that launches a Docker container. @@ -68,6 +70,8 @@ def generate_container_launch_script( env=all_env, volumes=volumes, extra_opts=extra_docker_opts, + auto_remove=auto_remove, + restart_policy=restart_policy, ) template = read_script("container_launch.sh") @@ -88,6 +92,8 @@ def generate_ray_head_script( env: dict[str, str] | None = None, volumes: dict[str, str] | None = None, nccl_env: dict[str, str] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> str: """Generate a script that starts a Ray head node in a container. @@ -135,6 +141,8 @@ def generate_ray_head_script( detach=True, env=all_env, volumes=volumes, + auto_remove=auto_remove, + restart_policy=restart_policy, ) template = read_script("ray_head.sh") @@ -152,6 +160,8 @@ def generate_ray_worker_script( env: dict[str, str] | None = None, volumes: dict[str, str] | None = None, nccl_env: dict[str, str] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> str: """Generate a script that starts a Ray worker node. @@ -185,6 +195,8 @@ def generate_ray_worker_script( detach=True, env=all_env, volumes=volumes, + auto_remove=auto_remove, + restart_policy=restart_policy, ) template = read_script("ray_worker.sh") diff --git a/src/sparkrun/runtimes/base.py b/src/sparkrun/runtimes/base.py index 3485551..e6fadac 100644 --- a/src/sparkrun/runtimes/base.py +++ b/src/sparkrun/runtimes/base.py @@ -521,6 +521,8 @@ def run( nccl_env: dict[str, str] | None = None, ib_ip_map: dict[str, str] | None = None, skip_keys: set[str] | frozenset[str] = frozenset(), + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Launch a workload -- delegates to solo or cluster implementation. @@ -569,6 +571,8 @@ def run( nccl_env=nccl_env, recipe=recipe, overrides=overrides, + auto_remove=auto_remove, + restart_policy=restart_policy, ) return self._run_cluster( hosts=hosts, @@ -585,6 +589,8 @@ def run( nccl_env=nccl_env, ib_ip_map=ib_ip_map, skip_keys=skip_keys, + auto_remove=auto_remove, + restart_policy=restart_policy, **kwargs, ) @@ -670,6 +676,8 @@ def _run_solo( nccl_env: dict[str, str] | None = None, recipe: Recipe | None = None, overrides: dict[str, Any] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> int: """Launch a single-node inference workload. @@ -728,6 +736,8 @@ def _run_solo( volumes=volumes, nccl_env=nccl_env, extra_docker_opts=self.get_extra_docker_opts() or None, + auto_remove=auto_remove, + restart_policy=restart_policy, ) result = run_script_on_host( host, launch_script, ssh_kwargs=ssh_kwargs, timeout=120, dry_run=dry_run, @@ -809,6 +819,8 @@ def _generate_node_script( volumes: dict[str, str] | None = None, nccl_env: dict[str, str] | None = None, extra_docker_opts: list[str] | None = None, + auto_remove: bool = True, + restart_policy: str | None = None, ) -> str: """Generate a script that launches a container with a direct entrypoint command. @@ -842,6 +854,8 @@ def _generate_node_script( env=all_env, volumes=volumes, extra_opts=extra_docker_opts, + auto_remove=auto_remove, + restart_policy=restart_policy, ) return ( diff --git a/src/sparkrun/runtimes/llama_cpp.py b/src/sparkrun/runtimes/llama_cpp.py index 754d713..8847c02 100644 --- a/src/sparkrun/runtimes/llama_cpp.py +++ b/src/sparkrun/runtimes/llama_cpp.py @@ -353,6 +353,8 @@ def _run_cluster( ib_ip_map: dict[str, str] | None = None, rpc_port: int = _DEFAULT_RPC_PORT, skip_keys: set[str] | frozenset[str] = frozenset(), + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Orchestrate a multi-node llama.cpp cluster using RPC. @@ -454,6 +456,7 @@ def _run_cluster( image=image, container_name=worker_container_name, serve_command=rpc_worker_command, label="llama.cpp node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) future = executor.submit( run_remote_script, host, script, @@ -512,6 +515,7 @@ def _run_cluster( image=image, container_name=head_container, serve_command=head_command, label="llama.cpp node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) head_result = run_remote_script( head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs, diff --git a/src/sparkrun/runtimes/sglang.py b/src/sparkrun/runtimes/sglang.py index 9a0204a..bf7e0e2 100644 --- a/src/sparkrun/runtimes/sglang.py +++ b/src/sparkrun/runtimes/sglang.py @@ -273,6 +273,8 @@ def _run_cluster( nccl_env: dict[str, str] | None = None, init_port: int = 25000, skip_keys: set[str] | frozenset[str] = frozenset(), + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Orchestrate a multi-node SGLang cluster using native distribution. @@ -374,6 +376,7 @@ def _run_cluster( image=image, container_name=head_container, serve_command=head_command, label="sglang node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) head_result = run_remote_script( head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs, @@ -439,6 +442,7 @@ def _run_cluster( image=image, container_name=worker_container, serve_command=worker_command, label="sglang node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) future = executor.submit( run_remote_script, host, worker_script, diff --git a/src/sparkrun/runtimes/trtllm.py b/src/sparkrun/runtimes/trtllm.py index 59ba76f..0c02946 100644 --- a/src/sparkrun/runtimes/trtllm.py +++ b/src/sparkrun/runtimes/trtllm.py @@ -388,6 +388,8 @@ def _run_cluster( detached: bool = True, nccl_env: dict[str, str] | None = None, skip_keys: set[str] | frozenset[str] = frozenset(), + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Orchestrate a multi-node TRT-LLM cluster using MPI. @@ -484,6 +486,8 @@ def _run_cluster( volumes=volumes, nccl_env=nccl_env, extra_docker_opts=extra_docker_opts or None, + auto_remove=auto_remove, + restart_policy=restart_policy, ) future = executor.submit( run_remote_script, host, launch_script, diff --git a/src/sparkrun/runtimes/vllm_distributed.py b/src/sparkrun/runtimes/vllm_distributed.py index 093b40c..d0af5a4 100644 --- a/src/sparkrun/runtimes/vllm_distributed.py +++ b/src/sparkrun/runtimes/vllm_distributed.py @@ -206,6 +206,8 @@ def _run_cluster( nccl_env: dict[str, str] | None = None, init_port: int = 25000, skip_keys: set[str] | frozenset[str] = frozenset(), + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Orchestrate a multi-node vLLM cluster using native distribution. @@ -307,6 +309,7 @@ def _run_cluster( image=image, container_name=head_container, serve_command=head_command, label="vllm node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) head_result = run_remote_script( head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs, @@ -372,6 +375,7 @@ def _run_cluster( image=image, container_name=worker_container, serve_command=worker_command, label="vllm node", env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, restart_policy=restart_policy, ) future = executor.submit( run_remote_script, host, worker_script, diff --git a/src/sparkrun/runtimes/vllm_ray.py b/src/sparkrun/runtimes/vllm_ray.py index ff2ea92..ea2758c 100644 --- a/src/sparkrun/runtimes/vllm_ray.py +++ b/src/sparkrun/runtimes/vllm_ray.py @@ -173,6 +173,8 @@ def _run_cluster( ray_port: int = 46379, dashboard_port: int = 8265, dashboard: bool = False, + auto_remove: bool = True, + restart_policy: str | None = None, **kwargs, ) -> int: """Orchestrate a multi-node Ray cluster for vLLM. @@ -254,6 +256,8 @@ def _run_cluster( env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, + restart_policy=restart_policy, ) head_result = run_remote_script( head_host, head_script, timeout=120, dry_run=dry_run, **ssh_kwargs, @@ -303,6 +307,8 @@ def _run_cluster( env=all_env, volumes=volumes, nccl_env=nccl_env, + auto_remove=auto_remove, + restart_policy=restart_policy, ) worker_results = run_remote_scripts_parallel( worker_hosts, worker_script, timeout=120, dry_run=dry_run, **ssh_kwargs,