diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..06bd5ee 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -1,7 +1,20 @@ -"""Renderer-based generate client for vLLM 0.20's /inference/v1/generate. +"""Renderer-based generate client for vLLM + Dynamo. - messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate - → completion tokens → Renderer.parse_response() → structured message +Two transports, selected per-call via ``transport=`` parameter: + + "vllm_generate" (default) + messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate + → completion tokens → Renderer.parse_response() → structured message + vLLM's TITO surface (server.py mounts the route in prime-rl). + + "dynamo_chat" + messages → Renderer.render_ids() → token IDs → POST /v1/chat/completions + with ``nvext.token_data`` + ``nvext.extra_fields=["engine_data"]`` + → completion tokens via ``nvext.engine_data.completion_token_ids`` + → Renderer.parse_response() → structured message + Dynamo has no ``/inference/v1/generate`` route; this branch posts to + the standard OpenAI chat-completions surface and reads the engine + token IDs back via the ``nvext.engine_data`` channel. When a RendererPool is passed instead of a single Renderer, the sync tokenization and parsing work is offloaded to threads for parallel execution across rollouts. @@ -12,10 +25,13 @@ from __future__ import annotations import asyncio +import base64 import json import logging +from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, Literal, cast import httpx from openai import AsyncOpenAI @@ -107,6 +123,425 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None return value +# Public type alias; matches verifiers.types.RendererTransport string set. +RendererTransport = Literal["vllm_generate", "dynamo_chat"] + +# Keys never forwarded to Dynamo at the top level: vLLM/prime-only fields its +# strict validator rejects (mirrors the token client's drop set). ``priority`` +# is routed into nvext.agent_hints and ``routed_experts_prompt_start`` into +# nvext.routed_experts_prompt_start instead (the worker applies the latter to +# SamplingParams so vLLM trims routing engine-side). ``max_tokens`` and ``nvext`` +# are handled explicitly and skipped separately. +_DYNAMO_DROP_KEYS = frozenset( + { + "return_token_ids", + "spaces_between_special_tokens", + "priority", + "routed_experts_prompt_start", + } +) + +# Absolute /inference/v1/generate URLs, cached per client base_url. +_vllm_endpoint_cache: dict[str, str] = {} + + +def _vllm_generate_endpoint(base_url: str) -> str: + """Absolute ``/inference/v1/generate`` URL for ``base_url`` (cached). + + The route is mounted at the server root, not under /v1, so strip the + client's trailing /v1 and build an absolute URL — otherwise AsyncOpenAI + prepends its automatic /v1. + """ + endpoint = _vllm_endpoint_cache.get(base_url) + if endpoint is None: + endpoint = f"{base_url.rstrip('/').removesuffix('/v1')}/inference/v1/generate" + _vllm_endpoint_cache[base_url] = endpoint + return endpoint + + +def _flatten_chat_logprobs(choice: Mapping[str, Any]) -> list[float]: + """Flatten ChatCompletionLogProbs ``{"content": [{"logprob": ...}, ...]}``.""" + raw = choice.get("logprobs") or {} + content = raw.get("content") if isinstance(raw, dict) else None + return [float(c.get("logprob") or 0.0) for c in content or []] + + +@dataclass(frozen=True) +class _WireResult: + """Normalized fields extracted from a backend's raw response.""" + + completion_ids: list[int] + completion_logprobs: list[float] + routed_experts: Any + request_id: str + finish_reason: str | None + + +class _Transport(ABC): + """Per-backend request/response strategy for :func:`generate`.""" + + @abstractmethod + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + """Build the wire body, POST it, and return the decoded response dict.""" + + @abstractmethod + def parse(self, data: dict[str, Any]) -> _WireResult: + """Extract normalized completion fields from the backend response.""" + + +class _VllmGenerateTransport(_Transport): + """vLLM TITO surface: ``POST /inference/v1/generate``.""" + + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + features = ( + _build_mm_features(renderer, mm_data) + if mm_data and not mm_data.is_empty() + else None + ) + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + if features is not None: + body["features"] = features + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + + endpoint = _vllm_generate_endpoint(str(client.base_url)) + _request_logger.debug( + "POST %s prompt_len=%d max_tokens=%s", + endpoint, + len(prompt_ids), + sp.get("max_tokens"), + ) + post_kwargs: dict[str, Any] = {"cast_to": httpx.Response, "body": body} + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + raw_response = await client.post(endpoint, **post_kwargs) + return parse_generate_response(raw_response.content) + + def parse(self, data: dict[str, Any]) -> _WireResult: + choice = (data.get("choices") or [{}])[0] + return _WireResult( + completion_ids=list(choice.get("token_ids") or []), + completion_logprobs=_flatten_chat_logprobs(choice), + routed_experts=choice.get("routed_experts"), + request_id=data.get("request_id") or "", + finish_reason=choice.get("finish_reason"), + ) + + +class _DynamoChatTransport(_Transport): + """NVIDIA Dynamo: ``POST /v1/chat/completions`` with the nvext envelope. + + Dynamo has no /inference/v1/generate route. ``nvext.token_data`` carries + the pre-tokenized prompt (Dynamo skips tokenization when present) and + ``extra_fields=["engine_data", "routed_experts"]`` opts into the + completion-IDs/logprobs channel and the MoE routed_experts channel. Mirrors + ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` + so the wire payload is identical via either client. routed_experts is read + from ``nvext.routed_experts`` (or ``nvext.engine_data.routed_experts``) and + surfaced as the prime-rl ``{data, shape, start, dtype}`` contract. + """ + + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + # TODO: Implement multimodal support for dynamo_chat transport. + if mm_data is not None and not mm_data.is_empty(): + raise NotImplementedError( + "Multimodal renderers are not yet supported on the dynamo_chat " + "transport. Use vllm_generate or stay on the token-client TITO " + "path for VLMs." + ) + body = self._build_body(model, prompt_ids, sp, cache_salt, priority) + post_kwargs: dict[str, Any] = { + "cast_to": cast(Any, dict[str, Any]), + "body": body, + } + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + # Engine 4xx propagate raw (matches the vLLM path). + resp = await client.post("/chat/completions", **post_kwargs) + # Dynamo's NvExt request schema carries no field for + # routed_experts_prompt_start, so the worker can't trim the prompt rows: + # it returns full-sequence routing with start=0. Trim the leading prompt + # rows here and set start, so the payload matches the consumer contract + # (row 0 == position start). NOTE: if a forwarded prompt-start field is + # later added to NvExt (worker trims + reports start), drop this. + _trim_dynamo_routed_experts(resp, sp) + return resp + + @staticmethod + def _build_body( + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + cache_salt: str | None, + priority: int | None, + ) -> dict[str, Any]: + # cache_salt / priority may arrive as dedicated kwargs or inside + # sampling_params (the kwargs win). On Dynamo both belong in nvext, so + # they're routed there and never forwarded as top-level chat fields — + # keeping a shared sampling_params dict consistent with vllm_generate. + if cache_salt is None: + cache_salt = sp.get("cache_salt") + if priority is None: + priority = sp.get("priority") + + # Merge caller-supplied nvext rather than overwriting it, then layer on + # the required fields: token_data (authoritative renderer tokens) and a + # cumulative extra_fields union with "engine_data". The cache_salt / + # priority values win over any caller nvext values. + nvext: dict[str, Any] = dict(sp.get("nvext") or {}) + nvext["token_data"] = list(prompt_ids) + extra_fields = list(nvext.get("extra_fields") or []) + # Only request "engine_data": the worker nests routed_experts inside it + # (engine_data.routed_experts), so also requesting the dedicated + # "routed_experts" field would duplicate the (large) base64 blob on the + # wire — once under engine_data and once promoted at the top level. + if "engine_data" not in extra_fields: + extra_fields.append("engine_data") + nvext["extra_fields"] = extra_fields + if cache_salt is not None: + nvext["cache_salt"] = cache_salt + if priority is not None: + agent_hints = dict(nvext.get("agent_hints") or {}) + agent_hints["priority"] = priority + nvext["agent_hints"] = agent_hints + # routed_experts_prompt_start rides nvext (Dynamo rejects unknown + # top-level chat fields). The worker applies it to SamplingParams so vLLM + # trims the leading prompt rows engine-side and stamps the payload's + # `start` — the client-side trim then no-ops (see _trim_dynamo_routed_experts). + reps = sp.get("routed_experts_prompt_start") + if reps is not None: + nvext["routed_experts_prompt_start"] = reps + + # messages is a placeholder stub the OpenAI schema requires but Dynamo + # ignores. tools are baked into token_data; forwarding the renderer + # ToolSpec (not the OpenAI tool shape) would 400. + body: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": ""}], + "stream": False, + "nvext": nvext, + } + if sp.get("max_tokens") is not None: + body["max_completion_tokens"] = sp["max_tokens"] + + # Forward every other non-None sampling field (denylist, not allowlist) + # so caller-requested params (presence_penalty, stop, guided_*, ...) are + # not silently dropped. Mirrors the token client's body construction. + for key, value in sp.items(): + if value is None or key in _DYNAMO_DROP_KEYS or key in body: + continue + if key in ("nvext", "max_tokens", "cache_salt"): + continue # handled above (cache_salt -> nvext; priority denylisted) + if key == "logprobs": + # vLLM takes logprobs=N (int); Dynamo's chat schema wants the + # OpenAI bool + top_logprobs split. + body["logprobs"] = True + if isinstance(value, int) and value > 1: + body["top_logprobs"] = value + else: + body[key] = value + return body + + def parse(self, data: dict[str, Any]) -> _WireResult: + choice = (data.get("choices") or [{}])[0] + nvext = data.get("nvext") or {} + engine = nvext.get("engine_data") or {} + + # Canonical Dynamo channel first (nvext.engine_data, then top-level + # nvext), then the OpenAI-extended choices[0].token_ids. The + # engine channel is authoritative — choices[0].token_ids may be a + # detokenize-then-retokenize echo that differs from what was sampled. + completion_ids = None + present = False + for src in (engine, nvext): + if src.get("completion_token_ids") is not None: + completion_ids = src["completion_token_ids"] + present = True + break + if not present and choice.get("token_ids") is not None: + completion_ids = choice["token_ids"] + present = True + if not present: + # Field absent (vs. an empty completion) — usually a missing + # nvext.extra_fields=["engine_data"] opt-in. + raise RuntimeError( + "dynamo_chat response carried no completion token IDs " + "(expected nvext.engine_data.completion_token_ids)." + ) + completion_ids = list(completion_ids or []) + + # Prefer engine_data.completion_logprobs — the same authoritative source + # as the engine completion_token_ids used above — so logprobs stay + # positionally aligned with the ids. The choices[0] chat logprobs are a + # detokenize/retokenize echo that can diverge from the sampled ids, which + # would misalign while still passing the length check below. Fall back to + # the chat logprobs only when the engine channel is absent — a present + # but empty engine list is authoritative (logprobs off) and must NOT + # fall through to the chat echo (distinguish presence from truthiness). + engine_logprobs = engine.get("completion_logprobs") + if engine_logprobs is not None: + logprobs = [float(x) for x in engine_logprobs] + else: + logprobs = _flatten_chat_logprobs(choice) + # Logprobs are indexed positionally against completion_ids downstream; + # a length mismatch would silently misalign tokens and logprobs. + if logprobs and len(logprobs) != len(completion_ids): + raise RuntimeError( + f"dynamo_chat logprobs length ({len(logprobs)}) does not match " + f"completion token count ({len(completion_ids)})." + ) + + # routed_experts: read the dedicated nvext.routed_experts field if a + # caller opted into it, else the engine_data passthrough where the + # worker nests it (engine_data.routed_experts). Normalize to the + # {data, shape, start, dtype} contract so a malformed/legacy payload + # fails here with context instead of deep in trajectory processing. + routed_experts = nvext.get("routed_experts") + if routed_experts is None: + routed_experts = engine.get("routed_experts") + routed_experts = _normalize_routed_experts(routed_experts) + + return _WireResult( + completion_ids=completion_ids, + completion_logprobs=logprobs, + routed_experts=routed_experts, + request_id=data.get("request_id") or data.get("id") or "", + finish_reason=choice.get("finish_reason"), + ) + + +_TRANSPORTS: dict[str, _Transport] = { + "vllm_generate": _VllmGenerateTransport(), + "dynamo_chat": _DynamoChatTransport(), +} + + +def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: + """Validate/normalize a dynamo_chat routed_experts payload to the + ``{data, shape, start, dtype}`` contract. + + Defaults ``start=0`` and ``dtype="uint8"`` for back-compat with payloads + serialized before those fields existed; raises a clear ``RuntimeError`` for + a non-contract payload (string/map, wrong rank) so the failure surfaces here + with context instead of as a ``TypeError``/``KeyError`` in trajectory + processing. + """ + if payload is None: + return None + if not isinstance(payload, Mapping) or "data" not in payload or "shape" not in payload: + raise RuntimeError( + "dynamo_chat routed_experts must be a mapping with 'data' and " + f"'shape'; got {type(payload).__name__}" + ) + shape = payload["shape"] + if not (isinstance(shape, (list, tuple)) and len(shape) == 3): + raise RuntimeError( + "dynamo_chat routed_experts 'shape' must be 3-D [seq, layers, topk]; " + f"got {shape!r}" + ) + return { + "data": payload["data"], + "shape": [int(d) for d in shape], + "start": int(payload.get("start", 0)), + "dtype": payload.get("dtype", "uint8"), + } + + +_ROUTED_EXPERTS_ITEMSIZE = {"uint8": 1, "uint16": 2, "int16": 2, "int32": 4} + + +def _trim_dynamo_routed_experts(resp: Any, sp: dict[str, Any]) -> None: + """Client-side trim of a dynamo_chat routed_experts payload, in place. + + This is a **back-compat fallback**. The renderer now forwards + ``routed_experts_prompt_start`` via ``nvext`` (see ``_build_body``), so a + current Dynamo worker trims the leading prompt rows engine-side (vLLM) and + stamps the payload's ``start`` > 0 — in which case this is a no-op. We only + trim here when the worker returned FULL-sequence routing with ``start == 0`` + (an older worker that ignored the nvext field) and the caller supplied a + positive ``routed_experts_prompt_start``: drop that many leading rows and set + ``start``. No-op when routed_experts is absent/empty, the worker already + trimmed (``start`` > 0), or no positive offset is supplied (first-turn + requests keep full-sequence routing with ``start=0``). + """ + if not isinstance(resp, Mapping): + return + nvext = resp.get("nvext") + if not isinstance(nvext, Mapping): + return + routed = nvext.get("routed_experts") + if not isinstance(routed, dict): + engine = nvext.get("engine_data") + routed = engine.get("routed_experts") if isinstance(engine, Mapping) else None + if not isinstance(routed, dict): + return + data = routed.get("data") + shape = routed.get("shape") + if not isinstance(data, str) or not ( + isinstance(shape, (list, tuple)) and len(shape) == 3 + ): + return + + offset = sp.get("routed_experts_prompt_start") + if offset is None: + return + # Worker already trimmed engine-side (stamped start > 0) — don't double-trim. + if int(routed.get("start") or 0) != 0: + return + offset = max(0, min(int(offset), int(shape[0]))) + if offset == 0: + return + + itemsize = _ROUTED_EXPERTS_ITEMSIZE.get(routed.get("dtype", "uint8"), 1) + row_size = int(shape[1]) * int(shape[2]) * itemsize + trimmed = base64.b64decode(data)[offset * row_size :] + routed["data"] = base64.b64encode(trimmed).decode("ascii") + routed["shape"] = [int(shape[0]) - offset, int(shape[1]), int(shape[2])] + routed["start"] = offset + + async def _maybe_offload(renderer: Renderer | RendererPool, fn): """Run sync renderer work on a thread iff ``renderer`` is a pool. @@ -152,6 +587,7 @@ async def generate( prompt_attribution: RenderedTokens | None = None, tools: list[ToolSpec] | None = None, sampling_params: dict[str, Any] | None = None, + transport: RendererTransport = "vllm_generate", cache_salt: str | None = None, priority: int | None = None, extra_headers: dict[str, str] | None = None, @@ -243,66 +679,30 @@ def _prepare(): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - body: dict[str, Any] = { - "model": model, - "token_ids": prompt_ids, - "sampling_params": sp, - } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) - if features is not None: - body["features"] = features - if cache_salt is not None: - body["cache_salt"] = cache_salt - if priority is not None: - body["priority"] = priority - - # /inference/v1/generate is mounted at the server root, not under /v1 - # like the OpenAI-compatible endpoints. Build an absolute URL so the - # AsyncOpenAI client doesn't prepend its automatic /v1. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - endpoint = f"{base}/inference/v1/generate" - _request_logger.debug( - "POST %s prompt_len=%d max_tokens=%s", - endpoint, - len(prompt_ids), - sp.get("max_tokens"), + impl = _TRANSPORTS.get(transport) + if impl is None: + raise ValueError(f"Unknown renderer transport: {transport!r}") + data = await impl.post( + client=client, + model=model, + prompt_ids=prompt_ids, + sp=sp, + renderer=renderer, + mm_data=mm_data, + cache_salt=cache_salt, + priority=priority, + extra_headers=extra_headers, ) - post_kwargs: dict[str, Any] = { - "cast_to": httpx.Response, - "body": body, - } - if extra_headers: - post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - raw_response = await client.post(endpoint, **post_kwargs) - data = parse_generate_response(raw_response.content) - - choice = (data.get("choices") or [{}])[0] - completion_ids = choice.get("token_ids") or [] + wire = impl.parse(data) parsed = await _maybe_offload( - renderer, lambda: renderer.parse_response(completion_ids, tools=tools) + renderer, lambda: renderer.parse_response(wire.completion_ids, tools=tools) ) - # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]} - raw_logprobs = choice.get("logprobs") or {} - content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None - completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] - - routed_experts = choice.get("routed_experts") - - # /inference/v1/generate returns finish_reason in {"stop","length",...} — - # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls - # when we extracted at least one well-formed tool call client-side, so - # OpenAI-compatible agent loops continue past the tool turn instead of - # treating the response as final. Malformed attempts (INVALID_JSON, - # UNCLOSED_BLOCK, ...) don't qualify — those still surface on - # ``parsed.tool_calls`` so verifiers can inspect them, but they don't - # trigger the tool-loop continuation. - finish_reason = choice.get("finish_reason") + # /inference/v1/generate never returns "tool_calls", so promote + # stop→tool_calls when we parsed tool calls client-side (keeps agent + # loops going). No-op on dynamo_chat, which can return it directly. + finish_reason = wire.finish_reason ok_tool_calls = [ tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK ] @@ -310,15 +710,15 @@ def _prepare(): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") or "", + "request_id": wire.request_id, "prompt_ids": list(prompt_ids), - "completion_ids": list(completion_ids), - "completion_logprobs": completion_logprobs, + "completion_ids": list(wire.completion_ids), + "completion_logprobs": wire.completion_logprobs, "content": parsed.content, "reasoning_content": parsed.reasoning_content, "tool_calls": parsed.tool_calls, "finish_reason": finish_reason, - "routed_experts": routed_experts, + "routed_experts": wire.routed_experts, # The mm sidecar consumed on the request side, surfaced back so # callers can persist it on the trajectory step for downstream # multi-turn bridging and training-sample construction. @@ -348,7 +748,7 @@ def _build_mm_features( model-family specific. For now we dispatch on the renderer class; extend the dispatch table as more multimodal renderers land. - NOTE — future engine pluggability: this encoder is vLLM 0.20-specific + NOTE — future engine pluggability: this encoder is vLLM-specific (uses ``vllm.multimodal.inputs.MultiModalKwargsItems``, ``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and ``_create_qwen2vl_field_factory``). When a second inference engine diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..76c57d5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -95,10 +95,13 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): } ], } - return httpx.Response( - 200, - content=json.dumps(payload, separators=(",", ":")).encode("utf-8"), - ) + # vLLM path requests cast_to=httpx.Response; Dynamo path uses cast_to=dict. + if cast_to is httpx.Response: + return httpx.Response( + 200, + content=json.dumps(payload, separators=(",", ":")).encode("utf-8"), + ) + return payload def test_generate_builds_request_body_and_parses_response(): @@ -275,6 +278,438 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): assert result["prompt_attribution"] is supplied +class _DynamoFakeClient(_FakeClient): + """Dynamo-shaped response: engine fields + routed_experts under nvext (not + choices[0]); used to prove routed_experts now surfaces on dynamo_chat.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}, + "finish_reason": "stop", + } + ], + "nvext": { + "engine_data": {"completion_token_ids": [7, 8]}, + "routed_experts": { + # full-sequence routing (4 rows); worker can't trim + "data": "AQIDBA==", + "shape": [4, 1, 1], + "start": 0, + "dtype": "uint8", + }, + }, + } + + +def test_dynamo_transport_forwards_priority_and_detokenize(): + client = _DynamoFakeClient() + + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "temperature": 0.3, + "max_tokens": 7, + "detokenize": False, + "allowed_token_ids": [7, 8], + "bad_words_token_ids": [[1, 2]], + }, + cache_salt="ckpt-42", + priority=17, + transport="dynamo_chat", + ) + ) + + assert len(client.calls) == 1 + assert client.calls[0]["path"] == "/chat/completions" + assert client.calls[0]["body"] == { + "model": "test-model", + "messages": [{"role": "user", "content": ""}], + "stream": False, + "nvext": { + "token_data": [1, 2, 3], + "extra_fields": ["engine_data"], + "cache_salt": "ckpt-42", + "agent_hints": {"priority": 17}, + }, + # tools are NOT forwarded on the wire (baked into token_data instead). + "temperature": 0.3, + "max_completion_tokens": 7, + "logprobs": True, + "skip_special_tokens": False, + "stop_token_ids": [99], + "bad_words_token_ids": [[1, 2]], + "allowed_token_ids": [7, 8], + "detokenize": False, + } + assert result["completion_ids"] == [7, 8] + # routed_experts surfaces on dynamo_chat as the {data, shape, start, dtype} + # contract. No routed_experts_prompt_start is set here (first-turn case), so + # the renderer does NOT trim — full-sequence routing passes through with + # start=0. + assert result["routed_experts"] == { + "data": "AQIDBA==", + "shape": [4, 1, 1], + "start": 0, + "dtype": "uint8", + } + + +class _NoCompletionIdsClient(_FakeClient): + """Dynamo response that carries no completion token IDs.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return {"request_id": "x", "choices": [{"index": 0, "finish_reason": "stop"}]} + + +def test_dynamo_transport_raises_without_completion_ids(): + with pytest.raises(RuntimeError, match="completion token IDs"): + asyncio.run( + generate( + client=_NoCompletionIdsClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + + +class _EmptyCompletionClient(_FakeClient): + """Dynamo response with a present-but-empty completion_token_ids list.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "request_id": "x", + "choices": [ + {"index": 0, "finish_reason": "stop", "logprobs": {"content": []}} + ], + "nvext": {"engine_data": {"completion_token_ids": []}}, + } + + +class _EmptyParseRenderer(_FakeRenderer): + def parse_response(self, completion_ids, *, tools=None) -> ParsedResponse: + assert completion_ids == [] + return ParsedResponse(content="", reasoning_content=None, tool_calls=[]) + + +def test_dynamo_transport_allows_present_but_empty_completion(): + """A present-but-empty completion_token_ids is a valid zero-token completion + and must NOT raise (only an absent field raises).""" + result = asyncio.run( + generate( + client=_EmptyCompletionClient(), + renderer=_EmptyParseRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + assert result["completion_ids"] == [] + + +class _BoomClient(_FakeClient): + async def post(self, path, *, cast_to=dict, body=None, options=None): + raise ValueError("boom") + + +@pytest.mark.parametrize("transport", ["vllm_generate", "dynamo_chat"]) +def test_generate_propagates_post_errors_raw(transport): + # POST errors must propagate unchanged (no NameError from a stale handler). + with pytest.raises(ValueError, match="boom"): + asyncio.run( + generate( + client=_BoomClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport=transport, + max_prompt_len=10_000, + ) + ) + + +def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): + """F1: sampling fields outside the old allowlist (presence_penalty, stop, + guided_json) must reach the wire; vLLM-only/internal keys are dropped.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "max_tokens": 7, + "presence_penalty": 0.5, + "frequency_penalty": 0.25, + "stop": [""], + "guided_json": {"type": "object"}, + # denylisted — must NOT hit the wire + "return_token_ids": True, + # not a top-level field (Dynamo rejects unknown ones); routed + # into nvext.routed_experts_prompt_start so the worker trims + # routing engine-side + "routed_experts_prompt_start": 3, + }, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + body = client.calls[0]["body"] + assert body["presence_penalty"] == 0.5 + assert body["frequency_penalty"] == 0.25 + assert body["stop"] == [""] + assert body["guided_json"] == {"type": "object"} + assert "return_token_ids" not in body + # routed_experts_prompt_start is dropped from the top level (Dynamo rejects + # unknown top-level fields) but routed into nvext so the worker applies it to + # SamplingParams and trims routing engine-side. + assert "routed_experts_prompt_start" not in body + assert body["nvext"]["routed_experts_prompt_start"] == 3 + assert "extra_args" not in body.get("nvext", {}) + + +def test_trim_dynamo_routed_experts(): + """Client-side trim is a back-compat fallback: it trims ONLY when the worker + returned full routing (start=0) AND the caller supplied a positive + routed_experts_prompt_start. No-op when the worker already trimmed (start>0), + no offset is supplied (first turn), offset is 0, or routed_experts is absent.""" + from renderers.client import _trim_dynamo_routed_experts + + def _payload(channel): + re = { + "data": base64.b64encode(bytes([0, 1, 2, 3, 4])).decode(), + "shape": [5, 1, 1], "start": 0, "dtype": "uint8", + } + return {"nvext": {channel: {"routed_experts": re}} if channel == "engine_data" + else {"routed_experts": re}} + + # explicit prompt_start=3 -> drop 3 rows, start=3 (engine_data channel) + resp = _payload("engine_data") + _trim_dynamo_routed_experts(resp, {"routed_experts_prompt_start": 3}) + re = resp["nvext"]["engine_data"]["routed_experts"] + assert re["shape"] == [2, 1, 1] and re["start"] == 3 + assert base64.b64decode(re["data"]) == bytes([3, 4]) + + # explicit prompt_start=3 (top-level routed_experts channel) + resp2 = _payload("routed_experts") + _trim_dynamo_routed_experts(resp2, {"routed_experts_prompt_start": 3}) + re2 = resp2["nvext"]["routed_experts"] + assert re2["shape"] == [2, 1, 1] and re2["start"] == 3 + + # worker already trimmed engine-side (start>0) -> no-op (don't double-trim) + resp_wt = {"nvext": {"engine_data": {"routed_experts": { + "data": base64.b64encode(bytes([3, 4])).decode(), + "shape": [2, 1, 1], "start": 3, "dtype": "uint8", + }}}} + _trim_dynamo_routed_experts(resp_wt, {"routed_experts_prompt_start": 3}) + rewt = resp_wt["nvext"]["engine_data"]["routed_experts"] + assert rewt["shape"] == [2, 1, 1] and rewt["start"] == 3 + assert base64.b64decode(rewt["data"]) == bytes([3, 4]) + + # absent start (first turn) -> NO trim, full-sequence with start=0 + resp3 = _payload("engine_data") + _trim_dynamo_routed_experts(resp3, {}) + re3 = resp3["nvext"]["engine_data"]["routed_experts"] + assert re3["shape"] == [5, 1, 1] and re3["start"] == 0 + + # offset 0 -> no-op + resp0 = _payload("engine_data") + _trim_dynamo_routed_experts(resp0, {"routed_experts_prompt_start": 0}) + assert resp0["nvext"]["engine_data"]["routed_experts"]["shape"] == [5, 1, 1] + + # absent routed_experts -> no-op + resp4 = {"nvext": {"engine_data": {}}} + _trim_dynamo_routed_experts(resp4, {"routed_experts_prompt_start": 3}) + assert resp4 == {"nvext": {"engine_data": {}}} + + +def test_dynamo_parse_present_empty_engine_logprobs_does_not_fall_back_to_chat(): + """A present-but-empty engine_data.completion_logprobs is authoritative + (logprobs off): parse must NOT fall through to the divergent chat echo.""" + data = { + "choices": [ + { + # chat-echo logprobs (would mismatch the engine ids) + "logprobs": {"content": [{"logprob": -9.9}, {"logprob": -8.8}]}, + "finish_reason": "stop", + } + ], + "nvext": { + "engine_data": { + "completion_token_ids": [7, 8], + "completion_logprobs": [], + } + }, + } + from renderers.client import _TRANSPORTS + + wire = _TRANSPORTS["dynamo_chat"].parse(data) + assert wire.completion_ids == [7, 8] + assert wire.completion_logprobs == [] # engine present-empty wins over chat echo + + +def test_dynamo_transport_merges_caller_nvext(): + """F2: caller-supplied nvext is merged — extra_fields union with engine_data, + agent_hints merged with priority, unrelated caller keys preserved.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "max_tokens": 7, + "nvext": { + "extra_fields": ["timing"], + "agent_hints": {"osl": 4}, + "annotations": ["trace"], + }, + }, + priority=9, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + nvext = client.calls[0]["body"]["nvext"] + assert nvext["token_data"] == [1, 2, 3] + # extra_fields union preserves caller "timing" + our "engine_data" + assert nvext["extra_fields"] == ["timing", "engine_data"] + # agent_hints merged: caller osl kept, priority overlaid + assert nvext["agent_hints"] == {"osl": 4, "priority": 9} + # unrelated caller nvext keys survive + assert nvext["annotations"] == ["trace"] + + +def test_dynamo_transport_routes_sampling_params_cache_salt_and_priority_to_nvext(): + """cache_salt/priority supplied inside sampling_params (not the dedicated + kwargs) must still land in nvext, never as top-level chat fields.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7, "cache_salt": "ckpt-9", "priority": 5}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + body = client.calls[0]["body"] + assert body["nvext"]["cache_salt"] == "ckpt-9" + assert body["nvext"]["agent_hints"] == {"priority": 5} + # neither leaks to a top-level chat field + assert "cache_salt" not in body + assert "priority" not in body + + +class _BothTokenIdsClient(_FakeClient): + """Dynamo response carrying engine_data.completion_token_ids AND a divergent + choices[0].token_ids — the canonical engine channel must win (F3).""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "token_ids": [99, 99], # divergent echo — must be ignored + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}, + "finish_reason": "stop", + } + ], + "nvext": {"engine_data": {"completion_token_ids": [7, 8]}}, + } + + +def test_dynamo_transport_prefers_engine_data_over_choices_token_ids(): + client = _BothTokenIdsClient() + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + assert result["completion_ids"] == [7, 8] + + +class _MisalignedLogprobsClient(_FakeClient): + """Dynamo response whose logprobs length disagrees with completion_ids (F4).""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "logprobs": {"content": [{"logprob": -0.1}]}, # only 1 logprob + "finish_reason": "stop", + } + ], + "nvext": {"engine_data": {"completion_token_ids": [7, 8]}}, # 2 tokens + } + + +def test_dynamo_transport_raises_on_logprob_length_mismatch(): + with pytest.raises(RuntimeError, match="logprobs length"): + asyncio.run( + generate( + client=_MisalignedLogprobsClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + + # --------------------------------------------------------------------------- # Multimodal features payload. # ---------------------------------------------------------------------------