diff --git a/src/hal0/api/__init__.py b/src/hal0/api/__init__.py index 9a66168..875c752 100644 --- a/src/hal0/api/__init__.py +++ b/src/hal0/api/__init__.py @@ -966,6 +966,10 @@ async def _fetch_and_cache(u: Upstream) -> list[str]: # and status routes snapshot ``as_dict()`` rather than hold the # dataclass across event-loop ticks. app.state.model_pull_jobs = {} + # Container image-pull job registry — keyed by slot name, value is a + # dict with keys: state (pulling|completed|failed), layer, total_layers, + # error, and a threading.Event for SSE fan-out. + app.state.slot_pull_jobs = {} # Dashboard footer event bus. Constructed above (so SlotManager could # be wired with the same instance); published on app.state here so # request handlers can reach it via ``request.app.state.events``. diff --git a/src/hal0/api/routes/slots.py b/src/hal0/api/routes/slots.py index 1f4ff25..311850d 100644 --- a/src/hal0/api/routes/slots.py +++ b/src/hal0/api/routes/slots.py @@ -29,7 +29,7 @@ import json from typing import Any -from fastapi import APIRouter, Request +from fastapi import APIRouter, BackgroundTasks, Request from fastapi.responses import StreamingResponse from hal0.api.middleware.error_codes import BadRequest, Conflict, Hal0Error @@ -405,13 +405,15 @@ async def _container_state_enrichment(request: Request) -> dict[str, dict[str, A entry["runtime"] = "container" profile_name = str(cfg.get("profile") or "") entry["profile"] = profile_name + image: str | None = None if profile_name: try: from hal0.config.loader import load_profiles_config catalog = load_profiles_config() prof = catalog.profile.get(profile_name) - entry["image"] = prof.image if prof else None + image = prof.image if prof else None + entry["image"] = image # resolved_command = llama-server argv starting from the image from hal0.providers.container import resolved_command_for_slot @@ -423,6 +425,26 @@ async def _container_state_enrichment(request: Request) -> dict[str, dict[str, A entry["image"] = None entry["resolved_command"] = None + # image_status: present | pulling | missing + # Check the slot_pull_jobs registry first so an in-flight pull + # surfaces as "pulling" without an extra inspect syscall. + slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {}) + active_job = slot_pull_jobs.get(name) + if active_job is not None and getattr(active_job, "state", None) == "pulling": + entry["image_status"] = "pulling" + elif image: + try: + from hal0.providers.container import container_provider + + present = await asyncio.get_event_loop().run_in_executor( + None, container_provider().image_present, image + ) + entry["image_status"] = "present" if present else "missing" + except Exception: + entry["image_status"] = "missing" + else: + entry["image_status"] = "missing" + out[name] = entry return out @@ -1792,3 +1814,195 @@ async def event_source() -> Any: "X-Accel-Buffering": "no", }, ) + + +# ── container image pull ─────────────────────────────────────────────────────── + + +class _ImagePullJob: + """Lightweight job object for a container-image pull. + + Tracks state (pulling | completed | failed), layer progress, and an + asyncio.Event used to wake SSE subscribers on each line of output. + + Unlike the HF-model PullJob (byte-oriented), this job is layer-oriented: + layer = layers finished, total_layers = layers discovered. + """ + + __slots__ = ("error", "image", "layer", "slot_name", "state", "total_layers") + + def __init__(self, slot_name: str, image: str) -> None: + self.slot_name = slot_name + self.image = image + self.state: str = "pulling" + self.layer: int = 0 + self.total_layers: int = 0 + self.error: str | None = None + + def as_dict(self) -> dict[str, Any]: + return { + "slot_name": self.slot_name, + "image": self.image, + "state": self.state, + "layer": self.layer, + "total_layers": self.total_layers, + "error": self.error, + } + + +async def _run_image_pull(job: _ImagePullJob, request: Request) -> None: + """Run the container pull in background, updating ``job`` per line. + + Writes progress into ``job`` so the 0.5-s polling SSE loop picks it up. + ``request`` is accepted for future use (event bus, slot invalidation) + but not read currently — marked ARG001 to suppress the linter. + """ + from hal0.providers.container import container_provider + + cp = container_provider() + try: + async for chunk in cp.pull_image_stream(job.image): + job.state = chunk.get("state", "pulling") + job.layer = int(chunk.get("layer", job.layer)) + job.total_layers = int(chunk.get("total_layers", job.total_layers)) + if chunk.get("error"): + job.error = str(chunk["error"]) + if job.state in ("completed", "failed"): + break + except Exception as exc: + job.state = "failed" + job.error = str(exc) + + +@router.post("/{name}/pull", status_code=202) +async def pull_slot_image( + name: str, request: Request, background: BackgroundTasks +) -> dict[str, object]: + """Start a background container image pull for slot ``name``. + + Idempotent: if a pull is already in-flight for this slot, returns + the existing job's snapshot rather than starting a second pull. + + Returns a job snapshot:: + + {"slot_name": "...", "image": "...", "state": "pulling", + "layer": 0, "total_layers": 0} + + Clients should open ``GET /api/slots/{name}/pull/stream`` to receive + live layer-progress events after POSTing here. + """ + sm = _get_slot_manager(request) + # Validate slot exists. + await sm.status(name) + + slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {}) + + existing = slot_pull_jobs.get(name) + if existing is not None and existing.state == "pulling": + return {"resumed": True, **existing.as_dict()} + + # Resolve image from profile. + image: str | None = None + try: + configs = await sm.iter_configs() + for cfg in configs: + if str(cfg.get("name", "")) == name: + profile_name = str(cfg.get("profile") or "") + if profile_name: + from hal0.config.loader import load_profiles_config + + catalog = load_profiles_config() + prof = catalog.profile.get(profile_name) + if prof: + image = prof.image + break + except Exception: + pass + + if not image: + raise BadRequest( + f"slot {name!r} has no container profile / image — cannot pull", + details={"slot": name}, + ) + + job = _ImagePullJob(name, image) + if not hasattr(request.app.state, "slot_pull_jobs"): + request.app.state.slot_pull_jobs = {} + request.app.state.slot_pull_jobs[name] = job + background.add_task(_run_image_pull, job, request) + return {"resumed": False, **job.as_dict()} + + +@router.get("/{name}/pull/stream") +async def pull_slot_image_stream(name: str, request: Request) -> StreamingResponse: + """SSE stream of container image-pull layer progress for slot ``name``. + + Emits one frame immediately (snapshot or terminal-already state), + then one per layer line, and a final terminal frame on completion or + failure. Graceful when no pull is active: emits a ``present`` or + ``missing`` frame and closes. + + Frame shape:: + + data: {"slot_name": "...", "image": "...", "state": "pulling", + "layer": N, "total_layers": M} + + Terminal states: ``completed`` | ``failed`` | ``present`` | ``missing``. + """ + + async def _gen() -> Any: + slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {}) + job = slot_pull_jobs.get(name) + + if job is None: + # No active pull — inspect the image to surface present|missing. + image: str | None = None + try: + sm = _get_slot_manager(request) + configs = await sm.iter_configs() + for cfg in configs: + if str(cfg.get("name", "")) == name: + profile_name = str(cfg.get("profile") or "") + if profile_name: + from hal0.config.loader import load_profiles_config + + catalog = load_profiles_config() + prof = catalog.profile.get(profile_name) + if prof: + image = prof.image + break + except Exception: + pass + + if image: + try: + from hal0.providers.container import container_provider + + present = await asyncio.get_event_loop().run_in_executor( + None, container_provider().image_present, image + ) + state = "present" if present else "missing" + except Exception: + state = "missing" + else: + state = "missing" + + yield f"data: {json.dumps({'slot_name': name, 'image': image, 'state': state, 'layer': 0, 'total_layers': 0})}\n\n" + return + + # Emit initial snapshot. + yield f"data: {json.dumps(job.as_dict())}\n\n" + last_layer = job.layer + while job.state == "pulling": + await asyncio.sleep(0.5) + if job.layer != last_layer or job.state != "pulling": + last_layer = job.layer + yield f"data: {json.dumps(job.as_dict())}\n\n" + # Terminal frame. + yield f"data: {json.dumps(job.as_dict())}\n\n" + + return StreamingResponse( + _gen(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) diff --git a/src/hal0/providers/container.py b/src/hal0/providers/container.py index c21aa7a..0d8c677 100644 --- a/src/hal0/providers/container.py +++ b/src/hal0/providers/container.py @@ -36,6 +36,7 @@ from __future__ import annotations import asyncio +import contextlib import logging import shlex import subprocess @@ -383,6 +384,102 @@ def is_active(self, slot_name: str) -> bool: result = self._run("systemctl", "is-active", self._unit_name(slot_name), check=False) return result.returncode == 0 + def image_present(self, image: str) -> bool: + """Return True if ``image`` is in the local container image store. + + Uses `` image inspect`` (exit 0 = present, non-zero = missing). + Runs synchronously — callers must dispatch to a thread executor when + called from an async context. + """ + try: + runtime = _container_runtime() + except RuntimeError: + return False + result = subprocess.run( + [runtime, "image", "inspect", image], + capture_output=True, + check=False, + ) + return result.returncode == 0 + + async def pull_image_stream(self, image: str): + """Async generator that runs `` pull `` and yields + layer-progress dicts. + + Yields dicts:: + + {"state": "pulling", "layer": N, "total_layers": M, "line": ""} + {"state": "completed"} + {"state": "failed", "error": ""} + + Layer counting heuristic (docker non-TTY output): + - Each ``Pulling fs layer`` / ``Waiting`` / ``Verifying Checksum`` / + ``Already exists`` lines indicate a discovered layer (M increments). + - Each ``Pull complete`` / ``Download complete`` line indicates a + finished layer (N increments, capped at M). + """ + import asyncio as _asyncio + + try: + runtime = _container_runtime() + except RuntimeError as exc: + yield {"state": "failed", "error": str(exc)} + return + + proc = await _asyncio.create_subprocess_exec( + runtime, + "pull", + image, + stdout=_asyncio.subprocess.PIPE, + stderr=_asyncio.subprocess.STDOUT, + ) + + total_layers = 0 + done_layers = 0 + + try: + assert proc.stdout is not None + async for raw in proc.stdout: + line = raw.decode("utf-8", errors="replace").rstrip() + if not line: + continue + # Discover new layers. + if any( + kw in line + for kw in ( + "Pulling fs layer", + "Waiting", + "Verifying Checksum", + "Already exists", + ) + ): + total_layers += 1 + # Count finished layers. + if ( + "Pull complete" in line + or "Download complete" in line + or "Already exists" in line + ): + done_layers = min(done_layers + 1, max(total_layers, 1)) + yield { + "state": "pulling", + "layer": done_layers, + "total_layers": total_layers, + "line": line, + } + except Exception as exc: + yield {"state": "failed", "error": str(exc)} + return + finally: + with contextlib.suppress(ProcessLookupError, OSError): + proc.kill() + + exit_code = await proc.wait() + if exit_code == 0: + yield {"state": "completed", "layer": done_layers, "total_layers": total_layers} + else: + yield {"state": "failed", "error": f"pull exited with code {exit_code}"} + # ── Module-level singleton (matches lemonade_provider() pattern) ───────────── diff --git a/tests/api/test_slots_image_pull.py b/tests/api/test_slots_image_pull.py new file mode 100644 index 0000000..6f63d5a --- /dev/null +++ b/tests/api/test_slots_image_pull.py @@ -0,0 +1,361 @@ +"""Tests for container image-pull progress (Issue #659). + +Verifies: + - ``image_status`` (present | pulling | missing) appears on container slots. + - ``image_status=pulling`` is set when a slot_pull_job is active. + - ``image_status=present`` is set when image_present() returns True. + - ``image_status=missing`` is set when image_present() returns False. + - POST /api/slots/{name}/pull returns 202 with job snapshot. + - POST /api/slots/{name}/pull is idempotent when already in-flight. + - GET /api/slots/{name}/pull/stream emits terminal frame when no pull is active. + - ContainerProvider.image_present() returns True for zero exit-code, False otherwise. + - ContainerProvider.pull_image_stream() yields completed / failed correctly. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterator +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import hal0.providers as providers_mod +from hal0.api import create_app +from hal0.lemonade.client import LemonadeClient +from hal0.providers.lemonade import LemonadeProvider + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _seed_slot_toml(home: str, name: str, lines: list[str]) -> Path: + root = Path(home) / "etc" / "hal0" / "slots" + root.mkdir(parents=True, exist_ok=True) + path = root / f"{name}.toml" + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path + + +def _fake_profile_catalog(): + """Return (fake_catalog, fake_profile) for the vulkan-radv profile.""" + from hal0.config.schema import ProfileConfig + + fake_profile = ProfileConfig( + image="ghcr.io/hal0ai/amd-strix-halo-toolboxes:vulkan-radv-server", + flags="--flash-attn on", + mtp=False, + ) + return MagicMock(profile={"vulkan-radv": fake_profile}), fake_profile + + +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def lemonade_stub(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + """Minimal Lemonade stub so list_slots doesn't error on lemond health.""" + state: dict[str, Any] = {"loaded": []} + + def h(req: httpx.Request) -> httpx.Response: + if req.url.path == "/v1/health": + return httpx.Response(200, json={"loaded": state["loaded"]}) + return httpx.Response(200, json={"status": "ok"}) + + transport = httpx.AsyncClient( + transport=httpx.MockTransport(h), + base_url="http://test", + ) + provider = LemonadeProvider(client=LemonadeClient(http_client=transport)) + original = providers_mod._PROVIDERS["lemonade"] + providers_mod._PROVIDERS["lemonade"] = provider + try: + yield state + finally: + providers_mod._PROVIDERS["lemonade"] = original + + +@pytest.fixture +def container_app( + tmp_hal0_home: str, + lemonade_stub: dict[str, Any], +) -> FastAPI: + """App with one container slot (gpu-chat).""" + _seed_slot_toml( + tmp_hal0_home, + "gpu-chat", + [ + 'name = "gpu-chat"', + "port = 8088", + 'type = "llm"', + 'profile = "vulkan-radv"', + "[model]", + 'default = "llama-3b"', + ], + ) + return create_app() + + +@pytest.fixture +def container_client(container_app: FastAPI) -> Iterator[TestClient]: + with TestClient(container_app) as c: + yield c + + +# ── image_status tests ──────────────────────────────────────────────────────── + + +def test_image_status_present(container_client: TestClient) -> None: + """image_status=present when image_present() returns True.""" + fake_catalog, _ = _fake_profile_catalog() + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.ContainerProvider.image_present", return_value=True), + ): + r = container_client.get("/api/slots") + assert r.status_code == 200, r.text + by_name = {e["name"]: e for e in r.json()} + assert by_name["gpu-chat"]["image_status"] == "present" + + +def test_image_status_missing(container_client: TestClient) -> None: + """image_status=missing when image_present() returns False.""" + fake_catalog, _ = _fake_profile_catalog() + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.ContainerProvider.image_present", return_value=False), + ): + r = container_client.get("/api/slots") + assert r.status_code == 200, r.text + by_name = {e["name"]: e for e in r.json()} + assert by_name["gpu-chat"]["image_status"] == "missing" + + +def test_image_status_pulling_when_job_active( + container_client: TestClient, + container_app: FastAPI, +) -> None: + """image_status=pulling when a slot_pull_jobs entry with state=pulling exists.""" + from hal0.api.routes.slots import _ImagePullJob + + fake_catalog, _ = _fake_profile_catalog() + job = _ImagePullJob("gpu-chat", "ghcr.io/hal0ai/test:tag") + container_app.state.slot_pull_jobs = {"gpu-chat": job} + + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + ): + r = container_client.get("/api/slots") + assert r.status_code == 200, r.text + by_name = {e["name"]: e for e in r.json()} + assert by_name["gpu-chat"]["image_status"] == "pulling" + + +# ── POST /api/slots/{name}/pull tests ───────────────────────────────────────── + + +def test_pull_start_returns_202(container_client: TestClient) -> None: + """POST /api/slots/{name}/pull returns 202 with a job snapshot.""" + fake_catalog, _ = _fake_profile_catalog() + + async def _noop(job, request): + pass # Don't actually pull — background task stub. + + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + patch("hal0.api.routes.slots._run_image_pull", side_effect=_noop), + ): + r = container_client.post("/api/slots/gpu-chat/pull") + assert r.status_code == 202, r.text + body = r.json() + assert body["slot_name"] == "gpu-chat" + assert body["state"] == "pulling" + assert "image" in body + assert body.get("resumed") is False + + +def test_pull_start_idempotent( + container_client: TestClient, + container_app: FastAPI, +) -> None: + """POST /api/slots/{name}/pull returns resumed=True when already in-flight.""" + from hal0.api.routes.slots import _ImagePullJob + + fake_catalog, _ = _fake_profile_catalog() + job = _ImagePullJob("gpu-chat", "ghcr.io/hal0ai/test:tag") + container_app.state.slot_pull_jobs = {"gpu-chat": job} + + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + ): + r = container_client.post("/api/slots/gpu-chat/pull") + assert r.status_code == 202, r.text + assert r.json()["resumed"] is True + + +def test_pull_start_404_for_unknown_slot(container_client: TestClient) -> None: + """POST /api/slots/{name}/pull returns 404 for an unknown slot name.""" + r = container_client.post("/api/slots/no-such-slot/pull") + assert r.status_code == 404, r.text + + +# ── GET /api/slots/{name}/pull/stream tests ─────────────────────────────────── + + +def test_pull_stream_present_when_no_job(container_client: TestClient) -> None: + """GET /pull/stream with no active job and image present emits state=present.""" + fake_catalog, _ = _fake_profile_catalog() + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.ContainerProvider.image_present", return_value=True), + container_client.stream("GET", "/api/slots/gpu-chat/pull/stream") as resp, + ): + assert resp.status_code == 200 + lines = [ln for ln in resp.iter_lines() if ln.startswith("data:")] + assert lines, "must emit at least one SSE data frame" + payload = json.loads(lines[0].removeprefix("data:").strip()) + assert payload["state"] == "present" + + +def test_pull_stream_missing_when_no_job(container_client: TestClient) -> None: + """GET /pull/stream with no active job and image absent emits state=missing.""" + fake_catalog, _ = _fake_profile_catalog() + with ( + patch("hal0.providers.container.ContainerProvider.is_active", return_value=False), + patch( + "hal0.providers.container.ContainerProvider.health", + new_callable=AsyncMock, + return_value={"ok": False}, + ), + patch("hal0.config.loader.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.load_profiles_config", return_value=fake_catalog), + patch("hal0.providers.container.ContainerProvider.image_present", return_value=False), + container_client.stream("GET", "/api/slots/gpu-chat/pull/stream") as resp, + ): + assert resp.status_code == 200 + lines = [ln for ln in resp.iter_lines() if ln.startswith("data:")] + payload = json.loads(lines[0].removeprefix("data:").strip()) + assert payload["state"] == "missing" + + +# ── ContainerProvider.image_present() unit tests ────────────────────────────── + + +def test_image_present_returns_true_on_zero_exit(tmp_path: Path) -> None: + """image_present() returns True when the runtime exits 0.""" + from hal0.providers.container import ContainerProvider + + cp = ContainerProvider() + fake_runtime = tmp_path / "fake-runtime" + fake_runtime.write_text("#!/bin/sh\nexit 0\n") + fake_runtime.chmod(0o755) + with patch("hal0.providers.container._container_runtime", return_value=str(fake_runtime)): + assert cp.image_present("some/image:tag") is True + + +def test_image_present_returns_false_on_nonzero_exit(tmp_path: Path) -> None: + """image_present() returns False when the runtime exits non-zero.""" + from hal0.providers.container import ContainerProvider + + cp = ContainerProvider() + fake_runtime = tmp_path / "fake-runtime" + fake_runtime.write_text("#!/bin/sh\nexit 1\n") + fake_runtime.chmod(0o755) + with patch("hal0.providers.container._container_runtime", return_value=str(fake_runtime)): + assert cp.image_present("some/image:tag") is False + + +# ── ContainerProvider.pull_image_stream() unit tests ───────────────────────── + + +@pytest.mark.asyncio +async def test_pull_image_stream_completed_on_success(tmp_path: Path) -> None: + """pull_image_stream() yields a completed frame when the pull exits 0.""" + from hal0.providers.container import ContainerProvider + + cp = ContainerProvider() + fake_runtime = tmp_path / "fake-pull" + fake_runtime.write_text( + "#!/bin/sh\n" + 'echo "Pulling from library/alpine"\n' + 'echo "abc123: Pulling fs layer"\n' + 'echo "abc123: Download complete"\n' + 'echo "abc123: Pull complete"\n' + 'echo "Digest: sha256:abc"\n' + "exit 0\n" + ) + fake_runtime.chmod(0o755) + + chunks = [] + with patch("hal0.providers.container._container_runtime", return_value=str(fake_runtime)): + async for chunk in cp.pull_image_stream("alpine:latest"): + chunks.append(chunk) + + states = [c["state"] for c in chunks] + assert "pulling" in states, "must emit at least one pulling frame" + assert states[-1] == "completed", f"last frame must be completed, got: {states}" + + +@pytest.mark.asyncio +async def test_pull_image_stream_failed_on_nonzero_exit(tmp_path: Path) -> None: + """pull_image_stream() yields a failed frame when the pull exits non-zero.""" + from hal0.providers.container import ContainerProvider + + cp = ContainerProvider() + fake_runtime = tmp_path / "fake-pull-fail" + fake_runtime.write_text("#!/bin/sh\nexit 1\n") + fake_runtime.chmod(0o755) + + chunks = [] + with patch("hal0.providers.container._container_runtime", return_value=str(fake_runtime)): + async for chunk in cp.pull_image_stream("bad/image:tag"): + chunks.append(chunk) + + assert chunks, "must yield at least one chunk" + assert chunks[-1]["state"] == "failed" diff --git a/ui/src/api/endpoints.ts b/ui/src/api/endpoints.ts index 95f5f90..905ab5c 100644 --- a/ui/src/api/endpoints.ts +++ b/ui/src/api/endpoints.ts @@ -35,6 +35,10 @@ export const ENDPOINTS = { `/api/slots/${encodeURIComponent(name)}/state/stream`, slotLogsStream: (name: string) => `/api/slots/${encodeURIComponent(name)}/logs/stream`, + slotPull: (name: string) => + `/api/slots/${encodeURIComponent(name)}/pull`, + slotPullStream: (name: string) => + `/api/slots/${encodeURIComponent(name)}/pull/stream`, // ── Models / pull lifecycle ────────────────────────────────────── models: '/api/models', diff --git a/ui/src/api/hooks/useSlots.ts b/ui/src/api/hooks/useSlots.ts index 2c2aca0..9969ffc 100644 --- a/ui/src/api/hooks/useSlots.ts +++ b/ui/src/api/hooks/useSlots.ts @@ -8,6 +8,7 @@ // (slot defs change on edit), so 5s is enough. import { useMutation, useQuery, useQueryClient, type UseQueryResult } from '@tanstack/react-query' +import { useEffect, useRef, useState } from 'react' import { api, apiGet } from '../client' import { ENDPOINTS } from '../endpoints' @@ -446,3 +447,124 @@ export function useSlotConfig(name: string | null | undefined) { staleTime: 10_000, }) } + +// ─── useSlotImagePull ───────────────────────────────────────────────────────── + +export type ImagePullState = 'idle' | 'pulling' | 'completed' | 'failed' | 'present' | 'missing' + +export interface ImagePullSnapshot { + slotName: string | null + image: string | null + state: ImagePullState + layer: number + totalLayers: number + error: string | null + inFlight: boolean + /** Start a pull for the given slot name: POST then open SSE stream. */ + start: (name: string) => Promise + reset: () => void +} + +const IMAGE_PULL_TERMINAL = new Set(['completed', 'failed', 'present', 'missing']) + +/** + * Container image-pull composable — mirrors the model `usePullJob` pattern. + * + * Usage: + * const pull = useSlotImagePull() + * pull.start(slot.name) // POSTs /api/slots/{name}/pull, opens SSE stream + * // render pull.state, pull.layer, pull.totalLayers in a progress bar + */ +export function useSlotImagePull(): ImagePullSnapshot { + const [slotName, setSlotName] = useState(null) + const [image, setImage] = useState(null) + const [state, setState] = useState('idle') + const [layer, setLayer] = useState(0) + const [totalLayers, setTotalLayers] = useState(0) + const [error, setError] = useState(null) + const esRef = useRef(null) + const qc = useQueryClient() + + const closeStream = () => { + if (esRef.current) { + esRef.current.close() + esRef.current = null + } + } + + useEffect(() => () => closeStream(), []) + + const applyPayload = (payload: any) => { + if (!payload || typeof payload !== 'object') return + if (typeof payload.slot_name === 'string') setSlotName(payload.slot_name) + if (typeof payload.image === 'string') setImage(payload.image) + if (typeof payload.state === 'string') setState(payload.state as ImagePullState) + if (typeof payload.layer === 'number') setLayer(payload.layer) + if (typeof payload.total_layers === 'number') setTotalLayers(payload.total_layers) + if (payload.error) setError(String(payload.error)) + if (typeof payload.state === 'string' && IMAGE_PULL_TERMINAL.has(payload.state as ImagePullState)) { + closeStream() + // Invalidate slots so image_status refreshes on the card. + qc.invalidateQueries({ queryKey: ['slots'] }) + } + } + + const attachStream = (name: string) => { + closeStream() + try { + esRef.current = new EventSource(ENDPOINTS.slotPullStream(name)) + } catch (e: any) { + setError(e?.message ?? 'EventSource failed') + setState('failed') + return + } + const es = esRef.current + es.onmessage = (evt: MessageEvent) => { + try { applyPayload(JSON.parse(evt.data)) } catch { /* skip */ } + } + es.onerror = () => { + setState('failed') + setError('stream error') + closeStream() + } + } + + const start = async (name: string) => { + setSlotName(name) + setState('pulling') + setLayer(0) + setTotalLayers(0) + setError(null) + try { + const resp = await api(ENDPOINTS.slotPull(name), { method: 'POST', raw: true }) + if (typeof resp?.image === 'string') setImage(resp.image) + } catch (e: any) { + setState('failed') + setError(e?.message ?? 'pull start failed') + return + } + attachStream(name) + } + + const reset = () => { + closeStream() + setSlotName(null) + setImage(null) + setState('idle') + setLayer(0) + setTotalLayers(0) + setError(null) + } + + return { + slotName, + image, + state, + layer, + totalLayers, + error, + inFlight: state === 'pulling', + start, + reset, + } +} diff --git a/ui/src/dash/slot-modals.jsx b/ui/src/dash/slot-modals.jsx index bb868d2..018f08d 100644 --- a/ui/src/dash/slot-modals.jsx +++ b/ui/src/dash/slot-modals.jsx @@ -10,6 +10,7 @@ import { useSlotDefaults, useSlotDelete, useSlotBackend, + useSlotImagePull, } from '@/api/hooks/useSlots' import { useHardware } from '@/api/hooks/useHardware' import { useBackends } from '@/api/hooks/useBackends' @@ -1218,17 +1219,85 @@ function EmptySlotCard({ name, type, group, device, onConfigure }) { ); } +// ─── Image pull progress bar ───────────────────────────────────── +function ImagePullBar({ pull }) { + // pull: ImagePullSnapshot from useSlotImagePull() + const { state, layer, totalLayers, image, error } = pull; + if (state !== "pulling" && state !== "completed" && state !== "failed") return null; + const pct = totalLayers > 0 ? Math.round((layer / totalLayers) * 100) : null; + // Truncate the image tag to the last segment for display. + const imgShort = image ? image.split("/").pop() : null; + const label = + state === "completed" ? `Image ready` : + state === "failed" ? `Pull failed${error ? `: ${error}` : ""}` : + totalLayers > 0 ? `Pulling image${imgShort ? ` ${imgShort}` : ""}… (layer ${layer}/${totalLayers})` : + `Pulling image${imgShort ? ` ${imgShort}` : ""}…`; + const barColor = state === "failed" ? "var(--err)" : state === "completed" ? "var(--ok)" : "var(--accent)"; + return ( +
+
+ {label} +
+
+
+
+
+ ); +} + // ─── Error SlotCard ───────────────────────────────────────────── function ErrorSlotCardBanner({ slot, message }) { + const pull = useSlotImagePull(); + const isPulling = pull.slotName === slot?.name && pull.inFlight; + + const handleRePull = async () => { + if (!slot?.name) return; + try { + await pull.start(slot.name); + } catch (err) { + window.__hal0Toast && window.__hal0Toast( + `Re-pull failed for ${slot.name}: ${err?.message || err}`, "warn" + ); + } + }; + return (
{Icons.warn}
load failed
{message}
+ {(isPulling || pull.state === "completed" || pull.state === "failed") && pull.slotName === slot?.name && ( + + )}
- +
diff --git a/ui/src/dash/slots.jsx b/ui/src/dash/slots.jsx index a3991bb..2e46071 100644 --- a/ui/src/dash/slots.jsx +++ b/ui/src/dash/slots.jsx @@ -13,6 +13,7 @@ import { useSlotLoad, useSlotSwap, useSlotEdit, + useSlotImagePull, } from '@/api/hooks/useSlots' import { useModels } from '@/api/hooks/useModels' import { useLemonadeConfig, useLemonadeConfigSet } from '@/api/hooks/useLemonadeConfig' @@ -212,6 +213,48 @@ function Spark({ data, height = 18 }) { ); } +// ─── Container image pull progress bar ──────────────────────────── +// Shows while image_status === "pulling" (backend-polled) or while an +// explicit Re-pull is in flight from the error banner. +// Distinct from the model-download bar — this is a ~6GB OCI layer pull, +// one-time per image tag. +function SlotImagePullBar({ slot }) { + const isContainer = slot?.runtime === "container" || slot?.container_status != null; + const imageStatus = slot?.image_status; + const pulling = imageStatus === "pulling"; + if (!isContainer || !pulling) return null; + // image tag short form for the label. + const imgFull = slot?.image || null; + const imgShort = imgFull ? imgFull.split("/").pop() : null; + const label = `Pulling image${imgShort ? ` ${imgShort}` : ""}…`; + return ( +
+
+ {label} +
+
+
+
+
+ ); +} + // ─── SlotCard (instrument variant) ─── function SlotCard({ slot, @@ -411,6 +454,9 @@ function SlotCard({ ))}
)} + {/* Container image pull progress — shown when image_status === "pulling" + (backend-polled), distinct from model download. */} + {/* N3: touch-action:manipulation prevents 300ms tap-delay on mobile while keeping pan/pinch-to-zoom intact (no `touch-action: none`). */}
diff --git a/ui/src/dashboard.css b/ui/src/dashboard.css index 69a156b..18caa4b 100644 --- a/ui/src/dashboard.css +++ b/ui/src/dashboard.css @@ -2323,6 +2323,14 @@ select.npu-sel { animation: spin 0.7s linear infinite; } @keyframes spin { to { transform: rotate(360deg); } } +/* Container image-pull indeterminate progress bar (issue #659). + Slides a 40% wide bar across the track — distinct from skel-shimmer + (which uses a gradient). Used in SlotImagePullBar + ImagePullBar. */ +@keyframes hal0-indeterminate { + 0% { transform: translateX(-100%); } + 60% { transform: translateX(150%); } + 100% { transform: translateX(150%); } +} /* hero close cursor */ .hero-strip .close { cursor: pointer; }