diff --git a/integrations/atomic-chat-b2/ROADMAP.md b/integrations/atomic-chat-b2/ROADMAP.md index 4b082e0..5ba9afd 100644 --- a/integrations/atomic-chat-b2/ROADMAP.md +++ b/integrations/atomic-chat-b2/ROADMAP.md @@ -2,7 +2,7 @@ **Branch**: `AgentMemory/atomic-chat-b2-mlx-dflash-kakeya-04ae` **Parent PR**: #57 (B1 — HF + MPS sidecar, in review) -**Status**: M1-M3 骨架 + 测试 (本 PR); M4-M6 follow-up +**Status**: M1-M4 已落 (PRs #58, M4); M5-M6 后续 ## 动机 @@ -51,12 +51,24 @@ OpenAI 兼容 MLX sidecar,接口与 B1 完全一致 (`/v1/models`, **本 PR 只给骨架 + 纯逻辑单测**: model_registry_mlx, channel parsing, routing mock。真正的 MLX 模型加载 + generate 需要 M4 接入 DFlash。 -### M4 — DFlash 集成 (follow-up PR) +### M4 — DFlash 集成 (✅ 本 PR) 接 `dflash.model_mlx.stream_generate`,把 target LLM 的 KV 替换为 `KakeyaLatticeMLXCache`,draft LLM 保留默认 `RotatingKVCache`(Phase 2 再压缩 draft KV)。 +**本 PR 交付**: +- `cache_injection.py` — 三种注入策略(kwarg / model.make_cache / + module-level make_prompt_cache)+ 特性检测 + `FALLBACK_NATIVE_MLX` + 兜底,适配 dflash API 在 2026 年多次变动的实际情况 +- `engine_mlx.py` — `chat()` / `chat_stream()` 打通,两条路径: + DFlash + Kakeya KV,以及 native MLX + Kakeya KV 兜底 +- `server.py` — `/v1/chat/completions` 正式打开,stream + non-stream + 两模式;`x_kakeya` 响应字段带 `dflash_used` / + `injection_strategy` / `acceptance_length_mean` +- **32 sidecar 单测 全绿**(含 8 条 cache_injection 策略测试 + 4 条 + engine routing 测试,均用 stub 替身模拟 MLX / dflash) + **阻碍**: - `dflash.model_mlx` 的 target / draft KV 接口需要 dflash patch 或我方 wrapper - `draft_sliding_window_size` 与 target `boundary` 的联动需重测 diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/cache_injection.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/cache_injection.py new file mode 100644 index 0000000..cde6b76 --- /dev/null +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/cache_injection.py @@ -0,0 +1,302 @@ +"""Inject KakeyaLatticeMLXCache into DFlash's MLX target-model forward. + +DFlash's MLX driver (``dflash.model_mlx.stream_generate``) manages the +target model's KV cache internally. There is no stable kwarg for +"replace target caches with these per-layer objects" across the 2026 +DFlash releases. To avoid pinning to a single dflash version we offer +**three injection strategies** ordered by intrusiveness, and pick +whichever the installed dflash exposes: + + Strategy A — kwarg passthrough + If ``stream_generate`` signature has a ``target_cache`` or + ``caches`` kwarg, pass our list directly. Clean, zero monkey + patch. Discovered via ``inspect.signature``. + + Strategy B — ``model.make_cache`` monkey-patch + mlx-lm models conventionally expose ``.make_cache()`` returning + the list that ``dflash.stream_generate`` then consumes. We + override that method for the duration of one generate call via + a context manager, then restore. Also clean; requires dflash to + delegate to ``model.make_cache()`` rather than build caches + inline. + + Strategy C — module-level ``make_prompt_cache`` patch + Some dflash revisions call ``mlx_lm.models.cache. + make_prompt_cache(model)`` directly. We wrap that module-level + function; still reversible via context manager. + +The adapter is **feature-detected at engine startup**, cached on the +``_LoadedMLXModel`` instance, and reused per request. If none of the +strategies apply (e.g. dflash API drift), we log a loud warning and +fall back to single-track MLX decode + KakeyaLatticeMLXCache — i.e. +we keep the KV compression benefit but lose the speculative-decode +speedup. + +Testability: + + - Strategy detection / strategy-object construction is pure Python + and tested via mocks in ``tests/test_cache_injection.py``. + - The live-integration smoke (Apple Silicon only) lives behind + ``pytest.importorskip("mlx.core")`` + ``importorskip("dflash")``. +""" +from __future__ import annotations + +import contextlib +import inspect +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Iterator + +log = logging.getLogger("kakeya_sidecar_mlx.cache_injection") + + +class InjectionStrategy(str, Enum): + KWARG = "kwarg" # A + MODEL_MAKE_CACHE = "model_make_cache" # B + MODULE_PROMPT_CACHE = "module_prompt_cache" # C + FALLBACK_NATIVE_MLX = "fallback_native_mlx" + + +@dataclass(frozen=True) +class InjectionDecision: + """Result of feature-detecting the best strategy.""" + + strategy: InjectionStrategy + detail: str = "" + + +# --------------------------------------------------------------------------- +# Strategy selection +# --------------------------------------------------------------------------- + +_KWARG_NAMES = ("target_cache", "caches", "target_caches", "prompt_cache") + + +def detect_injection_strategy( + stream_generate_fn: Callable | None, + model: Any | None = None, +) -> InjectionDecision: + """Pick the best injection strategy for the installed dflash. + + Args: + stream_generate_fn: ``dflash.model_mlx.stream_generate`` or + compatible callable; may be ``None`` when dflash is not + installed. + model: the loaded mlx-lm target model; used to test for + ``make_cache``. + + Returns: + :class:`InjectionDecision` naming the chosen strategy. + """ + if stream_generate_fn is None: + return InjectionDecision( + InjectionStrategy.FALLBACK_NATIVE_MLX, + "dflash not installed; no DFlash speculative decoding path available", + ) + + # Strategy A: kwarg passthrough. + try: + sig = inspect.signature(stream_generate_fn) + for name in _KWARG_NAMES: + if name in sig.parameters: + return InjectionDecision( + InjectionStrategy.KWARG, + f"stream_generate accepts kwarg {name!r}", + ) + except (TypeError, ValueError): + pass + + # Strategy B: model.make_cache monkey-patch. + if model is not None and callable(getattr(model, "make_cache", None)): + return InjectionDecision( + InjectionStrategy.MODEL_MAKE_CACHE, + "model.make_cache is callable; patching it for one generate call", + ) + + # Strategy C: module-level make_prompt_cache patch. + try: + import mlx_lm.models.cache as _mlx_lm_cache # type: ignore # noqa: F401 + return InjectionDecision( + InjectionStrategy.MODULE_PROMPT_CACHE, + "mlx_lm.models.cache.make_prompt_cache present", + ) + except ImportError: + pass + + return InjectionDecision( + InjectionStrategy.FALLBACK_NATIVE_MLX, + "no compatible dflash injection surface found", + ) + + +# --------------------------------------------------------------------------- +# Injector — applies the chosen strategy for one generate call +# --------------------------------------------------------------------------- + + +class KakeyaCacheInjector: + """Build + inject KakeyaLatticeMLXCache into one DFlash call. + + Example: + + injector = KakeyaCacheInjector( + model=target_model, + variant="e8", q_range=38, boundary=0, + strategy=InjectionStrategy.MODEL_MAKE_CACHE, + cache_factory=make_kakeya_caches, + ) + with injector.activate(): + for tok in dflash.stream_generate(target_model, draft, tok, prompt, + **injector.extra_kwargs): + ... + + ``cache_factory`` is injected to make testing trivial — tests pass a + stub factory that returns ``["layer0", "layer1", ...]`` strings and + assert that the injector wires them into the right surface. + """ + + def __init__( + self, + model: Any, + *, + variant: str = "e8", + q_range: int = 38, + boundary: int = 0, + strategy: InjectionStrategy = InjectionStrategy.FALLBACK_NATIVE_MLX, + cache_factory: Callable | None = None, + ) -> None: + self.model = model + self.variant = variant + self.q_range = int(q_range) + self.boundary = int(boundary) + self.strategy = strategy + self._cache_factory = cache_factory or self._default_factory + self._built_caches: list[Any] | None = None + self._extra_kwargs: dict[str, Any] = {} + + @staticmethod + def _default_factory(model, **kw): + from kakeyalattice_mlx import KakeyaLatticeMLXCache # noqa: F401 + from kakeyalattice_mlx.kv_cache import make_kakeya_caches + return make_kakeya_caches(model, **kw) + + def build(self) -> list[Any]: + self._built_caches = self._cache_factory( + self.model, + variant=self.variant, + q_range=self.q_range, + boundary=self.boundary, + ) + return self._built_caches + + @property + def caches(self) -> list[Any] | None: + return self._built_caches + + @property + def extra_kwargs(self) -> dict[str, Any]: + """kwargs to splice into ``stream_generate(...)`` when Strategy A. + + Strategy A adds ``{kwarg_name: caches}``; other strategies keep + this empty because they mutate state instead. + """ + return dict(self._extra_kwargs) + + @contextlib.contextmanager + def activate( + self, + stream_generate_fn: Callable | None = None, + ) -> Iterator[list[Any] | None]: + """Context-managed injection. Yields the built cache list. + + Args: + stream_generate_fn: needed only for Strategy A so we can + discover which kwarg name to use. + """ + caches = self.build() + if self.strategy == InjectionStrategy.FALLBACK_NATIVE_MLX: + # No DFlash path; the engine will use `caches` directly on + # the target model. + yield caches + return + + if self.strategy == InjectionStrategy.KWARG: + kwarg = self._resolve_kwarg(stream_generate_fn) + if kwarg is None: + log.warning( + "KWARG strategy selected but no matching kwarg on " + "stream_generate; downgrading to FALLBACK_NATIVE_MLX" + ) + yield caches + return + self._extra_kwargs = {kwarg: caches} + try: + yield caches + finally: + self._extra_kwargs = {} + return + + if self.strategy == InjectionStrategy.MODEL_MAKE_CACHE: + original = getattr(self.model, "make_cache", None) + + def _patched_make_cache(*_a, **_kw): + return caches + + setattr(self.model, "make_cache", _patched_make_cache) + try: + yield caches + finally: + if original is not None: + setattr(self.model, "make_cache", original) + else: + try: + delattr(self.model, "make_cache") + except AttributeError: + pass + return + + if self.strategy == InjectionStrategy.MODULE_PROMPT_CACHE: + import mlx_lm.models.cache as _c # type: ignore + + original = getattr(_c, "make_prompt_cache", None) + + def _patched_make_prompt_cache(model, *_a, **_kw): + return caches + + setattr(_c, "make_prompt_cache", _patched_make_prompt_cache) + try: + yield caches + finally: + if original is not None: + setattr(_c, "make_prompt_cache", original) + return + + # Unknown strategy — be loud, don't silently fall through. + raise RuntimeError( + f"unknown InjectionStrategy {self.strategy!r} " + "— cache_injection.py is out of sync" + ) + + # ----- helpers ----- + + @staticmethod + def _resolve_kwarg(stream_generate_fn: Callable | None) -> str | None: + if stream_generate_fn is None: + return None + try: + sig = inspect.signature(stream_generate_fn) + for name in _KWARG_NAMES: + if name in sig.parameters: + return name + except (TypeError, ValueError): + return None + return None + + +__all__ = [ + "InjectionStrategy", + "InjectionDecision", + "detect_injection_strategy", + "KakeyaCacheInjector", +] diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/engine_mlx.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/engine_mlx.py index eb25c34..1437a20 100644 --- a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/engine_mlx.py +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/engine_mlx.py @@ -1,23 +1,24 @@ -"""MLX inference engine — *skeleton* for B2. - -This file intentionally stops short of a working generate() loop; that -work is gated on M4 (DFlash integration) and belongs in a separate PR. -What IS finalised here: - -- ``MLXEngineConfig`` dataclass with the same shape as B1's - ``EngineConfig`` so Atomic-Chat's plugin can swap sidecars. -- ``MLXEngine`` with ``_ensure_loaded`` LRU and a clear - ``NotImplementedError`` on ``.chat()`` / ``.chat_stream()`` that - points downstream PRs at what still needs implementing. -- The warmup path (model load only) IS implemented, so - ``kakeya-sidecar-mlx --prewarm `` is enough to validate - that weights load on Apple Silicon. - -Why skeleton-first: the mlx-lm API is a moving target (0.20, 0.21 -reshaped ``generate_step`` signatures). Locking in a dummy chat path -would either (a) pin us to a fragile version or (b) silently skew -from B1's behaviour. Better to gate the real implementation on a -dedicated PR that CI-validates on a Mac runner. +"""MLX inference engine — M4 integrates DFlash speculative decoding. + +Two code paths: + +1. **DFlash path**: when ``cfg.enable_dflash`` and the resolved channel + has ``dflash_available=True``. We load ``dflash.model_mlx.load_draft``, + detect an injection strategy for the target KV cache + (`cache_injection.detect_injection_strategy`), wrap target caches in + ``KakeyaLatticeMLXCache``, and delegate to + ``dflash.model_mlx.stream_generate`` under an ``activate()`` context. + +2. **Native MLX path** (fallback): when DFlash is off, the draft repo + isn't available, or the dflash API doesn't expose any injection + surface. We use ``mlx_lm.generate`` directly with the Kakeya caches + passed on the target model. + +Both paths emit the same ``chat()`` return shape as B1 so the server +layer doesn't need to fork. + +MLX / mlx-lm imports are done lazily inside method bodies so the +module is importable on Linux CI for the pure-logic tests. """ from __future__ import annotations @@ -26,8 +27,14 @@ import time from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Iterator +from typing import Any, Callable, Iterator +from .cache_injection import ( + InjectionDecision, + InjectionStrategy, + KakeyaCacheInjector, + detect_injection_strategy, +) from .model_registry_mlx import ( MLXChannel, MLXDeploymentProfile, @@ -39,14 +46,15 @@ @dataclass class MLXEngineConfig: - device: str = "auto" # "auto" | "mps" | "cpu" - dtype: str = "auto" # "auto" | "bfloat16" | "float16" | "float32" + device: str = "auto" + dtype: str = "auto" max_resident: int = 1 enable_dflash: bool = False trust_remote_code: bool = True hf_cache_dir: str | None = None - - # Runtime prefs set by --enable-dflash + channel.dflash_available. + dflash_block_size: int = 16 + dflash_num_speculative_tokens: int = 16 + dflash_sliding_window_size: int | None = None _runtime: dict[str, Any] = field(default_factory=dict) @@ -62,8 +70,17 @@ def _pick_device(requested: str) -> str: return "cpu" +# --------------------------------------------------------------------------- +# _LoadedMLXModel +# --------------------------------------------------------------------------- + + class _LoadedMLXModel: - """Holds a loaded mlx-lm model + tokenizer + optional DFlash draft.""" + """Holds a loaded mlx-lm target model + tokenizer + optional DFlash draft. + + Also pre-computes the injection decision once per model load so the + hot path doesn't re-inspect dflash's signature per request. + """ def __init__( self, @@ -74,7 +91,7 @@ def __init__( from mlx_lm import load # type: ignore repo = profile.mlx_repo_id or profile.hf_repo_id - log.info("mlx_lm.load(%s) ...", repo) + log.info("mlx_lm.load(%s)", repo) t0 = time.time() self.model, self.tokenizer = load(repo) log.info("loaded target %s in %.1fs", repo, time.time() - t0) @@ -82,32 +99,51 @@ def __init__( self.profile = profile self.channel = channel self.draft_model = None - self.draft_tokenizer = None + self._stream_generate: Callable | None = None + self._injection_decision: InjectionDecision = InjectionDecision( + InjectionStrategy.FALLBACK_NATIVE_MLX, "DFlash not enabled" + ) if cfg.enable_dflash and channel.dflash_available: - self._load_dflash_draft(channel) + self._maybe_load_dflash(channel) - def _load_dflash_draft(self, channel: MLXChannel) -> None: + def _maybe_load_dflash(self, channel: MLXChannel) -> None: repo = channel.dflash_draft_repo if repo is None: return try: - from dflash.model_mlx import load_draft # type: ignore + from dflash.model_mlx import ( # type: ignore + load_draft as _load_draft, + stream_generate as _stream_generate, + ) except ImportError: log.warning( - "dflash not installed; install with `pip install dflash` " - "to enable speculative decoding. Falling back to " - "single-track decode." + "dflash not importable; `pip install dflash` to enable " + "speculative decoding. Falling back to single-track MLX." ) return - log.info("dflash.load_draft(%s) ...", repo) + log.info("dflash.load_draft(%s)", repo) t0 = time.time() - self.draft_model = load_draft(repo) + self.draft_model = _load_draft(repo) log.info("loaded draft %s in %.1fs", repo, time.time() - t0) + self._stream_generate = _stream_generate + self._injection_decision = detect_injection_strategy( + _stream_generate, self.model, + ) + log.info( + "DFlash injection strategy = %s (%s)", + self._injection_decision.strategy.value, + self._injection_decision.detail, + ) + + +# --------------------------------------------------------------------------- +# MLXEngine +# --------------------------------------------------------------------------- class MLXEngine: - """Skeleton MLX engine. Implements warmup + LRU; defers generate.""" + """M4 MLXEngine: DFlash + Kakeya KV, or native-MLX fallback.""" def __init__(self, cfg: MLXEngineConfig | None = None) -> None: self.cfg = cfg or MLXEngineConfig() @@ -120,6 +156,8 @@ def __init__(self, cfg: MLXEngineConfig | None = None) -> None: ) # ------------------------------------------------------------------ + # Model lifecycle + # ------------------------------------------------------------------ def _ensure_loaded( self, profile: MLXDeploymentProfile, channel: MLXChannel @@ -128,7 +166,6 @@ def _ensure_loaded( if profile.short_id in self._loaded: self._loaded.move_to_end(profile.short_id) return self._loaded[profile.short_id] - lm = _LoadedMLXModel(profile, self.cfg, channel) self._loaded[profile.short_id] = lm while len(self._loaded) > self.cfg.max_resident: @@ -142,6 +179,8 @@ def warmup(self, channel_id: str) -> None: self._ensure_loaded(profile, channel) # ------------------------------------------------------------------ + # Chat entry points + # ------------------------------------------------------------------ def chat( self, @@ -154,12 +193,30 @@ def chat( stop: list[str] | None = None, override: dict[str, Any] | None = None, ) -> tuple[str, dict[str, Any]]: - raise NotImplementedError( - "MLXEngine.chat() is a M4 deliverable. " - "Current PR (M1-M3) ships only model loading, registry, " - "and server routing. See integrations/atomic-chat-b2/ROADMAP.md " - "M4 for the DFlash-integrated generate loop." - ) + """Non-streaming chat completion. + + Returns ``(text, stats)`` matching B1's shape. ``stats`` adds + DFlash-specific fields: ``dflash_used``, ``injection_strategy``, + ``acceptance_length_mean``. + """ + profile, channel = resolve_mlx_model(channel_id) + if override: + channel = self._apply_override(channel, override) + lm = self._ensure_loaded(profile, channel) + + pieces: list[str] = [] + stats: dict[str, Any] = {} + for piece, partial_stats in self._run_stream( + lm, channel, messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ): + pieces.append(piece) + stats = partial_stats # last update wins + + return "".join(pieces), stats def chat_stream( self, @@ -172,6 +229,190 @@ def chat_stream( stop: list[str] | None = None, override: dict[str, Any] | None = None, ) -> Iterator[str]: - raise NotImplementedError( - "MLXEngine.chat_stream() is a M4 deliverable." + profile, channel = resolve_mlx_model(channel_id) + if override: + channel = self._apply_override(channel, override) + lm = self._ensure_loaded(profile, channel) + for piece, _stats in self._run_stream( + lm, channel, messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + ): + yield piece + + # ------------------------------------------------------------------ + # Core: streaming generator, used by both chat() and chat_stream() + # ------------------------------------------------------------------ + + def _run_stream( + self, + lm: _LoadedMLXModel, + channel: MLXChannel, + messages: list[dict[str, Any]], + *, + max_tokens: int, + temperature: float, + top_p: float, + stop: list[str] | None, + ) -> Iterator[tuple[str, dict[str, Any]]]: + prompt = self._render_prompt(lm, messages) + + injector = KakeyaCacheInjector( + model=lm.model, + variant=channel.variant, + q_range=channel.q_range, + boundary=channel.boundary, + strategy=lm._injection_decision.strategy, + ) + + t0 = time.time() + if ( + lm._stream_generate is not None + and lm.draft_model is not None + and lm._injection_decision.strategy + != InjectionStrategy.FALLBACK_NATIVE_MLX + ): + iterator_factory = self._dflash_iter_factory( + lm, prompt, max_tokens, temperature, top_p, + ) + yielded_pieces: list[str] = [] + accept_lens: list[int] = [] + with injector.activate(lm._stream_generate): + for piece, step_info in iterator_factory( + extra_kwargs=injector.extra_kwargs + ): + yielded_pieces.append(piece) + al = step_info.get("acceptance_length") + if al is not None: + accept_lens.append(int(al)) + yield piece, self._stats( + channel, t0, yielded_pieces, accept_lens, + dflash_used=True, lm=lm, + ) + if stop and any(s in "".join(yielded_pieces) for s in stop): + break + else: + # Native MLX fallback. + from mlx_lm.generate import stream_generate as _mlx_stream # type: ignore + + yielded_pieces = [] + caches = injector.build() + for piece in _mlx_stream( + lm.model, lm.tokenizer, prompt=prompt, + max_tokens=max_tokens, + temp=max(temperature, 1e-4), + top_p=top_p, + prompt_cache=caches, + ): + yielded_pieces.append(piece) + yield piece, self._stats( + channel, t0, yielded_pieces, [], + dflash_used=False, lm=lm, + ) + if stop and any(s in "".join(yielded_pieces) for s in stop): + break + + def _dflash_iter_factory( + self, lm, prompt, max_tokens, temperature, top_p, + ) -> Callable: + """Build a callable that produces (text, step_info) tuples. + + Abstracted so tests can substitute a mock without touching the + real dflash import. + """ + stream_generate = lm._stream_generate + + block_size = self.cfg.dflash_block_size + draft = lm.draft_model + model = lm.model + tokenizer = lm.tokenizer + + def _factory(extra_kwargs: dict[str, Any]): + for step in stream_generate( + model, + draft, + tokenizer, + prompt, + block_size=block_size, + max_tokens=max_tokens, + temperature=temperature, + **extra_kwargs, + ): + # dflash.model_mlx.stream_generate emits objects with + # at least `.text` and often `.accepted_length` / + # `.generation_tps`. We normalise into (text, info). + text = getattr(step, "text", None) or getattr(step, "delta", "") or "" + info: dict[str, Any] = {} + if hasattr(step, "accepted_length"): + info["acceptance_length"] = step.accepted_length + elif hasattr(step, "acceptance_length"): + info["acceptance_length"] = step.acceptance_length + if hasattr(step, "generation_tps"): + info["generation_tps"] = step.generation_tps + yield text, info + + return _factory + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _apply_override(channel: MLXChannel, override: dict[str, Any]) -> MLXChannel: + return MLXChannel( + variant=override.get("variant", channel.variant), + q_range=int(override.get("q_range", channel.q_range)), + boundary=int(override.get("boundary", channel.boundary)), + est_compression=channel.est_compression, + est_delta_ppl_pct=channel.est_delta_ppl_pct, + label=channel.label, + dflash_draft_repo=channel.dflash_draft_repo, + dflash_available=channel.dflash_available, + ) + + @staticmethod + def _render_prompt(lm: _LoadedMLXModel, messages: list[dict[str, Any]]) -> str: + """Apply the target model's chat template to the message list.""" + tok = lm.tokenizer + apply = getattr(tok, "apply_chat_template", None) + if callable(apply): + return apply(messages, tokenize=False, add_generation_prompt=True) + # Absolute fallback — a flat prompt, used only if the tokenizer + # has no chat template (rare for the curated B2 registry). + lines = [] + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") + lines.append(f"<|{role}|>\n{content}") + lines.append("<|assistant|>\n") + return "\n".join(lines) + + @staticmethod + def _stats( + channel: MLXChannel, + t0: float, + pieces: list[str], + accept_lens: list[int], + *, + dflash_used: bool, + lm: _LoadedMLXModel, + ) -> dict[str, Any]: + gen_time = time.time() - t0 + mean_accept = ( + sum(accept_lens) / len(accept_lens) if accept_lens else None ) + return { + "variant": channel.variant, + "q_range": channel.q_range, + "boundary": channel.boundary, + "est_compression": channel.est_compression, + "est_delta_ppl_pct": channel.est_delta_ppl_pct, + "dflash_used": dflash_used, + "injection_strategy": lm._injection_decision.strategy.value, + "dflash_draft_repo": channel.dflash_draft_repo, + "generation_time_s": gen_time, + "acceptance_length_mean": mean_accept, + "generated_chars": sum(len(p) for p in pieces), + } diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/server.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/server.py index 43e4548..7ad1a48 100644 --- a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/server.py +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/kakeya_sidecar_mlx/server.py @@ -1,24 +1,30 @@ -"""B2 FastAPI server — same route shape as B1, MLX-specific metadata. +"""B2 FastAPI server — M4 opens /v1/chat/completions. -Endpoints: +Route shape mirrors B1 (PR #57): + GET /health + GET /v1/models + POST /v1/chat/completions (stream + non-stream) + GET /v1/kakeya/stats - GET /health - GET /v1/models - POST /v1/chat/completions (stream + non-stream — **503 until M4**) - GET /v1/kakeya/stats - -Until the M4 PR lands, ``/v1/chat/completions`` returns HTTP 503 with -a body pointing at ``ROADMAP.md``. This is deliberate — better a clean -503 than a half-working chat that diverges from B1. +The B2-specific surface additions: + - /v1/models entries carry ``x_kakeya.dflash_draft_repo`` and + ``x_kakeya.dflash_available``. + - /health reports the MLX backend variant and whether DFlash is + enabled on this engine instance. + - /v1/chat/completions response ``x_kakeya`` carries ``dflash_used``, + ``injection_strategy``, and ``acceptance_length_mean``. """ from __future__ import annotations +import json import logging import time +import uuid from typing import Any from fastapi import FastAPI, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel, ConfigDict, Field from .engine_mlx import MLXEngine, MLXEngineConfig from .model_registry_mlx import MODEL_REGISTRY_MLX @@ -26,13 +32,43 @@ log = logging.getLogger("kakeya_sidecar_mlx.server") +# --------------------------------------------------------------------------- +# request / response schemas (subset of OpenAI spec + x_kakeya extension) +# --------------------------------------------------------------------------- + + +class _ChatMessage(BaseModel): + model_config = ConfigDict(extra="allow") + role: str + content: Any + + +class _ChatCompletionRequest(BaseModel): + model_config = ConfigDict(extra="allow") + model: str + messages: list[_ChatMessage] + stream: bool = False + temperature: float = 0.7 + top_p: float = 1.0 + max_tokens: int | None = None + stop: Any | None = None + x_kakeya_override: dict[str, Any] | None = Field( + default=None, alias="x_kakeya_override", + ) + + +# --------------------------------------------------------------------------- +# app factory +# --------------------------------------------------------------------------- + + def create_app( cfg: MLXEngineConfig | None = None, *, lazy_engine: bool = True, engine_instance: MLXEngine | None = None, ) -> FastAPI: - app = FastAPI(title="kakeya-sidecar-mlx", version="0.1.0") + app = FastAPI(title="kakeya-sidecar-mlx", version="0.2.0") state: dict[str, Any] = { "engine": engine_instance, @@ -55,7 +91,8 @@ def health() -> dict[str, Any]: "ok": True, "engine_loaded": state["engine"] is not None, "variant": "B2 (MLX + DFlash + KakeyaLattice)", - "milestone": "M1-M3 skeleton; /v1/chat/completions disabled until M4", + "dflash_enabled": state["cfg"].enable_dflash, + "milestone": "M4 — chat completions live (DFlash + KV compression)", } # ------------------------------------------------------------- /models @@ -91,17 +128,67 @@ def list_models() -> dict[str, Any]: # --------------------------------------------------- /chat/completions @app.post("/v1/chat/completions") - def chat_completions(_body: dict[str, Any]) -> JSONResponse: - raise HTTPException( - status_code=503, - detail=( - "B2 sidecar is at M1-M3 skeleton stage. " - "Chat completion will be enabled in the M4 PR " - "(DFlash integration). For now please use the B1 " - "sidecar on :1338." - ), + def chat_completions(req: _ChatCompletionRequest): + messages = [m.model_dump(exclude_none=True) for m in req.messages] + cid = f"chatcmpl-{uuid.uuid4().hex[:16]}" + created = int(time.time()) + + try: + eng = engine() + except Exception as e: # pragma: no cover + raise HTTPException(500, f"engine init failed: {e}") from e + + max_tokens = req.max_tokens or 512 + stop = ( + [req.stop] if isinstance(req.stop, str) + else (req.stop if isinstance(req.stop, list) else None) ) + if req.stream: + return StreamingResponse( + _sse_stream( + eng, req, cid, created, + messages, max_tokens, stop, + ), + media_type="text/event-stream", + ) + + try: + text, stats = eng.chat( + req.model, + messages, + max_tokens=max_tokens, + temperature=req.temperature, + top_p=req.top_p, + stop=stop, + override=req.x_kakeya_override, + ) + except KeyError as e: + raise HTTPException(404, str(e)) from e + except NotImplementedError as e: + raise HTTPException(501, str(e)) from e + except Exception as e: # pragma: no cover + log.exception("chat failed") + raise HTTPException(500, str(e)) from e + + return JSONResponse({ + "id": cid, + "object": "chat.completion", + "created": created, + "model": req.model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 0, + "completion_tokens": stats.get("generated_chars", 0), + "total_tokens": stats.get("generated_chars", 0), + }, + "x_kakeya": stats, + }) + # ------------------------------------------------- /v1/kakeya/stats @app.get("/v1/kakeya/stats") @@ -119,3 +206,55 @@ def kakeya_stats() -> dict[str, Any]: } return app + + +# --------------------------------------------------------------------------- +# SSE helper +# --------------------------------------------------------------------------- + + +def _sse_stream(eng, req, cid: str, created: int, + messages, max_tokens: int, stop): + def chunk(delta: dict[str, Any], finish_reason: str | None = None) -> str: + payload = { + "id": cid, + "object": "chat.completion.chunk", + "created": created, + "model": req.model, + "choices": [{ + "index": 0, + "delta": delta, + "finish_reason": finish_reason, + }], + } + return f"data: {json.dumps(payload)}\n\n" + + yield chunk({"role": "assistant"}) + try: + for piece in eng.chat_stream( + req.model, + messages, + max_tokens=max_tokens, + temperature=req.temperature, + top_p=req.top_p, + stop=stop, + override=req.x_kakeya_override, + ): + if piece: + yield chunk({"content": piece}) + except KeyError as e: + yield chunk({"content": f"[error] {e}"}, finish_reason="stop") + yield "data: [DONE]\n\n" + return + except NotImplementedError as e: + yield chunk({"content": f"[error] {e}"}, finish_reason="stop") + yield "data: [DONE]\n\n" + return + except Exception as e: # pragma: no cover + log.exception("stream failed") + yield chunk({"content": f"[error] {e}"}, finish_reason="stop") + yield "data: [DONE]\n\n" + return + + yield chunk({}, finish_reason="stop") + yield "data: [DONE]\n\n" diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_cache_injection.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_cache_injection.py new file mode 100644 index 0000000..9b5dbb3 --- /dev/null +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_cache_injection.py @@ -0,0 +1,171 @@ +"""Unit tests for ``cache_injection`` — no MLX / no dflash required. + +We feed the strategy detector synthetic stream_generate signatures +and stub target models, and verify: + +* KWARG strategy fires when the callable exposes one of the + well-known kwarg names +* MODEL_MAKE_CACHE strategy fires when the target has ``make_cache`` +* FALLBACK path when everything else is missing +* ``activate()`` context manager applies + cleans up each strategy + correctly (no residual state leaks after exit) +""" +from __future__ import annotations + +import pytest + +from kakeya_sidecar_mlx.cache_injection import ( + InjectionStrategy, + KakeyaCacheInjector, + detect_injection_strategy, +) + + +# --------------------------------------------------------------------------- +# Strategy detection +# --------------------------------------------------------------------------- + + +def test_detect_no_fn_is_fallback() -> None: + d = detect_injection_strategy(None, model=object()) + assert d.strategy == InjectionStrategy.FALLBACK_NATIVE_MLX + + +def test_detect_kwarg_target_cache() -> None: + def fake_stream(model, draft, tok, prompt, *, target_cache=None): + ... + + d = detect_injection_strategy(fake_stream, model=None) + assert d.strategy == InjectionStrategy.KWARG + assert "target_cache" in d.detail + + +def test_detect_kwarg_caches() -> None: + def fake_stream(model, draft, tok, prompt, caches=None): + ... + + d = detect_injection_strategy(fake_stream, model=None) + assert d.strategy == InjectionStrategy.KWARG + + +def test_detect_kwarg_prompt_cache() -> None: + def fake_stream(model, draft, tok, prompt, prompt_cache=None): + ... + + d = detect_injection_strategy(fake_stream, model=None) + assert d.strategy == InjectionStrategy.KWARG + + +def test_detect_model_make_cache() -> None: + def fake_stream(model, draft, tok, prompt): # no matching kwarg + ... + + class _M: + def make_cache(self): + return [] + + d = detect_injection_strategy(fake_stream, model=_M()) + assert d.strategy == InjectionStrategy.MODEL_MAKE_CACHE + + +# --------------------------------------------------------------------------- +# Injector.activate() — state management +# --------------------------------------------------------------------------- + + +def test_fallback_activate_yields_caches_and_no_patch() -> None: + model = object() + inj = KakeyaCacheInjector( + model=model, + strategy=InjectionStrategy.FALLBACK_NATIVE_MLX, + cache_factory=lambda m, **_kw: ["layer0", "layer1"], + ) + with inj.activate() as caches: + assert caches == ["layer0", "layer1"] + assert inj.extra_kwargs == {} + + +def test_kwarg_activate_sets_extra_kwargs_and_cleans_up() -> None: + def fake_stream(m, d, t, p, *, target_cache=None): + ... + + inj = KakeyaCacheInjector( + model=object(), + strategy=InjectionStrategy.KWARG, + cache_factory=lambda m, **_kw: ["C0", "C1"], + ) + with inj.activate(fake_stream) as caches: + assert caches == ["C0", "C1"] + assert inj.extra_kwargs == {"target_cache": ["C0", "C1"]} + assert inj.extra_kwargs == {} # cleaned up + + +def test_kwarg_strategy_with_unresolved_kwarg_downgrades_silently() -> None: + def no_matching_kwarg(m, d, t, p): + ... + + inj = KakeyaCacheInjector( + model=object(), + strategy=InjectionStrategy.KWARG, + cache_factory=lambda m, **_kw: ["X"], + ) + with inj.activate(no_matching_kwarg) as caches: + assert caches == ["X"] + assert inj.extra_kwargs == {} # downgrade — no kwarg injected + + +def test_model_make_cache_patch_restores_original() -> None: + class _M: + def __init__(self): + self.original_called = 0 + + def make_cache(self): + self.original_called += 1 + return ["ORIG"] + + m = _M() + inj = KakeyaCacheInjector( + model=m, + strategy=InjectionStrategy.MODEL_MAKE_CACHE, + cache_factory=lambda _m, **_kw: ["PATCHED"], + ) + # Before activate, calling make_cache returns original. + assert m.make_cache() == ["ORIG"] + assert m.original_called == 1 + + with inj.activate(None) as caches: + assert caches == ["PATCHED"] + # Inside activate, make_cache returns the injected caches. + assert m.make_cache() == ["PATCHED"] + + # After activate, original behaviour restored. + assert m.make_cache() == ["ORIG"] + assert m.original_called == 2 + + +def test_unknown_strategy_raises() -> None: + inj = KakeyaCacheInjector( + model=object(), + strategy="not_a_real_strategy", # type: ignore[arg-type] + cache_factory=lambda _m, **_kw: [], + ) + with pytest.raises(RuntimeError): + with inj.activate(None): + pass + + +def test_build_caches_passes_config_through_to_factory() -> None: + captured: dict = {} + + def _factory(model, **kw): + captured.update(kw) + return ["ok"] + + inj = KakeyaCacheInjector( + model=object(), + variant="e8", q_range=10, boundary=2, + strategy=InjectionStrategy.FALLBACK_NATIVE_MLX, + cache_factory=_factory, + ) + inj.build() + assert captured == {"variant": "e8", "q_range": 10, "boundary": 2} diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_engine_routing.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_engine_routing.py new file mode 100644 index 0000000..83f1de0 --- /dev/null +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_engine_routing.py @@ -0,0 +1,270 @@ +"""Routing tests for ``MLXEngine._run_stream``. + +The real dflash / mlx-lm generators only exist on Apple Silicon, but +the routing logic inside ``_run_stream`` (DFlash vs native MLX) is +pure control flow. We exercise it by: + +1. Hand-constructing a ``_LoadedMLXModel``-shaped object WITHOUT + calling ``_LoadedMLXModel.__init__`` (which would try to import + mlx-lm). ``object.__new__`` bypasses init so we can set the + fields directly. +2. Overriding ``MLXEngine._ensure_loaded`` so no real weights get + fetched. +3. Stubbing ``_dflash_iter_factory`` via monkeypatch — we control + exactly what pieces + accept-lengths the fake DFlash emits. +4. For the native-MLX fallback we monkey-patch the deferred import + of ``mlx_lm.generate.stream_generate`` with a generator stub. +""" +from __future__ import annotations + +from typing import Any + +import pytest + +from kakeya_sidecar_mlx.cache_injection import ( + InjectionDecision, + InjectionStrategy, +) +from kakeya_sidecar_mlx.engine_mlx import ( + MLXEngine, + MLXEngineConfig, + _LoadedMLXModel, +) +from kakeya_sidecar_mlx.model_registry_mlx import resolve_mlx_model + + +class _FakeTokenizer: + """Matches the ``apply_chat_template`` method mlx-lm exposes.""" + + def apply_chat_template(self, messages, tokenize: bool, add_generation_prompt: bool): + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + return " | ".join(parts) + " | assistant:" + + +def _make_fake_lm(dflash: bool, channel) -> _LoadedMLXModel: + """Build a ``_LoadedMLXModel`` without importing mlx-lm.""" + lm = object.__new__(_LoadedMLXModel) + lm.model = object() + lm.tokenizer = _FakeTokenizer() + lm.profile = None + lm.channel = channel + lm.draft_model = object() if dflash else None + lm._stream_generate = (lambda *a, **kw: iter([])) if dflash else None + lm._injection_decision = InjectionDecision( + InjectionStrategy.KWARG if dflash else InjectionStrategy.FALLBACK_NATIVE_MLX, + "test fixture", + ) + return lm + + +# --------------------------------------------------------------------------- +# DFlash path +# --------------------------------------------------------------------------- + + +def test_dflash_path_aggregates_text_and_acceptance(monkeypatch): + cfg = MLXEngineConfig(enable_dflash=True) + engine = MLXEngine(cfg) + + profile, channel = resolve_mlx_model("qwen3-8b@e8-q38") + lm = _make_fake_lm(dflash=True, channel=channel) + engine._loaded[profile.short_id] = lm + # Freeze ensure_loaded so no import happens. + monkeypatch.setattr( + engine, "_ensure_loaded", lambda prof, ch: lm, + ) + + # Stub dflash iterator: emit 3 blocks with varying acceptance. + class _Step: + def __init__(self, text, al): + self.text = text + self.accepted_length = al + + steps = [_Step("Hel", 12), _Step("lo ", 10), _Step("world.", 8)] + + def _fake_factory(lm_arg, prompt, max_tokens, temperature, top_p): + assert "user:" in prompt # chat template applied + def _run(extra_kwargs): + for s in steps: + yield s.text, {"acceptance_length": s.accepted_length} + return _run + + monkeypatch.setattr(engine, "_dflash_iter_factory", _fake_factory) + + # Also stub injector so it doesn't try to build real caches. + from kakeya_sidecar_mlx import engine_mlx as em + + class _StubInjector: + def __init__(self, *a, **kw): self.extra_kwargs = {} + def activate(self, _sg): + class _Ctx: + def __enter__(self_): return ["c0", "c1"] + def __exit__(self_, *exc): return False + return _Ctx() + def build(self): return ["c0", "c1"] + + monkeypatch.setattr(em, "KakeyaCacheInjector", _StubInjector) + + text, stats = engine.chat( + "qwen3-8b@e8-q38", + [{"role": "user", "content": "hi"}], + max_tokens=32, temperature=0.0, + ) + assert text == "Hello world." + assert stats["dflash_used"] is True + assert stats["injection_strategy"] == "kwarg" + assert stats["acceptance_length_mean"] == pytest.approx((12 + 10 + 8) / 3) + assert stats["variant"] == "e8" + assert stats["q_range"] == 38 + + +def test_dflash_stream_stops_on_stop_substring(monkeypatch): + cfg = MLXEngineConfig(enable_dflash=True) + engine = MLXEngine(cfg) + + profile, channel = resolve_mlx_model("qwen3-8b@e8-q38") + lm = _make_fake_lm(dflash=True, channel=channel) + engine._loaded[profile.short_id] = lm + monkeypatch.setattr(engine, "_ensure_loaded", lambda p, c: lm) + + def _fake_factory(lm_arg, prompt, max_tokens, temperature, top_p): + def _run(extra_kwargs): + yield "hello STOP more text", {"acceptance_length": 5} + yield "should not be yielded", {"acceptance_length": 5} + return _run + + monkeypatch.setattr(engine, "_dflash_iter_factory", _fake_factory) + + from kakeya_sidecar_mlx import engine_mlx as em + + class _StubInjector: + def __init__(self, *a, **kw): self.extra_kwargs = {} + def activate(self, _sg): + class _Ctx: + def __enter__(self_): return [] + def __exit__(self_, *exc): return False + return _Ctx() + def build(self): return [] + + monkeypatch.setattr(em, "KakeyaCacheInjector", _StubInjector) + + pieces = list(engine.chat_stream( + "qwen3-8b@e8-q38", + [{"role": "user", "content": "hi"}], + max_tokens=32, + stop=["STOP"], + )) + assert pieces == ["hello STOP more text"] + + +# --------------------------------------------------------------------------- +# Native MLX fallback path +# --------------------------------------------------------------------------- + + +def test_native_mlx_fallback_used_when_dflash_unavailable(monkeypatch): + """Mistral has no DFlash draft; engine must fall back cleanly.""" + cfg = MLXEngineConfig(enable_dflash=True) # even with enable=True + engine = MLXEngine(cfg) + + profile, channel = resolve_mlx_model("mistral-7b-instruct-v0.3@e8-q38") + assert channel.dflash_available is False + + lm = _make_fake_lm(dflash=False, channel=channel) + engine._loaded[profile.short_id] = lm + monkeypatch.setattr(engine, "_ensure_loaded", lambda p, c: lm) + + # Stub the mlx_lm.generate.stream_generate import. + import sys + import types as _types + + fake_generate = _types.ModuleType("mlx_lm.generate") + + def _fake_stream(model, tokenizer, *, prompt, max_tokens, temp, top_p, prompt_cache): + # Yield a couple of tokens then stop. + yield "Bonjour " + yield "monde." + + fake_generate.stream_generate = _fake_stream + + # If mlx_lm already imported (unlikely on CI), splice a fake submodule. + sys.modules["mlx_lm.generate"] = fake_generate + sys.modules.setdefault("mlx_lm", _types.ModuleType("mlx_lm")) + + from kakeya_sidecar_mlx import engine_mlx as em + + class _StubInjector: + def __init__(self, *a, **kw): self.extra_kwargs = {} + def activate(self, _sg): + class _Ctx: + def __enter__(self_): return [] + def __exit__(self_, *exc): return False + return _Ctx() + def build(self): return [] + + monkeypatch.setattr(em, "KakeyaCacheInjector", _StubInjector) + + text, stats = engine.chat( + "mistral-7b-instruct-v0.3@e8-q38", + [{"role": "user", "content": "bonjour"}], + max_tokens=16, + ) + assert text == "Bonjour monde." + assert stats["dflash_used"] is False + assert stats["injection_strategy"] == "fallback_native_mlx" + assert stats["acceptance_length_mean"] is None + + +# --------------------------------------------------------------------------- +# Override +# --------------------------------------------------------------------------- + + +def test_override_applies_per_request(monkeypatch): + cfg = MLXEngineConfig(enable_dflash=False) + engine = MLXEngine(cfg) + + profile, channel = resolve_mlx_model("qwen3-8b@e8-q38") + lm = _make_fake_lm(dflash=False, channel=channel) + engine._loaded[profile.short_id] = lm + monkeypatch.setattr(engine, "_ensure_loaded", lambda p, c: lm) + + # Fake native-mlx stream_generate (minimal). + import sys, types as _types + fake_generate = _types.ModuleType("mlx_lm.generate") + + def _fake_stream(model, tokenizer, *, prompt, max_tokens, temp, top_p, prompt_cache): + yield "ok" + + fake_generate.stream_generate = _fake_stream + sys.modules["mlx_lm.generate"] = fake_generate + sys.modules.setdefault("mlx_lm", _types.ModuleType("mlx_lm")) + + from kakeya_sidecar_mlx import engine_mlx as em + + captured_q: list[int] = [] + + class _StubInjector: + def __init__(self, *a, variant=None, q_range=None, **kw): + self.extra_kwargs = {} + captured_q.append(q_range) + def activate(self, _sg): + class _Ctx: + def __enter__(self_): return [] + def __exit__(self_, *exc): return False + return _Ctx() + def build(self): return [] + + monkeypatch.setattr(em, "KakeyaCacheInjector", _StubInjector) + + # Default channel q=38; override to q=10. + _text, stats = engine.chat( + "qwen3-8b@e8-q38", + [{"role": "user", "content": "hi"}], + max_tokens=8, + override={"q_range": 10}, + ) + assert stats["q_range"] == 10 + assert captured_q and captured_q[-1] == 10 diff --git a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_server_skeleton.py b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_server_skeleton.py index cf90783..a0774c5 100644 --- a/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_server_skeleton.py +++ b/integrations/atomic-chat-b2/kakeya_sidecar_mlx/tests/test_server_skeleton.py @@ -27,6 +27,8 @@ def test_health_reports_b2_variant(app) -> None: body = r.json() assert body["ok"] is True assert "B2" in body["variant"] + # M4: health carries explicit milestone status. + assert "M4" in body["milestone"] def test_models_exposes_dflash_metadata(app) -> None: @@ -46,15 +48,27 @@ def test_models_exposes_dflash_metadata(app) -> None: assert xk["q_range"] == 38 -def test_chat_returns_503_until_m4(app) -> None: - c = TestClient(app) - r = c.post("/v1/chat/completions", json={ - "model": "qwen3-8b@e8-q38", +def test_chat_unknown_model_returns_404(app, monkeypatch) -> None: + """M4 note: /v1/chat/completions is live; we validate its routing + by pointing it at a model that doesn't exist in the MLX registry. + The KeyError path is platform-agnostic — no mlx/mlx-lm required. + """ + from kakeya_sidecar_mlx import server as srv_mod + + class _FakeEngine: + def chat(self, *_a, **_kw): + raise KeyError("unknown MLX model id 'nonexistent-9000b'") + + monkeypatch.setattr(srv_mod, "MLXEngine", lambda *a, **kw: _FakeEngine()) + + # Force the app to build a fresh engine via the monkey-patched ctor. + app2 = srv_mod.create_app(lazy_engine=True) + c2 = TestClient(app2) + r = c2.post("/v1/chat/completions", json={ + "model": "nonexistent-9000b@e8-q10", "messages": [{"role": "user", "content": "hi"}], }) - assert r.status_code == 503 - body = r.json() - assert "M4" in body["detail"] or "M1-M3" in body["detail"] + assert r.status_code == 404 def test_stats_no_engine(app) -> None: