From 1550bf02b8e412833c4e9bdc694f1f65062a1fe2 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Wed, 27 May 2026 01:59:11 +0000 Subject: [PATCH 1/8] feat(mm): make pixel_values ephemeral on the rollout path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generate() now hands back descriptor-only multi_modal_data (image_grid_thw + mm_hashes + mm_placeholders, no pixel_values). Pixels are re-attached only for the engine POST via the new materialize_pixels (cache hit, else reprocess from the message base64; grid_thw asserted), then stripped again. This keeps the env worker from retaining decoded image tensors for the life of a rollout — resident pixel memory is now bounded by the per-image cache instead of growing with turns x concurrency. Also fixes a latent bridge bug: the merge shallow-copied the mm dict but shared the inner lists, so .extend mutated previous_multi_modal_data in place and corrupted earlier trajectory steps' cumulative sets. Copy the lists. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/base.py | 4 ++ renderers/client.py | 60 ++++++++++++++++++++--- renderers/qwen35.py | 18 +++++-- renderers/qwen3_vl.py | 108 ++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 177 insertions(+), 13 deletions(-) diff --git a/renderers/base.py b/renderers/base.py index 65edf68..f7ca756 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -877,6 +877,10 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No with self.checkout() as r: return r.bridge_to_next_turn(*args, **kwargs) + def materialize_pixels(self, *args: Any, **kwargs: Any) -> "MultiModalData": + with self.checkout() as r: + return r.materialize_pixels(*args, **kwargs) + # ``mm_token_type_id_map`` (the MultimodalRenderer protocol attribute) # is set in ``__init__`` only for pools wrapping multimodal renderers; # see the comment there for why this isn't a class-level property. diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..8436c38 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -15,6 +15,7 @@ import json import logging from collections.abc import Mapping +from dataclasses import replace from typing import Any, cast import httpx @@ -248,11 +249,34 @@ def _prepare(): "token_ids": prompt_ids, "sampling_params": sp, } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None + # Multimodal: ``mm_data`` carried into the rollout is descriptor-only + # (no ``pixel_values``) so the env worker never retains decoded image + # tensors. Re-attach pixels for the POST via ``materialize_pixels`` + # (cache hit, else reprocess from the message base64), build the engine + # features, then strip pixels again so the value handed back to the + # trajectory stays descriptor-only. + def _features_and_descriptor_mm() -> ( + "tuple[dict[str, Any] | None, MultiModalData | None]" + ): + if mm_data is None or mm_data.is_empty(): + return None, mm_data + # ``materialize_pixels`` lives on multimodal renderers + the pool, not + # the base ``Renderer`` protocol; reached only when ``mm_data`` is + # non-empty, which implies a multimodal renderer. + full_mm = cast(Any, renderer).materialize_pixels(mm_data, messages) + return _build_mm_features(renderer, full_mm), _strip_pixels(mm_data) + + features, out_mm_data = await _maybe_offload( + renderer, _features_and_descriptor_mm ) + # ``prompt_attr.multi_modal_data`` aliases the original pixel-bearing + # ``mm_data``; rebind it to the stripped copy so the attribution surfaced + # to the trajectory is also descriptor-only. + if ( + prompt_attr is not None + and getattr(prompt_attr, "multi_modal_data", None) is not None + ): + prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data) if features is not None: body["features"] = features if cache_salt is not None: @@ -321,8 +345,11 @@ def _prepare(): "routed_experts": routed_experts, # The mm sidecar consumed on the request side, surfaced back so # callers can persist it on the trajectory step for downstream - # multi-turn bridging and training-sample construction. - "multi_modal_data": mm_data, + # multi-turn bridging and training-sample construction. Descriptor + # only (``pixel_values`` stripped) — the env worker keeps no decoded + # tensors; pixels are re-derived for the next-turn POST and for + # training-sample construction. + "multi_modal_data": out_mm_data, # The renderer's per-token attribution for the prompt — either # the RenderedTokens computed here via renderer.render(...) or # the one threaded in by the caller alongside prompt_ids (the @@ -334,6 +361,27 @@ def _prepare(): } +def _strip_pixels(mm_data: MultiModalData) -> MultiModalData: + """Return ``mm_data`` with ``pixel_values`` dropped from every item. + + Keeps the descriptor (``image_grid_thw`` etc.), ``mm_hashes`` and + ``mm_placeholders`` — everything needed for token alignment and for + re-deriving pixels later (POST via ``materialize_pixels``; training via + the orchestrator). The decoded pixel tensors are never retained on the + trajectory, which is what keeps env-worker memory flat across a rollout. + """ + if not mm_data.mm_items: + return mm_data + new_items = { + modality: [ + {k: v for k, v in item.items() if k != "pixel_values"} + for item in items + ] + for modality, items in mm_data.mm_items.items() + } + return replace(mm_data, mm_items=new_items) + + def _build_mm_features( renderer: Renderer | RendererPool, mm_data: MultiModalData, diff --git a/renderers/qwen35.py b/renderers/qwen35.py index abcacec..9c26dd9 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -38,6 +38,7 @@ _is_image_part, _is_video_part, _load_pil_image, + materialize_image_pixels, ) # --------------------------------------------------------------------------- @@ -194,6 +195,13 @@ def _process_image(self, part: dict[str, Any]): self._image_cache[h] = (out, num_image_tokens) return pil, out, num_image_tokens, h + def materialize_pixels( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + """Re-attach pixel_values to descriptor-only mm_data; see + :func:`materialize_image_pixels`.""" + return materialize_image_pixels(self, mm_data, messages) + @staticmethod def _content_has_media(content: Any) -> bool: """True when ``content`` is a structured list containing image / video parts.""" @@ -813,18 +821,22 @@ def flush_buf() -> None: emit_text("\n\n", -1) # Merge prev mm_data (images from earlier turns) with the new turn's. + # Copy the inner lists (not just the dict) so ``.extend`` below never + # mutates ``previous_multi_modal_data`` in place — earlier trajectory + # steps alias that object, and mutating it corrupts their per-step + # cumulative set (and the downstream delta encoding). merged_hashes: dict[str, list[str]] = ( - dict(previous_multi_modal_data.mm_hashes) + {k: list(v) for k, v in previous_multi_modal_data.mm_hashes.items()} if previous_multi_modal_data else {} ) merged_placeholders: dict[str, list[PlaceholderRange]] = ( - dict(previous_multi_modal_data.mm_placeholders) + {k: list(v) for k, v in previous_multi_modal_data.mm_placeholders.items()} if previous_multi_modal_data else {} ) merged_items: dict[str, list[dict[str, Any]]] = ( - dict(previous_multi_modal_data.mm_items) + {k: list(v) for k, v in previous_multi_modal_data.mm_items.items()} if previous_multi_modal_data else {} ) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 7287159..0bb8fe6 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -162,6 +162,96 @@ def _image_hash(pil_image) -> str: return h.hexdigest()[:32] +def _iter_image_parts(messages: "list[Any]"): + """Yield image content parts from a message list, in conversation order.""" + for msg in messages or []: + content = msg.get("content") if isinstance(msg, dict) else None + if not isinstance(content, list): + continue + for item in content: + if isinstance(item, dict) and _is_image_part(item): + yield item + + +def _grids_equal(a: Any, b: Any) -> bool: + if a is None or b is None: + return False + al = a.tolist() if hasattr(a, "tolist") else list(a) + bl = b.tolist() if hasattr(b, "tolist") else list(b) + return al == bl + + +def materialize_image_pixels( + renderer: Any, mm_data: MultiModalData, messages: "list[Any]" +) -> MultiModalData: + """Return a pixel-complete copy of ``mm_data``. + + Rollouts retain *descriptor-only* ``mm_data`` (``image_grid_thw`` + + ``mm_hashes`` + ``mm_placeholders``, no ``pixel_values``) so the env + worker never holds decoded image tensors for the life of a rollout. + Before a generate POST the pixels are re-attached here: each image item + missing ``pixel_values`` is reprocessed from its base64 in ``messages`` + via ``renderer._process_image`` (which reuses the per-image cache on a + hit), matched back by the renderer's content hash. The reconstructed + ``image_grid_thw`` is asserted equal to the descriptor's so a processor + skew can never silently change the placeholder count. + """ + from dataclasses import replace + + image_items = mm_data.mm_items.get("image") or [] + if not image_items: + return mm_data + hashes = mm_data.mm_hashes.get("image") or [] + if len(hashes) != len(image_items): + raise ValueError( + "materialize_image_pixels: mm_hashes/mm_items length mismatch " + f"({len(hashes)} vs {len(image_items)})" + ) + missing = { + hashes[i] + for i, item in enumerate(image_items) + if item.get("pixel_values") is None + } + if not missing: + return mm_data + + resolved: dict[str, dict[str, Any]] = {} + for part in _iter_image_parts(messages): + if not missing: + break + _, out, _, h = renderer._process_image(part) + if h in missing: + resolved[h] = out + missing.discard(h) + if missing: + raise ValueError( + f"materialize_image_pixels: {len(missing)} image hash(es) not " + "found in messages; cannot reconstruct pixel_values" + ) + + new_image_items: list[dict[str, Any]] = [] + for i, item in enumerate(image_items): + if item.get("pixel_values") is not None: + new_image_items.append(item) + continue + out = resolved[hashes[i]] + if not _grids_equal(out["image_grid_thw"], item.get("image_grid_thw")): + raise ValueError( + "materialize_image_pixels: reconstructed image_grid_thw " + f"{out['image_grid_thw']!r} != descriptor " + f"{item.get('image_grid_thw')!r} (processor skew?)" + ) + new_image_items.append( + { + "pixel_values": out["pixel_values"], + "image_grid_thw": out["image_grid_thw"], + } + ) + new_items = dict(mm_data.mm_items) + new_items["image"] = new_image_items + return replace(mm_data, mm_items=new_items) + + class _Emitter: """Token-stream builder with BPE-safe text buffering. @@ -433,6 +523,13 @@ def _process_image(self, part: dict[str, Any]): self._image_cache[h] = (out, num_image_tokens) return pil, out, num_image_tokens, h + def materialize_pixels( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + """Re-attach pixel_values to descriptor-only mm_data; see + :func:`materialize_image_pixels`.""" + return materialize_image_pixels(self, mm_data, messages) + def render( self, messages: list[Message], @@ -802,19 +899,22 @@ def render_media_content(content: Any) -> None: em.text("assistant\n", is_sampled=False, is_content=False) em.finalize() - # Merge prev mm_data with the new turn's items. + # Merge prev mm_data with the new turn's items. Copy the inner lists + # (not just the dict) so ``.extend`` never mutates + # ``previous_multi_modal_data`` in place — earlier trajectory steps + # alias it, and mutating it corrupts their per-step cumulative set. merged_hashes = ( - dict(previous_multi_modal_data.mm_hashes) + {k: list(v) for k, v in previous_multi_modal_data.mm_hashes.items()} if previous_multi_modal_data else {} ) merged_placeholders = ( - dict(previous_multi_modal_data.mm_placeholders) + {k: list(v) for k, v in previous_multi_modal_data.mm_placeholders.items()} if previous_multi_modal_data else {} ) merged_items = ( - dict(previous_multi_modal_data.mm_items) + {k: list(v) for k, v in previous_multi_modal_data.mm_items.items()} if previous_multi_modal_data else {} ) From 73345a25a91c1e9d731fdccabe096492a2db6ea2 Mon Sep 17 00:00:00 2001 From: Codex Date: Fri, 29 May 2026 01:33:10 +0000 Subject: [PATCH 2/8] feat(mm): hash-only/full serialization + request-scoped image memory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - generate(force_full_pixels): first attempt sends new-turn images full and prior descriptor-only images hash-only; cache-miss fallback materializes all. - _build_qwen_vl_features: descriptor-aware — encode only pixel-bearing items, emit hash-only (None kwargs slot) for the rest, scattered back to original positions so kwargs_data stays aligned with mm_hashes / mm_placeholders. - image_cache_max default 0 (processed pixels stay request-scoped) + a guard so the disabled path never pops an empty cache; RENDERERS_MM_MAX_INFLIGHT semaphore bounds concurrent payload builds. Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 134 ++++++++++++++++++++++++++++++++---------- renderers/configs.py | 13 ++-- renderers/kimi_k25.py | 7 ++- renderers/qwen35.py | 7 ++- renderers/qwen3_vl.py | 11 ++-- 5 files changed, 124 insertions(+), 48 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 8436c38..8b248ef 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -12,9 +12,11 @@ from __future__ import annotations import asyncio +import contextlib import json import logging -from collections.abc import Mapping +import os +from collections.abc import AsyncIterator, Mapping from dataclasses import replace from typing import Any, cast @@ -33,6 +35,9 @@ _request_logger = logging.getLogger("renderers.client") ROUTED_EXPERTS_DATA_PREFIX = b'"routed_experts":{"data":"' +_MM_MAX_INFLIGHT_ENV = "RENDERERS_MM_MAX_INFLIGHT" +_DEFAULT_MM_MAX_INFLIGHT = 4 +_mm_payload_semaphores: dict[tuple[int, int], asyncio.Semaphore] = {} class OverlongPromptError(Exception): @@ -122,6 +127,41 @@ async def _maybe_offload(renderer: Renderer | RendererPool, fn): return fn() +def _mm_max_inflight() -> int | None: + raw = os.getenv(_MM_MAX_INFLIGHT_ENV) + if raw is None: + return _DEFAULT_MM_MAX_INFLIGHT + try: + value = int(raw) + except ValueError: + return _DEFAULT_MM_MAX_INFLIGHT + if value < 1: + return None + return value + + +@contextlib.asynccontextmanager +async def _limit_mm_payloads(mm_data: MultiModalData | None) -> AsyncIterator[None]: + if mm_data is None or mm_data.is_empty(): + yield + return + + limit = _mm_max_inflight() + if limit is None: + yield + return + + loop = asyncio.get_running_loop() + key = (id(loop), limit) + semaphore = _mm_payload_semaphores.get(key) + if semaphore is None: + semaphore = asyncio.Semaphore(limit) + _mm_payload_semaphores[key] = semaphore + + async with semaphore: + yield + + def strip_routed_experts_data(raw: bytes) -> tuple[bytes, memoryview | None]: data_start = raw.find(ROUTED_EXPERTS_DATA_PREFIX) if data_start < 0: @@ -157,6 +197,7 @@ async def generate( priority: int | None = None, extra_headers: dict[str, str] | None = None, max_prompt_len: int | None = None, + force_full_pixels: bool = False, ) -> dict[str, Any]: """Tokenize messages, call vLLM /inference/v1/generate, parse the response. @@ -187,6 +228,15 @@ async def generate( that still slip through propagate raw — converting them into a domain error is the calling client's job (its error shape is engine-specific). + ``force_full_pixels`` selects the multimodal serialization mode. When + ``False`` (default), images that arrive descriptor-only (no + ``pixel_values`` — typically prior-turn images carried through a bridge) + are sent hash-only on the assumption the engine still has them cached, + while new-turn images (which carry ``pixel_values``) are sent in full. + When ``True``, ``materialize_pixels`` re-attaches pixels for every image + and the whole prompt is sent in full — the caller's cache-miss fallback + after a hash-only request is rejected by the engine. + Returns a dict with: request_id, prompt_ids, completion_ids, completion_logprobs, content, reasoning_content, tool_calls, finish_reason, routed_experts, multi_modal_data, prompt_attribution. @@ -260,15 +310,27 @@ def _features_and_descriptor_mm() -> ( ): if mm_data is None or mm_data.is_empty(): return None, mm_data - # ``materialize_pixels`` lives on multimodal renderers + the pool, not - # the base ``Renderer`` protocol; reached only when ``mm_data`` is - # non-empty, which implies a multimodal renderer. - full_mm = cast(Any, renderer).materialize_pixels(mm_data, messages) - return _build_mm_features(renderer, full_mm), _strip_pixels(mm_data) - - features, out_mm_data = await _maybe_offload( - renderer, _features_and_descriptor_mm - ) + # First attempt (``force_full_pixels=False``): send ``mm_data`` as-is. + # New-turn images carry ``pixel_values`` (full payload); prior-turn + # images are descriptor-only and ``_build_mm_features`` serializes them + # hash-only, assuming the engine still has them cached. + # Cache-miss fallback (``force_full_pixels=True``): re-attach pixels for + # every image via ``materialize_pixels`` (reprocessed from the message + # base64) so the whole prompt is sent in full. ``materialize_pixels`` + # lives on multimodal renderers + the pool, not the base ``Renderer`` + # protocol; reached only when ``mm_data`` is non-empty, which implies a + # multimodal renderer. + build_mm = ( + cast(Any, renderer).materialize_pixels(mm_data, messages) + if force_full_pixels + else mm_data + ) + return _build_mm_features(renderer, build_mm), _strip_pixels(mm_data) + + async with _limit_mm_payloads(mm_data): + features, out_mm_data = await _maybe_offload( + renderer, _features_and_descriptor_mm + ) # ``prompt_attr.multi_modal_data`` aliases the original pixel-bearing # ``mm_data``; rebind it to the stripped copy so the attribution surfaced # to the trajectory is also descriptor-only. @@ -439,8 +501,8 @@ def _build_qwen_vl_features( ``MultiModalKwargsItems``, base64-encodes each item, and assembles a JSON-serializable dict matching vLLM's ``MultiModalFeatures`` schema. - Returns ``None`` semantics live one level up — this helper assumes - the caller already verified ``mm_data`` is non-empty. + Returns ``None`` semantics live one level up — this helper assumes the + caller already verified ``mm_data`` is non-empty. """ try: import torch @@ -463,21 +525,29 @@ def _build_qwen_vl_features( image_items = mm_data.mm_items.get("image") or [] if image_items: - # mm_items now ship numpy arrays (the renderer is torch-free); - # convert at this vLLM-glue boundary where torch is already a - # hard dependency. - pixel_values = torch.cat( - [torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0 - ) - image_grid_thw = torch.cat( - [torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0 - ) - hf_inputs = BatchFeature( - data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} - ) - config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) - kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) - encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]] + # An item carrying ``pixel_values`` is sent as a full payload; an item + # without (descriptor-only) is sent hash-only, on the assumption that + # the engine already has it cached from an earlier turn. ``kwargs_data`` + # stays aligned with ``mm_items``: ``None`` marks a hash-only slot. + # mm_items ship numpy arrays (the renderer is torch-free); convert at + # this vLLM-glue boundary where torch is already a hard dependency. + encoded: list[Any] = [None] * len(image_items) + full_indices = [i for i, it in enumerate(image_items) if it.get("pixel_values") is not None] + if full_indices: + full_items = [image_items[i] for i in full_indices] + pixel_values = torch.cat( + [torch.as_tensor(it["pixel_values"]) for it in full_items], dim=0 + ) + image_grid_thw = torch.cat( + [torch.as_tensor(it["image_grid_thw"]) for it in full_items], dim=0 + ) + hf_inputs = BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} + ) + config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) + kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) + for idx, item in zip(full_indices, kwargs_items["image"]): + encoded[idx] = encode_mm_kwargs_item(item) out["kwargs_data"]["image"] = encoded out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) out["mm_placeholders"]["image"] = [ @@ -485,10 +555,12 @@ def _build_qwen_vl_features( for p in mm_data.mm_placeholders.get("image") or [] ] - # If kwargs_data is empty across all modalities, drop the key so vLLM - # falls back to the hash-only (cache-hit) path. Otherwise hand it the - # full payload. - if not any(out["kwargs_data"].values()): + # If no full payload was built across any modality, drop ``kwargs_data`` so + # vLLM takes the hash-only (cache-hit) path. Otherwise hand it the payload + # (with ``None`` slots for the hash-only images). + if not any( + any(item is not None for item in items) for items in out["kwargs_data"].values() + ): out["kwargs_data"] = None return out diff --git a/renderers/configs.py b/renderers/configs.py index e0098ba..145a043 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -148,9 +148,10 @@ class Qwen35RendererConfig(BaseRendererConfig): running across the entire conversation. Mirrors the chat template's ``add_vision_id`` toggle.""" - image_cache_max: int = 256 - """FIFO bound on the per-renderer image processor cache. Renderer- - internal — not a Jinja chat-template kwarg.""" + image_cache_max: int = 0 + """FIFO bound on the per-renderer image processor cache. Zero disables + caching so processed pixel buffers stay request-scoped. Renderer-internal — + not a Jinja chat-template kwarg.""" _internal_fields = frozenset({"image_cache_max"}) @@ -166,7 +167,7 @@ class Qwen36RendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 + image_cache_max: int = 0 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) @@ -180,7 +181,7 @@ class Qwen3VLRendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 + image_cache_max: int = 0 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) @@ -294,7 +295,7 @@ class KimiK25RendererConfig(BaseRendererConfig): ``thinking`` (not ``enable_thinking``) to match the upstream chat template's native variable name.""" - image_cache_max: int = 256 + image_cache_max: int = 0 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index 352a9ee..5c7bc57 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -677,9 +677,10 @@ def _process_image(self, part: dict[str, Any]): # Patch count via the processor's own calculator (matches the # model's per-patch attention count); kept for debugging. num_patches = int(img_proc.media_tokens_calculator(media_item)) - if len(self._image_cache) >= self.config.image_cache_max: - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_patches) + if self.config.image_cache_max > 0: + if len(self._image_cache) >= self.config.image_cache_max: + self._image_cache.pop(next(iter(self._image_cache))) + self._image_cache[h] = (out, num_patches) return pil, out, num_patches, h # ------------------------------------------------------------------ diff --git a/renderers/qwen35.py b/renderers/qwen35.py index 9c26dd9..518c185 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -190,9 +190,10 @@ def _process_image(self, part: dict[str, Any]): grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) + if self.config.image_cache_max > 0: + if len(self._image_cache) >= self.config.image_cache_max: + self._image_cache.pop(next(iter(self._image_cache))) + self._image_cache[h] = (out, num_image_tokens) return pil, out, num_image_tokens, h def materialize_pixels( diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 0bb8fe6..530ac90 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -516,11 +516,12 @@ def _process_image(self, part: dict[str, Any]): grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - # FIFO eviction — Python dicts preserve insertion order, so - # ``next(iter(...))`` is the oldest key. - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) + if self.config.image_cache_max > 0: + if len(self._image_cache) >= self.config.image_cache_max: + # FIFO eviction — Python dicts preserve insertion order, so + # ``next(iter(...))`` is the oldest key. + self._image_cache.pop(next(iter(self._image_cache))) + self._image_cache[h] = (out, num_image_tokens) return pil, out, num_image_tokens, h def materialize_pixels( From 6d806bbe8a232d4e2b3c69f13332b5e50d2fd7ef Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:32:37 +0000 Subject: [PATCH 3/8] feat(mm): add shared mm_store module for run-scoped offload artifacts Single source of truth for the on-disk MM-offload contract, imported by the verifiers env-worker (images), the renderers feature writer, and prime-rl (both readers): - run-scoped paths under /data/outputs/run_/assets/{images,mm_features} (run_id_from_env, run_dir, image_asset_dir, feature_asset_dir + subdir consts). - mmfile format: version-pinned feature fingerprint, mm_feature_path (+ traversal guard), mmfile_ref emit + split_mmfile_ref parse (co-located so they can't drift). - msgpack envelope build + match helpers. - sweep_stale_artifacts: mtime TTL eviction over both asset dirs (content-addressed + re-writable, so over-eviction is safe). Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/mm_store.py | 213 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 renderers/mm_store.py diff --git a/renderers/mm_store.py b/renderers/mm_store.py new file mode 100644 index 0000000..d5fe03a --- /dev/null +++ b/renderers/mm_store.py @@ -0,0 +1,213 @@ +"""Shared run-scoped artifact store for offloaded multimodal data. + +Two subsystems offload heavy multimodal data to ``/data`` during a rollout, +ship a cheap reference, and re-load it on the consumer: + +1. **Image offload** — raw images written to + ``/assets/images/.jpg`` and shipped as ``file://`` refs. +2. **MM-feature offload** — processed vLLM ``MultiModalKwargsItem`` payloads + written to + ``/assets/mm_features/v1/vllm-mmitem////.msgpack`` + and shipped as ``mmfile:v1::::`` + tuple refs. + +This module is the single source of truth for the on-disk layout, the +fingerprint, the ref strings, and the msgpack envelope. It lives in +``renderers`` because that is the lowest common dependency of both the +verifiers env-worker client (writer of images), the renderers generate client +(writer of features), and prime-rl (reader of both). The writer/reader +file-I/O halves stay in their respective consumers; only the shared contract +lives here. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +from pathlib import Path + +# Root of every run's output tree. ``/data/outputs/run_`` in prod. +RUN_OUTPUT_ROOT = Path("/data/outputs") +MM_FEATURE_ROOT_ENV = "PRIME_RL_MM_FEATURE_ROOT" + +MMFILE_PREFIX = "mmfile:v1" + +_SAFE_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_FINGERPRINT_RE = re.compile(r"^[a-f0-9]{16,64}$") +_SAFE_MM_HASH_RE = re.compile(r"^[a-f0-9]{16,128}$") + +_MM_FEATURE_SCHEMA_VERSION = 1 +_MM_FEATURE_KIND = "vllm.MultiModalKwargsItem" + +# Run-dir-relative asset subdirs, for callers that already hold a run dir. +IMAGE_ASSET_SUBDIR = Path("assets/images") +FEATURE_ASSET_SUBDIR = Path("assets/mm_features") + + +def run_id_from_env() -> str: + """Return the safe ``RUN_ID`` from the environment. + + The platform injects ``RUN_ID`` into every container (env worker, + orchestrator, inference) so the run-scoped artifact dir can be derived + consistently across pods that don't share other env. + """ + run_id = os.environ.get("RUN_ID", "").strip() + if not run_id or not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise RuntimeError("RUN_ID must be set to a safe run id before writing multimodal feature artifacts.") + return run_id + + +def run_dir(run_id: str) -> Path: + """``/run_`` (resolved root from env or ``RUN_OUTPUT_ROOT``).""" + if not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise ValueError(f"Invalid multimodal feature run id: {run_id!r}") + root = Path(os.environ.get(MM_FEATURE_ROOT_ENV, str(RUN_OUTPUT_ROOT))) + return root / f"run_{run_id}" + + +def image_asset_dir(run_id: str) -> Path: + """``/assets/images``, resolved.""" + return (run_dir(run_id) / IMAGE_ASSET_SUBDIR).resolve() + + +def feature_asset_dir(run_id: str) -> Path: + """``/assets/mm_features``, resolved. + + This is the root that ``mm_feature_path`` builds under and that the + traversal guard checks against. Identical to ``mm_feature_run_root``. + """ + return (run_dir(run_id) / FEATURE_ASSET_SUBDIR).resolve() + + +# Alias kept for symmetry with the feature-format function names; both resolve +# to ``/run_/assets/mm_features``. +mm_feature_run_root = feature_asset_dir + + +def mm_feature_fingerprint(*, family: str, spatial_merge_size: int) -> str: + import importlib.metadata + + parts = { + "schema_version": _MM_FEATURE_SCHEMA_VERSION, + "kind": _MM_FEATURE_KIND, + "family": family, + "spatial_merge_size": spatial_merge_size, + "vllm": importlib.metadata.version("vllm"), + "transformers": importlib.metadata.version("transformers"), + "torch": importlib.metadata.version("torch"), + } + raw = json.dumps(parts, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:32] + + +def mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: str) -> Path: + if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): + raise ValueError(f"Invalid multimodal feature fingerprint: {fingerprint!r}") + if modality != "image": + raise ValueError(f"Unsupported multimodal feature modality: {modality!r}") + if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): + raise ValueError(f"Invalid multimodal feature hash: {mm_hash!r}") + + root = mm_feature_run_root(run_id) + path = (root / "v1" / "vllm-mmitem" / fingerprint / modality / mm_hash[:2] / f"{mm_hash}.msgpack").resolve() + if not path.is_relative_to(root): + raise ValueError(f"Multimodal feature path escaped root: {path}") + return path + + +def mmfile_ref(*, run_id: str, fingerprint: str, modality: str, mm_hash: str) -> str: + return f"{MMFILE_PREFIX}:{run_id}:{fingerprint}:{modality}:{mm_hash}" + + +def split_mmfile_ref(ref: str) -> "tuple[str | None, str, str, str]": + """Inverse of :func:`mmfile_ref`: parse the ref SHAPE into + ``(run_id_or_None, fingerprint, modality, mm_hash)``. ``run_id`` is ``None`` + for the legacy 5-part form (the caller supplies it from its own context). + Raises ``ValueError`` on a bad prefix/version/arity. The ref field order + lives here, next to the emitter — so emit and parse can't drift apart. + Field-level validation (safe regexes, slot/fingerprint matching) is the + caller's responsibility.""" + parts = ref.split(":") + if parts[:2] != ["mmfile", "v1"] or len(parts) not in {5, 6}: + raise ValueError(f"Invalid mmfile ref shape: {ref!r}") + if len(parts) == 6: + return parts[2], parts[3], parts[4], parts[5] + return None, parts[2], parts[3], parts[4] + + +def build_mm_feature_envelope( + *, + run_id: str, + fingerprint: str, + modality: str, + mm_hash: str, + payload: bytes, + placeholder_length: int, +) -> dict: + """Envelope dict the writer packs (with the payload) into the msgpack file.""" + return { + "schema_version": _MM_FEATURE_SCHEMA_VERSION, + "kind": _MM_FEATURE_KIND, + "run_id": run_id, + "fingerprint": fingerprint, + "modality": modality, + "mm_hash": mm_hash, + "placeholder_length": int(placeholder_length), + "payload_sha256": hashlib.sha256(payload).hexdigest(), + } + + +def mm_feature_envelope_matches( + envelope: dict, + *, + run_id: str, + fingerprint: str, + modality: str, + mm_hash: str, + payload: bytes, + require_run_id: bool = True, +) -> bool: + """Validate a parsed envelope against the requested artifact identity. + + ``require_run_id=False`` mirrors the reader's tolerance for envelopes that + predate the ``run_id`` field (``envelope.get("run_id", run_id)``). + """ + envelope_run_id = envelope.get("run_id") if require_run_id else envelope.get("run_id", run_id) + return ( + envelope.get("schema_version") == _MM_FEATURE_SCHEMA_VERSION + and envelope.get("kind") == _MM_FEATURE_KIND + and envelope_run_id == run_id + and envelope.get("fingerprint") == fingerprint + and envelope.get("modality") == modality + and envelope.get("mm_hash") == mm_hash + and envelope.get("payload_sha256") == hashlib.sha256(payload).hexdigest() + ) + + +def sweep_stale_artifacts(run_dir: Path, ttl_seconds: float) -> int: + """Delete artifact files under run_dir/{assets/images, assets/mm_features} whose + mtime is older than ttl_seconds. Returns count deleted. Safe by construction: + artifacts are content-addressed and re-writable, so over-eviction just triggers + re-materialization, never corruption. No-op if the dirs don't exist; ignore + per-file errors (a file may be mid-write). Walk files only; leave dir structure.""" + import time + + cutoff = time.time() - ttl_seconds + deleted = 0 + for subdir in (IMAGE_ASSET_SUBDIR, FEATURE_ASSET_SUBDIR): + base = run_dir / subdir + if not base.is_dir(): + continue + for path in base.rglob("*"): + if not path.is_file(): + continue + try: + if path.stat().st_mtime < cutoff: + path.unlink() + deleted += 1 + except OSError: + # File may be mid-write or already gone; over/under-eviction is safe. + continue + return deleted From d6ed2248394cc41b5df9e74ad4c0f2b384601596 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Sat, 30 May 2026 22:32:44 +0000 Subject: [PATCH 4/8] feat(mm): feature offload default-on + placeholder_length self-repair - _build_qwen_vl_features writes processed vLLM features to mm_store and ships mmfile refs; import the format/paths from mm_store (no local copies). - Collapse RENDERERS_MM_FEATURE_STORE_MODE to off/on, default on (deleted the never-differentiated disk-write-through/disk-read-nonstrict/disk-strict ladder; the latter two emitted identical refs). - _existing_mm_feature_valid now also checks placeholder_length: vLLM validates it on load but the envelope match did not, so a stale wrong-length artifact would fail in vLLM and never self-repair (we kept skipping the rewrite). Mismatch -> treat as invalid -> rewrite. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 223 ++++++++++++++++++++++++++++++++----------- tests/test_client.py | 126 +++++++++++++++++++----- 2 files changed, 270 insertions(+), 79 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 8b248ef..c5844ae 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -12,12 +12,15 @@ from __future__ import annotations import asyncio +import base64 import contextlib import json import logging import os +import tempfile from collections.abc import AsyncIterator, Mapping from dataclasses import replace +from pathlib import Path from typing import Any, cast import httpx @@ -32,12 +35,21 @@ ToolCallParseStatus, ToolSpec, ) +from renderers.mm_store import ( + build_mm_feature_envelope, + mm_feature_envelope_matches, + mm_feature_fingerprint, + mm_feature_path, + mmfile_ref, + run_id_from_env, +) _request_logger = logging.getLogger("renderers.client") ROUTED_EXPERTS_DATA_PREFIX = b'"routed_experts":{"data":"' _MM_MAX_INFLIGHT_ENV = "RENDERERS_MM_MAX_INFLIGHT" _DEFAULT_MM_MAX_INFLIGHT = 4 _mm_payload_semaphores: dict[tuple[int, int], asyncio.Semaphore] = {} +_MM_FEATURE_STORE_MODE_ENV = "RENDERERS_MM_FEATURE_STORE_MODE" class OverlongPromptError(Exception): @@ -58,10 +70,7 @@ class OverlongPromptError(Exception): def __init__(self, *, prompt_len: int, max_prompt_len: int) -> None: self.prompt_len = prompt_len self.max_prompt_len = max_prompt_len - super().__init__( - f"Prompt length ({prompt_len}) exceeds maximum " - f"context length ({max_prompt_len})." - ) + super().__init__(f"Prompt length ({prompt_len}) exceeds maximum context length ({max_prompt_len}).") # Per-process cache of resolved engine context-length caps, keyed by @@ -140,6 +149,106 @@ def _mm_max_inflight() -> int | None: return value +def _mm_feature_store_mode() -> str: + """Offload processed mm-feature payloads to disk and ship ``mmfile`` refs + (``on``, the default) or inline them as base64 in the request (``off``).""" + mode = os.getenv(_MM_FEATURE_STORE_MODE_ENV, "on").strip().lower() + if mode in {"", "0", "false", "off", "disabled", "none", "no"}: + return "off" + if mode in {"1", "true", "on", "enabled", "yes"}: + return "on" + raise ValueError(f"Invalid {_MM_FEATURE_STORE_MODE_ENV}={mode!r}; expected on or off.") + + +def _fsync_dir(path: Path) -> None: + fd = os.open(path, os.O_RDONLY | getattr(os, "O_DIRECTORY", 0)) + try: + os.fsync(fd) + finally: + os.close(fd) + + +def _existing_mm_feature_valid( + path: Path, *, run_id: str, fingerprint: str, modality: str, mm_hash: str, placeholder_length: int +) -> bool: + try: + import msgpack + + with path.open("rb") as f: + packed = f.read() + artifact = msgpack.unpackb(packed, raw=False) + envelope = artifact.get("envelope") if isinstance(artifact, dict) else None + payload = artifact.get("payload") if isinstance(artifact, dict) else None + if not isinstance(envelope, dict) or not isinstance(payload, bytes): + return False + # Validate placeholder_length too: vLLM checks it on load, but the envelope + # match doesn't — so a stale artifact with the right hash/fingerprint but a + # wrong placeholder_length would fail in vLLM and never get repaired (we'd + # keep skipping the rewrite). Treat a mismatch as invalid → rewrite. + if envelope.get("placeholder_length") != int(placeholder_length): + return False + return mm_feature_envelope_matches( + envelope, + run_id=run_id, + fingerprint=fingerprint, + modality=modality, + mm_hash=mm_hash, + payload=payload, + ) + except Exception: + return False + + +def _write_mm_feature_artifact( + *, + run_id: str, + fingerprint: str, + modality: str, + mm_hash: str, + payload: bytes, + placeholder_length: int, +) -> str: + import msgpack + + path = mm_feature_path(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) + if path.exists() and _existing_mm_feature_valid( + path, + run_id=run_id, + fingerprint=fingerprint, + modality=modality, + mm_hash=mm_hash, + placeholder_length=placeholder_length, + ): + return mmfile_ref(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) + + envelope = build_mm_feature_envelope( + run_id=run_id, + fingerprint=fingerprint, + modality=modality, + mm_hash=mm_hash, + payload=payload, + placeholder_length=placeholder_length, + ) + packed = msgpack.packb({"envelope": envelope, "payload": payload}, use_bin_type=True) + + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_name = tempfile.mkstemp(prefix=f".{path.name}.", suffix=".tmp", dir=str(path.parent)) + tmp_path = Path(tmp_name) + try: + with os.fdopen(fd, "wb") as f: + f.write(packed) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + _fsync_dir(path.parent) + except Exception: + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + raise + + return mmfile_ref(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) + + @contextlib.asynccontextmanager async def _limit_mm_payloads(mm_data: MultiModalData | None) -> AsyncIterator[None]: if mm_data is None or mm_data.is_empty(): @@ -278,16 +387,12 @@ def _prepare(): rendered, ) - prompt_ids, stop_token_ids, mm_data, prompt_attr = await _maybe_offload( - renderer, _prepare - ) + prompt_ids, stop_token_ids, mm_data, prompt_attr = await _maybe_offload(renderer, _prepare) if max_prompt_len is None: max_prompt_len = await _resolve_max_prompt_len(client, model) if max_prompt_len is not None and len(prompt_ids) > max_prompt_len: - raise OverlongPromptError( - prompt_len=len(prompt_ids), max_prompt_len=max_prompt_len - ) + raise OverlongPromptError(prompt_len=len(prompt_ids), max_prompt_len=max_prompt_len) sp: dict[str, Any] = dict(sampling_params or {}) sp["stop_token_ids"] = stop_token_ids @@ -299,15 +404,14 @@ def _prepare(): "token_ids": prompt_ids, "sampling_params": sp, } + # Multimodal: ``mm_data`` carried into the rollout is descriptor-only # (no ``pixel_values``) so the env worker never retains decoded image # tensors. Re-attach pixels for the POST via ``materialize_pixels`` # (cache hit, else reprocess from the message base64), build the engine # features, then strip pixels again so the value handed back to the # trajectory stays descriptor-only. - def _features_and_descriptor_mm() -> ( - "tuple[dict[str, Any] | None, MultiModalData | None]" - ): + def _features_and_descriptor_mm() -> "tuple[dict[str, Any] | None, MultiModalData | None]": if mm_data is None or mm_data.is_empty(): return None, mm_data # First attempt (``force_full_pixels=False``): send ``mm_data`` as-is. @@ -320,24 +424,15 @@ def _features_and_descriptor_mm() -> ( # lives on multimodal renderers + the pool, not the base ``Renderer`` # protocol; reached only when ``mm_data`` is non-empty, which implies a # multimodal renderer. - build_mm = ( - cast(Any, renderer).materialize_pixels(mm_data, messages) - if force_full_pixels - else mm_data - ) + build_mm = cast(Any, renderer).materialize_pixels(mm_data, messages) if force_full_pixels else mm_data return _build_mm_features(renderer, build_mm), _strip_pixels(mm_data) async with _limit_mm_payloads(mm_data): - features, out_mm_data = await _maybe_offload( - renderer, _features_and_descriptor_mm - ) + features, out_mm_data = await _maybe_offload(renderer, _features_and_descriptor_mm) # ``prompt_attr.multi_modal_data`` aliases the original pixel-bearing # ``mm_data``; rebind it to the stripped copy so the attribution surfaced # to the trajectory is also descriptor-only. - if ( - prompt_attr is not None - and getattr(prompt_attr, "multi_modal_data", None) is not None - ): + if prompt_attr is not None and getattr(prompt_attr, "multi_modal_data", None) is not None: prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data) if features is not None: body["features"] = features @@ -369,9 +464,7 @@ def _features_and_descriptor_mm() -> ( choice = (data.get("choices") or [{}])[0] completion_ids = choice.get("token_ids") or [] - parsed = await _maybe_offload( - renderer, lambda: renderer.parse_response(completion_ids, tools=tools) - ) + parsed = await _maybe_offload(renderer, lambda: renderer.parse_response(completion_ids, tools=tools)) # ChatCompletionLogProbs flatten: {"content": [{"logprob": ...}, ...]} raw_logprobs = choice.get("logprobs") or {} @@ -389,9 +482,7 @@ def _features_and_descriptor_mm() -> ( # ``parsed.tool_calls`` so verifiers can inspect them, but they don't # trigger the tool-loop continuation. finish_reason = choice.get("finish_reason") - ok_tool_calls = [ - tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK - ] + ok_tool_calls = [tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK] if ok_tool_calls and finish_reason == "stop": finish_reason = "tool_calls" @@ -435,10 +526,7 @@ def _strip_pixels(mm_data: MultiModalData) -> MultiModalData: if not mm_data.mm_items: return mm_data new_items = { - modality: [ - {k: v for k, v in item.items() if k != "pixel_values"} - for item in items - ] + modality: [{k: v for k, v in item.items() if k != "pixel_values"} for item in items] for modality, items in mm_data.mm_items.items() } return replace(mm_data, mm_items=new_items) @@ -475,9 +563,7 @@ def _build_mm_features( # Type dispatch only needs the renderer class. Pools expose # ``renderer_cls`` as a snapshot attribute, so we don't have to check # out a slot just to read ``type(r)``. - renderer_cls = ( - renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer) - ) + renderer_cls = renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer) # Qwen3-VL and Qwen3.5 both ship ``pixel_values`` + ``image_grid_thw`` # via the shared Qwen2-VL field factory. ``spatial_merge_size=2`` is @@ -491,9 +577,7 @@ def _build_mm_features( ) -def _build_qwen_vl_features( - mm_data: MultiModalData, *, spatial_merge_size: int -) -> dict[str, Any]: +def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) -> dict[str, Any]: """vLLM features payload for the Qwen-VL family (Qwen2-VL / Qwen3-VL). Stacks per-image processor outputs back into a batched ``BatchFeature``, @@ -507,9 +591,9 @@ def _build_qwen_vl_features( try: import torch from transformers.feature_extraction_utils import BatchFeature - from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory from vllm.multimodal.inputs import MultiModalKwargsItems + from vllm.v1.serial_utils import MsgpackEncoder except ImportError as exc: raise RuntimeError( "Multimodal generate via /inference/v1/generate requires `vllm` " @@ -517,6 +601,11 @@ def _build_qwen_vl_features( "environment, or pre-build features upstream." ) from exc + mode = _mm_feature_store_mode() + run_id = run_id_from_env() if mode != "off" else "" + encoder = MsgpackEncoder(size_threshold=2**62) + fingerprint = mm_feature_fingerprint(family="qwen_vl", spatial_merge_size=spatial_merge_size) + out: dict[str, Any] = { "mm_hashes": {}, "mm_placeholders": {}, @@ -532,35 +621,55 @@ def _build_qwen_vl_features( # mm_items ship numpy arrays (the renderer is torch-free); convert at # this vLLM-glue boundary where torch is already a hard dependency. encoded: list[Any] = [None] * len(image_items) + mmfile_count = 0 + inline_count = 0 full_indices = [i for i, it in enumerate(image_items) if it.get("pixel_values") is not None] if full_indices: full_items = [image_items[i] for i in full_indices] - pixel_values = torch.cat( - [torch.as_tensor(it["pixel_values"]) for it in full_items], dim=0 - ) - image_grid_thw = torch.cat( - [torch.as_tensor(it["image_grid_thw"]) for it in full_items], dim=0 - ) - hf_inputs = BatchFeature( - data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} - ) + pixel_values = torch.cat([torch.as_tensor(it["pixel_values"]) for it in full_items], dim=0) + image_grid_thw = torch.cat([torch.as_tensor(it["image_grid_thw"]) for it in full_items], dim=0) + hf_inputs = BatchFeature(data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}) config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) for idx, item in zip(full_indices, kwargs_items["image"]): - encoded[idx] = encode_mm_kwargs_item(item) + bufs = encoder.encode(item) + assert len(bufs) == 1, "All tensors should be inline" + raw_payload = bufs[0] + if mode == "off": + encoded[idx] = base64.b64encode(raw_payload).decode("ascii") + inline_count += 1 + continue + + mm_hash = (mm_data.mm_hashes.get("image") or [])[idx] + placeholder = (mm_data.mm_placeholders.get("image") or [])[idx] + ref = _write_mm_feature_artifact( + run_id=run_id, + fingerprint=fingerprint, + modality="image", + mm_hash=mm_hash, + payload=raw_payload, + placeholder_length=placeholder.length, + ) + encoded[idx] = ref + mmfile_count += 1 + if image_items: + _request_logger.debug( + "built qwen-vl mm features mode=%s none=%d inline=%d mmfile=%d", + mode, + len(image_items) - len(full_indices), + inline_count, + mmfile_count, + ) out["kwargs_data"]["image"] = encoded out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) out["mm_placeholders"]["image"] = [ - {"offset": p.offset, "length": p.length} - for p in mm_data.mm_placeholders.get("image") or [] + {"offset": p.offset, "length": p.length} for p in mm_data.mm_placeholders.get("image") or [] ] # If no full payload was built across any modality, drop ``kwargs_data`` so # vLLM takes the hash-only (cache-hit) path. Otherwise hand it the payload # (with ``None`` slots for the hash-only images). - if not any( - any(item is not None for item in items) for items in out["kwargs_data"].values() - ): + if not any(any(item is not None for item in items) for items in out["kwargs_data"].values()): out["kwargs_data"] = None return out diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..763afb5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -32,16 +32,12 @@ def render(self, messages, *, tools=None, add_generation_prompt=False): ) def render_ids(self, messages, *, tools=None, add_generation_prompt=False): - return self.render( - messages, tools=tools, add_generation_prompt=add_generation_prompt - ).token_ids + return self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt).token_ids def get_stop_token_ids(self): return [99] - def parse_response( - self, completion_ids: list[int], *, tools=None - ) -> ParsedResponse: + def parse_response(self, completion_ids: list[int], *, tools=None) -> ParsedResponse: assert completion_ids == [7, 8] # Stores tools so tests can assert the client plumbed them through. self._last_parse_tools = tools @@ -69,9 +65,7 @@ def __init__(self): self.base_url = "http://fake-host:8000/v1" async def post(self, path, *, cast_to=dict, body=None, options=None): - self.calls.append( - {"path": path, "cast_to": cast_to, "body": body, "options": options} - ) + self.calls.append({"path": path, "cast_to": cast_to, "body": body, "options": options}) routed_experts = np.array([[[1]], [[2]]], dtype=np.uint8) payload = { "request_id": "gen-test", @@ -87,9 +81,7 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): }, "finish_reason": "stop", "routed_experts": { - "data": base64.b64encode(routed_experts.tobytes()).decode( - "ascii" - ), + "data": base64.b64encode(routed_experts.tobytes()).decode("ascii"), "shape": list(routed_experts.shape), }, } @@ -119,9 +111,7 @@ def test_generate_builds_request_body_and_parses_response(): # The client must plumb `tools` through to parse_response so XML-style # parsers can preserve declared-string args verbatim. - assert renderer._last_parse_tools == [ - {"type": "function", "function": {"name": "echo"}} - ] + assert renderer._last_parse_tools == [{"type": "function", "function": {"name": "echo"}}] assert len(client.calls) == 1 # /inference/v1/generate is mounted at the server root, so we post to @@ -175,9 +165,7 @@ def test_generate_builds_request_body_and_parses_response(): class _MalformedToolRenderer(_FakeRenderer): """Returns only a malformed tool-call attempt — finish_reason must stay "stop".""" - def parse_response( - self, completion_ids: list[int], *, tools=None - ) -> ParsedResponse: + def parse_response(self, completion_ids: list[int], *, tools=None) -> ParsedResponse: return ParsedResponse( content="", reasoning_content=None, @@ -288,15 +276,16 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): ], ids=["qwen3_vl", "qwen35"], ) -def test_generate_serializes_multimodal_features_for_qwen_vl_family( - model_id, renderer_class_path -): +def test_generate_serializes_multimodal_features_for_qwen_vl_family(model_id, renderer_class_path, monkeypatch): """When the renderer emits ``MultiModalData``, ``generate`` translates it into vLLM's ``features`` payload (mm_hashes + mm_placeholders + base64-encoded kwargs_data) and sticks it in the request body. Covers - every renderer routed through ``_build_qwen_vl_features``.""" + every renderer routed through ``_build_qwen_vl_features``. Pins the store + mode off so it exercises the inline-base64 path (the on path, which emits + mmfile refs, is covered by ``test_qwen_vl_features_can_emit_mmfile_refs``).""" import importlib + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "off") pytest.importorskip("torch") pytest.importorskip("vllm", reason="vllm needed for features serialization") @@ -371,6 +360,46 @@ def test_generate_serializes_multimodal_features_for_qwen_vl_family( assert isinstance(item, str) and len(item) > 0 +def test_qwen_vl_features_can_emit_mmfile_refs(tmp_path, monkeypatch): + pytest.importorskip("torch") + pytest.importorskip("vllm", reason="vllm needed for features serialization") + + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange + from renderers.client import _build_qwen_vl_features + + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "on") + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + monkeypatch.setenv("RUN_ID", "mmfiletest") + + mm_data = MultiModalData( + mm_hashes={"image": ["a" * 32, "b" * 32]}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=5, length=1), + PlaceholderRange(offset=10, length=1), + ] + }, + mm_items={ + "image": [ + { + "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), + "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), + }, + {"image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64)}, + ] + }, + ) + + features = _build_qwen_vl_features(mm_data, spatial_merge_size=2) + + items = features["kwargs_data"]["image"] + assert items[0].startswith("mmfile:v1:mmfiletest:") + assert items[0].endswith(":image:" + "a" * 32) + assert items[1] is None + assert len(list(tmp_path.rglob("*.msgpack"))) == 1 + + # --------------------------------------------------------------------------- # Prompt overflow handling. # --------------------------------------------------------------------------- @@ -502,3 +531,56 @@ def test_generate_caches_max_prompt_len_lookup_failure(): assert len(client.calls) == 1 assert result["prompt_ids"] == list(range(10)) assert _max_prompt_len_cache[("http://no-models:8000/v1", "test-model")] is None + + +def test_sweep_stale_artifacts_evicts_only_stale_files(tmp_path): + import os + import time + + from renderers.mm_store import sweep_stale_artifacts + + run_dir = tmp_path / "run_x" + images = run_dir / "assets" / "images" + features = run_dir / "assets" / "mm_features" / "v1" + images.mkdir(parents=True) + features.mkdir(parents=True) + + stale_img = images / "stale.jpg" + fresh_img = images / "fresh.jpg" + stale_feat = features / "stale.msgpack" + fresh_feat = features / "fresh.msgpack" + for p in (stale_img, fresh_img, stale_feat, fresh_feat): + p.write_bytes(b"x") + + old = time.time() - 10_000 + os.utime(stale_img, (old, old)) + os.utime(stale_feat, (old, old)) + + deleted = sweep_stale_artifacts(run_dir, ttl_seconds=3600.0) + + assert deleted == 2 + assert not stale_img.exists() + assert not stale_feat.exists() + assert fresh_img.exists() + assert fresh_feat.exists() + + +def test_sweep_stale_artifacts_noops_on_missing_dirs(tmp_path): + from renderers.mm_store import sweep_stale_artifacts + + assert sweep_stale_artifacts(tmp_path / "does_not_exist", ttl_seconds=1.0) == 0 + + +def test_mmfile_ref_emit_parse_roundtrip(): + """The ref shape is defined once: split_mmfile_ref is the exact inverse of + mmfile_ref (guards against emit/parse drift across repos).""" + from renderers.mm_store import mmfile_ref, split_mmfile_ref + + ref = mmfile_ref(run_id="run-a", fingerprint="deadbeef", modality="image", mm_hash="abc123") + assert ref == "mmfile:v1:run-a:deadbeef:image:abc123" + assert split_mmfile_ref(ref) == ("run-a", "deadbeef", "image", "abc123") + # Legacy 5-part form → run_id is None (caller supplies it). + assert split_mmfile_ref("mmfile:v1:fp:image:hash") == (None, "fp", "image", "hash") + for bad in ("mmfile:v2:a:b:c:d", "notmmfile:v1:a:b:c:d", "mmfile:v1:a:b"): + with pytest.raises(ValueError): + split_mmfile_ref(bad) From 10b71d627184d4db5448fb12e2941e42b32b07b4 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:49:55 +0000 Subject: [PATCH 5/8] feat(mm): scope artifact sweep to features only; last-use mtime on reuse sweep_stale_artifacts now evicts only assets/mm_features (the expensive processed MultiModalKwargsItem payloads). assets/images are never swept: screenshots are terminal browser output with no regeneration path, so they are kept for the whole run as the recoverable source of truth, whereas features are a regenerable cache (the trainer rebuilds pixels from the image and never reads these files; the env-worker rewrites any missing feature on demand). Over- eviction of a feature is therefore safe; over-eviction of an image is not, which is why the sweep deliberately excludes the image subdir. The feature writer (_write_mm_feature_artifact) now refreshes mtime on the already-on-disk-and-valid path, so a recurring feature is treated as hot by the last-use sweep instead of aging out on its first-write mtime and forcing an expensive force_full_pixels reprocess. Test updated to the features-only / keep-images contract. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 9 ++++++++ renderers/mm_store.py | 49 +++++++++++++++++++++++++++---------------- tests/test_client.py | 11 ++++++---- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index c5844ae..8ea5983 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -219,6 +219,15 @@ def _write_mm_feature_artifact( mm_hash=mm_hash, placeholder_length=placeholder_length, ): + # Recurring image: the feature is already on disk and valid. Refresh its + # mtime so the orchestrator's last-use feature sweep treats it as hot — + # otherwise a frequently-reused feature keeps its first-write mtime, ages + # past the TTL while still being referenced, and gets swept (forcing an + # expensive force_full_pixels reprocess on the next reference). Suppress + # OSError in case a concurrent sweep just removed it — the rewrite below + # would handle that on a subsequent call anyway. + with contextlib.suppress(OSError): + os.utime(path, None) return mmfile_ref(run_id=run_id, fingerprint=fingerprint, modality=modality, mm_hash=mm_hash) envelope = build_mm_feature_envelope( diff --git a/renderers/mm_store.py b/renderers/mm_store.py index d5fe03a..41b2219 100644 --- a/renderers/mm_store.py +++ b/renderers/mm_store.py @@ -187,27 +187,40 @@ def mm_feature_envelope_matches( def sweep_stale_artifacts(run_dir: Path, ttl_seconds: float) -> int: - """Delete artifact files under run_dir/{assets/images, assets/mm_features} whose - mtime is older than ttl_seconds. Returns count deleted. Safe by construction: - artifacts are content-addressed and re-writable, so over-eviction just triggers - re-materialization, never corruption. No-op if the dirs don't exist; ignore - per-file errors (a file may be mid-write). Walk files only; leave dir structure.""" + """Delete stale ``assets/mm_features`` artifacts (the expensive processed + ``MultiModalKwargsItem`` payloads, ~tens of MB each) whose mtime is older than + ttl_seconds. Returns count deleted. + + Features ONLY — ``assets/images`` are never swept here. Features are a + regenerable cache: the trainer rebuilds pixels from the source image + (``materialize_pixels``) and never reads these files, and the env-worker + rewrites any missing feature on demand (``force_full_pixels`` repair retry + + write-if-missing). Source images, by contrast, are terminal browser output + with no regeneration path, so they are retained for the whole run as the + recoverable source of truth. Over-eviction of a feature is therefore safe + (it just forces a reprocess); over-eviction of an image is NOT, which is why + this sweep deliberately excludes ``IMAGE_ASSET_SUBDIR``. + + ttl_seconds only needs to exceed the write->vLLM-admit window (seconds), so + any horizon of minutes leaves a huge safety margin against racing in-flight + reads. No-op if the dir doesn't exist; ignore per-file errors (a file may be + mid-write). Walk files only; leave dir structure.""" import time cutoff = time.time() - ttl_seconds deleted = 0 - for subdir in (IMAGE_ASSET_SUBDIR, FEATURE_ASSET_SUBDIR): - base = run_dir / subdir - if not base.is_dir(): + base = run_dir / FEATURE_ASSET_SUBDIR + if not base.is_dir(): + return 0 + for path in base.rglob("*"): + if not path.is_file(): + continue + try: + if path.stat().st_mtime < cutoff: + path.unlink() + deleted += 1 + except OSError: + # File may be mid-write or already gone; over-eviction is safe (the + # feature is regenerable), so ignore and continue. continue - for path in base.rglob("*"): - if not path.is_file(): - continue - try: - if path.stat().st_mtime < cutoff: - path.unlink() - deleted += 1 - except OSError: - # File may be mid-write or already gone; over/under-eviction is safe. - continue return deleted diff --git a/tests/test_client.py b/tests/test_client.py index 763afb5..f23824e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -533,7 +533,7 @@ def test_generate_caches_max_prompt_len_lookup_failure(): assert _max_prompt_len_cache[("http://no-models:8000/v1", "test-model")] is None -def test_sweep_stale_artifacts_evicts_only_stale_files(tmp_path): +def test_sweep_stale_artifacts_evicts_only_stale_features_never_images(tmp_path): import os import time @@ -558,11 +558,14 @@ def test_sweep_stale_artifacts_evicts_only_stale_files(tmp_path): deleted = sweep_stale_artifacts(run_dir, ttl_seconds=3600.0) - assert deleted == 2 - assert not stale_img.exists() + # Features only: the stale feature is evicted, the fresh feature kept. + assert deleted == 1 assert not stale_feat.exists() - assert fresh_img.exists() assert fresh_feat.exists() + # Images are NEVER swept (terminal, non-regenerable source of truth) — even a + # stale image is retained for the whole run. + assert stale_img.exists() + assert fresh_img.exists() def test_sweep_stale_artifacts_noops_on_missing_dirs(tmp_path): From a8f874c416ecb155250db4ca1b732384018289af Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Wed, 3 Jun 2026 21:43:35 +0000 Subject: [PATCH 6/8] chore(mm): drop dead feature-build counters + redundant alias - client.py: remove the inline_count/mmfile_count debug counters and the "built qwen-vl mm features ..." debug log (the mode=="off" inline path itself is unchanged). - mm_store.py: fold the redundant mm_feature_run_root alias into feature_asset_dir (internal-only, no external importers). Pure cleanup; no behavior change. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 12 ------------ renderers/mm_store.py | 9 ++------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 8ea5983..d748502 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -630,8 +630,6 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) # mm_items ship numpy arrays (the renderer is torch-free); convert at # this vLLM-glue boundary where torch is already a hard dependency. encoded: list[Any] = [None] * len(image_items) - mmfile_count = 0 - inline_count = 0 full_indices = [i for i, it in enumerate(image_items) if it.get("pixel_values") is not None] if full_indices: full_items = [image_items[i] for i in full_indices] @@ -646,7 +644,6 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) raw_payload = bufs[0] if mode == "off": encoded[idx] = base64.b64encode(raw_payload).decode("ascii") - inline_count += 1 continue mm_hash = (mm_data.mm_hashes.get("image") or [])[idx] @@ -660,15 +657,6 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) placeholder_length=placeholder.length, ) encoded[idx] = ref - mmfile_count += 1 - if image_items: - _request_logger.debug( - "built qwen-vl mm features mode=%s none=%d inline=%d mmfile=%d", - mode, - len(image_items) - len(full_indices), - inline_count, - mmfile_count, - ) out["kwargs_data"]["image"] = encoded out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) out["mm_placeholders"]["image"] = [ diff --git a/renderers/mm_store.py b/renderers/mm_store.py index 41b2219..72c58bb 100644 --- a/renderers/mm_store.py +++ b/renderers/mm_store.py @@ -76,16 +76,11 @@ def feature_asset_dir(run_id: str) -> Path: """``/assets/mm_features``, resolved. This is the root that ``mm_feature_path`` builds under and that the - traversal guard checks against. Identical to ``mm_feature_run_root``. + traversal guard checks against. """ return (run_dir(run_id) / FEATURE_ASSET_SUBDIR).resolve() -# Alias kept for symmetry with the feature-format function names; both resolve -# to ``/run_/assets/mm_features``. -mm_feature_run_root = feature_asset_dir - - def mm_feature_fingerprint(*, family: str, spatial_merge_size: int) -> str: import importlib.metadata @@ -110,7 +105,7 @@ def mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: st if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): raise ValueError(f"Invalid multimodal feature hash: {mm_hash!r}") - root = mm_feature_run_root(run_id) + root = feature_asset_dir(run_id) path = (root / "v1" / "vllm-mmitem" / fingerprint / modality / mm_hash[:2] / f"{mm_hash}.msgpack").resolve() if not path.is_relative_to(root): raise ValueError(f"Multimodal feature path escaped root: {path}") From db5058e8e84ee716f662438bb8451367187b3463 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb <78387377+eligotts@users.noreply.github.com> Date: Tue, 9 Jun 2026 05:00:56 +0000 Subject: [PATCH 7/8] feat: raw-image (mmraw) layout-only renderer mode for Qwen3-VL family Env workers emit layout-only descriptors + mmraw refs instead of running the HF image processor; vLLM materializes pixels from the raw image on shared disk (hash + fingerprint + grid/placeholder validated). Avoids AutoProcessor and pixel_values on the env worker, cutting RSS. Co-Authored-By: Codex Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 107 ++++++++---- renderers/configs.py | 78 ++++++++- renderers/mm_store.py | 126 +++++++++++++++ renderers/qwen35.py | 26 ++- renderers/qwen3_vl.py | 367 ++++++++++++++++++++++++++++++++++++++++-- tests/test_client.py | 208 ++++++++++++++++++++++++ 6 files changed, 852 insertions(+), 60 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index d748502..57f5e6f 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -36,11 +36,15 @@ ToolSpec, ) from renderers.mm_store import ( + MM_RAW_PAYLOAD_KEY, + MM_RAW_PAYLOAD_VALUE, build_mm_feature_envelope, mm_feature_envelope_matches, mm_feature_fingerprint, mm_feature_path, + mm_payload_mode, mmfile_ref, + mmraw_ref, run_id_from_env, ) @@ -49,7 +53,6 @@ _MM_MAX_INFLIGHT_ENV = "RENDERERS_MM_MAX_INFLIGHT" _DEFAULT_MM_MAX_INFLIGHT = 4 _mm_payload_semaphores: dict[tuple[int, int], asyncio.Semaphore] = {} -_MM_FEATURE_STORE_MODE_ENV = "RENDERERS_MM_FEATURE_STORE_MODE" class OverlongPromptError(Exception): @@ -149,17 +152,6 @@ def _mm_max_inflight() -> int | None: return value -def _mm_feature_store_mode() -> str: - """Offload processed mm-feature payloads to disk and ship ``mmfile`` refs - (``on``, the default) or inline them as base64 in the request (``off``).""" - mode = os.getenv(_MM_FEATURE_STORE_MODE_ENV, "on").strip().lower() - if mode in {"", "0", "false", "off", "disabled", "none", "no"}: - return "off" - if mode in {"1", "true", "on", "enabled", "yes"}: - return "on" - raise ValueError(f"Invalid {_MM_FEATURE_STORE_MODE_ENV}={mode!r}; expected on or off.") - - def _fsync_dir(path: Path) -> None: fd = os.open(path, os.O_RDONLY | getattr(os, "O_DIRECTORY", 0)) try: @@ -263,6 +255,9 @@ async def _limit_mm_payloads(mm_data: MultiModalData | None) -> AsyncIterator[No if mm_data is None or mm_data.is_empty(): yield return + if mm_payload_mode() == "raw": + yield + return limit = _mm_max_inflight() if limit is None: @@ -433,7 +428,13 @@ def _features_and_descriptor_mm() -> "tuple[dict[str, Any] | None, MultiModalDat # lives on multimodal renderers + the pool, not the base ``Renderer`` # protocol; reached only when ``mm_data`` is non-empty, which implies a # multimodal renderer. - build_mm = cast(Any, renderer).materialize_pixels(mm_data, messages) if force_full_pixels else mm_data + if force_full_pixels: + if mm_payload_mode() == "raw": + build_mm = cast(Any, renderer).materialize_raw_refs(mm_data, messages) + else: + build_mm = cast(Any, renderer).materialize_pixels(mm_data, messages) + else: + build_mm = mm_data return _build_mm_features(renderer, build_mm), _strip_pixels(mm_data) async with _limit_mm_payloads(mm_data): @@ -524,18 +525,26 @@ def _features_and_descriptor_mm() -> "tuple[dict[str, Any] | None, MultiModalDat def _strip_pixels(mm_data: MultiModalData) -> MultiModalData: - """Return ``mm_data`` with ``pixel_values`` dropped from every item. + """Return ``mm_data`` with request-scoped multimodal payloads dropped. Keeps the descriptor (``image_grid_thw`` etc.), ``mm_hashes`` and ``mm_placeholders`` — everything needed for token alignment and for - re-deriving pixels later (POST via ``materialize_pixels``; training via - the orchestrator). The decoded pixel tensors are never retained on the - trajectory, which is what keeps env-worker memory flat across a rollout. + re-deriving pixels later (POST repair via raw refs / ``materialize_pixels``; + training via the orchestrator). Decoded pixels and one-request raw-ref + markers are never retained on the trajectory, so later bridge turns use the + cache-only ``None`` slot instead of reprocessing every image. """ if not mm_data.mm_items: return mm_data + drop_keys = { + "pixel_values", + "raw_uri", + "raw_image_id", + "mm_processor_fingerprint", + MM_RAW_PAYLOAD_KEY, + } new_items = { - modality: [{k: v for k, v in item.items() if k != "pixel_values"} for item in items] + modality: [{k: v for k, v in item.items() if k not in drop_keys} for item in items] for modality, items in mm_data.mm_items.items() } return replace(mm_data, mm_items=new_items) @@ -597,6 +606,55 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) Returns ``None`` semantics live one level up — this helper assumes the caller already verified ``mm_data`` is non-empty. """ + mode = mm_payload_mode() + out: dict[str, Any] = { + "mm_hashes": {}, + "mm_placeholders": {}, + "kwargs_data": {}, + } + + image_items = mm_data.mm_items.get("image") or [] + if image_items: + mm_hashes = list(mm_data.mm_hashes.get("image") or []) + placeholders = list(mm_data.mm_placeholders.get("image") or []) + if len(mm_hashes) != len(image_items) or len(placeholders) != len(image_items): + raise ValueError( + "Qwen-VL mm sidecar length mismatch: " + f"items={len(image_items)} hashes={len(mm_hashes)} placeholders={len(placeholders)}" + ) + + if mode == "raw": + run_id = run_id_from_env() + encoded: list[Any] = [None] * len(image_items) + for idx, item in enumerate(image_items): + if item.get(MM_RAW_PAYLOAD_KEY) != MM_RAW_PAYLOAD_VALUE: + continue + raw_image_id = item.get("raw_image_id") + grid_thw = item.get("image_grid_thw") + fingerprint = item.get("mm_processor_fingerprint") + if not isinstance(raw_image_id, str) or not raw_image_id: + raise ValueError("raw multimodal image item is missing raw_image_id") + if grid_thw is None: + raise ValueError("raw multimodal image item is missing image_grid_thw") + if not isinstance(fingerprint, str) or not fingerprint: + raise ValueError("raw multimodal image item is missing mm_processor_fingerprint") + encoded[idx] = mmraw_ref( + run_id=run_id, + fingerprint=fingerprint, + modality="image", + mm_hash=mm_hashes[idx], + raw_image_id=raw_image_id, + grid_thw=grid_thw, + ) + out["kwargs_data"]["image"] = encoded + out["mm_hashes"]["image"] = mm_hashes + out["mm_placeholders"]["image"] = [ + {"offset": p.offset, "length": p.length} for p in placeholders + ] + if not any(item is not None for item in encoded): + out["kwargs_data"] = None + return out + try: import torch from transformers.feature_extraction_utils import BatchFeature @@ -610,18 +668,9 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) "environment, or pre-build features upstream." ) from exc - mode = _mm_feature_store_mode() - run_id = run_id_from_env() if mode != "off" else "" + run_id = run_id_from_env() if mode == "processed" else "" encoder = MsgpackEncoder(size_threshold=2**62) fingerprint = mm_feature_fingerprint(family="qwen_vl", spatial_merge_size=spatial_merge_size) - - out: dict[str, Any] = { - "mm_hashes": {}, - "mm_placeholders": {}, - "kwargs_data": {}, - } - - image_items = mm_data.mm_items.get("image") or [] if image_items: # An item carrying ``pixel_values`` is sent as a full payload; an item # without (descriptor-only) is sent hash-only, on the assumption that @@ -642,7 +691,7 @@ def _build_qwen_vl_features(mm_data: MultiModalData, *, spatial_merge_size: int) bufs = encoder.encode(item) assert len(bufs) == 1, "All tensors should be inline" raw_payload = bufs[0] - if mode == "off": + if mode == "inline": encoded[idx] = base64.b64encode(raw_payload).decode("ascii") continue diff --git a/renderers/configs.py b/renderers/configs.py index e27307b..72444d5 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -153,7 +153,31 @@ class Qwen35RendererConfig(BaseRendererConfig): caching so processed pixel buffers stay request-scoped. Renderer-internal — not a Jinja chat-template kwarg.""" - _internal_fields = frozenset({"image_cache_max"}) + image_patch_size: int | None = None + """Optional Qwen image patch size override for raw-ref layout.""" + + image_temporal_patch_size: int | None = None + """Optional Qwen temporal patch size override for raw-ref layout.""" + + image_merge_size: int | None = None + """Optional Qwen spatial merge size override for raw-ref layout.""" + + image_min_pixels: int | None = None + """Optional Qwen minimum resized image area override for raw-ref layout.""" + + image_max_pixels: int | None = None + """Optional Qwen maximum resized image area override for raw-ref layout.""" + + _internal_fields = frozenset( + { + "image_cache_max", + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class Qwen36RendererConfig(BaseRendererConfig): @@ -170,7 +194,31 @@ class Qwen36RendererConfig(BaseRendererConfig): image_cache_max: int = 0 """See :class:`Qwen35RendererConfig.image_cache_max`.""" - _internal_fields = frozenset({"image_cache_max"}) + image_patch_size: int | None = None + """See :class:`Qwen35RendererConfig.image_patch_size`.""" + + image_temporal_patch_size: int | None = None + """See :class:`Qwen35RendererConfig.image_temporal_patch_size`.""" + + image_merge_size: int | None = None + """See :class:`Qwen35RendererConfig.image_merge_size`.""" + + image_min_pixels: int | None = None + """See :class:`Qwen35RendererConfig.image_min_pixels`.""" + + image_max_pixels: int | None = None + """See :class:`Qwen35RendererConfig.image_max_pixels`.""" + + _internal_fields = frozenset( + { + "image_cache_max", + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class Qwen3VLRendererConfig(BaseRendererConfig): @@ -184,7 +232,31 @@ class Qwen3VLRendererConfig(BaseRendererConfig): image_cache_max: int = 0 """See :class:`Qwen35RendererConfig.image_cache_max`.""" - _internal_fields = frozenset({"image_cache_max"}) + image_patch_size: int | None = None + """See :class:`Qwen35RendererConfig.image_patch_size`.""" + + image_temporal_patch_size: int | None = None + """See :class:`Qwen35RendererConfig.image_temporal_patch_size`.""" + + image_merge_size: int | None = None + """See :class:`Qwen35RendererConfig.image_merge_size`.""" + + image_min_pixels: int | None = None + """See :class:`Qwen35RendererConfig.image_min_pixels`.""" + + image_max_pixels: int | None = None + """See :class:`Qwen35RendererConfig.image_max_pixels`.""" + + _internal_fields = frozenset( + { + "image_cache_max", + "image_patch_size", + "image_temporal_patch_size", + "image_merge_size", + "image_min_pixels", + "image_max_pixels", + } + ) class GLM5RendererConfig(BaseRendererConfig): diff --git a/renderers/mm_store.py b/renderers/mm_store.py index 72c58bb..b2e9f3b 100644 --- a/renderers/mm_store.py +++ b/renderers/mm_store.py @@ -10,6 +10,10 @@ ``/assets/mm_features/v1/vllm-mmitem////.msgpack`` and shipped as ``mmfile:v1::::`` tuple refs. +3. **Raw-image inference refs** — raw images already written under + ``/assets/images`` and shipped as compact + ``mmraw:v1::::::`` + refs. vLLM loads the raw image and runs its own processor. This module is the single source of truth for the on-disk layout, the fingerprint, the ref strings, and the msgpack envelope. It lives in @@ -33,13 +37,20 @@ MM_FEATURE_ROOT_ENV = "PRIME_RL_MM_FEATURE_ROOT" MMFILE_PREFIX = "mmfile:v1" +MMRAW_PREFIX = "mmraw:v1" +MM_PAYLOAD_MODE_ENV = "RENDERERS_MM_FEATURE_STORE_MODE" +MM_RAW_PAYLOAD_KEY = "_prime_rl_mm_payload" +MM_RAW_PAYLOAD_VALUE = "raw" _SAFE_RUN_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") _SAFE_FINGERPRINT_RE = re.compile(r"^[a-f0-9]{16,64}$") _SAFE_MM_HASH_RE = re.compile(r"^[a-f0-9]{16,128}$") +_SAFE_RAW_IMAGE_ID_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_SAFE_GRID_THW_RE = re.compile(r"^[0-9]+x[0-9]+x[0-9]+$") _MM_FEATURE_SCHEMA_VERSION = 1 _MM_FEATURE_KIND = "vllm.MultiModalKwargsItem" +_MM_RAW_SCHEMA_VERSION = 1 # Run-dir-relative asset subdirs, for callers that already hold a run dir. IMAGE_ASSET_SUBDIR = Path("assets/images") @@ -81,6 +92,32 @@ def feature_asset_dir(run_id: str) -> Path: return (run_dir(run_id) / FEATURE_ASSET_SUBDIR).resolve() +def mm_payload_mode() -> str: + """Return the inference multimodal payload mode. + + Explicit env wins: + + - ``raw``: send only cache-only ``None`` slots or raw-image ``mmraw`` refs. + - ``on`` / ``processed``: legacy processed ``mmfile`` feature artifacts. + - ``off`` / ``inline``: legacy inline base64 processed payloads. + + With no explicit env, hosted runs (``RUN_ID`` present) default to ``raw`` so + env workers avoid the image processor. Local/dev processes without + ``RUN_ID`` keep the old processed-artifact behavior. + """ + raw = os.getenv(MM_PAYLOAD_MODE_ENV) + if raw is None: + return "raw" if os.getenv("RUN_ID", "").strip() else "processed" + mode = raw.strip().lower() + if mode in {"raw", "raw-ref", "raw_refs", "mmraw"}: + return "raw" + if mode in {"", "1", "true", "on", "enabled", "yes", "processed", "mmfile"}: + return "processed" + if mode in {"0", "false", "off", "disabled", "none", "no", "inline"}: + return "inline" + raise ValueError(f"Invalid {MM_PAYLOAD_MODE_ENV}={mode!r}; expected raw, on/processed, or off/inline.") + + def mm_feature_fingerprint(*, family: str, spatial_merge_size: int) -> str: import importlib.metadata @@ -97,6 +134,37 @@ def mm_feature_fingerprint(*, family: str, spatial_merge_size: int) -> str: return hashlib.sha256(raw).hexdigest()[:32] +def mm_processor_fingerprint( + *, + family: str, + patch_size: int, + merge_size: int, + temporal_patch_size: int, + min_pixels: int, + max_pixels: int, +) -> str: + import importlib.metadata + + try: + transformers_version = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError: + transformers_version = "missing" + + parts = { + "schema_version": _MM_RAW_SCHEMA_VERSION, + "kind": "raw-image-processor-layout", + "family": family, + "patch_size": int(patch_size), + "merge_size": int(merge_size), + "temporal_patch_size": int(temporal_patch_size), + "min_pixels": int(min_pixels), + "max_pixels": int(max_pixels), + "transformers": transformers_version, + } + raw = json.dumps(parts, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:32] + + def mm_feature_path(*, run_id: str, fingerprint: str, modality: str, mm_hash: str) -> Path: if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): raise ValueError(f"Invalid multimodal feature fingerprint: {fingerprint!r}") @@ -132,6 +200,64 @@ def split_mmfile_ref(ref: str) -> "tuple[str | None, str, str, str]": return None, parts[2], parts[3], parts[4] +def _grid_to_ref(grid_thw: object) -> str: + data = grid_thw.tolist() if hasattr(grid_thw, "tolist") else grid_thw + if isinstance(data, list) and data and isinstance(data[0], list): + data = data[0] + if not isinstance(data, (list, tuple)) or len(data) != 3: + raise ValueError(f"Invalid image grid_thw for raw ref: {grid_thw!r}") + return "x".join(str(int(v)) for v in data) + + +def _grid_from_ref(value: str) -> list[int]: + if not _SAFE_GRID_THW_RE.fullmatch(value): + raise ValueError(f"Invalid image grid_thw ref segment: {value!r}") + return [int(v) for v in value.split("x")] + + +def raw_image_path(*, run_id: str, raw_image_id: str) -> Path: + if not _SAFE_RAW_IMAGE_ID_RE.fullmatch(raw_image_id): + raise ValueError(f"Invalid raw image id: {raw_image_id!r}") + root = image_asset_dir(run_id) + path = (root / raw_image_id).resolve() + if not path.is_relative_to(root): + raise ValueError(f"Raw image path escaped root: {path}") + return path + + +def mmraw_ref( + *, + run_id: str, + fingerprint: str, + modality: str, + mm_hash: str, + raw_image_id: str, + grid_thw: object, +) -> str: + if not _SAFE_RUN_ID_RE.fullmatch(run_id): + raise ValueError(f"Invalid raw multimodal run id: {run_id!r}") + if not _SAFE_FINGERPRINT_RE.fullmatch(fingerprint): + raise ValueError(f"Invalid raw multimodal fingerprint: {fingerprint!r}") + if modality != "image": + raise ValueError(f"Unsupported raw multimodal modality: {modality!r}") + if not _SAFE_MM_HASH_RE.fullmatch(mm_hash): + raise ValueError(f"Invalid raw multimodal hash: {mm_hash!r}") + raw_image_path(run_id=run_id, raw_image_id=raw_image_id) + return f"{MMRAW_PREFIX}:{run_id}:{fingerprint}:{modality}:{mm_hash}:{raw_image_id}:{_grid_to_ref(grid_thw)}" + + +def split_mmraw_ref(ref: str) -> tuple[str, str, str, str, str, list[int]]: + """Parse a run-scoped raw-image ref into + ``(run_id, fingerprint, modality, mm_hash, raw_image_id, grid_thw)``. + Field-level safe-regex checks live in the reader, but the grid segment is + parsed here so emit/parse cannot drift. + """ + parts = ref.split(":") + if parts[:2] != ["mmraw", "v1"] or len(parts) != 8: + raise ValueError(f"Invalid mmraw ref shape: {ref!r}") + return parts[2], parts[3], parts[4], parts[5], parts[6], _grid_from_ref(parts[7]) + + def build_mm_feature_envelope( *, run_id: str, diff --git a/renderers/qwen35.py b/renderers/qwen35.py index a98719a..4a966e8 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -40,6 +40,8 @@ _is_video_part, _load_pil_image, materialize_image_pixels, + materialize_image_raw_refs, + qwen_image_item_for_render, ) # --------------------------------------------------------------------------- @@ -209,6 +211,12 @@ def materialize_pixels( :func:`materialize_image_pixels`.""" return materialize_image_pixels(self, mm_data, messages) + def materialize_raw_refs( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + """Attach raw-image refs to descriptor-only mm_data without pixels.""" + return materialize_image_raw_refs(self, mm_data, messages) + @staticmethod def _content_has_media(content: Any) -> bool: """True when ``content`` is a structured list containing image / video parts.""" @@ -373,7 +381,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # specials are template scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text( @@ -395,12 +403,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: """Emit a user message whose content list contains image parts. @@ -724,7 +727,7 @@ def emit_text_segments( content_mask.append(is_content) def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text(f"Picture {vision_counts['image']}: ", msg_idx) @@ -737,12 +740,7 @@ def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: emit_special(self._im_start, msg_idx) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index ce849a5..dffdd48 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -30,6 +30,10 @@ import hashlib import io import json +import math +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path from typing import Any from urllib.parse import urlparse @@ -48,6 +52,12 @@ trim_to_turn_close, ) from renderers.configs import Qwen3VLRendererConfig +from renderers.mm_store import ( + MM_RAW_PAYLOAD_KEY, + MM_RAW_PAYLOAD_VALUE, + mm_payload_mode, + mm_processor_fingerprint, +) from renderers.parsing import parse_qwen3 _TOOLS_HEADER = ( @@ -163,6 +173,275 @@ def _image_hash(pil_image) -> str: return h.hexdigest()[:32] +@dataclass(frozen=True) +class QwenImageLayoutConfig: + patch_size: int + temporal_patch_size: int + merge_size: int + min_pixels: int + max_pixels: int + + @classmethod + def from_explicit_renderer_config(cls, config: Any) -> "QwenImageLayoutConfig | None": + values = { + "patch_size": getattr(config, "image_patch_size", None), + "temporal_patch_size": getattr(config, "image_temporal_patch_size", None), + "merge_size": getattr(config, "image_merge_size", None), + "min_pixels": getattr(config, "image_min_pixels", None), + "max_pixels": getattr(config, "image_max_pixels", None), + } + if all(value is None for value in values.values()): + return None + if any(value is None for value in values.values()): + missing = ", ".join(k for k, v in values.items() if v is None) + raise ValueError( + "Qwen raw image layout overrides must be all-or-none; " + f"missing: {missing}" + ) + return cls( + patch_size=int(values["patch_size"]), + temporal_patch_size=int(values["temporal_patch_size"]), + merge_size=int(values["merge_size"]), + min_pixels=int(values["min_pixels"]), + max_pixels=int(values["max_pixels"]), + ) + + @classmethod + def from_image_processor(cls, image_processor: Any) -> "QwenImageLayoutConfig": + size = getattr(image_processor, "size", None) + min_pixels = getattr(size, "shortest_edge", None) + max_pixels = getattr(size, "longest_edge", None) + if isinstance(size, dict): + min_pixels = size.get("shortest_edge", min_pixels) + max_pixels = size.get("longest_edge", max_pixels) + if min_pixels is None or max_pixels is None: + raise ValueError("Qwen image processor size must include shortest_edge and longest_edge") + return cls( + patch_size=int(getattr(image_processor, "patch_size")), + temporal_patch_size=int(getattr(image_processor, "temporal_patch_size")), + merge_size=int(getattr(image_processor, "merge_size")), + min_pixels=int(min_pixels), + max_pixels=int(max_pixels), + ) + + @classmethod + def from_preprocessor_config(cls, data: dict[str, Any]) -> "QwenImageLayoutConfig": + if data.get("do_resize", True) is False: + raise ValueError("Qwen raw image layout requires do_resize=True") + size = data.get("size") + if not isinstance(size, dict): + raise ValueError("Qwen preprocessor_config.json is missing size") + min_pixels = data.get("min_pixels", size.get("shortest_edge")) + max_pixels = data.get("max_pixels", size.get("longest_edge")) + required = { + "patch_size": data.get("patch_size"), + "temporal_patch_size": data.get("temporal_patch_size"), + "merge_size": data.get("merge_size"), + "min_pixels": min_pixels, + "max_pixels": max_pixels, + } + missing = [k for k, v in required.items() if v is None] + if missing: + raise ValueError(f"Qwen preprocessor_config.json missing required field(s): {', '.join(missing)}") + return cls( + patch_size=int(required["patch_size"]), + temporal_patch_size=int(required["temporal_patch_size"]), + merge_size=int(required["merge_size"]), + min_pixels=int(required["min_pixels"]), + max_pixels=int(required["max_pixels"]), + ) + + +@dataclass(frozen=True) +class QwenImageLayoutDescriptor: + mm_hash: str + raw_uri: str + raw_image_id: str + image_grid_thw: list[list[int]] + num_image_tokens: int + fingerprint: str + + +@lru_cache(maxsize=32) +def _load_preprocessor_config_json(model_name_or_path: str) -> dict[str, Any]: + path = Path(model_name_or_path) + candidates: list[Path] = [] + if path.is_dir(): + candidates.append(path / "preprocessor_config.json") + elif path.is_file() and path.name == "preprocessor_config.json": + candidates.append(path) + else: + try: + from huggingface_hub import try_to_load_from_cache + + cached = try_to_load_from_cache(model_name_or_path, "preprocessor_config.json") + if isinstance(cached, str): + candidates.append(Path(cached)) + except Exception: + pass + + for candidate in candidates: + if candidate.is_file(): + with candidate.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"preprocessor_config.json is not an object: {candidate}") + return data + + raise RuntimeError( + "Qwen raw image layout could not find preprocessor_config.json for " + f"{model_name_or_path!r}. Ensure the model is cached locally or set all " + "image_* layout fields explicitly in the renderer config." + ) + + +def qwen_image_layout_config_for_renderer(renderer: Any) -> QwenImageLayoutConfig: + explicit = QwenImageLayoutConfig.from_explicit_renderer_config(renderer.config) + if explicit is not None: + return explicit + + processor = getattr(renderer, "_processor", None) + image_processor = getattr(processor, "image_processor", None) + if image_processor is not None: + return QwenImageLayoutConfig.from_image_processor(image_processor) + + model_name = getattr(getattr(renderer, "_tokenizer", None), "name_or_path", None) + if not model_name: + raise RuntimeError( + "Qwen raw image layout requires tokenizer.name_or_path, an explicit " + "processor, or explicit image_* layout config fields." + ) + return QwenImageLayoutConfig.from_preprocessor_config(_load_preprocessor_config_json(str(model_name))) + + +def _smart_resize( + height: int, + width: int, + *, + factor: int, + min_pixels: int, + max_pixels: int, +) -> tuple[int, int]: + """Qwen image resize math without materializing resized pixels.""" + if height <= 0 or width <= 0: + raise ValueError(f"image dimensions must be positive, got {height}x{width}") + if max(height, width) / min(height, width) > 200: + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def _image_source_url(item: dict[str, Any]) -> str | None: + raw: Any + if "image" in item: + raw = item["image"] + elif "image_url" in item: + iu = item.get("image_url") + raw = iu.get("url") if isinstance(iu, dict) else iu + else: + raw = item.get("url") or item.get("path") + return raw if isinstance(raw, str) else None + + +def _raw_uri_and_id(item: dict[str, Any]) -> tuple[str, str]: + source = _image_source_url(item) + if not source: + raise ValueError("raw multimodal mode requires image parts backed by file:// URLs") + parsed = urlparse(source) + if parsed.scheme == "file": + raw_uri = source + raw_image_id = Path(parsed.path).name + elif parsed.scheme == "": + path = Path(source).resolve() + raw_uri = f"file://{path}" + raw_image_id = path.name + else: + raise ValueError( + "raw multimodal mode requires file:// image URLs; " + f"got scheme {parsed.scheme!r}" + ) + if not raw_image_id: + raise ValueError(f"raw multimodal image URL has no basename: {source!r}") + return raw_uri, raw_image_id + + +def describe_qwen_image_layout(renderer: Any, part: dict[str, Any]) -> QwenImageLayoutDescriptor: + """Cheap Qwen image layout metadata, with no HF image processor call.""" + pil = _load_pil_image(part) + mm_hash = _image_hash(pil) + raw_uri, raw_image_id = _raw_uri_and_id(part) + layout = qwen_image_layout_config_for_renderer(renderer) + resized_h, resized_w = _smart_resize( + pil.height, + pil.width, + factor=layout.patch_size * layout.merge_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + grid_t = 1 + grid_h = resized_h // layout.patch_size + grid_w = resized_w // layout.patch_size + num_image_tokens = grid_t * grid_h * grid_w // (layout.merge_size * layout.merge_size) + fingerprint = mm_processor_fingerprint( + family="qwen_vl", + patch_size=layout.patch_size, + merge_size=layout.merge_size, + temporal_patch_size=layout.temporal_patch_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + return QwenImageLayoutDescriptor( + mm_hash=mm_hash, + raw_uri=raw_uri, + raw_image_id=raw_image_id, + image_grid_thw=[[grid_t, grid_h, grid_w]], + num_image_tokens=num_image_tokens, + fingerprint=fingerprint, + ) + + +def qwen_image_item_for_render(renderer: Any, part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: + """Return ``(num_tokens, mm_hash, mm_item)`` for Qwen image rendering. + + Raw mode emits layout-only descriptors so env workers do not instantiate + ``AutoProcessor``. Legacy modes preserve the processed-pixel path. + """ + if mm_payload_mode() == "raw": + desc = describe_qwen_image_layout(renderer, part) + return ( + desc.num_image_tokens, + desc.mm_hash, + { + "image_grid_thw": desc.image_grid_thw, + "raw_uri": desc.raw_uri, + "raw_image_id": desc.raw_image_id, + "mm_processor_fingerprint": desc.fingerprint, + MM_RAW_PAYLOAD_KEY: MM_RAW_PAYLOAD_VALUE, + }, + ) + _, out, n, h = renderer._process_image(part) + return ( + n, + h, + { + "pixel_values": out["pixel_values"], + "image_grid_thw": out["image_grid_thw"], + }, + ) + + def _iter_image_parts(messages: "list[Any]"): """Yield image content parts from a message list, in conversation order.""" for msg in messages or []: @@ -253,6 +532,70 @@ def materialize_image_pixels( return replace(mm_data, mm_items=new_items) +def materialize_image_raw_refs( + renderer: Any, mm_data: MultiModalData, messages: "list[Any]" +) -> MultiModalData: + """Attach raw-image refs for every image item without processing pixels.""" + from dataclasses import replace + + image_items = mm_data.mm_items.get("image") or [] + if not image_items: + return mm_data + hashes = mm_data.mm_hashes.get("image") or [] + if len(hashes) != len(image_items): + raise ValueError( + "materialize_image_raw_refs: mm_hashes/mm_items length mismatch " + f"({len(hashes)} vs {len(image_items)})" + ) + + missing = set(hashes) + resolved: dict[str, QwenImageLayoutDescriptor] = {} + for part in _iter_image_parts(messages): + if not missing: + break + desc = describe_qwen_image_layout(renderer, part) + if desc.mm_hash in missing: + resolved[desc.mm_hash] = desc + missing.discard(desc.mm_hash) + if missing: + raise ValueError( + f"materialize_image_raw_refs: {len(missing)} image hash(es) not " + "found in messages; cannot attach raw refs" + ) + + new_image_items: list[dict[str, Any]] = [] + for i, item in enumerate(image_items): + desc = resolved[hashes[i]] + item_grid = item.get("image_grid_thw") + if item_grid is not None and not _grids_equal(desc.image_grid_thw, item_grid): + raise ValueError( + "materialize_image_raw_refs: reconstructed image_grid_thw " + f"{desc.image_grid_thw!r} != descriptor {item_grid!r}" + ) + new_item = { + k: v + for k, v in item.items() + if k + not in { + "pixel_values", + "raw_uri", + "raw_image_id", + "mm_processor_fingerprint", + MM_RAW_PAYLOAD_KEY, + } + } + new_item["image_grid_thw"] = item_grid if item_grid is not None else desc.image_grid_thw + new_item["raw_uri"] = desc.raw_uri + new_item["raw_image_id"] = desc.raw_image_id + new_item["mm_processor_fingerprint"] = desc.fingerprint + new_item[MM_RAW_PAYLOAD_KEY] = MM_RAW_PAYLOAD_VALUE + new_image_items.append(new_item) + + new_items = dict(mm_data.mm_items) + new_items["image"] = new_image_items + return replace(mm_data, mm_items=new_items) + + class _Emitter: """Token-stream builder with BPE-safe text buffering. @@ -532,6 +875,12 @@ def materialize_pixels( :func:`materialize_image_pixels`.""" return materialize_image_pixels(self, mm_data, messages) + def materialize_raw_refs( + self, mm_data: MultiModalData, messages: list[Message] + ) -> MultiModalData: + """Attach raw-image refs to descriptor-only mm_data without pixels.""" + return materialize_image_raw_refs(self, mm_data, messages) + def render( self, messages: list[Message], @@ -561,7 +910,7 @@ def emit_image(part: dict[str, Any]) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # markers are renderer-emitted scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -578,12 +927,7 @@ def emit_image(part: dict[str, Any]) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: """Emit a user/tool content list with media handled inline. @@ -826,7 +1170,7 @@ def bridge_to_next_turn( vision_counts = {"image": prev_image_count, "video": prev_video_count} def emit_image(part: dict[str, Any]) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = qwen_image_item_for_render(self, part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -843,12 +1187,7 @@ def emit_image(part: dict[str, Any]) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: if isinstance(content, str): diff --git a/tests/test_client.py b/tests/test_client.py index f23824e..b5f70d6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -400,6 +400,193 @@ def test_qwen_vl_features_can_emit_mmfile_refs(tmp_path, monkeypatch): assert len(list(tmp_path.rglob("*.msgpack"))) == 1 +def test_qwen_vl_features_can_emit_mmraw_refs_without_processed_payloads(tmp_path, monkeypatch): + from renderers.base import MultiModalData, PlaceholderRange + from renderers.client import _build_qwen_vl_features + from renderers.mm_store import ( + MM_RAW_PAYLOAD_KEY, + MM_RAW_PAYLOAD_VALUE, + mm_processor_fingerprint, + raw_image_path, + split_mmraw_ref, + ) + + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "raw") + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + monkeypatch.setenv("RUN_ID", "rawtest") + raw_image_path(run_id="rawtest", raw_image_id="image.png").parent.mkdir(parents=True) + raw_image_path(run_id="rawtest", raw_image_id="image.png").write_bytes(b"not-read-by-serializer") + fingerprint = mm_processor_fingerprint( + family="qwen_vl", + patch_size=16, + merge_size=2, + temporal_patch_size=2, + min_pixels=65536, + max_pixels=16777216, + ) + mm_hash = "a" * 32 + mm_data = MultiModalData( + mm_hashes={"image": [mm_hash, "b" * 32]}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=5, length=1), + PlaceholderRange(offset=10, length=1), + ] + }, + mm_items={ + "image": [ + { + "image_grid_thw": [[1, 2, 2]], + "raw_image_id": "image.png", + "mm_processor_fingerprint": fingerprint, + MM_RAW_PAYLOAD_KEY: MM_RAW_PAYLOAD_VALUE, + }, + {"image_grid_thw": [[1, 2, 2]]}, + ] + }, + ) + + features = _build_qwen_vl_features(mm_data, spatial_merge_size=2) + + items = features["kwargs_data"]["image"] + assert items[1] is None + assert split_mmraw_ref(items[0]) == ( + "rawtest", + fingerprint, + "image", + mm_hash, + "image.png", + [1, 2, 2], + ) + assert list(tmp_path.rglob("*.msgpack")) == [] + + +def test_strip_pixels_removes_one_request_raw_markers(): + from renderers.base import MultiModalData + from renderers.client import _strip_pixels + from renderers.mm_store import MM_RAW_PAYLOAD_KEY, MM_RAW_PAYLOAD_VALUE + + mm_data = MultiModalData( + mm_items={ + "image": [ + { + "image_grid_thw": [[1, 2, 2]], + "raw_uri": "file:///tmp/image.png", + "raw_image_id": "image.png", + "mm_processor_fingerprint": "a" * 32, + MM_RAW_PAYLOAD_KEY: MM_RAW_PAYLOAD_VALUE, + } + ] + } + ) + + stripped = _strip_pixels(mm_data) + + assert stripped.mm_items == {"image": [{"image_grid_thw": [[1, 2, 2]]}]} + + +def test_qwen3_vl_raw_mode_render_does_not_process_pixels(tmp_path, monkeypatch): + import json + + from PIL import Image + from renderers.mm_store import MM_RAW_PAYLOAD_KEY, MM_RAW_PAYLOAD_VALUE + from renderers.qwen3_vl import Qwen3VLRenderer + + class _Tokenizer: + unk_token_id = -1 + _specials = { + "<|im_start|>": 1, + "<|im_end|>": 2, + "<|endoftext|>": 3, + "": 4, + "": 5, + "": 6, + "": 7, + "<|vision_start|>": 8, + "<|vision_end|>": 9, + "<|image_pad|>": 10, + "<|video_pad|>": 11, + } + + def __init__(self, name_or_path): + self.name_or_path = name_or_path + + def convert_tokens_to_ids(self, token): + return self._specials.get(token, self.unk_token_id) + + def encode(self, text, add_special_tokens=False): + return [100 + ord(ch) % 50 for ch in text] + + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "raw") + model_dir = tmp_path / "model" + model_dir.mkdir() + (model_dir / "preprocessor_config.json").write_text( + json.dumps( + { + "patch_size": 16, + "temporal_patch_size": 2, + "merge_size": 2, + "size": {"shortest_edge": 65536, "longest_edge": 16777216}, + } + ) + ) + path = tmp_path / "image.png" + Image.new("RGB", (32, 32), color=(255, 0, 0)).save(path) + renderer = Qwen3VLRenderer(_Tokenizer(str(model_dir)), processor=object()) + + rendered = renderer.render( + [{"role": "user", "content": [{"type": "image_url", "image_url": {"url": f"file://{path}"}}]}], + add_generation_prompt=True, + ) + + item = rendered.multi_modal_data.mm_items["image"][0] + assert "pixel_values" not in item + assert item[MM_RAW_PAYLOAD_KEY] == MM_RAW_PAYLOAD_VALUE + assert item["raw_image_id"] == "image.png" + assert item["image_grid_thw"] == [[1, 16, 16]] + assert rendered.multi_modal_data.mm_placeholders["image"][0].length == 64 + + +def test_qwen3_vl_raw_layout_matches_real_processor(tmp_path, monkeypatch): + from huggingface_hub import try_to_load_from_cache + from PIL import Image + + model_id = "Qwen/Qwen3-VL-4B-Instruct" + if not isinstance(try_to_load_from_cache(model_id, "preprocessor_config.json"), str): + pytest.skip(f"{model_id} preprocessor_config.json is not cached locally") + + transformers = pytest.importorskip("transformers") + from renderers.base import load_tokenizer + from renderers.qwen3_vl import Qwen3VLRenderer, describe_qwen_image_layout + + monkeypatch.setenv("RENDERERS_MM_FEATURE_STORE_MODE", "raw") + processor = transformers.AutoProcessor.from_pretrained(model_id, local_files_only=True) + tokenizer = load_tokenizer(model_id) + renderer = Qwen3VLRenderer(tokenizer) + + sizes = [ + (32, 32), + (512, 512), + (333, 777), + (1200, 300), + (4096, 2048), + (65, 97), + ] + for width, height in sizes: + path = tmp_path / f"image_{width}x{height}.png" + Image.new("RGB", (width, height), color=(width % 255, height % 255, 7)).save(path) + part = {"type": "image_url", "image_url": {"url": f"file://{path}"}} + desc = describe_qwen_image_layout(renderer, part) + with Image.open(path) as image: + expected = processor.image_processor(images=[image.convert("RGB")], return_tensors="np")["image_grid_thw"][ + 0 + ].tolist() + assert desc.image_grid_thw == [expected] + assert desc.num_image_tokens == int(expected[0] * expected[1] * expected[2]) // ( + processor.image_processor.merge_size**2 + ) + + # --------------------------------------------------------------------------- # Prompt overflow handling. # --------------------------------------------------------------------------- @@ -587,3 +774,24 @@ def test_mmfile_ref_emit_parse_roundtrip(): for bad in ("mmfile:v2:a:b:c:d", "notmmfile:v1:a:b:c:d", "mmfile:v1:a:b"): with pytest.raises(ValueError): split_mmfile_ref(bad) + + +def test_mmraw_ref_emit_parse_roundtrip(tmp_path, monkeypatch): + from renderers.mm_store import mmraw_ref, raw_image_path, split_mmraw_ref + + monkeypatch.setenv("PRIME_RL_MM_FEATURE_ROOT", str(tmp_path)) + raw_image_path(run_id="run-a", raw_image_id="abc.png").parent.mkdir(parents=True) + ref = mmraw_ref( + run_id="run-a", + fingerprint="deadbeefdeadbeef", + modality="image", + mm_hash="a" * 32, + raw_image_id="abc.png", + grid_thw=[[1, 2, 2]], + ) + + assert ref == "mmraw:v1:run-a:deadbeefdeadbeef:image:" + "a" * 32 + ":abc.png:1x2x2" + assert split_mmraw_ref(ref) == ("run-a", "deadbeefdeadbeef", "image", "a" * 32, "abc.png", [1, 2, 2]) + for bad in ("mmraw:v2:a:b:c:d:e:f", "notmmraw:v1:a:b:c:d:e:f", "mmraw:v1:a:b:c"): + with pytest.raises(ValueError): + split_mmraw_ref(bad) From 91f3040dcc1bb1228c2ae36c4163beb76fb1b58e Mon Sep 17 00:00:00 2001 From: hubert-marek Date: Tue, 9 Jun 2026 20:50:07 +0000 Subject: [PATCH 8/8] fix(qwen3-vl): fall back to hf_hub_download for mmraw preprocessor_config.json The raw-image (mmraw) layout path resolves the model's image geometry from preprocessor_config.json, but _load_preprocessor_config_json only checked local paths and the local HF cache (try_to_load_from_cache). Hosted env workers render models they never loaded locally, so hub-style ids always missed the cache and every image rollout failed with RuntimeError: Qwen raw image layout could not find preprocessor_config.json for 'Qwen/Qwen3.6-35B-A3B' ... even when the file is publicly available on the Hub. Add an hf_hub_download fallback on cache miss (a few hundred bytes, lands in the HF cache, then memoized by the lru_cache). Offline/no-network workers fall through to the existing RuntimeError, whose message now also mentions Hub reachability alongside the explicit image_* config escape hatch. Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/qwen3_vl.py | 19 +++++++++++++++++-- tests/test_client.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index dffdd48..3fe272e 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -279,6 +279,20 @@ def _load_preprocessor_config_json(model_name_or_path: str) -> dict[str, Any]: candidates.append(Path(cached)) except Exception: pass + if not candidates: + # Cache miss for a hub-style id: fall back to downloading the file + # (a few hundred bytes; lands in the HF cache so this is one-time + # per process pool). Hosted env workers render models they never + # loaded locally, so the cache-only lookup above rarely hits there. + # Offline/no-network environments fall through to the error below, + # which names the explicit image_* config escape hatch. + try: + from huggingface_hub import hf_hub_download + + downloaded = hf_hub_download(model_name_or_path, "preprocessor_config.json") + candidates.append(Path(downloaded)) + except Exception: + pass for candidate in candidates: if candidate.is_file(): @@ -290,8 +304,9 @@ def _load_preprocessor_config_json(model_name_or_path: str) -> dict[str, Any]: raise RuntimeError( "Qwen raw image layout could not find preprocessor_config.json for " - f"{model_name_or_path!r}. Ensure the model is cached locally or set all " - "image_* layout fields explicitly in the renderer config." + f"{model_name_or_path!r}. Ensure the model is cached locally or " + "reachable on the Hugging Face Hub, or set all image_* layout fields " + "explicitly in the renderer config." ) diff --git a/tests/test_client.py b/tests/test_client.py index b5f70d6..42566cc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -587,6 +587,41 @@ def test_qwen3_vl_raw_layout_matches_real_processor(tmp_path, monkeypatch): ) +def test_qwen3_vl_preprocessor_config_hub_download_fallback(tmp_path, monkeypatch): + """Hub-style ids that miss the local HF cache fall back to + ``hf_hub_download``; download failure (offline) keeps the explicit-config + error.""" + import huggingface_hub + + from renderers.qwen3_vl import _load_preprocessor_config_json + + config = { + "patch_size": 16, + "temporal_patch_size": 2, + "merge_size": 2, + "size": {"shortest_edge": 65536, "longest_edge": 16777216}, + } + downloaded = tmp_path / "preprocessor_config.json" + downloaded.write_text(json.dumps(config)) + + def fake_download(repo_id, filename): + assert filename == "preprocessor_config.json" + return str(downloaded) + + monkeypatch.setattr(huggingface_hub, "try_to_load_from_cache", lambda *a, **k: None) + monkeypatch.setattr(huggingface_hub, "hf_hub_download", fake_download) + _load_preprocessor_config_json.cache_clear() + assert _load_preprocessor_config_json("org/uncached-model") == config + + def offline_download(repo_id, filename): + raise OSError("offline") + + monkeypatch.setattr(huggingface_hub, "hf_hub_download", offline_download) + _load_preprocessor_config_json.cache_clear() + with pytest.raises(RuntimeError, match="could not find preprocessor_config.json"): + _load_preprocessor_config_json("org/uncached-model-offline") + + # --------------------------------------------------------------------------- # Prompt overflow handling. # ---------------------------------------------------------------------------