Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/hal0/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,10 @@ async def _fetch_and_cache(u: Upstream) -> list[str]:
# and status routes snapshot ``as_dict()`` rather than hold the
# dataclass across event-loop ticks.
app.state.model_pull_jobs = {}
# Container image-pull job registry — keyed by slot name, value is a
# dict with keys: state (pulling|completed|failed), layer, total_layers,
# error, and a threading.Event for SSE fan-out.
app.state.slot_pull_jobs = {}
# Dashboard footer event bus. Constructed above (so SlotManager could
# be wired with the same instance); published on app.state here so
# request handlers can reach it via ``request.app.state.events``.
Expand Down
218 changes: 216 additions & 2 deletions src/hal0/api/routes/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import json
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi.responses import StreamingResponse

from hal0.api.middleware.error_codes import BadRequest, Conflict, Hal0Error
Expand Down Expand Up @@ -405,13 +405,15 @@ async def _container_state_enrichment(request: Request) -> dict[str, dict[str, A
entry["runtime"] = "container"
profile_name = str(cfg.get("profile") or "")
entry["profile"] = profile_name
image: str | None = None
if profile_name:
try:
from hal0.config.loader import load_profiles_config

catalog = load_profiles_config()
prof = catalog.profile.get(profile_name)
entry["image"] = prof.image if prof else None
image = prof.image if prof else None
entry["image"] = image
# resolved_command = llama-server argv starting from the image
from hal0.providers.container import resolved_command_for_slot

Expand All @@ -423,6 +425,26 @@ async def _container_state_enrichment(request: Request) -> dict[str, dict[str, A
entry["image"] = None
entry["resolved_command"] = None

# image_status: present | pulling | missing
# Check the slot_pull_jobs registry first so an in-flight pull
# surfaces as "pulling" without an extra inspect syscall.
slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {})
active_job = slot_pull_jobs.get(name)
if active_job is not None and getattr(active_job, "state", None) == "pulling":
entry["image_status"] = "pulling"
elif image:
try:
from hal0.providers.container import container_provider

present = await asyncio.get_event_loop().run_in_executor(
None, container_provider().image_present, image
)
entry["image_status"] = "present" if present else "missing"
except Exception:
entry["image_status"] = "missing"
else:
entry["image_status"] = "missing"

out[name] = entry
return out

Expand Down Expand Up @@ -1792,3 +1814,195 @@ async def event_source() -> Any:
"X-Accel-Buffering": "no",
},
)


# ── container image pull ───────────────────────────────────────────────────────


class _ImagePullJob:
"""Lightweight job object for a container-image pull.

Tracks state (pulling | completed | failed), layer progress, and an
asyncio.Event used to wake SSE subscribers on each line of output.

Unlike the HF-model PullJob (byte-oriented), this job is layer-oriented:
layer = layers finished, total_layers = layers discovered.
"""

__slots__ = ("error", "image", "layer", "slot_name", "state", "total_layers")

def __init__(self, slot_name: str, image: str) -> None:
self.slot_name = slot_name
self.image = image
self.state: str = "pulling"
self.layer: int = 0
self.total_layers: int = 0
self.error: str | None = None

def as_dict(self) -> dict[str, Any]:
return {
"slot_name": self.slot_name,
"image": self.image,
"state": self.state,
"layer": self.layer,
"total_layers": self.total_layers,
"error": self.error,
}


async def _run_image_pull(job: _ImagePullJob, request: Request) -> None:
"""Run the container pull in background, updating ``job`` per line.

Writes progress into ``job`` so the 0.5-s polling SSE loop picks it up.
``request`` is accepted for future use (event bus, slot invalidation)
but not read currently — marked ARG001 to suppress the linter.
"""
from hal0.providers.container import container_provider

cp = container_provider()
try:
async for chunk in cp.pull_image_stream(job.image):
job.state = chunk.get("state", "pulling")
job.layer = int(chunk.get("layer", job.layer))
job.total_layers = int(chunk.get("total_layers", job.total_layers))
if chunk.get("error"):
job.error = str(chunk["error"])
if job.state in ("completed", "failed"):
break
except Exception as exc:
job.state = "failed"
job.error = str(exc)


@router.post("/{name}/pull", status_code=202)
async def pull_slot_image(
name: str, request: Request, background: BackgroundTasks
) -> dict[str, object]:
"""Start a background container image pull for slot ``name``.

Idempotent: if a pull is already in-flight for this slot, returns
the existing job's snapshot rather than starting a second pull.

Returns a job snapshot::

{"slot_name": "...", "image": "...", "state": "pulling",
"layer": 0, "total_layers": 0}

Clients should open ``GET /api/slots/{name}/pull/stream`` to receive
live layer-progress events after POSTing here.
"""
sm = _get_slot_manager(request)
# Validate slot exists.
await sm.status(name)

slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {})

existing = slot_pull_jobs.get(name)
if existing is not None and existing.state == "pulling":
return {"resumed": True, **existing.as_dict()}

# Resolve image from profile.
image: str | None = None
try:
configs = await sm.iter_configs()
for cfg in configs:
if str(cfg.get("name", "")) == name:
profile_name = str(cfg.get("profile") or "")
if profile_name:
from hal0.config.loader import load_profiles_config

catalog = load_profiles_config()
prof = catalog.profile.get(profile_name)
if prof:
image = prof.image
break
except Exception:
pass

if not image:
raise BadRequest(
f"slot {name!r} has no container profile / image — cannot pull",
details={"slot": name},
)

job = _ImagePullJob(name, image)
if not hasattr(request.app.state, "slot_pull_jobs"):
request.app.state.slot_pull_jobs = {}
request.app.state.slot_pull_jobs[name] = job
background.add_task(_run_image_pull, job, request)
return {"resumed": False, **job.as_dict()}


@router.get("/{name}/pull/stream")
async def pull_slot_image_stream(name: str, request: Request) -> StreamingResponse:
"""SSE stream of container image-pull layer progress for slot ``name``.

Emits one frame immediately (snapshot or terminal-already state),
then one per layer line, and a final terminal frame on completion or
failure. Graceful when no pull is active: emits a ``present`` or
``missing`` frame and closes.

Frame shape::

data: {"slot_name": "...", "image": "...", "state": "pulling",
"layer": N, "total_layers": M}

Terminal states: ``completed`` | ``failed`` | ``present`` | ``missing``.
"""

async def _gen() -> Any:
slot_pull_jobs: dict[str, Any] = getattr(request.app.state, "slot_pull_jobs", {})
job = slot_pull_jobs.get(name)

if job is None:
# No active pull — inspect the image to surface present|missing.
image: str | None = None
try:
sm = _get_slot_manager(request)
configs = await sm.iter_configs()
for cfg in configs:
if str(cfg.get("name", "")) == name:
profile_name = str(cfg.get("profile") or "")
if profile_name:
from hal0.config.loader import load_profiles_config

catalog = load_profiles_config()
prof = catalog.profile.get(profile_name)
if prof:
image = prof.image
break
except Exception:
pass

if image:
try:
from hal0.providers.container import container_provider

present = await asyncio.get_event_loop().run_in_executor(
None, container_provider().image_present, image
)
state = "present" if present else "missing"
except Exception:
state = "missing"
else:
state = "missing"

yield f"data: {json.dumps({'slot_name': name, 'image': image, 'state': state, 'layer': 0, 'total_layers': 0})}\n\n"
return

# Emit initial snapshot.
yield f"data: {json.dumps(job.as_dict())}\n\n"
last_layer = job.layer
while job.state == "pulling":
await asyncio.sleep(0.5)
if job.layer != last_layer or job.state != "pulling":
last_layer = job.layer
yield f"data: {json.dumps(job.as_dict())}\n\n"
# Terminal frame.
yield f"data: {json.dumps(job.as_dict())}\n\n"

return StreamingResponse(
_gen(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
97 changes: 97 additions & 0 deletions src/hal0/providers/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import shlex
import subprocess
Expand Down Expand Up @@ -383,6 +384,102 @@ def is_active(self, slot_name: str) -> bool:
result = self._run("systemctl", "is-active", self._unit_name(slot_name), check=False)
return result.returncode == 0

def image_present(self, image: str) -> bool:
"""Return True if ``image`` is in the local container image store.

Uses ``<runtime> image inspect`` (exit 0 = present, non-zero = missing).
Runs synchronously — callers must dispatch to a thread executor when
called from an async context.
"""
try:
runtime = _container_runtime()
except RuntimeError:
return False
result = subprocess.run(
[runtime, "image", "inspect", image],
capture_output=True,
check=False,
)
return result.returncode == 0

async def pull_image_stream(self, image: str):
"""Async generator that runs ``<runtime> pull <image>`` and yields
layer-progress dicts.

Yields dicts::

{"state": "pulling", "layer": N, "total_layers": M, "line": "<raw line>"}
{"state": "completed"}
{"state": "failed", "error": "<message>"}

Layer counting heuristic (docker non-TTY output):
- Each ``Pulling fs layer`` / ``Waiting`` / ``Verifying Checksum`` /
``Already exists`` lines indicate a discovered layer (M increments).
- Each ``Pull complete`` / ``Download complete`` line indicates a
finished layer (N increments, capped at M).
"""
import asyncio as _asyncio

try:
runtime = _container_runtime()
except RuntimeError as exc:
yield {"state": "failed", "error": str(exc)}
return

proc = await _asyncio.create_subprocess_exec(
runtime,
"pull",
image,
stdout=_asyncio.subprocess.PIPE,
stderr=_asyncio.subprocess.STDOUT,
)

total_layers = 0
done_layers = 0

try:
assert proc.stdout is not None
async for raw in proc.stdout:
line = raw.decode("utf-8", errors="replace").rstrip()
if not line:
continue
# Discover new layers.
if any(
kw in line
for kw in (
"Pulling fs layer",
"Waiting",
"Verifying Checksum",
"Already exists",
)
):
total_layers += 1
# Count finished layers.
if (
"Pull complete" in line
or "Download complete" in line
or "Already exists" in line
):
done_layers = min(done_layers + 1, max(total_layers, 1))
yield {
"state": "pulling",
"layer": done_layers,
"total_layers": total_layers,
"line": line,
}
except Exception as exc:
yield {"state": "failed", "error": str(exc)}
return
finally:
with contextlib.suppress(ProcessLookupError, OSError):
proc.kill()

exit_code = await proc.wait()
if exit_code == 0:
yield {"state": "completed", "layer": done_layers, "total_layers": total_layers}
else:
yield {"state": "failed", "error": f"pull exited with code {exit_code}"}


# ── Module-level singleton (matches lemonade_provider() pattern) ─────────────

Expand Down
Loading
Loading