From 334e496cbf5079b990f656594f48fb621f5a79e2 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 14 May 2026 10:21:34 -0700 Subject: [PATCH 01/21] feat(client): add transport selector + dynamo_chat_nvext branch --- renderers/client.py | 277 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 231 insertions(+), 46 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..50e28f1 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -1,7 +1,20 @@ -"""Renderer-based generate client for vLLM 0.20's /inference/v1/generate. +"""Renderer-based generate client for vLLM 0.20 + Dynamo. - messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate - → completion tokens → Renderer.parse_response() → structured message +Two transports, selected per-call via ``transport=`` parameter: + + "prime_vllm_generate" (default) + messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate + → completion tokens → Renderer.parse_response() → structured message + vLLM's TITO surface (server.py mounts the route in prime-rl). + + "dynamo_chat_nvext" + messages → Renderer.render_ids() → token IDs → POST /v1/chat/completions + with ``nvext.token_data`` + ``nvext.extra_fields=["engine_data"]`` + → completion tokens via ``nvext.engine_data.completion_token_ids`` + → Renderer.parse_response() → structured message + Dynamo has no ``/inference/v1/generate`` route; this branch posts to + the standard OpenAI chat-completions surface and reads the engine + token IDs back via the PR #8119 ``nvext.engine_data`` channel. When a RendererPool is passed instead of a single Renderer, the sync tokenization and parsing work is offloaded to threads for parallel execution across rollouts. @@ -15,7 +28,7 @@ import json import logging from collections.abc import Mapping -from typing import Any, cast +from typing import Any, Literal, cast import httpx from openai import AsyncOpenAI @@ -106,6 +119,9 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None _max_prompt_len_cache[key] = value return value +# Public type alias; matches verifiers.types.RendererTransport string set. +RendererTransport = Literal["prime_vllm_generate", "dynamo_chat_nvext"] + async def _maybe_offload(renderer: Renderer | RendererPool, fn): """Run sync renderer work on a thread iff ``renderer`` is a pool. @@ -152,6 +168,7 @@ async def generate( prompt_attribution: RenderedTokens | None = None, tools: list[ToolSpec] | None = None, sampling_params: dict[str, Any] | None = None, + transport: RendererTransport = "prime_vllm_generate", cache_salt: str | None = None, priority: int | None = None, extra_headers: dict[str, str] | None = None, @@ -243,65 +260,128 @@ def _prepare(): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - body: dict[str, Any] = { - "model": model, - "token_ids": prompt_ids, - "sampling_params": sp, - } features = ( _build_mm_features(renderer, mm_data) if mm_data and not mm_data.is_empty() else None ) - if features is not None: - body["features"] = features - if cache_salt is not None: - body["cache_salt"] = cache_salt - if priority is not None: - body["priority"] = priority - - # /inference/v1/generate is mounted at the server root, not under /v1 - # like the OpenAI-compatible endpoints. Build an absolute URL so the - # AsyncOpenAI client doesn't prepend its automatic /v1. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - endpoint = f"{base}/inference/v1/generate" - _request_logger.debug( - "POST %s prompt_len=%d max_tokens=%s", - endpoint, - len(prompt_ids), - sp.get("max_tokens"), - ) - post_kwargs: dict[str, Any] = { - "cast_to": httpx.Response, - "body": body, - } - if extra_headers: - post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - raw_response = await client.post(endpoint, **post_kwargs) - data = parse_generate_response(raw_response.content) + + if transport == "dynamo_chat_nvext": + # Dynamo branch: POST /v1/chat/completions with nvext.token_data. + # Dynamo has no /inference/v1/generate route; the equivalent TITO + # surface lives on chat-completions via the ``nvext`` envelope + # (PR #8119: response token IDs come back under + # ``nvext.engine_data.completion_token_ids``). + if features is not None: + raise NotImplementedError( + "Multimodal renderers are not yet supported on the " + "dynamo_chat_nvext transport. Use prime_vllm_generate or " + "stay on the token-client TITO path for VLMs." + ) + data = await _post_dynamo_chat_nvext( + client=client, + model=model, + prompt_ids=prompt_ids, + sp=sp, + tools=tools, + cache_salt=cache_salt, + priority=priority, + extra_headers=extra_headers, + messages=messages, + ) + else: + # vLLM-native branch: POST /inference/v1/generate (vLLM 0.20 TITO). + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + if features is not None: + body["features"] = features + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + + # /inference/v1/generate is mounted at the server root, not under /v1 + # like the OpenAI-compatible endpoints. Build an absolute URL so the + # AsyncOpenAI client doesn't prepend its automatic /v1. + base = str(client.base_url).rstrip("/").removesuffix("/v1") + endpoint = f"{base}/inference/v1/generate" + _request_logger.debug( + "POST %s prompt_len=%d max_tokens=%s", + endpoint, + len(prompt_ids), + sp.get("max_tokens"), + ) + post_kwargs: dict[str, Any] = { + "cast_to": cast(Any, dict[str, Any]), + "body": body, + } + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + try: + data = await client.post(endpoint, **post_kwargs) + except BadRequestError as exc: + _log_overlong_prompt_diagnostic( + prompt_ids=prompt_ids, + messages=messages, + max_tokens=sp.get("max_tokens"), + exc=exc, + ) + raise choice = (data.get("choices") or [{}])[0] - completion_ids = choice.get("token_ids") or [] + # Dynamo emits engine token IDs under ``nvext.engine_data.completion_token_ids`` + # (PR #8119 channel) rather than ``choice.token_ids``. Try both — vLLM's + # /inference/v1/generate writes the top-level shape; Dynamo's + # /v1/chat/completions writes the nested one. The first present wins. + completion_ids = choice.get("token_ids") + if not completion_ids: + nvext_resp = data.get("nvext") or {} + engine_data = nvext_resp.get("engine_data") or {} + completion_ids = ( + engine_data.get("completion_token_ids") + or nvext_resp.get("completion_token_ids") + or [] + ) + completion_ids = list(completion_ids or []) parsed = await _maybe_offload( renderer, lambda: renderer.parse_response(completion_ids, tools=tools) ) - # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]} + # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}. + # Same shape on both transports (Dynamo aliases the standard OpenAI + # logprobs field). engine_data.completion_logprobs is a fallback when + # the OpenAI-style logprobs array is absent. raw_logprobs = choice.get("logprobs") or {} content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] - - routed_experts = choice.get("routed_experts") + if not completion_logprobs: + nvext_resp = data.get("nvext") or {} + engine_data = nvext_resp.get("engine_data") or {} + engine_lp = engine_data.get("completion_logprobs") or [] + if engine_lp: + completion_logprobs = [float(x) for x in engine_lp] + + routed_experts = None + raw_re = choice.get("routed_experts") or (data.get("nvext") or {}).get( + "routed_experts" + ) + if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re: + routed_experts = ( + np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32) + .reshape(raw_re["shape"]) + .tolist() + ) # /inference/v1/generate returns finish_reason in {"stop","length",...} — # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls - # when we extracted at least one well-formed tool call client-side, so - # OpenAI-compatible agent loops continue past the tool turn instead of - # treating the response as final. Malformed attempts (INVALID_JSON, - # UNCLOSED_BLOCK, ...) don't qualify — those still surface on - # ``parsed.tool_calls`` so verifiers can inspect them, but they don't - # trigger the tool-loop continuation. + # when we extracted tool calls client-side, so OpenAI-compatible agent + # loops continue past the tool turn instead of treating the response as + # final. Dynamo's chat-completions surface CAN return "tool_calls" + # directly, so this promotion is a no-op there. finish_reason = choice.get("finish_reason") ok_tool_calls = [ tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK @@ -310,7 +390,7 @@ def _prepare(): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") or "", + "request_id": data.get("request_id") or data.get("id") or "", "prompt_ids": list(prompt_ids), "completion_ids": list(completion_ids), "completion_logprobs": completion_logprobs, @@ -334,6 +414,111 @@ def _prepare(): } +async def _post_dynamo_chat_nvext( + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + tools: list[ToolSpec] | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + messages: list[Message], +) -> dict[str, Any]: + """POST ``prompt_ids`` to Dynamo's ``/v1/chat/completions`` route. + + Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat_nvext`` + in shape, so the wire payload is identical whether the rollout goes + through the token client or the renderer client. Anything that lands + on Dynamo's chat-completions surface, lands here. + + Wire shape: + + - ``nvext.token_data``: pre-tokenized prompt; Dynamo's preprocessor + skips tokenization when present. + - ``nvext.extra_fields = ["engine_data"]``: opt-in to the PR #8119 + channel — response carries ``nvext.engine_data.completion_token_ids`` + and ``nvext.engine_data.completion_logprobs``. + - ``messages``: placeholder (single user message). Dynamo ignores + when ``token_data`` is present, but the OpenAI schema requires + a non-empty messages array, so we send a 1-token stub. + - ``stop_token_ids`` / ``cache_salt`` / ``logprobs`` ride as + ``extra_body`` passthrough (Dynamo's + ``PASSTHROUGH_EXTRA_FIELDS`` allowlist accepts them). + """ + # Standard OpenAI fields that map 1:1 onto Dynamo's chat-completions + # request schema (validate.rs accepts them natively). + body: dict[str, Any] = { + "model": model, + # Single placeholder user message; ignored when token_data is set. + "messages": [{"role": "user", "content": ""}], + "stream": False, + "nvext": { + "token_data": list(prompt_ids), + "extra_fields": ["engine_data"], + }, + } + if tools: + body["tools"] = tools + if cache_salt is not None: + body["nvext"]["cache_salt"] = cache_salt + + # Surface standard sampling params at top level (Dynamo's schema + # recognizes them natively, so they flow into SamplingOptions cleanly). + promotable = ( + "max_tokens", + "temperature", + "top_p", + "top_k", + "min_p", + "seed", + "n", + "repetition_penalty", + "min_tokens", + "logprobs", + "skip_special_tokens", + ) + for key in promotable: + value = sp.get(key) + if value is None: + continue + if key == "max_tokens": + body["max_completion_tokens"] = value + elif key == "logprobs": + # Standard OpenAI shape: logprobs=true + top_logprobs=N. The + # vLLM TITO surface accepts ``logprobs=N`` (int); Dynamo's + # chat-completions schema requires the bool+top_logprobs split. + body["logprobs"] = True + if isinstance(value, int) and value > 1: + body["top_logprobs"] = value + else: + body[key] = value + + # Pass-through hints that Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist + # accepts (stop_token_ids, bad_words_token_ids, ...). + for key in ("stop_token_ids", "bad_words_token_ids", "allowed_token_ids"): + if sp.get(key) is not None: + body[key] = sp[key] + + post_kwargs: dict[str, Any] = { + "cast_to": cast(Any, dict[str, Any]), + "body": body, + } + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + try: + return await client.post("/chat/completions", **post_kwargs) + except BadRequestError as exc: + _log_overlong_prompt_diagnostic( + prompt_ids=prompt_ids, + messages=messages, + max_tokens=sp.get("max_tokens"), + exc=exc, + ) + raise + + def _build_mm_features( renderer: Renderer | RendererPool, mm_data: MultiModalData, From 6a215742b0d89a1d4e79e8d217efa4909ed8b0a5 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 15 May 2026 22:02:52 -0700 Subject: [PATCH 02/21] feat: forward Dynamo nvext TITO fields --- renderers/client.py | 24 ++++++++++++++-------- tests/test_client.py | 49 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 50e28f1..708408b 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -437,15 +437,14 @@ async def _post_dynamo_chat_nvext( - ``nvext.token_data``: pre-tokenized prompt; Dynamo's preprocessor skips tokenization when present. - - ``nvext.extra_fields = ["engine_data"]``: opt-in to the PR #8119 - channel — response carries ``nvext.engine_data.completion_token_ids`` - and ``nvext.engine_data.completion_logprobs``. + - ``nvext.extra_fields = ["engine_data", "routed_experts"]``: opt-in + to Dynamo's engine metadata and router replay channels. - ``messages``: placeholder (single user message). Dynamo ignores when ``token_data`` is present, but the OpenAI schema requires a non-empty messages array, so we send a 1-token stub. - - ``stop_token_ids`` / ``cache_salt`` / ``logprobs`` ride as - ``extra_body`` passthrough (Dynamo's - ``PASSTHROUGH_EXTRA_FIELDS`` allowlist accepts them). + - ``stop_token_ids`` / ``cache_salt`` / ``logprobs`` / backend sampling + hints ride as passthrough fields accepted by Dynamo's + ``PASSTHROUGH_EXTRA_FIELDS`` allowlist. """ # Standard OpenAI fields that map 1:1 onto Dynamo's chat-completions # request schema (validate.rs accepts them natively). @@ -456,13 +455,15 @@ async def _post_dynamo_chat_nvext( "stream": False, "nvext": { "token_data": list(prompt_ids), - "extra_fields": ["engine_data"], + "extra_fields": ["engine_data", "routed_experts"], }, } if tools: body["tools"] = tools if cache_salt is not None: body["nvext"]["cache_salt"] = cache_salt + if priority is not None: + body["nvext"]["agent_hints"] = {"priority": priority} # Surface standard sampling params at top level (Dynamo's schema # recognizes them natively, so they flow into SamplingOptions cleanly). @@ -496,8 +497,13 @@ async def _post_dynamo_chat_nvext( body[key] = value # Pass-through hints that Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist - # accepts (stop_token_ids, bad_words_token_ids, ...). - for key in ("stop_token_ids", "bad_words_token_ids", "allowed_token_ids"): + # accepts (stop_token_ids, token constraints, backend sampling toggles). + for key in ( + "stop_token_ids", + "bad_words_token_ids", + "allowed_token_ids", + "detokenize", + ): if sp.get(key) is not None: body[key] = sp[key] diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..b2ac8a6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -275,6 +275,55 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): assert result["prompt_attribution"] is supplied +def test_dynamo_transport_forwards_priority_and_detokenize(): + client = _FakeClient() + + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "temperature": 0.3, + "max_tokens": 7, + "detokenize": False, + "allowed_token_ids": [7, 8], + "bad_words_token_ids": [[1, 2]], + }, + cache_salt="ckpt-42", + priority=17, + transport="dynamo_chat_nvext", + ) + ) + + assert len(client.calls) == 1 + assert client.calls[0]["path"] == "/chat/completions" + assert client.calls[0]["body"] == { + "model": "test-model", + "messages": [{"role": "user", "content": ""}], + "stream": False, + "nvext": { + "token_data": [1, 2, 3], + "extra_fields": ["engine_data", "routed_experts"], + "cache_salt": "ckpt-42", + "agent_hints": {"priority": 17}, + }, + "tools": [{"type": "function", "function": {"name": "echo"}}], + "temperature": 0.3, + "max_completion_tokens": 7, + "logprobs": True, + "skip_special_tokens": False, + "stop_token_ids": [99], + "bad_words_token_ids": [[1, 2]], + "allowed_token_ids": [7, 8], + "detokenize": False, + } + assert result["completion_ids"] == [7, 8] + assert result["routed_experts"] == [[[1]], [[2]]] + + # --------------------------------------------------------------------------- # Multimodal features payload. # --------------------------------------------------------------------------- From a35e0236e55c9542fda08ce641ce0c9ebf2f05df Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 00:24:30 -0700 Subject: [PATCH 03/21] =?UTF-8?q?fix(client):=20address=20codex=20review?= =?UTF-8?q?=20=E2=80=94=20revert=20default=20vLLM=20path,=20drop=20tools?= =?UTF-8?q?=20from=20dynamo=20body,=20raise=20on=20missing=20ids;=20rename?= =?UTF-8?q?=20transport=20to=20dynamo=5Fchat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- renderers/client.py | 85 +++++++++++++++++++------------------------- tests/test_client.py | 70 ++++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 56 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 708408b..08b5c51 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -2,12 +2,12 @@ Two transports, selected per-call via ``transport=`` parameter: - "prime_vllm_generate" (default) + "vllm_generate" (default) messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate → completion tokens → Renderer.parse_response() → structured message vLLM's TITO surface (server.py mounts the route in prime-rl). - "dynamo_chat_nvext" + "dynamo_chat" messages → Renderer.render_ids() → token IDs → POST /v1/chat/completions with ``nvext.token_data`` + ``nvext.extra_fields=["engine_data"]`` → completion tokens via ``nvext.engine_data.completion_token_ids`` @@ -120,7 +120,7 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None return value # Public type alias; matches verifiers.types.RendererTransport string set. -RendererTransport = Literal["prime_vllm_generate", "dynamo_chat_nvext"] +RendererTransport = Literal["vllm_generate", "dynamo_chat"] async def _maybe_offload(renderer: Renderer | RendererPool, fn): @@ -168,7 +168,7 @@ async def generate( prompt_attribution: RenderedTokens | None = None, tools: list[ToolSpec] | None = None, sampling_params: dict[str, Any] | None = None, - transport: RendererTransport = "prime_vllm_generate", + transport: RendererTransport = "vllm_generate", cache_salt: str | None = None, priority: int | None = None, extra_headers: dict[str, str] | None = None, @@ -260,25 +260,19 @@ def _prepare(): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) - - if transport == "dynamo_chat_nvext": + if transport == "dynamo_chat": # Dynamo branch: POST /v1/chat/completions with nvext.token_data. # Dynamo has no /inference/v1/generate route; the equivalent TITO # surface lives on chat-completions via the ``nvext`` envelope # (PR #8119: response token IDs come back under # ``nvext.engine_data.completion_token_ids``). - if features is not None: + if mm_data is not None and not mm_data.is_empty(): raise NotImplementedError( "Multimodal renderers are not yet supported on the " - "dynamo_chat_nvext transport. Use prime_vllm_generate or " + "dynamo_chat transport. Use vllm_generate or " "stay on the token-client TITO path for VLMs." ) - data = await _post_dynamo_chat_nvext( + data = await _post_dynamo_chat( client=client, model=model, prompt_ids=prompt_ids, @@ -289,8 +283,13 @@ def _prepare(): extra_headers=extra_headers, messages=messages, ) - else: - # vLLM-native branch: POST /inference/v1/generate (vLLM 0.20 TITO). + elif transport == "vllm_generate": + # vLLM-native branch: POST /inference/v1/generate (unchanged from upstream). + features = ( + _build_mm_features(renderer, mm_data) + if mm_data and not mm_data.is_empty() + else None + ) body: dict[str, Any] = { "model": model, "token_ids": prompt_ids, @@ -315,21 +314,15 @@ def _prepare(): sp.get("max_tokens"), ) post_kwargs: dict[str, Any] = { - "cast_to": cast(Any, dict[str, Any]), + "cast_to": httpx.Response, "body": body, } if extra_headers: post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - try: - data = await client.post(endpoint, **post_kwargs) - except BadRequestError as exc: - _log_overlong_prompt_diagnostic( - prompt_ids=prompt_ids, - messages=messages, - max_tokens=sp.get("max_tokens"), - exc=exc, - ) - raise + raw_response = await client.post(endpoint, **post_kwargs) + data = parse_generate_response(raw_response.content) + else: + raise ValueError(f"Unknown renderer transport: {transport!r}") choice = (data.get("choices") or [{}])[0] # Dynamo emits engine token IDs under ``nvext.engine_data.completion_token_ids`` @@ -346,6 +339,13 @@ def _prepare(): or [] ) completion_ids = list(completion_ids or []) + if transport == "dynamo_chat" and not completion_ids: + # Fail loudly rather than parse an empty completion (usually a missing + # nvext.extra_fields=["engine_data"] opt-in). + raise RuntimeError( + "dynamo_chat response carried no completion token IDs " + "(expected nvext.engine_data.completion_token_ids)." + ) parsed = await _maybe_offload( renderer, lambda: renderer.parse_response(completion_ids, tools=tools) @@ -365,16 +365,11 @@ def _prepare(): if engine_lp: completion_logprobs = [float(x) for x in engine_lp] - routed_experts = None - raw_re = choice.get("routed_experts") or (data.get("nvext") or {}).get( + # Pass the routed_experts sidecar through unchanged (binary {shape, data} + # for vLLM; same dict under nvext for Dynamo). Decoding is the consumer's job. + routed_experts = choice.get("routed_experts") or (data.get("nvext") or {}).get( "routed_experts" ) - if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re: - routed_experts = ( - np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32) - .reshape(raw_re["shape"]) - .tolist() - ) # /inference/v1/generate returns finish_reason in {"stop","length",...} — # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls @@ -414,7 +409,7 @@ def _prepare(): } -async def _post_dynamo_chat_nvext( +async def _post_dynamo_chat( *, client: AsyncOpenAI, model: str, @@ -428,7 +423,7 @@ async def _post_dynamo_chat_nvext( ) -> dict[str, Any]: """POST ``prompt_ids`` to Dynamo's ``/v1/chat/completions`` route. - Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat_nvext`` + Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` in shape, so the wire payload is identical whether the rollout goes through the token client or the renderer client. Anything that lands on Dynamo's chat-completions surface, lands here. @@ -458,8 +453,8 @@ async def _post_dynamo_chat_nvext( "extra_fields": ["engine_data", "routed_experts"], }, } - if tools: - body["tools"] = tools + # tools are NOT sent on the wire: the renderer already bakes them into + # token_data, and renderer ToolSpec isn't the OpenAI tool shape (would 400). if cache_salt is not None: body["nvext"]["cache_salt"] = cache_salt if priority is not None: @@ -513,16 +508,8 @@ async def _post_dynamo_chat_nvext( } if extra_headers: post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - try: - return await client.post("/chat/completions", **post_kwargs) - except BadRequestError as exc: - _log_overlong_prompt_diagnostic( - prompt_ids=prompt_ids, - messages=messages, - max_tokens=sp.get("max_tokens"), - exc=exc, - ) - raise + # Engine 4xx propagate raw (matches the vLLM path). + return await client.post("/chat/completions", **post_kwargs) def _build_mm_features( diff --git a/tests/test_client.py b/tests/test_client.py index b2ac8a6..39a1e2c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -95,10 +95,13 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): } ], } - return httpx.Response( - 200, - content=json.dumps(payload, separators=(",", ":")).encode("utf-8"), - ) + # vLLM path requests cast_to=httpx.Response; Dynamo path uses cast_to=dict. + if cast_to is httpx.Response: + return httpx.Response( + 200, + content=json.dumps(payload, separators=(",", ":")).encode("utf-8"), + ) + return payload def test_generate_builds_request_body_and_parses_response(): @@ -294,7 +297,7 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): }, cache_salt="ckpt-42", priority=17, - transport="dynamo_chat_nvext", + transport="dynamo_chat", ) ) @@ -310,7 +313,7 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "cache_salt": "ckpt-42", "agent_hints": {"priority": 17}, }, - "tools": [{"type": "function", "function": {"name": "echo"}}], + # tools are NOT forwarded on the wire (baked into token_data instead). "temperature": 0.3, "max_completion_tokens": 7, "logprobs": True, @@ -321,7 +324,60 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "detokenize": False, } assert result["completion_ids"] == [7, 8] - assert result["routed_experts"] == [[[1]], [[2]]] + # routed_experts passes through unchanged (no client-side decode). + assert result["routed_experts"] == { + "data": base64.b64encode(b"\x01\x02").decode("ascii"), + "shape": [2, 1, 1], + } + + +class _NoCompletionIdsClient(_FakeClient): + """Dynamo response that carries no completion token IDs.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return {"request_id": "x", "choices": [{"index": 0, "finish_reason": "stop"}]} + + +def test_dynamo_transport_raises_without_completion_ids(): + with pytest.raises(RuntimeError, match="completion token IDs"): + asyncio.run( + generate( + client=_NoCompletionIdsClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + + +class _BoomClient(_FakeClient): + async def post(self, path, *, cast_to=dict, body=None, options=None): + raise ValueError("boom") + + +@pytest.mark.parametrize("transport", ["vllm_generate", "dynamo_chat"]) +def test_generate_propagates_post_errors_raw(transport): + # POST errors must propagate unchanged (no NameError from a stale handler). + with pytest.raises(ValueError, match="boom"): + asyncio.run( + generate( + client=_BoomClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport=transport, + max_prompt_len=10_000, + ) + ) # --------------------------------------------------------------------------- From b6f50d0d441efcadfd843ed7d531efc1c9cf8129 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 01:21:08 -0700 Subject: [PATCH 04/21] fix(client): gate nvext fallbacks to dynamo path, fix zero-token guard, drop routed_experts on dynamo (codex round 2) --- renderers/client.py | 55 ++++++++++++++++++++++---------------------- tests/test_client.py | 40 +++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 29 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 08b5c51..10f66cc 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -325,23 +325,25 @@ def _prepare(): raise ValueError(f"Unknown renderer transport: {transport!r}") choice = (data.get("choices") or [{}])[0] - # Dynamo emits engine token IDs under ``nvext.engine_data.completion_token_ids`` - # (PR #8119 channel) rather than ``choice.token_ids``. Try both — vLLM's - # /inference/v1/generate writes the top-level shape; Dynamo's - # /v1/chat/completions writes the nested one. The first present wins. + is_dynamo = transport == "dynamo_chat" + # vLLM writes choices[0].token_ids; Dynamo returns engine fields under nvext + # (PR #8119). Only consult nvext on the dynamo path so the vLLM path stays + # byte-identical to upstream. + nvext_resp = (data.get("nvext") or {}) if is_dynamo else {} + engine_data = nvext_resp.get("engine_data") or {} + completion_ids = choice.get("token_ids") - if not completion_ids: - nvext_resp = data.get("nvext") or {} - engine_data = nvext_resp.get("engine_data") or {} - completion_ids = ( - engine_data.get("completion_token_ids") - or nvext_resp.get("completion_token_ids") - or [] - ) + ids_present = completion_ids is not None + if not ids_present and is_dynamo: + for src in (engine_data, nvext_resp): + if src.get("completion_token_ids") is not None: + completion_ids = src["completion_token_ids"] + ids_present = True + break completion_ids = list(completion_ids or []) - if transport == "dynamo_chat" and not completion_ids: - # Fail loudly rather than parse an empty completion (usually a missing - # nvext.extra_fields=["engine_data"] opt-in). + if is_dynamo and not ids_present: + # Field absent (not merely an empty completion) — usually a missing + # nvext.extra_fields=["engine_data"] opt-in. raise RuntimeError( "dynamo_chat response carried no completion token IDs " "(expected nvext.engine_data.completion_token_ids)." @@ -352,24 +354,19 @@ def _prepare(): ) # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}. - # Same shape on both transports (Dynamo aliases the standard OpenAI - # logprobs field). engine_data.completion_logprobs is a fallback when - # the OpenAI-style logprobs array is absent. + # engine_data.completion_logprobs is a dynamo-only fallback. raw_logprobs = choice.get("logprobs") or {} content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] - if not completion_logprobs: - nvext_resp = data.get("nvext") or {} - engine_data = nvext_resp.get("engine_data") or {} + if not completion_logprobs and is_dynamo: engine_lp = engine_data.get("completion_logprobs") or [] if engine_lp: completion_logprobs = [float(x) for x in engine_lp] - # Pass the routed_experts sidecar through unchanged (binary {shape, data} - # for vLLM; same dict under nvext for Dynamo). Decoding is the consumer's job. - routed_experts = choice.get("routed_experts") or (data.get("nvext") or {}).get( - "routed_experts" - ) + # vLLM routed_experts sidecar ({shape, data}) passed through unchanged. Not + # surfaced on dynamo_chat: its nvext shape differs from the downstream + # RoutedExpertsPayload contract (base64 + ``start``). + routed_experts = choice.get("routed_experts") # /inference/v1/generate returns finish_reason in {"stop","length",...} — # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls @@ -385,7 +382,9 @@ def _prepare(): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") or data.get("id") or "", + "request_id": data.get("request_id") + or (data.get("id") if is_dynamo else None) + or "", "prompt_ids": list(prompt_ids), "completion_ids": list(completion_ids), "completion_logprobs": completion_logprobs, @@ -450,7 +449,7 @@ async def _post_dynamo_chat( "stream": False, "nvext": { "token_data": list(prompt_ids), - "extra_fields": ["engine_data", "routed_experts"], + "extra_fields": ["engine_data"], }, } # tools are NOT sent on the wire: the renderer already bakes them into diff --git a/tests/test_client.py b/tests/test_client.py index 39a1e2c..d65a429 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -309,7 +309,7 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "stream": False, "nvext": { "token_data": [1, 2, 3], - "extra_fields": ["engine_data", "routed_experts"], + "extra_fields": ["engine_data"], "cache_salt": "ckpt-42", "agent_hints": {"priority": 17}, }, @@ -357,6 +357,44 @@ def test_dynamo_transport_raises_without_completion_ids(): ) +class _EmptyCompletionClient(_FakeClient): + """Dynamo response with a present-but-empty completion_token_ids list.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "request_id": "x", + "choices": [{"index": 0, "finish_reason": "stop", "logprobs": {"content": []}}], + "nvext": {"engine_data": {"completion_token_ids": []}}, + } + + +class _EmptyParseRenderer(_FakeRenderer): + def parse_response(self, completion_ids, *, tools=None) -> ParsedResponse: + assert completion_ids == [] + return ParsedResponse(content="", reasoning_content=None, tool_calls=[]) + + +def test_dynamo_transport_allows_present_but_empty_completion(): + """A present-but-empty completion_token_ids is a valid zero-token completion + and must NOT raise (only an absent field raises).""" + result = asyncio.run( + generate( + client=_EmptyCompletionClient(), + renderer=_EmptyParseRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + assert result["completion_ids"] == [] + + class _BoomClient(_FakeClient): async def post(self, path, *, cast_to=dict, body=None, options=None): raise ValueError("boom") From 5dbf4941f118a126c3be2031cf9d6b4b68cf23fe Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 01:31:17 -0700 Subject: [PATCH 05/21] test(client): prove routed_experts dropped on dynamo (Dynamo-shaped fake); docstring fix --- renderers/client.py | 6 ++++-- tests/test_client.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 10f66cc..f6b5ea5 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -431,8 +431,10 @@ async def _post_dynamo_chat( - ``nvext.token_data``: pre-tokenized prompt; Dynamo's preprocessor skips tokenization when present. - - ``nvext.extra_fields = ["engine_data", "routed_experts"]``: opt-in - to Dynamo's engine metadata and router replay channels. + - ``nvext.extra_fields = ["engine_data"]``: opt-in to Dynamo's engine + metadata channel (completion token IDs + logprobs). routed_experts is + intentionally NOT requested — Dynamo's nvext shape differs from the + downstream RoutedExpertsPayload contract, so it is dropped on this path. - ``messages``: placeholder (single user message). Dynamo ignores when ``token_data`` is present, but the OpenAI schema requires a non-empty messages array, so we send a 1-token stub. diff --git a/tests/test_client.py b/tests/test_client.py index d65a429..0e116be 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -278,8 +278,32 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): assert result["prompt_attribution"] is supplied +class _DynamoFakeClient(_FakeClient): + """Dynamo-shaped response: engine fields + routed_experts under nvext (not + choices[0]), so the test proves routed_experts is dropped on dynamo_chat.""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}, + "finish_reason": "stop", + } + ], + "nvext": { + "engine_data": {"completion_token_ids": [7, 8]}, + "routed_experts": {"data": "AQI=", "shape": [2, 1, 1]}, + }, + } + + def test_dynamo_transport_forwards_priority_and_detokenize(): - client = _FakeClient() + client = _DynamoFakeClient() result = asyncio.run( generate( @@ -324,11 +348,8 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "detokenize": False, } assert result["completion_ids"] == [7, 8] - # routed_experts passes through unchanged (no client-side decode). - assert result["routed_experts"] == { - "data": base64.b64encode(b"\x01\x02").decode("ascii"), - "shape": [2, 1, 1], - } + # routed_experts is dropped on dynamo_chat (nvext shape != downstream contract). + assert result["routed_experts"] is None class _NoCompletionIdsClient(_FakeClient): From 287871c39fd6a17d5d783f0c38c5d1e6c4023c54 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 15:31:09 -0700 Subject: [PATCH 06/21] style: apply ruff format to client + tests (fix CI) --- renderers/client.py | 1 + tests/test_client.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/renderers/client.py b/renderers/client.py index f6b5ea5..41b2f6a 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -119,6 +119,7 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None _max_prompt_len_cache[key] = value return value + # Public type alias; matches verifiers.types.RendererTransport string set. RendererTransport = Literal["vllm_generate", "dynamo_chat"] diff --git a/tests/test_client.py b/tests/test_client.py index 0e116be..0cd8bd0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -387,7 +387,9 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): ) return { "request_id": "x", - "choices": [{"index": 0, "finish_reason": "stop", "logprobs": {"content": []}}], + "choices": [ + {"index": 0, "finish_reason": "stop", "logprobs": {"content": []}} + ], "nvext": {"engine_data": {"completion_token_ids": []}}, } From 60411345b258320ab6e984f299aa9fe406500a88 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 15:35:59 -0700 Subject: [PATCH 07/21] docs(client): trim verbose comments in dynamo_chat path --- renderers/client.py | 69 +++++++++++++-------------------------------- 1 file changed, 19 insertions(+), 50 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 41b2f6a..6774513 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -262,11 +262,6 @@ def _prepare(): sp.setdefault("skip_special_tokens", False) if transport == "dynamo_chat": - # Dynamo branch: POST /v1/chat/completions with nvext.token_data. - # Dynamo has no /inference/v1/generate route; the equivalent TITO - # surface lives on chat-completions via the ``nvext`` envelope - # (PR #8119: response token IDs come back under - # ``nvext.engine_data.completion_token_ids``). if mm_data is not None and not mm_data.is_empty(): raise NotImplementedError( "Multimodal renderers are not yet supported on the " @@ -285,7 +280,6 @@ def _prepare(): messages=messages, ) elif transport == "vllm_generate": - # vLLM-native branch: POST /inference/v1/generate (unchanged from upstream). features = ( _build_mm_features(renderer, mm_data) if mm_data and not mm_data.is_empty() @@ -327,9 +321,7 @@ def _prepare(): choice = (data.get("choices") or [{}])[0] is_dynamo = transport == "dynamo_chat" - # vLLM writes choices[0].token_ids; Dynamo returns engine fields under nvext - # (PR #8119). Only consult nvext on the dynamo path so the vLLM path stays - # byte-identical to upstream. + # Only consult nvext on the dynamo path, so the vLLM path is unchanged. nvext_resp = (data.get("nvext") or {}) if is_dynamo else {} engine_data = nvext_resp.get("engine_data") or {} @@ -343,7 +335,7 @@ def _prepare(): break completion_ids = list(completion_ids or []) if is_dynamo and not ids_present: - # Field absent (not merely an empty completion) — usually a missing + # Field absent (vs. an empty completion) — usually a missing # nvext.extra_fields=["engine_data"] opt-in. raise RuntimeError( "dynamo_chat response carried no completion token IDs " @@ -355,26 +347,22 @@ def _prepare(): ) # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}. - # engine_data.completion_logprobs is a dynamo-only fallback. raw_logprobs = choice.get("logprobs") or {} content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] if not completion_logprobs and is_dynamo: + # Dynamo-only fallback when chat logprobs are absent. engine_lp = engine_data.get("completion_logprobs") or [] if engine_lp: completion_logprobs = [float(x) for x in engine_lp] - # vLLM routed_experts sidecar ({shape, data}) passed through unchanged. Not - # surfaced on dynamo_chat: its nvext shape differs from the downstream - # RoutedExpertsPayload contract (base64 + ``start``). + # Not surfaced on dynamo_chat: its nvext shape differs from the + # downstream RoutedExpertsPayload contract (base64 + ``start``). routed_experts = choice.get("routed_experts") - # /inference/v1/generate returns finish_reason in {"stop","length",...} — - # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls - # when we extracted tool calls client-side, so OpenAI-compatible agent - # loops continue past the tool turn instead of treating the response as - # final. Dynamo's chat-completions surface CAN return "tool_calls" - # directly, so this promotion is a no-op there. + # /inference/v1/generate never returns "tool_calls", so promote + # stop→tool_calls when we parsed tool calls client-side (keeps agent + # loops going). No-op on dynamo_chat, which can return it directly. finish_reason = choice.get("finish_reason") ok_tool_calls = [ tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK @@ -424,30 +412,14 @@ async def _post_dynamo_chat( """POST ``prompt_ids`` to Dynamo's ``/v1/chat/completions`` route. Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` - in shape, so the wire payload is identical whether the rollout goes - through the token client or the renderer client. Anything that lands - on Dynamo's chat-completions surface, lands here. - - Wire shape: - - - ``nvext.token_data``: pre-tokenized prompt; Dynamo's preprocessor - skips tokenization when present. - - ``nvext.extra_fields = ["engine_data"]``: opt-in to Dynamo's engine - metadata channel (completion token IDs + logprobs). routed_experts is - intentionally NOT requested — Dynamo's nvext shape differs from the - downstream RoutedExpertsPayload contract, so it is dropped on this path. - - ``messages``: placeholder (single user message). Dynamo ignores - when ``token_data`` is present, but the OpenAI schema requires - a non-empty messages array, so we send a 1-token stub. - - ``stop_token_ids`` / ``cache_salt`` / ``logprobs`` / backend sampling - hints ride as passthrough fields accepted by Dynamo's - ``PASSTHROUGH_EXTRA_FIELDS`` allowlist. + so the wire payload is identical via either client. ``nvext.token_data`` + carries the pre-tokenized prompt (Dynamo skips tokenization when present) + and ``extra_fields=["engine_data"]`` opts into the completion-IDs/logprobs + channel. ``messages`` is a placeholder stub the OpenAI schema requires but + Dynamo ignores. routed_experts is not requested (incompatible nvext shape). """ - # Standard OpenAI fields that map 1:1 onto Dynamo's chat-completions - # request schema (validate.rs accepts them natively). body: dict[str, Any] = { "model": model, - # Single placeholder user message; ignored when token_data is set. "messages": [{"role": "user", "content": ""}], "stream": False, "nvext": { @@ -455,15 +427,14 @@ async def _post_dynamo_chat( "extra_fields": ["engine_data"], }, } - # tools are NOT sent on the wire: the renderer already bakes them into - # token_data, and renderer ToolSpec isn't the OpenAI tool shape (would 400). + # tools are baked into token_data already; the renderer ToolSpec isn't the + # OpenAI tool shape, so forwarding it would 400. if cache_salt is not None: body["nvext"]["cache_salt"] = cache_salt if priority is not None: body["nvext"]["agent_hints"] = {"priority": priority} - # Surface standard sampling params at top level (Dynamo's schema - # recognizes them natively, so they flow into SamplingOptions cleanly). + # Sampling params Dynamo's schema recognizes natively at top level. promotable = ( "max_tokens", "temperature", @@ -484,17 +455,15 @@ async def _post_dynamo_chat( if key == "max_tokens": body["max_completion_tokens"] = value elif key == "logprobs": - # Standard OpenAI shape: logprobs=true + top_logprobs=N. The - # vLLM TITO surface accepts ``logprobs=N`` (int); Dynamo's - # chat-completions schema requires the bool+top_logprobs split. + # vLLM takes logprobs=N (int); Dynamo's chat schema wants the + # OpenAI bool + top_logprobs split. body["logprobs"] = True if isinstance(value, int) and value > 1: body["top_logprobs"] = value else: body[key] = value - # Pass-through hints that Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist - # accepts (stop_token_ids, token constraints, backend sampling toggles). + # Pass-through hints on Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist. for key in ( "stop_token_ids", "bad_words_token_ids", From 503846c3b15f3bab5ad691e6f6d1ce6da64def79 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 15:49:39 -0700 Subject: [PATCH 08/21] refactor(client): replace transport if/else with strategy classes + cached endpoints --- renderers/client.py | 475 +++++++++++++++++++++++++++----------------- 1 file changed, 288 insertions(+), 187 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 6774513..6666a2c 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -27,7 +27,9 @@ import asyncio import json import logging +from abc import ABC, abstractmethod from collections.abc import Mapping +from dataclasses import dataclass from typing import Any, Literal, cast import httpx @@ -123,6 +125,272 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None # Public type alias; matches verifiers.types.RendererTransport string set. RendererTransport = Literal["vllm_generate", "dynamo_chat"] +# Sampling params Dynamo's chat schema recognizes natively at top level. +_DYNAMO_PROMOTABLE_KEYS = ( + "max_tokens", + "temperature", + "top_p", + "top_k", + "min_p", + "seed", + "n", + "repetition_penalty", + "min_tokens", + "logprobs", + "skip_special_tokens", +) +# Pass-through hints on Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist. +_DYNAMO_PASSTHROUGH_KEYS = ( + "stop_token_ids", + "bad_words_token_ids", + "allowed_token_ids", + "detokenize", +) + +# Absolute /inference/v1/generate URLs, cached per client base_url. +_vllm_endpoint_cache: dict[str, str] = {} + + +def _vllm_generate_endpoint(base_url: str) -> str: + """Absolute ``/inference/v1/generate`` URL for ``base_url`` (cached). + + The route is mounted at the server root, not under /v1, so strip the + client's trailing /v1 and build an absolute URL — otherwise AsyncOpenAI + prepends its automatic /v1. + """ + endpoint = _vllm_endpoint_cache.get(base_url) + if endpoint is None: + endpoint = f"{base_url.rstrip('/').removesuffix('/v1')}/inference/v1/generate" + _vllm_endpoint_cache[base_url] = endpoint + return endpoint + + +def _flatten_chat_logprobs(choice: Mapping[str, Any]) -> list[float]: + """Flatten ChatCompletionLogProbs ``{"content": [{"logprob": ...}, ...]}``.""" + raw = choice.get("logprobs") or {} + content = raw.get("content") if isinstance(raw, dict) else None + return [float(c.get("logprob") or 0.0) for c in content or []] + + +@dataclass(frozen=True) +class _WireResult: + """Normalized fields extracted from a backend's raw response.""" + + completion_ids: list[int] + completion_logprobs: list[float] + routed_experts: Any + request_id: str + finish_reason: str | None + + +class _Transport(ABC): + """Per-backend request/response strategy for :func:`generate`.""" + + @abstractmethod + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + """Build the wire body, POST it, and return the decoded response dict.""" + + @abstractmethod + def parse(self, data: dict[str, Any]) -> _WireResult: + """Extract normalized completion fields from the backend response.""" + + +class _VllmGenerateTransport(_Transport): + """vLLM 0.20 TITO surface: ``POST /inference/v1/generate``.""" + + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + features = ( + _build_mm_features(renderer, mm_data) + if mm_data and not mm_data.is_empty() + else None + ) + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + if features is not None: + body["features"] = features + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + + endpoint = _vllm_generate_endpoint(str(client.base_url)) + _request_logger.debug( + "POST %s prompt_len=%d max_tokens=%s", + endpoint, + len(prompt_ids), + sp.get("max_tokens"), + ) + post_kwargs: dict[str, Any] = {"cast_to": httpx.Response, "body": body} + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + raw_response = await client.post(endpoint, **post_kwargs) + return parse_generate_response(raw_response.content) + + def parse(self, data: dict[str, Any]) -> _WireResult: + choice = (data.get("choices") or [{}])[0] + return _WireResult( + completion_ids=list(choice.get("token_ids") or []), + completion_logprobs=_flatten_chat_logprobs(choice), + routed_experts=choice.get("routed_experts"), + request_id=data.get("request_id") or "", + finish_reason=choice.get("finish_reason"), + ) + + +class _DynamoChatTransport(_Transport): + """NVIDIA Dynamo: ``POST /v1/chat/completions`` with the nvext envelope. + + Dynamo has no /inference/v1/generate route. ``nvext.token_data`` carries + the pre-tokenized prompt (Dynamo skips tokenization when present) and + ``extra_fields=["engine_data"]`` opts into the completion-IDs/logprobs + channel (PR #8119). Mirrors + ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` + so the wire payload is identical via either client. routed_experts is not + requested (its nvext shape differs from the downstream contract). + """ + + async def post( + self, + *, + client: AsyncOpenAI, + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + renderer: Renderer | RendererPool, + mm_data: MultiModalData | None, + cache_salt: str | None, + priority: int | None, + extra_headers: dict[str, str] | None, + ) -> dict[str, Any]: + if mm_data is not None and not mm_data.is_empty(): + raise NotImplementedError( + "Multimodal renderers are not yet supported on the dynamo_chat " + "transport. Use vllm_generate or stay on the token-client TITO " + "path for VLMs." + ) + body = self._build_body(model, prompt_ids, sp, cache_salt, priority) + post_kwargs: dict[str, Any] = { + "cast_to": cast(Any, dict[str, Any]), + "body": body, + } + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + # Engine 4xx propagate raw (matches the vLLM path). + return await client.post("/chat/completions", **post_kwargs) + + @staticmethod + def _build_body( + model: str, + prompt_ids: list[int], + sp: dict[str, Any], + cache_salt: str | None, + priority: int | None, + ) -> dict[str, Any]: + # messages is a placeholder stub the OpenAI schema requires but Dynamo + # ignores. tools are baked into token_data; forwarding the renderer + # ToolSpec (not the OpenAI tool shape) would 400. + nvext: dict[str, Any] = { + "token_data": list(prompt_ids), + "extra_fields": ["engine_data"], + } + if cache_salt is not None: + nvext["cache_salt"] = cache_salt + if priority is not None: + nvext["agent_hints"] = {"priority": priority} + body: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": ""}], + "stream": False, + "nvext": nvext, + } + + for key in _DYNAMO_PROMOTABLE_KEYS: + value = sp.get(key) + if value is None: + continue + if key == "max_tokens": + body["max_completion_tokens"] = value + elif key == "logprobs": + # vLLM takes logprobs=N (int); Dynamo's chat schema wants the + # OpenAI bool + top_logprobs split. + body["logprobs"] = True + if isinstance(value, int) and value > 1: + body["top_logprobs"] = value + else: + body[key] = value + + for key in _DYNAMO_PASSTHROUGH_KEYS: + if sp.get(key) is not None: + body[key] = sp[key] + return body + + def parse(self, data: dict[str, Any]) -> _WireResult: + choice = (data.get("choices") or [{}])[0] + nvext = data.get("nvext") or {} + engine = nvext.get("engine_data") or {} + + completion_ids = choice.get("token_ids") + present = completion_ids is not None + if not present: + for src in (engine, nvext): + if src.get("completion_token_ids") is not None: + completion_ids = src["completion_token_ids"] + present = True + break + if not present: + # Field absent (vs. an empty completion) — usually a missing + # nvext.extra_fields=["engine_data"] opt-in. + raise RuntimeError( + "dynamo_chat response carried no completion token IDs " + "(expected nvext.engine_data.completion_token_ids)." + ) + + logprobs = _flatten_chat_logprobs(choice) + if not logprobs: + # Dynamo-only fallback when chat logprobs are absent. + logprobs = [float(x) for x in engine.get("completion_logprobs") or []] + + return _WireResult( + completion_ids=list(completion_ids or []), + completion_logprobs=logprobs, + routed_experts=None, + request_id=data.get("request_id") or data.get("id") or "", + finish_reason=choice.get("finish_reason"), + ) + + +_TRANSPORTS: dict[str, _Transport] = { + "vllm_generate": _VllmGenerateTransport(), + "dynamo_chat": _DynamoChatTransport(), +} + async def _maybe_offload(renderer: Renderer | RendererPool, fn): """Run sync renderer work on a thread iff ``renderer`` is a pool. @@ -261,109 +529,30 @@ def _prepare(): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - if transport == "dynamo_chat": - if mm_data is not None and not mm_data.is_empty(): - raise NotImplementedError( - "Multimodal renderers are not yet supported on the " - "dynamo_chat transport. Use vllm_generate or " - "stay on the token-client TITO path for VLMs." - ) - data = await _post_dynamo_chat( - client=client, - model=model, - prompt_ids=prompt_ids, - sp=sp, - tools=tools, - cache_salt=cache_salt, - priority=priority, - extra_headers=extra_headers, - messages=messages, - ) - elif transport == "vllm_generate": - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) - body: dict[str, Any] = { - "model": model, - "token_ids": prompt_ids, - "sampling_params": sp, - } - if features is not None: - body["features"] = features - if cache_salt is not None: - body["cache_salt"] = cache_salt - if priority is not None: - body["priority"] = priority - - # /inference/v1/generate is mounted at the server root, not under /v1 - # like the OpenAI-compatible endpoints. Build an absolute URL so the - # AsyncOpenAI client doesn't prepend its automatic /v1. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - endpoint = f"{base}/inference/v1/generate" - _request_logger.debug( - "POST %s prompt_len=%d max_tokens=%s", - endpoint, - len(prompt_ids), - sp.get("max_tokens"), - ) - post_kwargs: dict[str, Any] = { - "cast_to": httpx.Response, - "body": body, - } - if extra_headers: - post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - raw_response = await client.post(endpoint, **post_kwargs) - data = parse_generate_response(raw_response.content) - else: + impl = _TRANSPORTS.get(transport) + if impl is None: raise ValueError(f"Unknown renderer transport: {transport!r}") - - choice = (data.get("choices") or [{}])[0] - is_dynamo = transport == "dynamo_chat" - # Only consult nvext on the dynamo path, so the vLLM path is unchanged. - nvext_resp = (data.get("nvext") or {}) if is_dynamo else {} - engine_data = nvext_resp.get("engine_data") or {} - - completion_ids = choice.get("token_ids") - ids_present = completion_ids is not None - if not ids_present and is_dynamo: - for src in (engine_data, nvext_resp): - if src.get("completion_token_ids") is not None: - completion_ids = src["completion_token_ids"] - ids_present = True - break - completion_ids = list(completion_ids or []) - if is_dynamo and not ids_present: - # Field absent (vs. an empty completion) — usually a missing - # nvext.extra_fields=["engine_data"] opt-in. - raise RuntimeError( - "dynamo_chat response carried no completion token IDs " - "(expected nvext.engine_data.completion_token_ids)." - ) + data = await impl.post( + client=client, + model=model, + prompt_ids=prompt_ids, + sp=sp, + renderer=renderer, + mm_data=mm_data, + cache_salt=cache_salt, + priority=priority, + extra_headers=extra_headers, + ) + wire = impl.parse(data) parsed = await _maybe_offload( - renderer, lambda: renderer.parse_response(completion_ids, tools=tools) + renderer, lambda: renderer.parse_response(wire.completion_ids, tools=tools) ) - # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]}. - raw_logprobs = choice.get("logprobs") or {} - content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None - completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] - if not completion_logprobs and is_dynamo: - # Dynamo-only fallback when chat logprobs are absent. - engine_lp = engine_data.get("completion_logprobs") or [] - if engine_lp: - completion_logprobs = [float(x) for x in engine_lp] - - # Not surfaced on dynamo_chat: its nvext shape differs from the - # downstream RoutedExpertsPayload contract (base64 + ``start``). - routed_experts = choice.get("routed_experts") - # /inference/v1/generate never returns "tool_calls", so promote # stop→tool_calls when we parsed tool calls client-side (keeps agent # loops going). No-op on dynamo_chat, which can return it directly. - finish_reason = choice.get("finish_reason") + finish_reason = wire.finish_reason ok_tool_calls = [ tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK ] @@ -371,17 +560,15 @@ def _prepare(): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") - or (data.get("id") if is_dynamo else None) - or "", + "request_id": wire.request_id, "prompt_ids": list(prompt_ids), - "completion_ids": list(completion_ids), - "completion_logprobs": completion_logprobs, + "completion_ids": list(wire.completion_ids), + "completion_logprobs": wire.completion_logprobs, "content": parsed.content, "reasoning_content": parsed.reasoning_content, "tool_calls": parsed.tool_calls, "finish_reason": finish_reason, - "routed_experts": routed_experts, + "routed_experts": wire.routed_experts, # The mm sidecar consumed on the request side, surfaced back so # callers can persist it on the trajectory step for downstream # multi-turn bridging and training-sample construction. @@ -397,92 +584,6 @@ def _prepare(): } -async def _post_dynamo_chat( - *, - client: AsyncOpenAI, - model: str, - prompt_ids: list[int], - sp: dict[str, Any], - tools: list[ToolSpec] | None, - cache_salt: str | None, - priority: int | None, - extra_headers: dict[str, str] | None, - messages: list[Message], -) -> dict[str, Any]: - """POST ``prompt_ids`` to Dynamo's ``/v1/chat/completions`` route. - - Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` - so the wire payload is identical via either client. ``nvext.token_data`` - carries the pre-tokenized prompt (Dynamo skips tokenization when present) - and ``extra_fields=["engine_data"]`` opts into the completion-IDs/logprobs - channel. ``messages`` is a placeholder stub the OpenAI schema requires but - Dynamo ignores. routed_experts is not requested (incompatible nvext shape). - """ - body: dict[str, Any] = { - "model": model, - "messages": [{"role": "user", "content": ""}], - "stream": False, - "nvext": { - "token_data": list(prompt_ids), - "extra_fields": ["engine_data"], - }, - } - # tools are baked into token_data already; the renderer ToolSpec isn't the - # OpenAI tool shape, so forwarding it would 400. - if cache_salt is not None: - body["nvext"]["cache_salt"] = cache_salt - if priority is not None: - body["nvext"]["agent_hints"] = {"priority": priority} - - # Sampling params Dynamo's schema recognizes natively at top level. - promotable = ( - "max_tokens", - "temperature", - "top_p", - "top_k", - "min_p", - "seed", - "n", - "repetition_penalty", - "min_tokens", - "logprobs", - "skip_special_tokens", - ) - for key in promotable: - value = sp.get(key) - if value is None: - continue - if key == "max_tokens": - body["max_completion_tokens"] = value - elif key == "logprobs": - # vLLM takes logprobs=N (int); Dynamo's chat schema wants the - # OpenAI bool + top_logprobs split. - body["logprobs"] = True - if isinstance(value, int) and value > 1: - body["top_logprobs"] = value - else: - body[key] = value - - # Pass-through hints on Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist. - for key in ( - "stop_token_ids", - "bad_words_token_ids", - "allowed_token_ids", - "detokenize", - ): - if sp.get(key) is not None: - body[key] = sp[key] - - post_kwargs: dict[str, Any] = { - "cast_to": cast(Any, dict[str, Any]), - "body": body, - } - if extra_headers: - post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - # Engine 4xx propagate raw (matches the vLLM path). - return await client.post("/chat/completions", **post_kwargs) - - def _build_mm_features( renderer: Renderer | RendererPool, mm_data: MultiModalData, From ed03eaacb49bb57d8d2c538ddee69832259d0ce3 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 17:46:12 -0700 Subject: [PATCH 09/21] fix(client): address codex F1-F4 on dynamo_chat (denylist sampling, merge nvext, canonical completion-ids, logprobs alignment) --- renderers/client.py | 112 ++++++++++++++++++--------------- tests/test_client.py | 143 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+), 48 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 6666a2c..517e817 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -125,26 +125,18 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None # Public type alias; matches verifiers.types.RendererTransport string set. RendererTransport = Literal["vllm_generate", "dynamo_chat"] -# Sampling params Dynamo's chat schema recognizes natively at top level. -_DYNAMO_PROMOTABLE_KEYS = ( - "max_tokens", - "temperature", - "top_p", - "top_k", - "min_p", - "seed", - "n", - "repetition_penalty", - "min_tokens", - "logprobs", - "skip_special_tokens", -) -# Pass-through hints on Dynamo's PASSTHROUGH_EXTRA_FIELDS allowlist. -_DYNAMO_PASSTHROUGH_KEYS = ( - "stop_token_ids", - "bad_words_token_ids", - "allowed_token_ids", - "detokenize", +# Keys never forwarded to Dynamo at the top level: vLLM/prime-only fields its +# strict validator rejects (mirrors the token client's drop set), the +# renderer-internal ``routed_experts_prompt_start`` parse hint (never a wire +# field), and ``priority`` (routed into nvext.agent_hints instead). ``max_tokens`` +# and ``nvext`` are handled explicitly and skipped separately. +_DYNAMO_DROP_KEYS = frozenset( + { + "return_token_ids", + "spaces_between_special_tokens", + "priority", + "routed_experts_prompt_start", + } ) # Absolute /inference/v1/generate URLs, cached per client base_url. @@ -289,6 +281,7 @@ async def post( priority: int | None, extra_headers: dict[str, str] | None, ) -> dict[str, Any]: + # TODO: Implement multimodal support for dynamo_chat transport. if mm_data is not None and not mm_data.is_empty(): raise NotImplementedError( "Multimodal renderers are not yet supported on the dynamo_chat " @@ -313,31 +306,44 @@ def _build_body( cache_salt: str | None, priority: int | None, ) -> dict[str, Any]: - # messages is a placeholder stub the OpenAI schema requires but Dynamo - # ignores. tools are baked into token_data; forwarding the renderer - # ToolSpec (not the OpenAI tool shape) would 400. - nvext: dict[str, Any] = { - "token_data": list(prompt_ids), - "extra_fields": ["engine_data"], - } + # Merge caller-supplied nvext rather than overwriting it, then layer on + # the required fields: token_data (authoritative renderer tokens) and a + # cumulative extra_fields union with "engine_data". The cache_salt / + # priority kwargs win over any caller nvext values. + nvext: dict[str, Any] = dict(sp.get("nvext") or {}) + nvext["token_data"] = list(prompt_ids) + extra_fields = list(nvext.get("extra_fields") or []) + if "engine_data" not in extra_fields: + extra_fields.append("engine_data") + nvext["extra_fields"] = extra_fields if cache_salt is not None: nvext["cache_salt"] = cache_salt if priority is not None: - nvext["agent_hints"] = {"priority": priority} + agent_hints = dict(nvext.get("agent_hints") or {}) + agent_hints["priority"] = priority + nvext["agent_hints"] = agent_hints + + # messages is a placeholder stub the OpenAI schema requires but Dynamo + # ignores. tools are baked into token_data; forwarding the renderer + # ToolSpec (not the OpenAI tool shape) would 400. body: dict[str, Any] = { "model": model, "messages": [{"role": "user", "content": ""}], "stream": False, "nvext": nvext, } - - for key in _DYNAMO_PROMOTABLE_KEYS: - value = sp.get(key) - if value is None: + if sp.get("max_tokens") is not None: + body["max_completion_tokens"] = sp["max_tokens"] + + # Forward every other non-None sampling field (denylist, not allowlist) + # so caller-requested params (presence_penalty, stop, guided_*, ...) are + # not silently dropped. Mirrors the token client's body construction. + for key, value in sp.items(): + if value is None or key in _DYNAMO_DROP_KEYS or key in body: continue - if key == "max_tokens": - body["max_completion_tokens"] = value - elif key == "logprobs": + if key in ("nvext", "max_tokens"): + continue # handled above + if key == "logprobs": # vLLM takes logprobs=N (int); Dynamo's chat schema wants the # OpenAI bool + top_logprobs split. body["logprobs"] = True @@ -345,10 +351,6 @@ def _build_body( body["top_logprobs"] = value else: body[key] = value - - for key in _DYNAMO_PASSTHROUGH_KEYS: - if sp.get(key) is not None: - body[key] = sp[key] return body def parse(self, data: dict[str, Any]) -> _WireResult: @@ -356,14 +358,20 @@ def parse(self, data: dict[str, Any]) -> _WireResult: nvext = data.get("nvext") or {} engine = nvext.get("engine_data") or {} - completion_ids = choice.get("token_ids") - present = completion_ids is not None - if not present: - for src in (engine, nvext): - if src.get("completion_token_ids") is not None: - completion_ids = src["completion_token_ids"] - present = True - break + # Canonical Dynamo channel first (PR #8119: nvext.engine_data, then + # top-level nvext), then the OpenAI-extended choices[0].token_ids. The + # engine channel is authoritative — choices[0].token_ids may be a + # detokenize-then-retokenize echo that differs from what was sampled. + completion_ids = None + present = False + for src in (engine, nvext): + if src.get("completion_token_ids") is not None: + completion_ids = src["completion_token_ids"] + present = True + break + if not present and choice.get("token_ids") is not None: + completion_ids = choice["token_ids"] + present = True if not present: # Field absent (vs. an empty completion) — usually a missing # nvext.extra_fields=["engine_data"] opt-in. @@ -371,14 +379,22 @@ def parse(self, data: dict[str, Any]) -> _WireResult: "dynamo_chat response carried no completion token IDs " "(expected nvext.engine_data.completion_token_ids)." ) + completion_ids = list(completion_ids or []) logprobs = _flatten_chat_logprobs(choice) if not logprobs: # Dynamo-only fallback when chat logprobs are absent. logprobs = [float(x) for x in engine.get("completion_logprobs") or []] + # Logprobs are indexed positionally against completion_ids downstream; + # a length mismatch would silently misalign tokens and logprobs. + if logprobs and len(logprobs) != len(completion_ids): + raise RuntimeError( + f"dynamo_chat logprobs length ({len(logprobs)}) does not match " + f"completion token count ({len(completion_ids)})." + ) return _WireResult( - completion_ids=list(completion_ids or []), + completion_ids=completion_ids, completion_logprobs=logprobs, routed_experts=None, request_id=data.get("request_id") or data.get("id") or "", diff --git a/tests/test_client.py b/tests/test_client.py index 0cd8bd0..f81e5b8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -441,6 +441,149 @@ def test_generate_propagates_post_errors_raw(transport): ) +def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): + """F1: sampling fields outside the old allowlist (presence_penalty, stop, + guided_json) must reach the wire; vLLM-only/internal keys are dropped.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "max_tokens": 7, + "presence_penalty": 0.5, + "frequency_penalty": 0.25, + "stop": [""], + "guided_json": {"type": "object"}, + # denylisted / renderer-internal — must NOT hit the wire + "return_token_ids": True, + "routed_experts_prompt_start": 3, + }, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + body = client.calls[0]["body"] + assert body["presence_penalty"] == 0.5 + assert body["frequency_penalty"] == 0.25 + assert body["stop"] == [""] + assert body["guided_json"] == {"type": "object"} + assert "return_token_ids" not in body + assert "routed_experts_prompt_start" not in body + + +def test_dynamo_transport_merges_caller_nvext(): + """F2: caller-supplied nvext is merged — extra_fields union with engine_data, + agent_hints merged with priority, unrelated caller keys preserved.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "max_tokens": 7, + "nvext": { + "extra_fields": ["timing"], + "agent_hints": {"osl": 4}, + "annotations": ["trace"], + }, + }, + priority=9, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + nvext = client.calls[0]["body"]["nvext"] + assert nvext["token_data"] == [1, 2, 3] + # extra_fields union preserves caller "timing" + our "engine_data" + assert nvext["extra_fields"] == ["timing", "engine_data"] + # agent_hints merged: caller osl kept, priority overlaid + assert nvext["agent_hints"] == {"osl": 4, "priority": 9} + # unrelated caller nvext keys survive + assert nvext["annotations"] == ["trace"] + + +class _BothTokenIdsClient(_FakeClient): + """Dynamo response carrying engine_data.completion_token_ids AND a divergent + choices[0].token_ids — the canonical engine channel must win (F3).""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "token_ids": [99, 99], # divergent echo — must be ignored + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}]}, + "finish_reason": "stop", + } + ], + "nvext": {"engine_data": {"completion_token_ids": [7, 8]}}, + } + + +def test_dynamo_transport_prefers_engine_data_over_choices_token_ids(): + client = _BothTokenIdsClient() + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + assert result["completion_ids"] == [7, 8] + + +class _MisalignedLogprobsClient(_FakeClient): + """Dynamo response whose logprobs length disagrees with completion_ids (F4).""" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append( + {"path": path, "cast_to": cast_to, "body": body, "options": options} + ) + return { + "id": "gen-dyn", + "choices": [ + { + "index": 0, + "logprobs": {"content": [{"logprob": -0.1}]}, # only 1 logprob + "finish_reason": "stop", + } + ], + "nvext": {"engine_data": {"completion_token_ids": [7, 8]}}, # 2 tokens + } + + +def test_dynamo_transport_raises_on_logprob_length_mismatch(): + with pytest.raises(RuntimeError, match="logprobs length"): + asyncio.run( + generate( + client=_MisalignedLogprobsClient(), + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + + # --------------------------------------------------------------------------- # Multimodal features payload. # --------------------------------------------------------------------------- From eb0bdb2d5a44595554472675a91c38168ddd7195 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 9 Jun 2026 20:05:11 -0700 Subject: [PATCH 10/21] fix(client): route sampling_params cache_salt and priority into nvext on dynamo path --- renderers/client.py | 15 ++++++++++++--- tests/test_client.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 517e817..4c71c55 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -306,10 +306,19 @@ def _build_body( cache_salt: str | None, priority: int | None, ) -> dict[str, Any]: + # cache_salt / priority may arrive as dedicated kwargs or inside + # sampling_params (the kwargs win). On Dynamo both belong in nvext, so + # they're routed there and never forwarded as top-level chat fields — + # keeping a shared sampling_params dict consistent with vllm_generate. + if cache_salt is None: + cache_salt = sp.get("cache_salt") + if priority is None: + priority = sp.get("priority") + # Merge caller-supplied nvext rather than overwriting it, then layer on # the required fields: token_data (authoritative renderer tokens) and a # cumulative extra_fields union with "engine_data". The cache_salt / - # priority kwargs win over any caller nvext values. + # priority values win over any caller nvext values. nvext: dict[str, Any] = dict(sp.get("nvext") or {}) nvext["token_data"] = list(prompt_ids) extra_fields = list(nvext.get("extra_fields") or []) @@ -341,8 +350,8 @@ def _build_body( for key, value in sp.items(): if value is None or key in _DYNAMO_DROP_KEYS or key in body: continue - if key in ("nvext", "max_tokens"): - continue # handled above + if key in ("nvext", "max_tokens", "cache_salt"): + continue # handled above (cache_salt -> nvext; priority denylisted) if key == "logprobs": # vLLM takes logprobs=N (int); Dynamo's chat schema wants the # OpenAI bool + top_logprobs split. diff --git a/tests/test_client.py b/tests/test_client.py index f81e5b8..dd33b4c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -509,6 +509,30 @@ def test_dynamo_transport_merges_caller_nvext(): assert nvext["annotations"] == ["trace"] +def test_dynamo_transport_routes_sampling_params_cache_salt_and_priority_to_nvext(): + """cache_salt/priority supplied inside sampling_params (not the dedicated + kwargs) must still land in nvext, never as top-level chat fields.""" + client = _DynamoFakeClient() + asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={"max_tokens": 7, "cache_salt": "ckpt-9", "priority": 5}, + transport="dynamo_chat", + max_prompt_len=10_000, + ) + ) + body = client.calls[0]["body"] + assert body["nvext"]["cache_salt"] == "ckpt-9" + assert body["nvext"]["agent_hints"] == {"priority": 5} + # neither leaks to a top-level chat field + assert "cache_salt" not in body + assert "priority" not in body + + class _BothTokenIdsClient(_FakeClient): """Dynamo response carrying engine_data.completion_token_ids AND a divergent choices[0].token_ids — the canonical engine channel must win (F3).""" From 30c01b600ba2ff76066d19fdbe6ff8c025e3a910 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 02:33:59 -0700 Subject: [PATCH 11/21] feat(client): surface routed_experts on dynamo_chat transport --- renderers/client.py | 39 ++++++++++++++++++++++++++++++++------- tests/test_client.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 4c71c55..bb0662e 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -261,11 +261,13 @@ class _DynamoChatTransport(_Transport): Dynamo has no /inference/v1/generate route. ``nvext.token_data`` carries the pre-tokenized prompt (Dynamo skips tokenization when present) and - ``extra_fields=["engine_data"]`` opts into the completion-IDs/logprobs - channel (PR #8119). Mirrors + ``extra_fields=["engine_data", "routed_experts"]`` opts into the + completion-IDs/logprobs channel (PR #8119) and the MoE routed_experts + channel. Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` - so the wire payload is identical via either client. routed_experts is not - requested (its nvext shape differs from the downstream contract). + so the wire payload is identical via either client. routed_experts is read + from ``nvext.routed_experts`` (or ``nvext.engine_data.routed_experts``) and + surfaced as the prime-rl ``{data, shape, start, dtype}`` contract. """ async def post( @@ -322,9 +324,23 @@ def _build_body( nvext: dict[str, Any] = dict(sp.get("nvext") or {}) nvext["token_data"] = list(prompt_ids) extra_fields = list(nvext.get("extra_fields") or []) - if "engine_data" not in extra_fields: - extra_fields.append("engine_data") + for field in ("engine_data", "routed_experts"): + if field not in extra_fields: + extra_fields.append(field) nvext["extra_fields"] = extra_fields + # ``routed_experts_prompt_start`` is the multi-turn replay offset. Dynamo + # exposes no top-level field for it, so thread it through + # ``nvext.extra_args.sampling_options`` where the worker's + # ``build_sampling_params`` applies it to vLLM ``SamplingParams`` — vLLM + # then trims the leading prompt rows and the worker echoes the value back + # as ``routed_experts.start`` for completion alignment. + prompt_start = sp.get("routed_experts_prompt_start") + if prompt_start is not None: + extra_args = dict(nvext.get("extra_args") or {}) + sampling_options = dict(extra_args.get("sampling_options") or {}) + sampling_options["routed_experts_prompt_start"] = prompt_start + extra_args["sampling_options"] = sampling_options + nvext["extra_args"] = extra_args if cache_salt is not None: nvext["cache_salt"] = cache_salt if priority is not None: @@ -402,10 +418,19 @@ def parse(self, data: dict[str, Any]) -> _WireResult: f"completion token count ({len(completion_ids)})." ) + # routed_experts: prefer the dedicated nvext.routed_experts field + # (extra_fields=["routed_experts"]); fall back to the engine_data + # passthrough (extra_fields=["engine_data"]) since the worker also nests + # it there. The Dynamo vLLM worker already emits the prime-rl contract + # shape {data(base64), shape, start, dtype}, so pass it through unchanged. + routed_experts = nvext.get("routed_experts") + if routed_experts is None: + routed_experts = engine.get("routed_experts") + return _WireResult( completion_ids=completion_ids, completion_logprobs=logprobs, - routed_experts=None, + routed_experts=routed_experts, request_id=data.get("request_id") or data.get("id") or "", finish_reason=choice.get("finish_reason"), ) diff --git a/tests/test_client.py b/tests/test_client.py index dd33b4c..396b596 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -280,7 +280,7 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): class _DynamoFakeClient(_FakeClient): """Dynamo-shaped response: engine fields + routed_experts under nvext (not - choices[0]), so the test proves routed_experts is dropped on dynamo_chat.""" + choices[0]); used to prove routed_experts now surfaces on dynamo_chat.""" async def post(self, path, *, cast_to=dict, body=None, options=None): self.calls.append( @@ -297,7 +297,12 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): ], "nvext": { "engine_data": {"completion_token_ids": [7, 8]}, - "routed_experts": {"data": "AQI=", "shape": [2, 1, 1]}, + "routed_experts": { + "data": "AQI=", + "shape": [2, 1, 1], + "start": 0, + "dtype": "uint8", + }, }, } @@ -333,7 +338,7 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "stream": False, "nvext": { "token_data": [1, 2, 3], - "extra_fields": ["engine_data"], + "extra_fields": ["engine_data", "routed_experts"], "cache_salt": "ckpt-42", "agent_hints": {"priority": 17}, }, @@ -348,8 +353,14 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "detokenize": False, } assert result["completion_ids"] == [7, 8] - # routed_experts is dropped on dynamo_chat (nvext shape != downstream contract). - assert result["routed_experts"] is None + # routed_experts now surfaces on dynamo_chat via nvext.routed_experts, + # passed through as the prime-rl {data, shape, start, dtype} contract. + assert result["routed_experts"] == { + "data": "AQI=", + "shape": [2, 1, 1], + "start": 0, + "dtype": "uint8", + } class _NoCompletionIdsClient(_FakeClient): @@ -458,8 +469,10 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): "frequency_penalty": 0.25, "stop": [""], "guided_json": {"type": "object"}, - # denylisted / renderer-internal — must NOT hit the wire + # denylisted — must NOT hit the wire "return_token_ids": True, + # renderer-internal: forwarded via nvext.extra_args (below), + # never as a top-level chat field "routed_experts_prompt_start": 3, }, transport="dynamo_chat", @@ -472,7 +485,14 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): assert body["stop"] == [""] assert body["guided_json"] == {"type": "object"} assert "return_token_ids" not in body + # routed_experts_prompt_start is not a top-level field; it rides + # nvext.extra_args.sampling_options so the worker applies it to vLLM + # SamplingParams (trims the prompt rows) and echoes it back as start. assert "routed_experts_prompt_start" not in body + assert ( + body["nvext"]["extra_args"]["sampling_options"]["routed_experts_prompt_start"] + == 3 + ) def test_dynamo_transport_merges_caller_nvext(): @@ -501,8 +521,8 @@ def test_dynamo_transport_merges_caller_nvext(): ) nvext = client.calls[0]["body"]["nvext"] assert nvext["token_data"] == [1, 2, 3] - # extra_fields union preserves caller "timing" + our "engine_data" - assert nvext["extra_fields"] == ["timing", "engine_data"] + # extra_fields union preserves caller "timing" + our "engine_data"/"routed_experts" + assert nvext["extra_fields"] == ["timing", "engine_data", "routed_experts"] # agent_hints merged: caller osl kept, priority overlaid assert nvext["agent_hints"] == {"osl": 4, "priority": 9} # unrelated caller nvext keys survive From 28e3d023ce8cf9562d078f644a6933572ffcee2e Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 11:30:57 -0700 Subject: [PATCH 12/21] fix(client): drop duplicate routed_experts request; normalize parsed payload to contract --- renderers/client.py | 51 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index bb0662e..526eef9 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -324,9 +324,12 @@ def _build_body( nvext: dict[str, Any] = dict(sp.get("nvext") or {}) nvext["token_data"] = list(prompt_ids) extra_fields = list(nvext.get("extra_fields") or []) - for field in ("engine_data", "routed_experts"): - if field not in extra_fields: - extra_fields.append(field) + # Only request "engine_data": the worker nests routed_experts inside it + # (engine_data.routed_experts), so also requesting the dedicated + # "routed_experts" field would duplicate the (large) base64 blob on the + # wire — once under engine_data and once promoted at the top level. + if "engine_data" not in extra_fields: + extra_fields.append("engine_data") nvext["extra_fields"] = extra_fields # ``routed_experts_prompt_start`` is the multi-turn replay offset. Dynamo # exposes no top-level field for it, so thread it through @@ -418,14 +421,15 @@ def parse(self, data: dict[str, Any]) -> _WireResult: f"completion token count ({len(completion_ids)})." ) - # routed_experts: prefer the dedicated nvext.routed_experts field - # (extra_fields=["routed_experts"]); fall back to the engine_data - # passthrough (extra_fields=["engine_data"]) since the worker also nests - # it there. The Dynamo vLLM worker already emits the prime-rl contract - # shape {data(base64), shape, start, dtype}, so pass it through unchanged. + # routed_experts: read the dedicated nvext.routed_experts field if a + # caller opted into it, else the engine_data passthrough where the + # worker nests it (engine_data.routed_experts). Normalize to the + # {data, shape, start, dtype} contract so a malformed/legacy payload + # fails here with context instead of deep in trajectory processing. routed_experts = nvext.get("routed_experts") if routed_experts is None: routed_experts = engine.get("routed_experts") + routed_experts = _normalize_routed_experts(routed_experts) return _WireResult( completion_ids=completion_ids, @@ -442,6 +446,37 @@ def parse(self, data: dict[str, Any]) -> _WireResult: } +def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: + """Validate/normalize a dynamo_chat routed_experts payload to the + ``{data, shape, start, dtype}`` contract. + + Defaults ``start=0`` and ``dtype="uint8"`` for back-compat with payloads + serialized before those fields existed; raises a clear ``RuntimeError`` for + a non-contract payload (string/map, wrong rank) so the failure surfaces here + with context instead of as a ``TypeError``/``KeyError`` in trajectory + processing. + """ + if payload is None: + return None + if not isinstance(payload, Mapping) or "data" not in payload or "shape" not in payload: + raise RuntimeError( + "dynamo_chat routed_experts must be a mapping with 'data' and " + f"'shape'; got {type(payload).__name__}" + ) + shape = payload["shape"] + if not (isinstance(shape, (list, tuple)) and len(shape) == 3): + raise RuntimeError( + "dynamo_chat routed_experts 'shape' must be 3-D [seq, layers, topk]; " + f"got {shape!r}" + ) + return { + "data": payload["data"], + "shape": [int(d) for d in shape], + "start": int(payload.get("start", 0)), + "dtype": payload.get("dtype", "uint8"), + } + + async def _maybe_offload(renderer: Renderer | RendererPool, fn): """Run sync renderer work on a thread iff ``renderer`` is a pool. From c31854f60184d796421688d11850d3421216f053 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 11:31:54 -0700 Subject: [PATCH 13/21] test(client): update dynamo extra_fields expectations to engine_data only --- tests/test_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 396b596..d135303 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -338,7 +338,7 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "stream": False, "nvext": { "token_data": [1, 2, 3], - "extra_fields": ["engine_data", "routed_experts"], + "extra_fields": ["engine_data"], "cache_salt": "ckpt-42", "agent_hints": {"priority": 17}, }, @@ -521,8 +521,8 @@ def test_dynamo_transport_merges_caller_nvext(): ) nvext = client.calls[0]["body"]["nvext"] assert nvext["token_data"] == [1, 2, 3] - # extra_fields union preserves caller "timing" + our "engine_data"/"routed_experts" - assert nvext["extra_fields"] == ["timing", "engine_data", "routed_experts"] + # extra_fields union preserves caller "timing" + our "engine_data" + assert nvext["extra_fields"] == ["timing", "engine_data"] # agent_hints merged: caller osl kept, priority overlaid assert nvext["agent_hints"] == {"osl": 4, "priority": 9} # unrelated caller nvext keys survive From 59553dae6cfae7ef5da00ace3b4616e7ca9cff79 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:14:34 -0700 Subject: [PATCH 14/21] fix(client): stamp routed_experts.start on dynamo_chat from prompt offset --- renderers/client.py | 52 ++++++++++++++++++++++++++++++++------------ tests/test_client.py | 37 +++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 526eef9..e9ca48c 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -298,7 +298,16 @@ async def post( if extra_headers: post_kwargs["options"] = cast(Any, {"headers": extra_headers}) # Engine 4xx propagate raw (matches the vLLM path). - return await client.post("/chat/completions", **post_kwargs) + resp = await client.post("/chat/completions", **post_kwargs) + # Dynamo's NvExt request schema carries no field for + # routed_experts_prompt_start, so the worker cannot trim the prompt + # rows: it returns full-sequence routing with start=0. Stamp the + # intended trim offset (the prompt we sent) onto the payload here so the + # consumer aligns the completion. NOTE: this assumes the worker did NOT + # trim; if a forwarded prompt-start field is later added to NvExt (so the + # worker trims and reports the offset itself), drop this stamping. + _stamp_dynamo_routed_experts_start(resp, prompt_ids, sp) + return resp @staticmethod def _build_body( @@ -331,19 +340,6 @@ def _build_body( if "engine_data" not in extra_fields: extra_fields.append("engine_data") nvext["extra_fields"] = extra_fields - # ``routed_experts_prompt_start`` is the multi-turn replay offset. Dynamo - # exposes no top-level field for it, so thread it through - # ``nvext.extra_args.sampling_options`` where the worker's - # ``build_sampling_params`` applies it to vLLM ``SamplingParams`` — vLLM - # then trims the leading prompt rows and the worker echoes the value back - # as ``routed_experts.start`` for completion alignment. - prompt_start = sp.get("routed_experts_prompt_start") - if prompt_start is not None: - extra_args = dict(nvext.get("extra_args") or {}) - sampling_options = dict(extra_args.get("sampling_options") or {}) - sampling_options["routed_experts_prompt_start"] = prompt_start - extra_args["sampling_options"] = sampling_options - nvext["extra_args"] = extra_args if cache_salt is not None: nvext["cache_salt"] = cache_salt if priority is not None: @@ -477,6 +473,34 @@ def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: } +def _stamp_dynamo_routed_experts_start( + resp: Any, prompt_ids: list[int], sp: dict[str, Any] +) -> None: + """Set ``routed_experts.start`` on a dynamo_chat response in place. + + The worker returns full-sequence routing with ``start=0`` (it can't trim — + NvExt carries no ``routed_experts_prompt_start``). Stamp the intended trim + offset so the consumer aligns the completion: the caller's + ``routed_experts_prompt_start`` if set, else ``len(prompt_ids) - 1`` (where + the completion's routing begins). No-op when routed_experts is absent. + """ + if not isinstance(resp, Mapping): + return + nvext = resp.get("nvext") + if not isinstance(nvext, Mapping): + return + routed = nvext.get("routed_experts") + if not isinstance(routed, dict): + engine = nvext.get("engine_data") + routed = engine.get("routed_experts") if isinstance(engine, Mapping) else None + if not isinstance(routed, dict): + return + start = sp.get("routed_experts_prompt_start") + if start is None: + start = max(0, len(prompt_ids) - 1) + routed["start"] = int(start) + + async def _maybe_offload(renderer: Renderer | RendererPool, fn): """Run sync renderer work on a thread iff ``renderer`` is a pool. diff --git a/tests/test_client.py b/tests/test_client.py index d135303..e7e8f5d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -471,8 +471,8 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): "guided_json": {"type": "object"}, # denylisted — must NOT hit the wire "return_token_ids": True, - # renderer-internal: forwarded via nvext.extra_args (below), - # never as a top-level chat field + # renderer-internal: never a wire field (stamped onto the + # response's routed_experts.start instead) "routed_experts_prompt_start": 3, }, transport="dynamo_chat", @@ -485,14 +485,33 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): assert body["stop"] == [""] assert body["guided_json"] == {"type": "object"} assert "return_token_ids" not in body - # routed_experts_prompt_start is not a top-level field; it rides - # nvext.extra_args.sampling_options so the worker applies it to vLLM - # SamplingParams (trims the prompt rows) and echoes it back as start. + # routed_experts_prompt_start is renderer-internal: Dynamo has no nvext + # field to forward it to the worker, so the renderer stamps it as + # routed_experts.start on the response instead. It must never hit the wire. assert "routed_experts_prompt_start" not in body - assert ( - body["nvext"]["extra_args"]["sampling_options"]["routed_experts_prompt_start"] - == 3 - ) + assert "extra_args" not in body.get("nvext", {}) + + +def test_stamp_dynamo_routed_experts_start(): + """The dynamo transport stamps routed_experts.start (the worker returns + full routing with start=0): caller's routed_experts_prompt_start wins, + else prompt_len-1; absent routed_experts is a no-op.""" + from renderers.client import _stamp_dynamo_routed_experts_start + + # caller-provided prompt_start wins (engine_data channel) + resp = {"nvext": {"engine_data": {"routed_experts": {"data": "x", "shape": [5, 1, 1], "start": 0}}}} + _stamp_dynamo_routed_experts_start(resp, [1, 2, 3, 4], {"routed_experts_prompt_start": 3}) + assert resp["nvext"]["engine_data"]["routed_experts"]["start"] == 3 + + # fallback to prompt_len - 1 (top-level routed_experts channel) + resp2 = {"nvext": {"routed_experts": {"data": "x", "shape": [5, 1, 1], "start": 0}}} + _stamp_dynamo_routed_experts_start(resp2, [1, 2, 3, 4], {}) + assert resp2["nvext"]["routed_experts"]["start"] == 3 + + # no routed_experts -> no-op + resp3 = {"nvext": {"engine_data": {}}} + _stamp_dynamo_routed_experts_start(resp3, [1, 2], {}) + assert resp3 == {"nvext": {"engine_data": {}}} def test_dynamo_transport_merges_caller_nvext(): From b7927f8319b54a8aac98f331c8f3cd02aae0b3aa Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:15:31 -0700 Subject: [PATCH 15/21] test(client): expect stamped routed_experts.start on dynamo_chat --- tests/test_client.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index e7e8f5d..bc5ff5b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -353,12 +353,13 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): "detokenize": False, } assert result["completion_ids"] == [7, 8] - # routed_experts now surfaces on dynamo_chat via nvext.routed_experts, - # passed through as the prime-rl {data, shape, start, dtype} contract. + # routed_experts surfaces on dynamo_chat as the {data, shape, start, dtype} + # contract. start is stamped by the renderer from the prompt offset + # (no routed_experts_prompt_start set here -> prompt_len - 1 = 2). assert result["routed_experts"] == { "data": "AQI=", "shape": [2, 1, 1], - "start": 0, + "start": 2, "dtype": "uint8", } From b554520a0c18cad49dd0e4f3cd8781499b4f7a38 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:23:52 -0700 Subject: [PATCH 16/21] fix(client): trim dynamo_chat routed_experts rows to start (was stamping only) --- renderers/client.py | 63 +++++++++++++++++++++++++++------------ tests/test_client.py | 70 ++++++++++++++++++++++++++++---------------- 2 files changed, 89 insertions(+), 44 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index e9ca48c..2e5776c 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -25,6 +25,7 @@ from __future__ import annotations import asyncio +import base64 import json import logging from abc import ABC, abstractmethod @@ -300,13 +301,12 @@ async def post( # Engine 4xx propagate raw (matches the vLLM path). resp = await client.post("/chat/completions", **post_kwargs) # Dynamo's NvExt request schema carries no field for - # routed_experts_prompt_start, so the worker cannot trim the prompt - # rows: it returns full-sequence routing with start=0. Stamp the - # intended trim offset (the prompt we sent) onto the payload here so the - # consumer aligns the completion. NOTE: this assumes the worker did NOT - # trim; if a forwarded prompt-start field is later added to NvExt (so the - # worker trims and reports the offset itself), drop this stamping. - _stamp_dynamo_routed_experts_start(resp, prompt_ids, sp) + # routed_experts_prompt_start, so the worker can't trim the prompt rows: + # it returns full-sequence routing with start=0. Trim the leading prompt + # rows here and set start, so the payload matches the consumer contract + # (row 0 == position start). NOTE: if a forwarded prompt-start field is + # later added to NvExt (worker trims + reports start), drop this. + _trim_dynamo_routed_experts(resp, prompt_ids, sp) return resp @staticmethod @@ -473,16 +473,25 @@ def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: } -def _stamp_dynamo_routed_experts_start( +_ROUTED_EXPERTS_ITEMSIZE = {"uint8": 1, "uint16": 2, "int16": 2, "int32": 4} + + +def _trim_dynamo_routed_experts( resp: Any, prompt_ids: list[int], sp: dict[str, Any] ) -> None: - """Set ``routed_experts.start`` on a dynamo_chat response in place. - - The worker returns full-sequence routing with ``start=0`` (it can't trim — - NvExt carries no ``routed_experts_prompt_start``). Stamp the intended trim - offset so the consumer aligns the completion: the caller's - ``routed_experts_prompt_start`` if set, else ``len(prompt_ids) - 1`` (where - the completion's routing begins). No-op when routed_experts is absent. + """Trim a dynamo_chat routed_experts payload to begin at ``start``, in place. + + The Dynamo worker returns FULL-sequence routing with ``start=0`` because + NvExt carries no field to forward ``routed_experts_prompt_start`` for + worker-side trimming. The consumer contract is that row 0 of the payload is + the row at ``start`` (the vllm_generate path has vLLM trim internally), so + we drop the leading prompt rows here and set ``start`` to the offset: + the caller's ``routed_experts_prompt_start`` if set, else ``prefix_len - 1`` + (the boundary row that produces the first completion token). No-op when + routed_experts is absent/empty or the offset is 0. + + Worker-side trimming (avoiding full-prompt routing on the wire) is a future + optimization gated on a forwarded NvExt prompt-start field. """ if not isinstance(resp, Mapping): return @@ -495,10 +504,26 @@ def _stamp_dynamo_routed_experts_start( routed = engine.get("routed_experts") if isinstance(engine, Mapping) else None if not isinstance(routed, dict): return - start = sp.get("routed_experts_prompt_start") - if start is None: - start = max(0, len(prompt_ids) - 1) - routed["start"] = int(start) + data = routed.get("data") + shape = routed.get("shape") + if not isinstance(data, str) or not ( + isinstance(shape, (list, tuple)) and len(shape) == 3 + ): + return + + offset = sp.get("routed_experts_prompt_start") + if offset is None: + offset = len(prompt_ids) - 1 + offset = max(0, min(int(offset), int(shape[0]))) + if offset == 0: + return + + itemsize = _ROUTED_EXPERTS_ITEMSIZE.get(routed.get("dtype", "uint8"), 1) + row_size = int(shape[1]) * int(shape[2]) * itemsize + trimmed = base64.b64decode(data)[offset * row_size :] + routed["data"] = base64.b64encode(trimmed).decode("ascii") + routed["shape"] = [int(shape[0]) - offset, int(shape[1]), int(shape[2])] + routed["start"] = offset async def _maybe_offload(renderer: Renderer | RendererPool, fn): diff --git a/tests/test_client.py b/tests/test_client.py index bc5ff5b..10d331c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -298,8 +298,9 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): "nvext": { "engine_data": {"completion_token_ids": [7, 8]}, "routed_experts": { - "data": "AQI=", - "shape": [2, 1, 1], + # full-sequence routing (4 rows); worker can't trim + "data": "AQIDBA==", + "shape": [4, 1, 1], "start": 0, "dtype": "uint8", }, @@ -354,10 +355,11 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): } assert result["completion_ids"] == [7, 8] # routed_experts surfaces on dynamo_chat as the {data, shape, start, dtype} - # contract. start is stamped by the renderer from the prompt offset - # (no routed_experts_prompt_start set here -> prompt_len - 1 = 2). + # contract. The renderer trims the leading prompt rows (offset = prompt_len + # - 1 = 2 here, no routed_experts_prompt_start set): 4 rows -> last 2, with + # start=2 so row 0 is the boundary the consumer expects. assert result["routed_experts"] == { - "data": "AQI=", + "data": "AwQ=", "shape": [2, 1, 1], "start": 2, "dtype": "uint8", @@ -493,26 +495,44 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): assert "extra_args" not in body.get("nvext", {}) -def test_stamp_dynamo_routed_experts_start(): - """The dynamo transport stamps routed_experts.start (the worker returns - full routing with start=0): caller's routed_experts_prompt_start wins, - else prompt_len-1; absent routed_experts is a no-op.""" - from renderers.client import _stamp_dynamo_routed_experts_start - - # caller-provided prompt_start wins (engine_data channel) - resp = {"nvext": {"engine_data": {"routed_experts": {"data": "x", "shape": [5, 1, 1], "start": 0}}}} - _stamp_dynamo_routed_experts_start(resp, [1, 2, 3, 4], {"routed_experts_prompt_start": 3}) - assert resp["nvext"]["engine_data"]["routed_experts"]["start"] == 3 - - # fallback to prompt_len - 1 (top-level routed_experts channel) - resp2 = {"nvext": {"routed_experts": {"data": "x", "shape": [5, 1, 1], "start": 0}}} - _stamp_dynamo_routed_experts_start(resp2, [1, 2, 3, 4], {}) - assert resp2["nvext"]["routed_experts"]["start"] == 3 - - # no routed_experts -> no-op - resp3 = {"nvext": {"engine_data": {}}} - _stamp_dynamo_routed_experts_start(resp3, [1, 2], {}) - assert resp3 == {"nvext": {"engine_data": {}}} +def test_trim_dynamo_routed_experts(): + """The dynamo transport trims leading prompt rows (worker returns full + routing, start=0) and sets start: caller's routed_experts_prompt_start + wins, else prompt_len-1; offset 0 / absent routed_experts is a no-op.""" + from renderers.client import _trim_dynamo_routed_experts + + def _payload(channel): + return {"nvext": {channel: {"routed_experts": { + "data": base64.b64encode(bytes([0, 1, 2, 3, 4])).decode(), + "shape": [5, 1, 1], "start": 0, "dtype": "uint8", + }}}} if channel == "engine_data" else {"nvext": {"routed_experts": { + "data": base64.b64encode(bytes([0, 1, 2, 3, 4])).decode(), + "shape": [5, 1, 1], "start": 0, "dtype": "uint8", + }}} + + # caller-provided prompt_start=3 -> drop 3 rows, start=3 (engine_data channel) + resp = _payload("engine_data") + _trim_dynamo_routed_experts(resp, [1, 2, 3, 4], {"routed_experts_prompt_start": 3}) + re = resp["nvext"]["engine_data"]["routed_experts"] + assert re["shape"] == [2, 1, 1] and re["start"] == 3 + assert base64.b64decode(re["data"]) == bytes([3, 4]) + + # fallback prompt_len-1 = 3 (top-level routed_experts channel) + resp2 = _payload("routed_experts") + _trim_dynamo_routed_experts(resp2, [1, 2, 3, 4], {}) + re2 = resp2["nvext"]["routed_experts"] + assert re2["shape"] == [2, 1, 1] and re2["start"] == 3 + assert base64.b64decode(re2["data"]) == bytes([3, 4]) + + # offset 0 -> no-op + resp3 = _payload("engine_data") + _trim_dynamo_routed_experts(resp3, [1], {"routed_experts_prompt_start": 0}) + assert resp3["nvext"]["engine_data"]["routed_experts"]["shape"] == [5, 1, 1] + + # absent routed_experts -> no-op + resp4 = {"nvext": {"engine_data": {}}} + _trim_dynamo_routed_experts(resp4, [1, 2], {}) + assert resp4 == {"nvext": {"engine_data": {}}} def test_dynamo_transport_merges_caller_nvext(): From 010c894c1ad98194f50d1c1de99a610ea88e4765 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:29:39 -0700 Subject: [PATCH 17/21] fix(client): only trim routed_experts when caller sets prompt_start (first-turn stays full) --- renderers/client.py | 17 +++++++-------- tests/test_client.py | 50 ++++++++++++++++++++++++-------------------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 2e5776c..827989a 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -306,7 +306,7 @@ async def post( # rows here and set start, so the payload matches the consumer contract # (row 0 == position start). NOTE: if a forwarded prompt-start field is # later added to NvExt (worker trims + reports start), drop this. - _trim_dynamo_routed_experts(resp, prompt_ids, sp) + _trim_dynamo_routed_experts(resp, sp) return resp @staticmethod @@ -476,19 +476,18 @@ def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: _ROUTED_EXPERTS_ITEMSIZE = {"uint8": 1, "uint16": 2, "int16": 2, "int32": 4} -def _trim_dynamo_routed_experts( - resp: Any, prompt_ids: list[int], sp: dict[str, Any] -) -> None: +def _trim_dynamo_routed_experts(resp: Any, sp: dict[str, Any]) -> None: """Trim a dynamo_chat routed_experts payload to begin at ``start``, in place. The Dynamo worker returns FULL-sequence routing with ``start=0`` because NvExt carries no field to forward ``routed_experts_prompt_start`` for worker-side trimming. The consumer contract is that row 0 of the payload is the row at ``start`` (the vllm_generate path has vLLM trim internally), so - we drop the leading prompt rows here and set ``start`` to the offset: - the caller's ``routed_experts_prompt_start`` if set, else ``prefix_len - 1`` - (the boundary row that produces the first completion token). No-op when - routed_experts is absent/empty or the offset is 0. + when the caller explicitly supplies ``routed_experts_prompt_start`` we drop + that many leading rows and set ``start`` to it. No-op when routed_experts is + absent/empty, the offset is 0, or no offset is supplied — a first-turn + request with no caller start keeps full-sequence routing with ``start=0`` + rather than claiming a prefix the consumer has no state for. Worker-side trimming (avoiding full-prompt routing on the wire) is a future optimization gated on a forwarded NvExt prompt-start field. @@ -513,7 +512,7 @@ def _trim_dynamo_routed_experts( offset = sp.get("routed_experts_prompt_start") if offset is None: - offset = len(prompt_ids) - 1 + return offset = max(0, min(int(offset), int(shape[0]))) if offset == 0: return diff --git a/tests/test_client.py b/tests/test_client.py index 10d331c..2034e18 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -355,13 +355,13 @@ def test_dynamo_transport_forwards_priority_and_detokenize(): } assert result["completion_ids"] == [7, 8] # routed_experts surfaces on dynamo_chat as the {data, shape, start, dtype} - # contract. The renderer trims the leading prompt rows (offset = prompt_len - # - 1 = 2 here, no routed_experts_prompt_start set): 4 rows -> last 2, with - # start=2 so row 0 is the boundary the consumer expects. + # contract. No routed_experts_prompt_start is set here (first-turn case), so + # the renderer does NOT trim — full-sequence routing passes through with + # start=0. assert result["routed_experts"] == { - "data": "AwQ=", - "shape": [2, 1, 1], - "start": 2, + "data": "AQIDBA==", + "shape": [4, 1, 1], + "start": 0, "dtype": "uint8", } @@ -496,42 +496,46 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): def test_trim_dynamo_routed_experts(): - """The dynamo transport trims leading prompt rows (worker returns full - routing, start=0) and sets start: caller's routed_experts_prompt_start - wins, else prompt_len-1; offset 0 / absent routed_experts is a no-op.""" + """The dynamo transport trims leading prompt rows ONLY when the caller + supplies routed_experts_prompt_start (worker returns full routing, start=0). + Absent start (first turn), offset 0, or absent routed_experts are no-ops.""" from renderers.client import _trim_dynamo_routed_experts def _payload(channel): - return {"nvext": {channel: {"routed_experts": { + re = { "data": base64.b64encode(bytes([0, 1, 2, 3, 4])).decode(), "shape": [5, 1, 1], "start": 0, "dtype": "uint8", - }}}} if channel == "engine_data" else {"nvext": {"routed_experts": { - "data": base64.b64encode(bytes([0, 1, 2, 3, 4])).decode(), - "shape": [5, 1, 1], "start": 0, "dtype": "uint8", - }}} + } + return {"nvext": {channel: {"routed_experts": re}} if channel == "engine_data" + else {"routed_experts": re}} - # caller-provided prompt_start=3 -> drop 3 rows, start=3 (engine_data channel) + # explicit prompt_start=3 -> drop 3 rows, start=3 (engine_data channel) resp = _payload("engine_data") - _trim_dynamo_routed_experts(resp, [1, 2, 3, 4], {"routed_experts_prompt_start": 3}) + _trim_dynamo_routed_experts(resp, {"routed_experts_prompt_start": 3}) re = resp["nvext"]["engine_data"]["routed_experts"] assert re["shape"] == [2, 1, 1] and re["start"] == 3 assert base64.b64decode(re["data"]) == bytes([3, 4]) - # fallback prompt_len-1 = 3 (top-level routed_experts channel) + # explicit prompt_start=3 (top-level routed_experts channel) resp2 = _payload("routed_experts") - _trim_dynamo_routed_experts(resp2, [1, 2, 3, 4], {}) + _trim_dynamo_routed_experts(resp2, {"routed_experts_prompt_start": 3}) re2 = resp2["nvext"]["routed_experts"] assert re2["shape"] == [2, 1, 1] and re2["start"] == 3 - assert base64.b64decode(re2["data"]) == bytes([3, 4]) - # offset 0 -> no-op + # absent start (first turn) -> NO trim, full-sequence with start=0 resp3 = _payload("engine_data") - _trim_dynamo_routed_experts(resp3, [1], {"routed_experts_prompt_start": 0}) - assert resp3["nvext"]["engine_data"]["routed_experts"]["shape"] == [5, 1, 1] + _trim_dynamo_routed_experts(resp3, {}) + re3 = resp3["nvext"]["engine_data"]["routed_experts"] + assert re3["shape"] == [5, 1, 1] and re3["start"] == 0 + + # offset 0 -> no-op + resp0 = _payload("engine_data") + _trim_dynamo_routed_experts(resp0, {"routed_experts_prompt_start": 0}) + assert resp0["nvext"]["engine_data"]["routed_experts"]["shape"] == [5, 1, 1] # absent routed_experts -> no-op resp4 = {"nvext": {"engine_data": {}}} - _trim_dynamo_routed_experts(resp4, [1, 2], {}) + _trim_dynamo_routed_experts(resp4, {"routed_experts_prompt_start": 3}) assert resp4 == {"nvext": {"engine_data": {}}} From 51d9154958505533e124acbe22de09f1791992c8 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:34:57 -0700 Subject: [PATCH 18/21] fix(client): prefer engine_data.completion_logprobs to stay aligned with engine ids --- renderers/client.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 827989a..fd4b3ae 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -405,10 +405,15 @@ def parse(self, data: dict[str, Any]) -> _WireResult: ) completion_ids = list(completion_ids or []) - logprobs = _flatten_chat_logprobs(choice) + # Prefer engine_data.completion_logprobs — the same authoritative source + # as the engine completion_token_ids used above — so logprobs stay + # positionally aligned with the ids. The choices[0] chat logprobs are a + # detokenize/retokenize echo that can diverge from the sampled ids, which + # would misalign while still passing the length check below. Fall back to + # the chat logprobs only when the engine channel carries none. + logprobs = [float(x) for x in engine.get("completion_logprobs") or []] if not logprobs: - # Dynamo-only fallback when chat logprobs are absent. - logprobs = [float(x) for x in engine.get("completion_logprobs") or []] + logprobs = _flatten_chat_logprobs(choice) # Logprobs are indexed positionally against completion_ids downstream; # a length mismatch would silently misalign tokens and logprobs. if logprobs and len(logprobs) != len(completion_ids): From 5f2a91485681c75e766da0ef0b128156b759ae23 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 12:40:14 -0700 Subject: [PATCH 19/21] fix(client): treat present-empty engine logprobs as authoritative (no chat fallback) --- renderers/client.py | 10 +++++++--- tests/test_client.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index fd4b3ae..56d23a8 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -410,9 +410,13 @@ def parse(self, data: dict[str, Any]) -> _WireResult: # positionally aligned with the ids. The choices[0] chat logprobs are a # detokenize/retokenize echo that can diverge from the sampled ids, which # would misalign while still passing the length check below. Fall back to - # the chat logprobs only when the engine channel carries none. - logprobs = [float(x) for x in engine.get("completion_logprobs") or []] - if not logprobs: + # the chat logprobs only when the engine channel is absent — a present + # but empty engine list is authoritative (logprobs off) and must NOT + # fall through to the chat echo (distinguish presence from truthiness). + engine_logprobs = engine.get("completion_logprobs") + if engine_logprobs is not None: + logprobs = [float(x) for x in engine_logprobs] + else: logprobs = _flatten_chat_logprobs(choice) # Logprobs are indexed positionally against completion_ids downstream; # a length mismatch would silently misalign tokens and logprobs. diff --git a/tests/test_client.py b/tests/test_client.py index 2034e18..d292887 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -539,6 +539,31 @@ def _payload(channel): assert resp4 == {"nvext": {"engine_data": {}}} +def test_dynamo_parse_present_empty_engine_logprobs_does_not_fall_back_to_chat(): + """A present-but-empty engine_data.completion_logprobs is authoritative + (logprobs off): parse must NOT fall through to the divergent chat echo.""" + data = { + "choices": [ + { + # chat-echo logprobs (would mismatch the engine ids) + "logprobs": {"content": [{"logprob": -9.9}, {"logprob": -8.8}]}, + "finish_reason": "stop", + } + ], + "nvext": { + "engine_data": { + "completion_token_ids": [7, 8], + "completion_logprobs": [], + } + }, + } + from renderers.client import _TRANSPORTS + + wire = _TRANSPORTS["dynamo_chat"].parse(data) + assert wire.completion_ids == [7, 8] + assert wire.completion_logprobs == [] # engine present-empty wins over chat echo + + def test_dynamo_transport_merges_caller_nvext(): """F2: caller-supplied nvext is merged — extra_fields union with engine_data, agent_hints merged with priority, unrelated caller keys preserved.""" From 75673773bcd44e214a4038ec23f78638e4249d00 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 16:59:56 -0700 Subject: [PATCH 20/21] feat(client): send routed_experts_prompt_start in nvext; client-side trim is now a back-compat fallback --- renderers/client.py | 45 ++++++++++++++++++++++++++------------------ tests/test_client.py | 29 ++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 56d23a8..ee10f86 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -127,10 +127,11 @@ async def _resolve_max_prompt_len(client: AsyncOpenAI, model: str) -> int | None RendererTransport = Literal["vllm_generate", "dynamo_chat"] # Keys never forwarded to Dynamo at the top level: vLLM/prime-only fields its -# strict validator rejects (mirrors the token client's drop set), the -# renderer-internal ``routed_experts_prompt_start`` parse hint (never a wire -# field), and ``priority`` (routed into nvext.agent_hints instead). ``max_tokens`` -# and ``nvext`` are handled explicitly and skipped separately. +# strict validator rejects (mirrors the token client's drop set). ``priority`` +# is routed into nvext.agent_hints and ``routed_experts_prompt_start`` into +# nvext.routed_experts_prompt_start instead (the worker applies the latter to +# SamplingParams so vLLM trims routing engine-side). ``max_tokens`` and ``nvext`` +# are handled explicitly and skipped separately. _DYNAMO_DROP_KEYS = frozenset( { "return_token_ids", @@ -346,6 +347,13 @@ def _build_body( agent_hints = dict(nvext.get("agent_hints") or {}) agent_hints["priority"] = priority nvext["agent_hints"] = agent_hints + # routed_experts_prompt_start rides nvext (Dynamo rejects unknown + # top-level chat fields). The worker applies it to SamplingParams so vLLM + # trims the leading prompt rows engine-side and stamps the payload's + # `start` — the client-side trim then no-ops (see _trim_dynamo_routed_experts). + reps = sp.get("routed_experts_prompt_start") + if reps is not None: + nvext["routed_experts_prompt_start"] = reps # messages is a placeholder stub the OpenAI schema requires but Dynamo # ignores. tools are baked into token_data; forwarding the renderer @@ -486,20 +494,18 @@ def _normalize_routed_experts(payload: Any) -> dict[str, Any] | None: def _trim_dynamo_routed_experts(resp: Any, sp: dict[str, Any]) -> None: - """Trim a dynamo_chat routed_experts payload to begin at ``start``, in place. - - The Dynamo worker returns FULL-sequence routing with ``start=0`` because - NvExt carries no field to forward ``routed_experts_prompt_start`` for - worker-side trimming. The consumer contract is that row 0 of the payload is - the row at ``start`` (the vllm_generate path has vLLM trim internally), so - when the caller explicitly supplies ``routed_experts_prompt_start`` we drop - that many leading rows and set ``start`` to it. No-op when routed_experts is - absent/empty, the offset is 0, or no offset is supplied — a first-turn - request with no caller start keeps full-sequence routing with ``start=0`` - rather than claiming a prefix the consumer has no state for. - - Worker-side trimming (avoiding full-prompt routing on the wire) is a future - optimization gated on a forwarded NvExt prompt-start field. + """Client-side trim of a dynamo_chat routed_experts payload, in place. + + This is a **back-compat fallback**. The renderer now forwards + ``routed_experts_prompt_start`` via ``nvext`` (see ``_build_body``), so a + current Dynamo worker trims the leading prompt rows engine-side (vLLM) and + stamps the payload's ``start`` > 0 — in which case this is a no-op. We only + trim here when the worker returned FULL-sequence routing with ``start == 0`` + (an older worker that ignored the nvext field) and the caller supplied a + positive ``routed_experts_prompt_start``: drop that many leading rows and set + ``start``. No-op when routed_experts is absent/empty, the worker already + trimmed (``start`` > 0), or no positive offset is supplied (first-turn + requests keep full-sequence routing with ``start=0``). """ if not isinstance(resp, Mapping): return @@ -522,6 +528,9 @@ def _trim_dynamo_routed_experts(resp: Any, sp: dict[str, Any]) -> None: offset = sp.get("routed_experts_prompt_start") if offset is None: return + # Worker already trimmed engine-side (stamped start > 0) — don't double-trim. + if int(routed.get("start") or 0) != 0: + return offset = max(0, min(int(offset), int(shape[0]))) if offset == 0: return diff --git a/tests/test_client.py b/tests/test_client.py index d292887..76c57d5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -474,8 +474,9 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): "guided_json": {"type": "object"}, # denylisted — must NOT hit the wire "return_token_ids": True, - # renderer-internal: never a wire field (stamped onto the - # response's routed_experts.start instead) + # not a top-level field (Dynamo rejects unknown ones); routed + # into nvext.routed_experts_prompt_start so the worker trims + # routing engine-side "routed_experts_prompt_start": 3, }, transport="dynamo_chat", @@ -488,17 +489,19 @@ def test_dynamo_transport_forwards_extra_sampling_fields_and_drops_denylist(): assert body["stop"] == [""] assert body["guided_json"] == {"type": "object"} assert "return_token_ids" not in body - # routed_experts_prompt_start is renderer-internal: Dynamo has no nvext - # field to forward it to the worker, so the renderer stamps it as - # routed_experts.start on the response instead. It must never hit the wire. + # routed_experts_prompt_start is dropped from the top level (Dynamo rejects + # unknown top-level fields) but routed into nvext so the worker applies it to + # SamplingParams and trims routing engine-side. assert "routed_experts_prompt_start" not in body + assert body["nvext"]["routed_experts_prompt_start"] == 3 assert "extra_args" not in body.get("nvext", {}) def test_trim_dynamo_routed_experts(): - """The dynamo transport trims leading prompt rows ONLY when the caller - supplies routed_experts_prompt_start (worker returns full routing, start=0). - Absent start (first turn), offset 0, or absent routed_experts are no-ops.""" + """Client-side trim is a back-compat fallback: it trims ONLY when the worker + returned full routing (start=0) AND the caller supplied a positive + routed_experts_prompt_start. No-op when the worker already trimmed (start>0), + no offset is supplied (first turn), offset is 0, or routed_experts is absent.""" from renderers.client import _trim_dynamo_routed_experts def _payload(channel): @@ -522,6 +525,16 @@ def _payload(channel): re2 = resp2["nvext"]["routed_experts"] assert re2["shape"] == [2, 1, 1] and re2["start"] == 3 + # worker already trimmed engine-side (start>0) -> no-op (don't double-trim) + resp_wt = {"nvext": {"engine_data": {"routed_experts": { + "data": base64.b64encode(bytes([3, 4])).decode(), + "shape": [2, 1, 1], "start": 3, "dtype": "uint8", + }}}} + _trim_dynamo_routed_experts(resp_wt, {"routed_experts_prompt_start": 3}) + rewt = resp_wt["nvext"]["engine_data"]["routed_experts"] + assert rewt["shape"] == [2, 1, 1] and rewt["start"] == 3 + assert base64.b64decode(rewt["data"]) == bytes([3, 4]) + # absent start (first turn) -> NO trim, full-sequence with start=0 resp3 = _payload("engine_data") _trim_dynamo_routed_experts(resp3, {}) From f5c480d9cd22cd180640097d05fba130d3ac9475 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Wed, 10 Jun 2026 17:30:49 -0700 Subject: [PATCH 21/21] docs(client): drop PR-number references and stale vLLM version from comments --- renderers/client.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index ee10f86..06bd5ee 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -1,4 +1,4 @@ -"""Renderer-based generate client for vLLM 0.20 + Dynamo. +"""Renderer-based generate client for vLLM + Dynamo. Two transports, selected per-call via ``transport=`` parameter: @@ -14,7 +14,7 @@ → Renderer.parse_response() → structured message Dynamo has no ``/inference/v1/generate`` route; this branch posts to the standard OpenAI chat-completions surface and reads the engine - token IDs back via the PR #8119 ``nvext.engine_data`` channel. + token IDs back via the ``nvext.engine_data`` channel. When a RendererPool is passed instead of a single Renderer, the sync tokenization and parsing work is offloaded to threads for parallel execution across rollouts. @@ -202,7 +202,7 @@ def parse(self, data: dict[str, Any]) -> _WireResult: class _VllmGenerateTransport(_Transport): - """vLLM 0.20 TITO surface: ``POST /inference/v1/generate``.""" + """vLLM TITO surface: ``POST /inference/v1/generate``.""" async def post( self, @@ -264,8 +264,7 @@ class _DynamoChatTransport(_Transport): Dynamo has no /inference/v1/generate route. ``nvext.token_data`` carries the pre-tokenized prompt (Dynamo skips tokenization when present) and ``extra_fields=["engine_data", "routed_experts"]`` opts into the - completion-IDs/logprobs channel (PR #8119) and the MoE routed_experts - channel. Mirrors + completion-IDs/logprobs channel and the MoE routed_experts channel. Mirrors ``verifiers.clients.openai_chat_completions_token_client._post_dynamo_chat`` so the wire payload is identical via either client. routed_experts is read from ``nvext.routed_experts`` (or ``nvext.engine_data.routed_experts``) and @@ -390,8 +389,8 @@ def parse(self, data: dict[str, Any]) -> _WireResult: nvext = data.get("nvext") or {} engine = nvext.get("engine_data") or {} - # Canonical Dynamo channel first (PR #8119: nvext.engine_data, then - # top-level nvext), then the OpenAI-extended choices[0].token_ids. The + # Canonical Dynamo channel first (nvext.engine_data, then top-level + # nvext), then the OpenAI-extended choices[0].token_ids. The # engine channel is authoritative — choices[0].token_ids may be a # detokenize-then-retokenize echo that differs from what was sampled. completion_ids = None @@ -749,7 +748,7 @@ def _build_mm_features( model-family specific. For now we dispatch on the renderer class; extend the dispatch table as more multimodal renderers land. - NOTE — future engine pluggability: this encoder is vLLM 0.20-specific + NOTE — future engine pluggability: this encoder is vLLM-specific (uses ``vllm.multimodal.inputs.MultiModalKwargsItems``, ``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and ``_create_qwen2vl_field_factory``). When a second inference engine