Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/hal0/providers/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _render_unit(
device_paths: list[str] | None = None,
context_size: int | None = None,
extra_args: str | None = None,
model_alias: str | None = None,
) -> str:
"""Render a complete (non-drop-in) systemd unit for a container slot.

Expand Down Expand Up @@ -188,6 +189,10 @@ def _render_unit(
model_path,
]
)
# Advertise the hal0 registry model id (else llama-server reports the raw
# GGUF basename, which the dispatcher can't match to hal0/* virtual names).
if model_alias:
argv.extend(["--alias", model_alias])
# Slot context window (else llama-server defaults to 4096).
if context_size is not None:
argv.extend(["--ctx-size", str(context_size)])
Expand Down Expand Up @@ -365,6 +370,11 @@ def load_sync(
context_size = model_table.get("context_size") if isinstance(model_table, dict) else None
server_table = slot_cfg.get("server") or {}
extra_args = server_table.get("extra_args") if isinstance(server_table, dict) else None
# Registry model id → llama-server --alias so the container advertises
# the hal0 id (not the raw GGUF basename) for dispatcher matching.
model_alias = model_info.get("_model_key") or (
model_table.get("default") if isinstance(model_table, dict) else None
)

unit_path = self._unit_path(slot_name)
unit_text = _render_unit(
Expand All @@ -375,6 +385,7 @@ def load_sync(
flags_str,
context_size=context_size,
extra_args=extra_args,
model_alias=model_alias,
)

log.info(
Expand Down Expand Up @@ -578,6 +589,8 @@ def resolved_command_for_slot(
]
if effective_model:
argv += ["--model", effective_model]
if default_model:
argv += ["--alias", str(default_model)]
if context_size is not None:
argv += ["--ctx-size", str(context_size)]
argv.extend(flag_tokens)
Expand Down
65 changes: 65 additions & 0 deletions tests/providers/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,26 @@ def test_explicit_device_nodes_emitted_no_bare_dri_dir(self) -> None:
assert "--device=/dev/dri/renderD128" in tokens
assert "--device=/dev/dri" not in tokens

def test_model_alias_in_exec_start(self) -> None:
"""The container must advertise the hal0 registry model id via
--alias, else the dispatcher can't match hal0/* names (llama-server
otherwise advertises the raw GGUF basename)."""
profile = _moe_profile()
flags = resolve_profile_flags(profile)
unit = _render_unit(
"test-slot",
profile.image,
8095,
"/mnt/ai-models/model.gguf",
flags,
runtime_bin=_TEST_RUNTIME,
device_paths=["/dev/kfd", "/dev/dri/renderD128"],
model_alias="qwopus3.6-27b-v2",
)
tokens = shlex.split(self._get_exec_start(unit))
assert "--alias" in tokens
assert tokens[tokens.index("--alias") + 1] == "qwopus3.6-27b-v2"

def test_ctx_size_in_exec_start(self) -> None:
"""The slot's context_size must reach the container as --ctx-size,
else llama-server boots at its 4096 default (severe ctx regression)."""
Expand Down Expand Up @@ -516,6 +536,37 @@ def fake_run(*args: str, check: bool = True) -> MagicMock:
assert "--ctx-size 131072" in unit
assert "--override-kv k=bool:false" in unit

def test_load_sync_advertises_model_id_alias(self, tmp_path: Path) -> None:
"""load_sync must pass the registry model id (model_info._model_key)
as --alias so the dispatcher can route hal0/* names to the container."""
profile = _moe_profile()
provider = ContainerProvider()
unit_file = tmp_path / "test.service"

def fake_run(*args: str, check: bool = True) -> MagicMock:
m = MagicMock()
m.returncode = 0
return m

with (
patch("hal0.providers.container._resolve_profile", return_value=profile),
patch(
"hal0.providers.container.resolve_gpu_device_paths",
return_value=["/dev/kfd", "/dev/dri/renderD128"],
),
patch.object(provider, "_run", side_effect=fake_run),
patch.object(provider, "_unit_path", return_value=unit_file),
):
provider.load_sync(
{"name": "agent", "port": 8101, "profile": "moe-rocmfp4"},
{
"path": "/mnt/ai-models/m.gguf",
"_model_key": "chadrock-35b-ace-saber",
},
)

assert "--alias chadrock-35b-ace-saber" in unit_file.read_text()

def test_resolved_command_includes_ctx_size(self) -> None:
"""The displayed resolved_command must show --ctx-size so it matches
what actually launches."""
Expand All @@ -531,6 +582,20 @@ def test_resolved_command_includes_ctx_size(self) -> None:
assert "--ctx-size" in argv
assert argv[argv.index("--ctx-size") + 1] == "131072"

def test_resolved_command_includes_model_alias(self) -> None:
"""resolved_command shows --alias <model id> so it matches the unit."""
profile = _moe_profile()
cfg = {
"profile": "moe-rocmfp4",
"port": 8095,
"model": {"default": "chadrock-35b-ace-saber", "context_size": 131072},
}
with patch("hal0.providers.container._resolve_profile", return_value=profile):
argv = resolved_command_for_slot(cfg)
assert argv is not None
assert "--alias" in argv
assert argv[argv.index("--alias") + 1] == "chadrock-35b-ace-saber"

def test_unload_sync_calls_stop(self, tmp_path: Path) -> None:
provider = ContainerProvider()
unit_file = tmp_path / "hal0-slot@test-container.service"
Expand Down
Loading