diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 9fc64248..4447e4f6 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -607,6 +607,39 @@ def _decode_frames_to_png_and_video( VideoEncoder(frames=tensors, frame_rate=fps).to_file(mp4_path) +def _resize_with_pad(chw, size: int): + """Aspect-preserving letterbox of a (C, H, W) uint8 tensor to size x size. + + Scales the longer side to ``size`` and pads the shorter with black (0). + Mirrors the server's ``Pi05ViTEncoderSubmodule._prepare_one`` geometry, so + sending the pre-resized frame produces the same model input as decoding at + native resolution and resizing on the worker. + """ + import torch + import torch.nn.functional as F + + _, h, w = chw.shape + if (h, w) == (size, size): + return chw + ratio = max(w / size, h / size) + rh, rw = int(h / ratio), int(w / ratio) + x = F.interpolate(chw[None].float(), size=(rh, rw), mode="bilinear", align_corners=False) + ph0, remh = divmod(size - rh, 2) + pw0, remw = divmod(size - rw, 2) + x = F.pad(x, (pw0, pw0 + remw, ph0, ph0 + remh), value=0.0) + return x[0].round().clamp(0, 255).to(torch.uint8) + + +def _decode_frame_to_npy(video_path: str, frame_index: int, npy_path: str, size: int) -> None: + """Decode one frame, letterbox-resize to ``size`` x ``size``, save as a + (C, H, W) uint8 ``.npy`` (the "numpy" upload the server np.loads in memory).""" + import numpy as np + from torchcodec.decoders import VideoDecoder + + frame = VideoDecoder(video_path).get_frames_at(indices=[frame_index]).data[0] # (C,H,W) uint8 + np.save(npy_path, _resize_with_pad(frame, size).numpy()) + + class DROIDDataset(BaseDataset): """DROID robotics dataset for evaluating pi0.5 and V-JEPA 2-AC. @@ -631,6 +664,14 @@ class DROIDDataset(BaseDataset): """ HF_REPO = "lerobot/droid_100" + # pi05 camera frames are letterboxed to this size client-side (matches the + # server's vit_image_size) so both mstar and openpi get identical input. + IMAGE_SIZE = 224 + + PI05_KEYS = [ + "observation.images.exterior_image_1_left", + "observation.images.wrist_image_left" + ] def __init__( self, @@ -663,6 +704,17 @@ def __init__( # producing 8 token-frames; only the first is used as rollout context. self.num_video_frames = num_video_frames + # Fast path: reuse a manifest (PNG paths + robot_state + prompt) built + # on a previous run so repeat benchmarks skip the full-parquet load and + # the per-frame video decode entirely. Only pi05 caches; vjepa2_ac + # streams the episode mp4 directly and is left uncached. + if task == "pi05": + cached = self._load_manifest() + if cached is not None: + print(f" [cache] reusing {len(cached)} pi05 items from manifest") + self.items = self._resize_data(cached) + return + def _dl(filename): return hf_hub_download( self.HF_REPO, filename, repo_type="dataset", cache_dir=cache_dir @@ -683,6 +735,12 @@ def _dl(filename): f"No video keys found in {self.HF_REPO}/meta/info.json. " f"Top-level keys: {list(info.keys())}" ) + + if task == "pi05": + assert all((key in camera_keys for key in self.PI05_KEYS)), \ + f"Expected camera keys {self.PI05_KEYS} not all found in {camera_keys}" + camera_keys = self.PI05_KEYS + chunks_size: int = info.get("chunks_size", info.get("chunk_size", 1000)) print(f" camera keys : {camera_keys}") print(f" chunks_size : {chunks_size}") @@ -724,7 +782,13 @@ def _dl(filename): for frames in episodes.values(): frames.sort(key=lambda r: int(r[frame_col])) - ep_ids = sorted(episodes.keys())[:num_requests] + # pi05 caches a complete manifest, so build every episode once + # (_resize_data truncates to num_requests below) and the cache is reused + # for any num_requests. vjepa2_ac streams mp4s uncached, so keep the + # original [:num_requests] cap to bound its first-run decode cost. + ep_ids = sorted(episodes.keys()) + if task != "pi05": + ep_ids = ep_ids[:num_requests] print(f" using {len(ep_ids)} of {len(episodes)} episodes") self.items: list[RequestInput] = [] @@ -749,6 +813,9 @@ def _dl(filename): if item is not None: self.items.append(item) + if task == "pi05": + self._save_manifest(self.items) + self.items = self._resize_data(self.items) # ------------------------------------------------------------------ @@ -772,19 +839,21 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, local_indices = self._local_frame_indices(frames) first_local = local_indices[0] - image_paths: list[str] = [] + # Decode + letterbox-resize each camera frame to 224x224 uint8 and save + # as a ".npy" (the "numpy" modality). Sending pre-resized arrays lets the + # server skip both image decode and the resize, and lets us hand mstar + # and openpi identical input. + npy_paths: list[str] = [] for cam_key in camera_keys[:3]: chunk_video = download_fn(self._chunk_video_path(ep_id, cam_key, chunks_size)) - png_path = os.path.join(self.local_file_dir, f"ep{ep_id}_cam{len(image_paths)}.png") - _decode_frames_to_png_and_video( - chunk_video, [first_local], png_path=png_path, mp4_path=None - ) - image_paths.append(png_path) + npy_path = os.path.join(self.local_file_dir, f"ep{ep_id}_cam{len(npy_paths)}.npy") + _decode_frame_to_npy(chunk_video, first_local, npy_path, self.IMAGE_SIZE) + npy_paths.append(npy_path) - if not image_paths: + if not npy_paths: raise ValueError("no camera videos found") - while len(image_paths) < 3: - image_paths.append(image_paths[0]) + while len(npy_paths) < 3: + npy_paths.append(npy_paths[0]) state = _to_float_list( frames[0].get(state_col) if state_col else None, self.action_dim @@ -792,8 +861,8 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, return RequestInput( req_type=RequestType.VLA, prompt=language or "manipulate the object", - image_path=image_paths[0], - extra_image_paths=image_paths[1:], + # openpi droid policy only uses the first extra image, so send 2 cameras. + _numpy_paths=npy_paths[:2], model_kwargs={"robot_state": state}, ) @@ -824,7 +893,9 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, ) mp4_path = os.path.join(self.local_file_dir, f"ep{ep_id}.mp4") _decode_frames_to_png_and_video( - chunk_video, video_local_indices, png_path=None, mp4_path=mp4_path + video_path=chunk_video, + frame_indices=video_local_indices, + png_path=None, mp4_path=mp4_path ) actions = [_to_float_list(f.get(action_col), self.action_dim) @@ -842,6 +913,76 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, }, ) + # ------------------------------------------------------------------ + # pi05 manifest cache + # ------------------------------------------------------------------ + + def _manifest_path(self) -> str: + """Manifest filename keyed by the params that change the built items.""" + return os.path.join( + self.local_file_dir, + f"manifest_pi05_npy{self.IMAGE_SIZE}_nvf{self.num_video_frames}_ad{self.action_dim}.json", + ) + + def _load_manifest(self) -> list[RequestInput] | None: + """Return cached pi05 RequestInputs, or None to force a rebuild. + + Returns None if the manifest is absent, unreadable, or references a .npy + that no longer exists on disk. + """ + import json as _json + + path = self._manifest_path() + if not os.path.exists(path): + return None + try: + with open(path) as f: + data = _json.load(f) + items: list[RequestInput] = [] + for entry in data["items"]: + npy_paths = [os.path.join(self.local_file_dir, p) + for p in entry.get("numpy_paths", [])] + for p in npy_paths: + if not os.path.exists(p): + print(f" [cache] missing {p}; rebuilding") + return None + items.append(RequestInput( + req_type=RequestType.VLA, + prompt=entry["prompt"], + _numpy_paths=npy_paths, + model_kwargs=entry.get("model_kwargs", {}), + )) + return items or None + except Exception as e: + print(f" [cache] manifest unreadable ({e}); rebuilding") + return None + + def _save_manifest(self, items: list[RequestInput]) -> None: + """Persist built pi05 items so the next run can skip parquet + decode. + + .npy paths are stored as basenames (relative to local_file_dir) and the + write is atomic (tmp + os.replace) so an interrupted run never leaves a + half-written manifest that would later be reused. + """ + import json as _json + + entries = [{ + "prompt": it.prompt, + "numpy_paths": [os.path.basename(p) for p in it._numpy_paths], + "model_kwargs": it.model_kwargs, + } for it in items] + payload = { + "version": 2, + "task": "pi05", + "num_video_frames": self.num_video_frames, + "action_dim": self.action_dim, + "items": entries, + } + tmp = self._manifest_path() + ".tmp" + with open(tmp, "w") as f: + _json.dump(payload, f) + os.replace(tmp, self._manifest_path()) + @property def num_requests(self) -> int: return self._num_requests @@ -852,7 +993,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> RequestInput: return self.items[idx] - # ------------------------------------------------------------------ + class VideoMMEDataset(BaseDataset): """ Dataset loader for Video-MME (https://video-mme.github.io/). diff --git a/benchmark/download_pi05_ckpt.py b/benchmark/download_pi05_ckpt.py new file mode 100644 index 00000000..088cceb8 --- /dev/null +++ b/benchmark/download_pi05_ckpt.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Download the pi0.5 checkpoint and print the local path. + conda activate openpi + python benchmark/download_pi05_ckpt.py +""" + +from __future__ import annotations + +import argparse +import sys + +DEFAULT_CONFIG = "pi05_droid" +DEFAULT_CHECKPOINT = "gs://openpi-assets/checkpoints/pi05_droid" + + +def main(): + p = argparse.ArgumentParser(description="Download pi0.5 checkpoint") + p.add_argument("--config", default=DEFAULT_CONFIG) + p.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) + args = p.parse_args() + + try: + from openpi.shared import download + from openpi.training import config as _config + except ImportError as e: + sys.exit( + f"\n[ERROR] openpi is not importable ({e}).\n" + "Run inside the openpi conda env:\n" + " conda activate openpi\n" + ) + + _config.get_config(args.config) + ckpt_dir = download.maybe_download(args.checkpoint) + print(ckpt_dir) + + +if __name__ == "__main__": + main() diff --git a/benchmark/openpi_instructions.md b/benchmark/openpi_instructions.md new file mode 100644 index 00000000..4e1502fa --- /dev/null +++ b/benchmark/openpi_instructions.md @@ -0,0 +1,19 @@ +1. Make a python3.12 environment +2. Clone the `openpi` repo +3. Run the following in your environment: +``` +git submodule update --init --recursive +GIT_LFS_SKIP_SMUDGE=1 uv sync +GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . +``` +4. From the mstar repo, run: +``` +pip install gsutil +python benchmark/download_pi05_ckpt.py +mkdir +mv /home/$USER/.cache/openpi/* +``` +5. Start the server with: +``` +uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=/openpi-assets/checkpoints/pi05_droid +``` diff --git a/benchmark/request.py b/benchmark/request.py index 009646c4..91219f23 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -14,7 +14,7 @@ import aiohttp import numpy as np -from benchmark.base import Bagel, Model, Orpheus, RequestType, Status +from benchmark.base import Bagel, Model, Orpheus, Pi05, RequestType, Status from benchmark.utils import _write_wav @@ -604,6 +604,11 @@ class RequestInput: # All paths are uploaded as separate "files" form fields alongside image_path. extra_image_paths: list[str] = field(default_factory=list) + # Pre-decoded ".npy" uploads (the "numpy" modality): paths to raw uint8 + # arrays the server np.loads in memory (no disk, no decode). Used by pi0.5 + # (resized 224x224 camera frames); each path is one camera. + _numpy_paths: list[str] = field(default_factory=list) + # Per-request model_kwargs merged into the JSON payload at send time. # Use this for robotics-specific fields: robot_state, actions, states, # rollout_horizon, etc. @@ -620,6 +625,7 @@ class RequestInput: _audio_b64: Optional[str] = field(default=None, repr=False) _video_b64: Optional[str] = field(default=None, repr=False) _extra_image_bytes: list[bytes] = field(default_factory=list, repr=False) + _numpy_bytes: list[bytes] = field(default_factory=list, repr=False) def __post_init__(self): if self.image_path and self._image_bytes is None: @@ -633,6 +639,8 @@ def __post_init__(self): self._video_b64 = base64.b64encode(self._video_bytes).decode() if self.extra_image_paths and not self._extra_image_bytes: self._extra_image_bytes = [Path(p).read_bytes() for p in self.extra_image_paths] + if self._numpy_paths and not self._numpy_bytes: + self._numpy_bytes = [Path(p).read_bytes() for p in self._numpy_paths] def get_all_filepaths(self) -> dict[str, str]: res = {} @@ -721,7 +729,9 @@ async def send_request( input_mod = req_type.get_input_modalities() if "," in input_mod or input_mod not in ("text",): # TODO: if a request does not have text as an input modality, this must be revisited - form.add_field("input_modalities", ",".join([input_mod, "text"])) + if "text" not in input_mod: + input_mod = ",".join([input_mod, "text"]) + form.add_field("input_modalities", input_mod) for modality in req_input.get_all_filepaths(): file_content = req_input.get_bytes(modality) @@ -739,7 +749,16 @@ async def send_request( "files", content, filename=os.path.basename(path), - content_type="image/png", + content_type="application/octet-stream", + ) + # Pre-decoded ".npy" uploads (numpy modality): the server keeps these + # in memory and np.loads them — no disk, no decode (pi0.5 cameras). + for path, content in zip(req_input._numpy_paths, req_input._numpy_bytes, strict=True): + form.add_field( + "files", + content, + filename=os.path.basename(path), + content_type="application/octet-stream", ) metrics.start_time = time.monotonic() @@ -1652,3 +1671,142 @@ async def _send_request_audio_speech( else: metrics.record_completion() return metrics + + +# --------------------------------------------------------------------------- +# openpi: call their own api server +# --------------------------------------------------------------------------- +def _build_obs(req_input: RequestInput) -> dict: + """Map our DROIDDataset RequestInput → openpi DroidInputs dict. + + DROIDDataset emits the camera frames as ``_numpy_bytes`` — already-decoded, + letterboxed 224x224 uint8 arrays (CxHxW), the same bytes mstar receives, so + both systems get identical input. openpi wants (H,W,3) uint8, so we just + transpose; no decode/resize needed here. cam0 → exterior_image_1_left, + cam1 → wrist_image_left. The 32-dim DROID state holds joint positions in + [:7] and gripper in [7]; the rest is padding we ignore. + + NOTE: lerobot/droid_100 ships no gripper signal, so state[7:8] is always + padding 0.0 — actions are not semantically valid for either system. Fine + here: the benchmark measures latency (identical tensor shapes => identical + compute), not action quality. + + openpi DroidInputs (droid_policy.py:make_droid_example) expects: + observation/exterior_image_1_left : (H,W,3) uint8 + observation/wrist_image_left : (H,W,3) uint8 + observation/joint_position : (7,) float + observation/gripper_position : (1,) float + prompt : str + """ + import io + + import numpy as np + + state = np.asarray(req_input.model_kwargs.get("robot_state", []), dtype=np.float32) + if state.size < 8: + state = np.pad(state, (0, 8 - state.size)) + + imgs = [np.load(io.BytesIO(b)).transpose(1, 2, 0) # CxHxW uint8 -> HxWxC + for b in req_input._numpy_bytes] + base_img = imgs[0] + wrist_img = imgs[1] if len(imgs) > 1 else imgs[0] + + return { + "observation/exterior_image_1_left": base_img, + "observation/wrist_image_left": wrist_img, + "observation/joint_position": state[:7], + "observation/gripper_position": state[7:8], + "prompt": req_input.prompt or "manipulate the object", + } + +class OpenPi(InferenceSystem): + async def send_request( + self, + session: aiohttp.ClientSession, + req_input: RequestInput, + base_url: str, + request_id: int, + model: Model, + additional_model_kwargs: dict = {}, + ) -> RequestMetrics: + assert isinstance(model, Pi05), "openpi only supports Pi05 models" + assert req_input.req_type == RequestType.VLA, "openpi only supports VLA requests" + + import numpy as np + from openpi_client import msgpack_numpy + metrics = RequestMetrics( + request_id=request_id, + type=req_input.req_type, + expected_output_modalities=["action"], + ) + + # base_url is expected to be an http(s) URL for consistency with the rest + # of the harness; convert to ws(s) for the websocket handshake. + ws_url = base_url + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[len("http://"):] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[len("https://"):] + + # Build the observation. _build_obs gives us full-res images + a + # 32-dim DROID state; openpi expects 224x224 uint8 images and + # separate joint/gripper vectors, which _build_obs already provides. + obs = _build_obs(req_input) + packer = msgpack_numpy.Packer() + payload = packer.pack(obs) + + try: + metrics.start_time = time.monotonic() + async with session.ws_connect( + ws_url, + max_msg_size=0, # no limit; action chunks are small but obs is large + compress=0, # match the openpi client (compression=None) + timeout=aiohttp.ClientWSTimeout(ws_close=30), + ) as ws: + # Server sends metadata as the first message on connect. + # Drain it; we don't need it for benchmarking, but we MUST read + # it before sending or the server's send buffer can stall. + metadata_msg = await ws.receive() + if metadata_msg.type != aiohttp.WSMsgType.BINARY: + raise RuntimeError( + f"Expected binary metadata frame, got {metadata_msg.type}: " + f"{metadata_msg.data!r}" + ) + _ = msgpack_numpy.unpackb(metadata_msg.data) + + # Send observation, await action chunk. + await ws.send_bytes(payload) + response_msg = await ws.receive() + + if response_msg.type == aiohttp.WSMsgType.TEXT: + # The openpi server signals errors by sending a string. + raise RuntimeError(f"Error in inference server:\n{response_msg.data}") + if response_msg.type != aiohttp.WSMsgType.BINARY: + raise RuntimeError( + f"Unexpected ws frame type {response_msg.type}: " + f"{response_msg.data!r}" + ) + + arrival_time = time.monotonic() + response = msgpack_numpy.unpackb(response_msg.data) + action_chunk = response["actions"] # (action_horizon, action_dim) + + # One-shot output: the entire action chunk arrives at once, + # so TTFT == E2E. Encode the chunk as a single output unit. + # n_tokens = action_horizon so throughput numbers are + # in "actions/sec" if the metrics layer divides by n_tokens. + action_bytes = np.asarray(action_chunk, dtype=np.float32).tobytes() + metrics.record_output_chunk( + modality="action", + data_b64=base64.b64encode(action_bytes), + arrival_time=arrival_time, + n_tokens=int(action_chunk.shape[0]), + ) + + except Exception as e: + metrics.record_error(str(e)) + else: + metrics.record_completion() + + return metrics + diff --git a/benchmark/runner.py b/benchmark/runner.py index c7544c49..fb4e9883 100644 --- a/benchmark/runner.py +++ b/benchmark/runner.py @@ -24,6 +24,7 @@ from benchmark.request import ( AggregateMetrics, InferenceSystem, + OpenPi, OursOpenAI, OurSystem, RequestInput, @@ -52,6 +53,7 @@ class InferenceSystemType(Enum): VLLM_OMNI = "vllm_omni" VOX_SERVE = "vox_serve" SGLANG_OMNI = "sglang_omni" + OPENPI = "openpi" def instantiate(self) -> InferenceSystem: if self == InferenceSystemType.OURS: @@ -64,6 +66,10 @@ def instantiate(self) -> InferenceSystem: return VoxServe() elif self == InferenceSystemType.SGLANG_OMNI: return SGLangOmni() + elif self == InferenceSystemType.OPENPI: + return OpenPi() + else: + raise NotImplementedError("Unknown inference system", self) class ProfilingType(Enum): diff --git a/configs/pi05.yaml b/configs/pi05.yaml index e7aaaf09..85202fe7 100644 --- a/configs/pi05.yaml +++ b/configs/pi05.yaml @@ -8,6 +8,7 @@ node_groups: - 0 - node_names: - - LLM + - paligemma_LLM + - action_expert_LLM ranks: - 0 diff --git a/configs/pi05_droid.yaml b/configs/pi05_droid.yaml index 74b504b4..0d750a06 100644 --- a/configs/pi05_droid.yaml +++ b/configs/pi05_droid.yaml @@ -18,11 +18,13 @@ max_seq_len: 2048 # post-compute Python, so both systems do identical 32-dim work. # # CUDA-graph note: action_horizon and action_dim are both baked into the -# graph captures (see Pi05LLMSubmodule.get_cuda_graph_configs in -# mstar/model/pi05/submodules.py:325-329). They MUST be set at server-init -# time, never per-request — that's what this yaml override is for. +# graph captures (see Pi05ActionExpertSubmodule.get_cuda_graph_configs +# They MUST be set at server-init time, never per-request — that's what +# this yaml override is for. + model_kwargs: action_horizon: 15 + num_cameras: 2 node_groups: - node_names: @@ -31,6 +33,7 @@ node_groups: - 0 - node_names: - - LLM + - paligemma_LLM + - action_expert_LLM ranks: - 0 diff --git a/mstar/api_server/_timing.py b/mstar/api_server/_timing.py new file mode 100644 index 00000000..811ee0ba --- /dev/null +++ b/mstar/api_server/_timing.py @@ -0,0 +1,20 @@ +"""Lightweight, env-gated timing prints shared by the API server and data worker. + +Enabled with ``MSTAR_TIMING=1`` (anything other than unset/``0``/``false``). +``perf_counter`` is process-wide monotonic, so timestamps stamped in the +API-server handler thread and read in the data-worker thread are directly +comparable — that's how queue-wait (polling) latency is separated from actual +work in the [API-TIMING]/[DW-TIMING] brackets. +""" +import os + +TIMING_ENABLED = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") + + +def make_tlog(prefix: str): + """Return a ``tlog(msg)`` that prints ``[] `` when enabled.""" + def _tlog(msg: str) -> None: + if TIMING_ENABLED: + print(f"[{prefix}] {msg}", flush=True) + + return _tlog diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index c7dabb7b..67731b78 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -1,10 +1,12 @@ +import io import logging import queue import threading import time +import numpy as np import torch from mstar.graph.loop_indices import NestedLoopIndices @@ -15,6 +17,8 @@ except (ImportError, RuntimeError, OSError): VideoDecoder = None +from mstar.api_server._timing import TIMING_ENABLED as _TIMING +from mstar.api_server._timing import make_tlog from mstar.api_server.request_types import PreprocessInput, ResultChunk, ResultTensors from mstar.communication.communicator import CommProtocol, ZMQCommunicator from mstar.communication.tensors import NameToTensorList, create_tensor_communication_manager @@ -30,6 +34,10 @@ logger = logging.getLogger(__name__) +# See mstar.api_server._timing; env-gated [DW-TIMING] prints (MSTAR_TIMING=1) +# that pair with the [API-TIMING] prints to split queue-wait from actual work. +_tlog = make_tlog("DW-TIMING") + def _preprocess_loop(**kwargs): worker = PreprocessWorkerThread(**kwargs) @@ -75,6 +83,7 @@ def __init__( self.thread.start() def new_request(self, input: PreprocessInput): + input._t_enqueue = time.perf_counter() # for queue-wait timing self.output_loop_idxs[input.request_id] = {} self.per_request_reading_tensors[input.request_id] = 0 self.request_input_queue.put(input) @@ -157,6 +166,7 @@ def __init__( self.model = model self.tensor_uuid_to_metadata_per_request = {} + self._t_read_start: dict[str, float] = {} # request_id -> read-start time self.communicator = ZMQCommunicator( my_id="api_server_preprocess_worker", @@ -176,6 +186,8 @@ def __init__( def _process_input( self, input: PreprocessInput ): + _t0 = time.perf_counter() + _enq = getattr(input, "_t_enqueue", None) tensors: NameToTensorList = {} input_metadata = {} @@ -207,6 +219,18 @@ def _process_input( tensors[key].append(out.data) input_metadata[key].append(out.metadata) + # ".npy" uploads (modality "numpy") are kept in memory and np.load'd + # here as "raw_inputs"; the model maps them in process_prompt. + if input.numpy_bytes: + tensors["raw_inputs"] = [] + input_metadata["raw_inputs"] = [] + for blob in input.numpy_bytes: + tensors["raw_inputs"].append( + torch.from_numpy(np.load(io.BytesIO(blob))).to(self.device) + ) + input_metadata["raw_inputs"].append({}) + + _t_load = time.perf_counter() # media decode (load_image/audio/video) done # Then, tokenize the prompt and let the model augment/transform the # tensors dict (e.g., Qwen3-Omni needs to compute pixel_values, @@ -231,6 +255,8 @@ def _process_input( list(byte_data), dtype=torch.uint8, device=self.device )] + _t_prompt = time.perf_counter() # tokenization / process_prompt done + initial_signals = self.tensor_manager.store_and_return_tensor_info( request_id=input.request_id, tensors=tensors # dict(modality_input: list[tensors]) @@ -248,6 +274,8 @@ def _process_input( input.request_id, uuid, persist=True ) + _t_store = time.perf_counter() # tensor store/register/persist done + msg = ConductorMessage( message_type=ConductorMessageType.NEW_REQUEST, body=NewRequestConductor( @@ -260,10 +288,26 @@ def _process_input( ), ) self.communicator.send("conductor", msg) + if _TIMING: + _t_send = time.perf_counter() + _qwait = (_t0 - _enq) * 1e3 if _enq is not None else -1.0 + _imgs = tensors.get("image_inputs") or [] + _img_shape = tuple(_imgs[0].shape) if _imgs else None + _tlog( + f"{input.request_id[:8]} INPUT " + f"img={_img_shape}x{len(_imgs)} " # decoded shape x count (decode cost driver) + f"qwait={_qwait:.2f} " # enqueue->dequeue (polling) + f"load={(_t_load - _t0) * 1e3:.2f} " # media decode + f"prompt={(_t_prompt - _t_load) * 1e3:.2f} " # tokenize + f"store={(_t_store - _t_prompt) * 1e3:.2f} " # tensor store/register + f"send={(_t_send - _t_store) * 1e3:.2f} " # zmq send to conductor + f"total={(_t_send - _t0) * 1e3:.2f}ms" + ) def _read_result_tensor( self, result: ResultTensors ): + self._t_read_start[result.request_id] = time.perf_counter() result.graph_edge.name = f"{result.modality}_output" self.tensor_manager.start_read_tensors( request_id=result.request_id, @@ -279,18 +323,31 @@ def _process_read_tensors(self): did_work = False for request_id, graph_edges in self.tensor_manager.get_ready_tensors().items(): did_work = True + _t_ready = time.perf_counter() # tensor became ready (RDMA read done) + _read_start = self._t_read_start.pop(request_id, None) for graph_edge in graph_edges: modality = graph_edge.name.replace("_output", "") for tensor_info in graph_edge.tensor_info: logger.debug("Reading in OUTPUT tensor %s with uuid %s", graph_edge.name, tensor_info.uuid) + _t_a = time.perf_counter() tensor = self.tensor_manager.get_tensor( request_id=request_id, uuid=tensor_info.uuid ) + _t_get = time.perf_counter() postprocessed = self.model.postprocess( tensor, modality ) + _t_post = time.perf_counter() + if _TIMING: + _rw = (_t_ready - _read_start) * 1e3 if _read_start else -1.0 + _tlog( + f"{request_id[:8]} OUTPUT " + f"read_wait={_rw:.2f} " # start_read -> ready (RDMA + polling) + f"get={(_t_get - _t_a) * 1e3:.2f} " # fetch tensor handle + f"post={(_t_post - _t_get) * 1e3:.2f}ms" # model.postprocess + ) chunk_metadata = self.tensor_uuid_to_metadata_per_request[request_id][ tensor_info.uuid] or {} @@ -302,12 +359,14 @@ def _process_read_tensors(self): "sample_rate": self.model.get_output_sample_rate("audio"), } - self.out_queue.put(ResultChunk( + _chunk = ResultChunk( request_id=request_id, modality=modality, data=postprocessed, metadata=chunk_metadata, - )) + ) + _chunk._t_outqueue = time.perf_counter() + self.out_queue.put(_chunk) del self.tensor_uuid_to_metadata_per_request[request_id][ tensor_info.uuid] self.tensor_manager.dereference( diff --git a/mstar/api_server/entrypoint.py b/mstar/api_server/entrypoint.py index d71611b2..4ee164d0 100644 --- a/mstar/api_server/entrypoint.py +++ b/mstar/api_server/entrypoint.py @@ -15,11 +15,13 @@ from typing import Optional import uvicorn -from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from starlette.concurrency import run_in_threadpool +from mstar.api_server._timing import TIMING_ENABLED as _TIMING +from mstar.api_server._timing import make_tlog from mstar.api_server.data_worker import PreprocessWorker from mstar.api_server.request_types import APIServerMessage, PreprocessInput, ResultChunk from mstar.communication.communicator import CommProtocol, ZMQCommunicator @@ -28,7 +30,15 @@ logger = logging.getLogger(__name__) -SUPPORTED_MODALITIES = frozenset({"text", "image", "audio", "video", "action", "scalar", "tensor"}) +# See mstar.api_server._timing; env-gated [API-TIMING] prints (MSTAR_TIMING=1) +# that pair with the [DW-TIMING] prints to split HTTP/handler overhead from +# data-worker work. +_tlog = make_tlog("API-TIMING") + + +SUPPORTED_MODALITIES = frozenset( + {"text", "image", "audio", "video", "action", "scalar", "tensor", "numpy"} +) # Extension-based modality detection for uploaded files. _EXT_TO_MODALITY: dict[str, str] = {} @@ -36,6 +46,7 @@ "image": (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".gif"), "audio": (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac"), "video": (".mp4", ".avi", ".mov", ".mkv", ".webm"), + "numpy": (".npy",) }.items(): for _ext in _exts: _EXT_TO_MODALITY[_ext] = _mod @@ -224,6 +235,7 @@ def submit_request( *, text: str | None = None, file_paths: dict[str, list[str]] | None = None, + numpy_bytes: list[bytes] | None = None, input_modalities: list[str], output_modalities: list[str], model_kwargs: dict | None = None, @@ -254,12 +266,14 @@ def submit_request( "input_modalities": input_modalities, "output_modalities": output_modalities, "final_outputs": {}, + "_t_submit": time.perf_counter(), # for end-to-end wait timing } self.preprocess_worker.new_request(PreprocessInput( request_id=request_id, text=text, file_paths=file_paths, + numpy_bytes=numpy_bytes, input_modalities=input_modalities, output_modalities=output_modalities, model_kwargs=model_kwargs @@ -335,6 +349,14 @@ def _process_messages(self) -> None: result_chunk.modality, result_chunk.request_id ) rid = result_chunk.request_id + if _TIMING: + _oq = getattr(result_chunk, "_t_outqueue", None) + if _oq is not None: + _tlog( + f"{rid[:8]} CHUNK " + # out_queue.put -> picked up here (output polling hop) + f"outq_wait={(time.perf_counter() - _oq) * 1e3:.2f}ms" + ) with self.request_lock: self.pending_requests[rid]["chunks"].append( result_chunk @@ -360,6 +382,10 @@ async def iter_result_chunks(self, request_id: str): pre-serialized line). """ start = time.time() + with self.request_lock: + _req0 = self.pending_requests.get(request_id) + _t_submit = _req0["_t_submit"] if _req0 else None + _t_first = None while True: if time.time() - start > self.timeout_seconds: with self.request_lock: @@ -380,9 +406,19 @@ async def iter_result_chunks(self, request_id: str): done = True for chunk in new_chunks: + if _t_first is None: + _t_first = time.perf_counter() yield chunk if done: + if _TIMING and _t_submit is not None: + _now = time.perf_counter() + _tlog( + f"{request_id[:8]} STREAM " + # submit -> first chunk delivered (full worker round-trip) + f"ttfc={(_t_first - _t_submit) * 1e3 if _t_first else -1:.2f} " + f"total={(_now - _t_submit) * 1e3:.2f}ms" # submit -> done + ) logger.info("Async stream results received finish for %s", request_id) # flush remaining remaining: list[ResultChunk] = [] @@ -460,6 +496,18 @@ def cleanup(self) -> None: allow_headers=["*"], ) + +@app.middleware("http") +async def _stamp_recv_time(request: Request, call_next): + # Stamp ASGI request arrival. The gap to the handler body (_t_in) covers + # routing + multipart form parsing (FastAPI reads the upload bodies while + # resolving the File()/Form() params, before the handler runs) — that's the + # HTTP-side overhead not visible in the [DW-TIMING]/STREAM brackets. + if _TIMING: + request.state._t_recv = time.perf_counter() + return await call_next(request) + + api_server: APIServer | None = None # Mount the OpenAI-compatible routes (/v1/*) alongside the native /generate. @@ -472,6 +520,7 @@ def cleanup(self) -> None: @app.post("/generate") async def generate( + request: Request, text: Optional[str] = Form(None), files: Optional[list[UploadFile]] = File(None), input_modalities: Optional[str] = Form(None), @@ -500,10 +549,20 @@ async def generate( if api_server is None: raise HTTPException(status_code=503, detail="Server not ready") + _t_in = time.perf_counter() + if _TIMING: + _recv = getattr(request.state, "_t_recv", None) + if _recv is not None: + # ASGI receive -> handler body = routing + multipart parse (HTTP-side) + _tlog(f"PREHDLR parse={(_t_in - _recv) * 1e3:.2f}ms") out_mods = [m.strip() for m in output_modalities.split(",") if m.strip()] # --- save uploaded files, grouped by modality ---------------- + # The "numpy" modality (.npy) is kept in memory and np.load'd by the data + # worker; image/audio/video are written to disk so their decoders work from + # a file (PNG/mp4 decode prefers a path). file_paths: dict[str, list[str]] = {} + numpy_bytes: list[bytes] = [] if files: for f in files: modality = _detect_modality(f.filename or "") @@ -512,9 +571,12 @@ async def generate( status_code=400, detail=f"Cannot determine modality for file: {f.filename}", ) + content = await f.read() + if modality == "numpy": + numpy_bytes.append(content) + continue save_name = f"{uuid.uuid4()}_{f.filename}" save_path = api_server.upload_dir / save_name - content = await f.read() await run_in_threadpool(save_path.write_bytes, content) file_paths.setdefault(modality, []).append(str(save_path)) @@ -524,21 +586,33 @@ async def generate( else: in_mods: list[str] = [] in_mods.extend(file_paths.keys()) + # ".npy" uploads bypass file_paths (kept in memory as numpy_bytes), so + # add their "numpy" modality explicitly or auto-detect would drop it. + if numpy_bytes: + in_mods.append("numpy") if text: in_mods.append("text") parsed_kwargs = json.loads(model_kwargs) if model_kwargs else None + _t_files = time.perf_counter() # multipart read + disk save done try: request_id = api_server.submit_request( text=text, file_paths=file_paths or None, + numpy_bytes=numpy_bytes or None, input_modalities=in_mods, output_modalities=out_mods, model_kwargs=parsed_kwargs, streaming=streaming, request_id=request_id, ) + if _TIMING: + _tlog( + f"{request_id[:8]} HANDLER " + f"files={(_t_files - _t_in) * 1e3:.2f} " # multipart read + disk write + f"submit={(time.perf_counter() - _t_files) * 1e3:.2f}ms" # submit_request + ) if streaming: return StreamingResponse( @@ -550,12 +624,20 @@ async def generate( chunks = await run_in_threadpool( api_server.collect_results, request_id ) + _t_results = time.perf_counter() outputs: dict[str, list[dict]] = {} for chunk in chunks: outputs.setdefault(chunk.modality, []).append({ "data": base64.b64encode(chunk.data).decode("ascii"), "metadata": chunk.metadata, }) + if _TIMING: + _tlog( + f"{request_id[:8]} BLOCKING " + f"wait={(_t_results - _t_files) * 1e3:.2f} " # submit -> all results in + f"serialize={(time.perf_counter() - _t_results) * 1e3:.2f} " # b64 + json + f"total={(time.perf_counter() - _t_in) * 1e3:.2f}ms" + ) return JSONResponse({ "request_id": request_id, "outputs": outputs, diff --git a/mstar/api_server/request_types.py b/mstar/api_server/request_types.py index 23c5c899..e7bfa415 100644 --- a/mstar/api_server/request_types.py +++ b/mstar/api_server/request_types.py @@ -49,3 +49,8 @@ class PreprocessInput: input_modalities: list[str] output_modalities: list[str] model_kwargs: dict + + # In-memory uploads for the "numpy" modality (.npy): the bytes are NOT + # written to disk (unlike images/audio/video), so the data worker np.loads + # them directly. Each entry is one .npy blob (e.g. one camera frame). + numpy_bytes: list[bytes] | None = None diff --git a/mstar/engine/kv_store.py b/mstar/engine/kv_store.py index 380d5755..b2975429 100644 --- a/mstar/engine/kv_store.py +++ b/mstar/engine/kv_store.py @@ -604,6 +604,10 @@ def add_request(self, request_id: str, labels: list[str]=None): } def remove_request(self, request_id: str): + if request_id not in self.request_states: + # This request has already been removed; e.g., if we have colocated + # nodes sharing a KV cache + return for label in self.request_states[request_id]: self.wait_for_retrieves(request_id, label) with self._lock: diff --git a/mstar/model/base.py b/mstar/model/base.py index 54a7e90d..892b22e1 100644 --- a/mstar/model/base.py +++ b/mstar/model/base.py @@ -379,7 +379,11 @@ def process_prompt( def load_image(self, filepath: str, device: str) -> TensorAndMetadata: import torchvision - img = torchvision.io.decode_image(filepath).to(device) # uint8 CxHxW + with open(filepath, "rb") as f: + raw = f.read() + img = torchvision.io.decode_image( + torch.frombuffer(bytearray(raw), dtype=torch.uint8) + ).to(device) # uint8 CxHxW img = img.float() / 255.0 return TensorAndMetadata(img) diff --git a/mstar/model/pi05/components/siglip.py b/mstar/model/pi05/components/siglip.py index d017558c..8c512441 100644 --- a/mstar/model/pi05/components/siglip.py +++ b/mstar/model/pi05/components/siglip.py @@ -1,45 +1,197 @@ -"""SigLIP vision encoder for Pi0.5. +"""SigLIP vision encoder for Pi0.5 (native mstar port). -Thin wrapper around the HuggingFace SiglipVisionModel that produces a fixed -number of image tokens (default 256) per camera image at the resolution Pi0.5 -expects (224x224). A learned linear projection maps SigLIP's hidden dim to the -LLM hidden dim so the resulting tokens can be concatenated with PaliGemma -language token embeddings. +Ports the inference path of HuggingFace's ``SiglipVisionModel`` (So400m/14) +into mstar so we own the code and can fuse projections. Differences from the +transformers implementation: + + * **Fused QKV** — the three ``q/k/v_proj`` GEMMs are merged into one + ``QKVParallelLinear`` (loaded from the separate checkpoint keys via the + ``q/k/v`` stacked-param rules; see ``SIGLIP_STACKED_PARAMS``). + * **SDPA attention** — full bidirectional ``scaled_dot_product_attention``. + We do NOT use flash-attn or the Triton ``sliding_window_attn`` here: the + encoder runs in **fp32** (Pi05VitEncoderSubmodule forces it, since bf16 + rounding over 27 layers perturbs the actions) and flash-attn is fp16/bf16 + only, while the Triton kernel is causal-only and rejects head_dim=72. + * **Inference-only** — all weight-init, gradient-checkpointing, the text + tower, pooling head, and variable-resolution position interpolation are + dropped. Images are a fixed 224x224 → 256 patches. + +Only ``last_hidden_state`` is consumed downstream (``vision_use_head=False`` +in the original), so the pooling head is omitted entirely. """ +from __future__ import annotations import torch +import torch.nn.functional as F from torch import nn -from transformers import SiglipVisionConfig, SiglipVisionModel +from mstar.distributed.communication import TPCommGroup +from mstar.model.components.distributed.linear import QKVParallelLinear +from mstar.model.loader import StackedParamRule from mstar.model.pi05.config import Pi05Config +# SigLIP architectural constants not carried on Pi05Config. These match +# HF ``SiglipVisionConfig`` defaults for the So400m checkpoint. +_LAYER_NORM_EPS = 1e-6 -class Pi05SiglipEncoder(nn.Module): - """SigLIP image encoder + linear connector to the LLM hidden size.""" +# Route the checkpoint's separate q/k/v projection keys into the fused +# ``qkv_proj`` parameter. Consumed by ``load_hf_weights`` when loading the +# encoder (the SigLIP MLP is ungated, so there are no gate/up rules). +SIGLIP_STACKED_PARAMS: list[StackedParamRule] = [ + StackedParamRule(".qkv_proj", ".q_proj", "q"), + StackedParamRule(".qkv_proj", ".k_proj", "k"), + StackedParamRule(".qkv_proj", ".v_proj", "v"), +] + + +class _SiglipVisionEmbeddings(nn.Module): + """Conv patch embedding + learned position embedding. + + Fixed-resolution only: 224x224 input → a 16x16 grid of 14px patches → + 256 tokens. Position ids are computed inline (no buffer) so the module + has no non-persistent state to re-materialize after ``to_empty``. + """ def __init__(self, config: Pi05Config): super().__init__() - self.config = config + self.embed_dim = config.vit_hidden_size + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=config.vit_patch_size, + stride=config.vit_patch_size, + padding="valid", + ) + self.num_positions = (config.vit_image_size // config.vit_patch_size) ** 2 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + # pixel_values: (N, 3, H, W) -> patches (N, embed_dim, gh, gw). + patch_embeds = self.patch_embedding(pixel_values.to(self.patch_embedding.weight.dtype)) + embeddings = patch_embeds.flatten(2).transpose(1, 2) # (N, num_patches, embed_dim) + positions = torch.arange(self.num_positions, device=embeddings.device) + return embeddings + self.position_embedding(positions) + + +class _SiglipAttention(nn.Module): + """Bidirectional multi-head self-attention with a fused QKV projection. + + Full MHA (no GQA): num_kv_heads == num_heads. Attention is computed + per-image over its own 256 patches (the batch dim isolates images), so + no attention mask is needed. + """ - siglip_cfg = SiglipVisionConfig( - hidden_size=config.vit_hidden_size, - intermediate_size=config.vit_intermediate_size, - num_hidden_layers=config.vit_num_layers, - num_attention_heads=config.vit_num_heads, - num_channels=3, - image_size=config.vit_image_size, - patch_size=config.vit_patch_size, - # Pi0.5 / lerobot's PaliGemma SigLIP does NOT use the pooling - # head — only ``last_hidden_state`` is consumed downstream by the - # multi_modal_projector. Disabling the head matches the - # production checkpoint key set (no ``vision_model.head.*`` keys). - vision_use_head=False, + def __init__(self, config: Pi05Config): + super().__init__() + self.embed_dim = config.vit_hidden_size + self.num_heads = config.vit_num_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"vit_hidden_size {self.embed_dim} not divisible by " + f"vit_num_heads {self.num_heads}" + ) + self.scale = self.head_dim**-0.5 + + # Trivial (single-rank) comm group: reuses the TP-aware fused-QKV + # loader without any actual sharding. bias=True — SigLIP projects + # q/k/v with bias. + self.qkv_proj = QKVParallelLinear( + comm_group=TPCommGroup.trivial(), + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_heads, + bias=True, + ) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + n, seq_len, _ = hidden_states.shape + qkv = self.qkv_proj(hidden_states) # (N, seq, 3*embed_dim) + q, k, v = qkv.split([self.embed_dim, self.embed_dim, self.embed_dim], dim=-1) + + # (N, seq, embed) -> (N, heads, seq, head_dim) for SDPA. + def to_heads(x: torch.Tensor) -> torch.Tensor: + return x.view(n, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + out = F.scaled_dot_product_attention( + to_heads(q), to_heads(k), to_heads(v), scale=self.scale, ) - self.vision_model = SiglipVisionModel(siglip_cfg) - self.connector = nn.Linear( - config.vit_hidden_size, config.hidden_size, bias=True + out = out.transpose(1, 2).reshape(n, seq_len, self.embed_dim) + return self.out_proj(out) + + +class _SiglipMLP(nn.Module): + """Ungated 2-layer MLP with gelu-tanh activation.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.fc1 = nn.Linear(config.vit_hidden_size, config.vit_intermediate_size) + self.activation_fn = nn.GELU(approximate="tanh") # gelu_pytorch_tanh + self.fc2 = nn.Linear(config.vit_intermediate_size, config.vit_hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.activation_fn(self.fc1(hidden_states))) + + +class _SiglipEncoderLayer(nn.Module): + """Pre-norm transformer block: ln1→attn→res, ln2→mlp→res.""" + + def __init__(self, config: Pi05Config): + super().__init__() + embed_dim = config.vit_hidden_size + self.layer_norm1 = nn.LayerNorm(embed_dim, eps=_LAYER_NORM_EPS) + self.self_attn = _SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(embed_dim, eps=_LAYER_NORM_EPS) + self.mlp = _SiglipMLP(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.self_attn(self.layer_norm1(hidden_states)) + hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states)) + return hidden_states + + +class _SiglipEncoder(nn.Module): + """Stack of encoder layers. Named to match the ``encoder.layers.N`` + checkpoint key layout so weights load without per-layer remapping.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.layers = nn.ModuleList( + [_SiglipEncoderLayer(config) for _ in range(config.vit_num_layers)] ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class _SiglipVisionTransformer(nn.Module): + """Embeddings → encoder stack → final layer norm.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.embeddings = _SiglipVisionEmbeddings(config) + self.encoder = _SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(config.vit_hidden_size, eps=_LAYER_NORM_EPS) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(hidden_states) + return self.post_layernorm(hidden_states) + + +class Pi05SiglipEncoder(nn.Module): + """SigLIP image encoder + linear connector to the LLM hidden size.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.config = config + self.vision_model = _SiglipVisionTransformer(config) + self.connector = nn.Linear(config.vit_hidden_size, config.hidden_size, bias=True) + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Encode a batch of images into LLM-space tokens. @@ -50,7 +202,5 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape ``(N, tokens_per_image, hidden_size)``. """ - outputs = self.vision_model(pixel_values=pixel_values) - # last_hidden_state: [N, num_patches, vit_hidden_size] - features = outputs.last_hidden_state + features = self.vision_model(pixel_values) # (N, num_patches, vit_hidden) return self.connector(features) diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index f1846730..56d9f248 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -40,7 +40,6 @@ GraphEdge, GraphNode, GraphSection, - Loop, Sequential, TensorPointerInfo, ) @@ -49,10 +48,10 @@ from mstar.model.loader import LLAMA_STACKED_PARAMS, load_hf_weights from mstar.model.pi05.components.action_expert import Pi05ActionExpert, Pi05TimeMLP from mstar.model.pi05.components.paligemma import Pi05PaliGemmaExpert -from mstar.model.pi05.components.siglip import Pi05SiglipEncoder +from mstar.model.pi05.components.siglip import SIGLIP_STACKED_PARAMS, Pi05SiglipEncoder from mstar.model.pi05.components.tokenization import Pi05Tokenizer from mstar.model.pi05.config import Pi05Config, load_pi05_config -from mstar.model.pi05.submodules import Pi05LLMSubmodule, Pi05ViTEncoderSubmodule +from mstar.model.pi05.submodules import Pi05ActionExpertSubmodule, Pi05PaligemmaSubmodule, Pi05ViTEncoderSubmodule from mstar.model.submodule_base import NodeSubmodule logger = logging.getLogger(__name__) @@ -91,40 +90,8 @@ def __init__( self.time_mlp = time_mlp -def _reset_non_persistent_buffers(module: nn.Module, device) -> None: - """Re-initialize non-persistent buffers like ``position_ids`` after a - ``meta + to_empty`` materialization. - - Modules constructed on the meta device skip ``post_init``, and - ``to_empty`` only allocates uninitialized storage for parameters and - buffers. Non-persistent buffers (registered with ``persistent=False``) - are not in the state_dict, so ``load_state_dict`` will not restore them - either — leaving them as garbage. The most common offender is HuggingFace - SigLIP's ``position_ids`` buffer (``register_buffer("position_ids", - arange(num_positions), persistent=False)``), which feeds the position - embedding lookup. If left as garbage int64 it produces wildly incorrect - image embeddings (off by the full norm of the position table). - - This walks the module tree and resets any sub-module that has a - ``position_ids`` buffer to the canonical ``arange(num_positions)``. - """ - with torch.no_grad(): - for sub in module.modules(): - pos = getattr(sub, "position_ids", None) - if isinstance(pos, torch.Tensor): - shape = pos.shape - num_positions = shape[-1] - pos.copy_( - torch.arange( - num_positions, device=pos.device, dtype=pos.dtype - ).expand(shape) - ) - - class Pi05Model(Model): """Pi0.5 vision-language-action model implementation.""" - - PREFILL_WALK = "prefill" ACTION_GEN_WALK = "action_gen" def __init__( @@ -365,12 +332,13 @@ def _extract_siglip_state_dict( if inner.startswith("vision_tower.vision_model."): # The lerobot key is # paligemma.model.vision_tower.vision_model. - # Pi05SiglipEncoder owns ``self.vision_model = SiglipVisionModel(...)``, - # and HF's SiglipVisionModel has its own inner ``.vision_model`` - # attribute, so the corresponding key is - # ``vision_model.vision_model.``. We replace - # ``vision_tower`` with ``vision_model`` to make that explicit. - out["vision_model." + inner.removeprefix("vision_tower.")] = tensor + # Pi05SiglipEncoder owns ``self.vision_model`` (our native + # _SiglipVisionTransformer) directly, so the matching key is + # ``vision_model.``. Stripping ``vision_tower.`` yields + # exactly that. The separate q/k/v_proj keys under + # `` = encoder.layers.N.self_attn.*`` are fused into + # ``qkv_proj`` by SIGLIP_STACKED_PARAMS at load time. + out[inner.removeprefix("vision_tower.")] = tensor elif inner.startswith("multi_modal_projector.linear."): sub = inner.removeprefix("multi_modal_projector.linear.") out[f"connector.{sub}"] = tensor @@ -387,68 +355,50 @@ def get_kv_cache_config(self) -> KVCacheConfig: head_dim=self.config.head_dim, max_seq_len=self.config.max_position_embeddings, num_qo_heads=self.config.num_qo_heads, + nodes=["paligemma_LLM", "action_expert_LLM"] )] def get_node_engine_types(self) -> dict[str, EngineType]: return { "vit_encoder": EngineType.STATELESS, - "LLM": EngineType.KV_CACHE, + "paligemma_LLM": EngineType.KV_CACHE, + "action_expert_LLM": EngineType.KV_CACHE, } def get_graph_walk_graphs(self) -> dict[str, GraphSection]: - # Pi0.5 encodes the robot state as a decimal-string suffix on the - # language prompt (e.g. "Task: pick up the block, State: 12 87 ...; - # \nAction: ") and tokenizes the whole thing with the PaliGemma - # tokenizer. So the model only ever sees a single "text_inputs" - # stream — there are no separate state-bin tokens. This matches - # lerobot's processor_pi05.Pi05PrepareStateTokenizerProcessorStep. - prefill = Sequential( + # NOTE: the full action generation flow loop is extremely short (total < 50ms), so + # we opt to have it as one node to reduce cuda graph startup, flashinfer planning, + # etc. overhead. Cache planning only needs to happen at the beginning of the flow + # loop, so this collapsed loop is valid. + action_gen = Sequential( [ GraphNode( name="vit_encoder", input_names=["image_inputs"], - outputs=[GraphEdge(next_node="LLM", name="img_emb")], + outputs=[GraphEdge(next_node="paligemma_LLM", name="img_emb")], ), GraphNode( - name="LLM", + name="paligemma_LLM", input_names=["img_emb", "text_inputs"], - outputs=[], + outputs=[ + GraphEdge(next_node="action_expert_LLM", name="action_expert_trigger") + ], ), + GraphNode( + name="action_expert_LLM", + input_names=["action_expert_trigger"], + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name="actions", + output_modality="action", + ) + ], + ) ] ) - # NOTE: The Loop's terminal ``outputs`` are matched into the section's - # node outputs by **name** (see Loop._replace_outputs_for_final_iter - # in mstar/graph/base.py): on the final iteration, any section-output - # edge whose name matches a terminal output's name is replaced with - # the terminal version. This is the same convention BAGEL's image_gen - # uses (section returns ``latents`` looping back to LLM, terminal - # output is ``name="latents" → vae_decoder``). So our terminal output - # MUST be named ``noisy_actions`` to match the section's loop-back - # edge — the name is just a graph-internal key, while the actual - # client-facing modality bucket is determined by ``output_modality``. - action_gen = Loop( - section=GraphNode( - name="LLM", - input_names=["noisy_actions", "timestep_index"], - outputs=[ - GraphEdge(next_node="LLM", name="noisy_actions"), - GraphEdge(next_node="LLM", name="timestep_index"), - ], - ), - max_iters=self.config.num_flow_steps, - outputs=[ - GraphEdge( - next_node=EMIT_TO_CLIENT, - name="noisy_actions", - output_modality="action", - persist=True, - ), - ], - ) - return { - self.PREFILL_WALK: prefill, self.ACTION_GEN_WALK: action_gen, } @@ -477,6 +427,18 @@ def process_prompt( here so the resulting ``text_inputs`` stream matches the production format. """ + # A "numpy" upload arrives as "raw_inputs"; Pi0.5 treats it as an image input + # We append the raw_inputs onto the image_inputs, so the user can pass in both + # images and numpy arrays + tensors = kwargs.get("tensors") + if tensors is not None and "raw_inputs" in tensors: + tensors.setdefault("image_inputs", []).extend(tensors.pop("raw_inputs")) + input_metadata = kwargs.get("input_metadata") + if input_metadata is not None and "raw_inputs" in input_metadata: + input_metadata.setdefault("image_inputs", []).extend( + input_metadata.pop("raw_inputs") + ) + if self.tokenizer is None: # Tokenizer-less fallback used by structural unit tests. if prompt is not None: @@ -527,7 +489,7 @@ def get_initial_forward_pass_args( full_metadata = CurrentForwardConductorMetadata( input_modalities=input_modalities, output_modalities=output_modalities, - graph_walk=self.PREFILL_WALK, + graph_walk=self.ACTION_GEN_WALK, is_prefill=True, kwargs={}, ) @@ -538,7 +500,7 @@ def get_initial_forward_pass_args( edge.tensor_info = input_signals["image_inputs"] inputs.append(edge) if "text_inputs" in input_signals: - edge = GraphEdge(next_node="LLM", name="text_inputs") + edge = GraphEdge(next_node="paligemma_LLM", name="text_inputs") edge.tensor_info = input_signals["text_inputs"] inputs.append(edge) @@ -547,7 +509,6 @@ def get_initial_forward_pass_args( full_metadata=full_metadata, inputs=inputs, unpersist_tensors=unpersist_tensors, - step_metadata={"is_prefill": True}, ) def get_partition_forward_pass_args( @@ -557,29 +518,12 @@ def get_partition_forward_pass_args( persist_signals: dict[str, list[TensorPointerInfo]], incoming_connections: list[StreamingConnectionState] | None = None, ) -> ForwardPassArgs: - metadata = partition_metadata - request_done = False - inputs: list[GraphEdge] = [] - - if metadata.graph_walk == self.PREFILL_WALK: - metadata.is_prefill = False - metadata.graph_walk = self.ACTION_GEN_WALK - # Inputs for the first action_gen iteration are sampled inside the - # LLM submodule's preprocess (Gaussian noise + timestep_index=0). - inputs = [ - GraphEdge(next_node="LLM", name="noisy_actions"), - GraphEdge(next_node="LLM", name="timestep_index"), - ] - elif metadata.graph_walk == self.ACTION_GEN_WALK: - request_done = True - - unpersist_tensors = sum([inp.tensor_info for inp in inputs], start=[]) + # only one graph walk, so we're done return ForwardPassArgs( - full_metadata=metadata, - inputs=inputs, - unpersist_tensors=unpersist_tensors, - step_metadata={"is_prefill": metadata.is_prefill}, - request_done=request_done, + full_metadata=partition_metadata, + inputs=[], + unpersist_tensors=[], + request_done=True, ) # ------------------------------------------------------------------ @@ -605,11 +549,16 @@ def _create_submodule( return Pi05ViTEncoderSubmodule( encoder=self.siglip, config=self.config ) - if node_name == "LLM": + if node_name == "paligemma_LLM": self._init_llm_components(device) - return Pi05LLMSubmodule( + return Pi05PaligemmaSubmodule( embed_tokens=self.embed_tokens, paligemma=self.paligemma, + config=self.config, + ) + if node_name == "action_expert_LLM": + self._init_llm_components(device) + return Pi05ActionExpertSubmodule( action_expert=self.action_expert, action_in_proj=self.action_in_proj, action_out_proj=self.action_out_proj, @@ -634,29 +583,19 @@ def _init_vit_components(self, device: str): self.siglip = Pi05SiglipEncoder(self.config) if self.skip_weight_loading: self.siglip = self.siglip.to_empty(device=device) - _reset_non_persistent_buffers(self.siglip, device) return flat = self._ensure_lerobot_flat() self.siglip.to_empty(device=device) - # CRITICAL: HF's SiglipVisionEmbeddings registers ``position_ids`` as - # a NON-persistent buffer (persistent=False), so it's not in any - # state_dict. ``to_empty`` materializes it as uninitialized GPU - # memory, ``_init_weights`` is never called (we never go through - # post_init), and ``load_state_dict(strict=False)`` does not restore - # it. The result is garbage int64 indices feeding into - # ``position_embedding``, which corrupts every image embedding by - # ~the full norm of the position table. We must manually reset any - # non-persistent ``position_ids`` buffer with the canonical - # ``arange`` values before running the forward. - _reset_non_persistent_buffers(self.siglip, device) # The extracted bucket may contain stray pooling-head keys that # Pi05SiglipEncoder doesn't model (``vision_use_head=False``); # ``load_hf_weights`` silently ignores any key that has no matching # parameter in the target module, so the leftover keys are dropped # without needing an explicit ``strict=False`` switch. siglip_sd = self._extract_siglip_state_dict(flat) - load_hf_weights(self.siglip, siglip_sd.items()) + load_hf_weights( + self.siglip, siglip_sd.items(), stacked_params=SIGLIP_STACKED_PARAMS, + ) def _init_llm_components(self, device: str): if self.embed_tokens is not None: diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 9b55af89..6a3f7574 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -1,12 +1,11 @@ """NodeSubmodule wrappers for the Pi0.5 model nodes. -Two submodules: - Pi05ViTEncoderSubmodule -- SigLIP vision encoder for camera images. - Pi05LLMSubmodule -- combined PaliGemma + action expert. Dispatches by - graph_walk between prefill (PaliGemma writes the - prefix KV cache) and action_gen (action expert - reads the frozen prefix KV cache and runs one - Euler flow-matching step). +Three submodules: + Pi05ViTEncoderSubmodule -- SigLIP vision encoder for camera images. + Pi05PaligemmaSubmodule -- PaliGemma prefix expert; prefills and writes the + prefix KV cache. + Pi05ActionExpertSubmodule -- action expert; reads the frozen prefix KV cache + and runs the Euler flow-matching denoising loop. """ import logging @@ -69,7 +68,7 @@ def to(self, *args, **kwargs): return result def _prepare_one(self, images: torch.Tensor) -> torch.Tensor: - """Resize one request's stack of camera images with aspect-preserving + """Resize one request's camera image(s) with aspect-preserving letterbox padding. Matches openpi's ``image_tools.resize_with_pad_torch`` exactly: @@ -152,16 +151,13 @@ def get_cuda_graph_configs(self, device: torch.device, tp_world_size: int = 1) - (num_cameras, 3, H, W). preprocess() stacks them to (bs, num_cameras, 3, H, W) so shape[0] == bs, satisfying StatelessCudaGraphRunner's leading-dim == bs requirement. - - compile=False because warmup() already applies torch.compile to - forward_batched; _capture_one captures the compiled callable directly. """ from mstar.engine.cuda_graph_config import BasicBatchedCudaGraphConfig H = W = self.config.vit_image_size num_cameras = self.config.num_cameras return [ BasicBatchedCudaGraphConfig( - capture_graph_walk="prefill", + capture_graph_walk="action_gen", single_request_inputs=ARNodeInputs( input_seq_len=0, # not used by StatelessCudaGraphRunner tensor_inputs={ @@ -172,7 +168,7 @@ def get_cuda_graph_configs(self, device: torch.device, tp_world_size: int = 1) - }, ), capture_batch_sizes=[1], - compile=False, + compile=False, # empircally does better than compile=True for now ) ] @@ -183,9 +179,14 @@ def prepare_inputs( inputs: NameToTensorList, **kwargs ) -> NodeInputs: - return NodeInputs(tensor_inputs={"pixel_values": self._prepare_one( - inputs["image_inputs"][0] - )}) + images = torch.cat([ + self._prepare_one(img) for img in inputs["image_inputs"] + ]) + # TODO: assert images.shape == (num_cameras, 3, H, W) once worker errors + # are surfaced. A wrong count silently broadcasts in the static CUDA + # graph; today a raised error is swallowed and the client hangs, so this + # needs prepare_inputs errors threaded engine -> conductor -> API server. + return NodeInputs(tensor_inputs={"pixel_values": images}) def preprocess( self, @@ -254,20 +255,10 @@ def forward_batched( } -class Pi05LLMSubmodule(ARNodeSubmodule): - """Combined PaliGemma prefix expert + action expert. - - Dispatches by graph_walk: - - ``prefill``: PaliGemma forwards over the prefix - ``[image_tokens, language_tokens, state_tokens]`` and - writes the KV cache. - - ``action_gen``: action expert runs one Euler step of flow-matching - denoising over the action suffix, attending to the - frozen prefix KV cache. The current ``noisy_actions`` - and ``timestep_index`` cycle through the loop via - loop-back graph edges; on the final iteration the - denoised action tensor is emitted as ``action_output``. - """ +class Pi05PaligemmaSubmodule(ARNodeSubmodule): + """PaliGemma prefix expert: forwards over the prefix + ``[image_tokens, language_tokens]`` and writes the KV cache that the + action expert later reads.""" # Parameter name fragments whose weights must stay in float32 even when # the rest of the model is bf16. Matches lerobot's @@ -283,12 +274,229 @@ class Pi05LLMSubmodule(ARNodeSubmodule): # For the default image size and a simple text prompt, one request is about 400 tokens PREFILL_TOKEN_BUCKETS = [512, 1024, 1800] # 2048 was giving OOM PREFILL_CAPTURE_BATCH_SIZES = [1, 2, 4] - ACTION_GEN_CAPTURE_BATCH_SIZES = [1, 2, 4] def __init__( self, embed_tokens: nn.Embedding, paligemma: Pi05PaliGemmaExpert, + config: Pi05Config, + ): + super().__init__() + self.embed_tokens = embed_tokens + self.paligemma = paligemma + self.config = config + # lerobot scales images by sqrt(H) but text by H: its + # embed_language_tokens routes through HF Gemma's + # GemmaTextScaledWordEmbedding, which already bakes in a sqrt(H) factor, + # so the effective text scale is sqrt(H)*sqrt(H) = H. Our plain + # nn.Embedding has no internal scale, so we apply the full H here. + # Mismatching makes the text prefix ~45x too small and corrupts context. + self._image_embed_scale = math.sqrt(config.hidden_size) + self._text_embed_scale = float(config.hidden_size) + + def to(self, *args, **kwargs): + """Apply standard ``to()`` then upcast norm parameters back to fp32. + + Matches lerobot's ``to_bfloat16_for_selected_params`` which keeps + ``input_layernorm``, ``post_attention_layernorm``, and ``model.norm`` + in float32 while the rest of the transformer runs in bfloat16. + """ + result = super().to(*args, **kwargs) + for name, param in result.named_parameters(): + if any(sel in name for sel in self._FLOAT32_PARAM_SELECTORS): + param.data = param.data.to(torch.float32) + return result + + def can_batch( + self, + batch: NodeBatch, + model_inputs: list[NodeInputs], + ) -> bool: + return True + + def get_needed_cache_labels( + self, + graph_walk: str, + per_request_info: dict[str, CurrentForwardPassInfo], + ) -> list[str] | None: + return ["main"] + + def _embed_tokens_scaled(self, ids: torch.Tensor) -> torch.Tensor: + emb = self.embed_tokens(ids) + return emb * self._text_embed_scale + + def get_cuda_graph_configs( + self, device: torch.device, tp_world_size: int = 1, + ) -> list[BasicBatchedCudaGraphConfig | FlashInferPackedCudaGraphConfig]: + prefill_packed = { + num_tokens: { + "prefix_embs": torch.zeros(num_tokens, self.config.hidden_size, device=device) + } + for num_tokens in self.PREFILL_TOKEN_BUCKETS + } + return [ + FlashInferPackedCudaGraphConfig( + capture_graph_walk="action_gen", + packed_seq_len_to_inputs=prefill_packed, + requires_cfg=False, + labels=["main"], + compile=True, + causal_attention=False, + capture_batch_sizes=self.PREFILL_CAPTURE_BATCH_SIZES, + ), + ] + + def prepare_inputs( + self, + graph_walk: str, + fwd_info: CurrentForwardPassInfo, + inputs: NameToTensorList, + **kwargs + ) -> ARNodeInputs: + return self._prepare_inputs_prefill( + inputs=inputs, + fwd_info=fwd_info, + ) + + def _prepare_inputs_prefill( + self, + inputs: NameToTensorList, + **kwargs + ) -> ARNodeInputs: + # Prefix layout [image_tokens, language_tokens]. Robot state is not a + # separate stream — process_prompt already appended it as a decimal + # suffix on the prompt. Image and text embeds use different scales (see + # __init__); applying them here is load-bearing. + img_emb = inputs["img_emb"][0] * self._image_embed_scale + text_ids = inputs["text_inputs"][0] + text_emb = self._embed_tokens_scaled(text_ids) + prefix_emb = torch.cat([img_emb, text_emb], dim=0) + seq_len = prefix_emb.shape[0] + + return ARNodeInputs(input_embeds=prefix_emb, input_seq_len=seq_len) + + + def preprocess( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + inputs: list[ARNodeInputs], + ) -> dict[str, torch.Tensor | Any]: + + return self._preprocess_prefill( + inputs=inputs, + cache_manager=engine_inputs.cache_manager, + ) + + def _preprocess_prefill( + self, + inputs: list[ARNodeInputs], + cache_manager: BatchedCacheManager, + ) -> dict[str, torch.Tensor | Any]: + per_request_seqs = [inp.input_embeds for inp in inputs] + prefix_embs = torch.cat(per_request_seqs, dim=0) + seq_lens = [inp.input_seq_len for inp in inputs] + + # Bidirectional attention over the prefix; PaliGemma is a prefix-LM. + cache_manager.plan_attention( + seq_lens=seq_lens, is_causal=False, label="main", dtype=torch.bfloat16 + ) + cache_manager.plan_rope(seq_lens=seq_lens, pos_ids=None, label="main") + return {"prefix_embs": prefix_embs} + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + def forward( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + **kwargs # coming from preprocess output + ) -> NameToTensorList: + cache_handle=engine_inputs.cache_manager + + return self._forward_prefill(cache_handle=cache_handle, **kwargs) + + def forward_batched( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + **kwargs, # coming from preprocess output + ) -> dict[str, NameToTensorList]: + """Batched forward: process all requests in a single transformer pass. + + Called by ``KVCacheEngine._execute_batched`` when ``can_batch()`` returns + True. ``packed_inputs`` comes from ``preprocess()`` which already + concatenated per-request tensors and planned attention/RoPE for the + full batch. + """ + + return self._forward_prefill_batched( + cache_manager=engine_inputs.cache_manager, + request_ids=engine_inputs.request_ids, + **kwargs, + ) + + + def _forward_prefill_batched( + self, + cache_manager: BatchedCacheManager, + request_ids: list[str], + prefix_embs: torch.Tensor, + **kwargs, + ) -> dict[str, NameToTensorList]: + """Batched prefill: single PaliGemma forward over concatenated prefixes.""" + cache_manager.set_active_label("main") + self.paligemma( + query_sequence=prefix_embs, + cache_handle=cache_manager, + write_cache=True, + ) + # Prefill produces no graph-edge outputs. + return {rid: {} for rid in request_ids} + + def _forward_prefill( + self, + prefix_embs: torch.Tensor, + cache_handle: BatchedCacheManager, + **kwargs, + ) -> NameToTensorList: + if cache_handle is not None: + cache_handle.set_active_label("main") + self.paligemma( + query_sequence=prefix_embs, + cache_handle=cache_handle, + write_cache=True, + ) + return {} + + def postprocess(self, request_id, request_info, outputs, **kwargs): + outputs["action_expert_trigger"] = [] + + +class Pi05ActionExpertSubmodule(ARNodeSubmodule): + """Action expert flow-matching loop. + + Runs all ``num_flow_steps`` Euler denoising steps over the action suffix + in a single forward, attending to the frozen prefix KV cache that the + PaliGemma submodule wrote, then emits the denoised action tensor. + """ + + # Parameter name fragments whose weights must stay in float32 even when + # the rest of the model is bf16. Matches lerobot's + # ``to_bfloat16_for_selected_params`` — keeping norms in fp32 prevents + # the per-layer precision loss that otherwise compounds across 18 layers + # and causes ~0.2 abs delta on the final actions. + _FLOAT32_PARAM_SELECTORS = ( + "input_layernorm", + "post_attention_layernorm", + ".norm.", # final RMSNorm / adaRMS norm + ) + + ACTION_GEN_CAPTURE_BATCH_SIZES = [1, 2, 4] + + def __init__( + self, action_expert: Pi05ActionExpert, action_in_proj: nn.Linear, action_out_proj: nn.Linear, @@ -296,34 +504,11 @@ def __init__( config: Pi05Config, ): super().__init__() - self.embed_tokens = embed_tokens - self.paligemma = paligemma self.action_expert = action_expert self.action_in_proj = action_in_proj self.action_out_proj = action_out_proj self.time_mlp = time_mlp self.config = config - # Image features and language token embeddings use DIFFERENT scaling - # factors in lerobot's reference, even though both end up calling it - # ``sqrt(hidden_size)``: - # - # * Images: ``embed_image`` returns - # ``connector(siglip_features) * sqrt(hidden_size)`` -> scale = sqrt(H). - # - # * Text: lerobot's ``lang_embed_func`` does - # ``embed_language_tokens(tokens) * sqrt(hidden_size)``, but - # ``embed_language_tokens`` calls HF Gemma's - # ``GemmaTextScaledWordEmbedding`` whose ``forward`` already - # multiplies the raw lookup by an internal ``embed_scale = - # sqrt(hidden_size)``. So the EFFECTIVE text scale is - # ``sqrt(H) * sqrt(H) = H``, not ``sqrt(H)``. - # - # We use a plain ``nn.Embedding`` for ``embed_tokens`` (no internal - # scale), so we have to apply the full ``H`` factor manually here. - # Mismatching this produces a ~45x undersized text prefix and the - # action expert sees a wildly wrong context. - self._image_embed_scale = math.sqrt(config.hidden_size) - self._text_embed_scale = float(config.hidden_size) # Lazily allocated on first action Euler step, sized for the largest # captured batch. sincos_timestep_embedding fully overwrites this buffer @@ -349,16 +534,6 @@ def can_batch( batch: NodeBatch, model_inputs: list[NodeInputs], ) -> bool: - """Pi0.5 supports batched execution for both graph walks. - - - ``prefill``: prefix embeddings are concatenated across requests and - processed in a single PaliGemma forward with batched FlashInfer - attention. Each request can have a different prefix length (different - text prompt lengths). - - ``action_gen``: all requests in a batch are at the same Euler - iteration (guaranteed by the Loop primitive), so their suffix tokens - can be concatenated and processed in a single action expert forward. - """ return True def get_needed_cache_labels( @@ -400,10 +575,6 @@ def _get_time_emb_buffer(self, bs: int) -> torch.Tensor: ) return self._time_emb_buffer[:bs] - def _embed_tokens_scaled(self, ids: torch.Tensor) -> torch.Tensor: - emb = self.embed_tokens(ids) - return emb * self._text_embed_scale - def get_cuda_graph_configs( self, device: torch.device, tp_world_size: int = 1, ) -> list[BasicBatchedCudaGraphConfig | FlashInferPackedCudaGraphConfig]: @@ -413,7 +584,7 @@ def get_cuda_graph_configs( # are read directly from self.config — same source as the nn.Linear # weight shapes — so they're guaranteed consistent. logger.info( - "Pi05LLMSubmodule.get_cuda_graph_configs: capturing 'action_gen' " + "Pi05ActionExpertSubmodule.get_cuda_graph_configs: capturing 'action_gen' " "graph with input_seq_len=%d, noisy_actions=(%d, %d), batch_sizes=[1,2,4] " "(num_flow_steps=%d denoising iters runs INSIDE this captured graph; " "denoising count is independent of horizon)", @@ -421,12 +592,6 @@ def get_cuda_graph_configs( self.config.action_horizon, self.config.action_dim, self.config.num_flow_steps, ) - prefill_packed = { - num_tokens: { - "prefix_embs": torch.zeros(num_tokens, self.config.hidden_size, device=device) - } - for num_tokens in self.PREFILL_TOKEN_BUCKETS - } return [ # Action generation always has latents of the same size, so it is a similar # paradigm to AR decode and can use the batched cuda graphs @@ -443,15 +608,6 @@ def get_cuda_graph_configs( ), capture_batch_sizes=self.ACTION_GEN_CAPTURE_BATCH_SIZES ), - FlashInferPackedCudaGraphConfig( - capture_graph_walk="prefill", - packed_seq_len_to_inputs=prefill_packed, - requires_cfg=False, - labels=["main"], - compile=True, - causal_attention=False, - capture_batch_sizes=self.PREFILL_CAPTURE_BATCH_SIZES, - ), ] def prepare_inputs( @@ -461,46 +617,10 @@ def prepare_inputs( inputs: NameToTensorList, **kwargs ) -> ARNodeInputs: - if graph_walk == "prefill": - return self._prepare_inputs_prefill( - inputs=inputs, - ) - if graph_walk == "action_gen": - return self._prepare_inputs_action_gen( - inputs=inputs, - fwd_info=fwd_info, - ) - raise ValueError(f"Unknown Pi0.5 LLM graph walk: {graph_walk!r}") - - def _prepare_inputs_prefill( - self, - inputs: NameToTensorList, - **kwargs - ) -> ARNodeInputs: - # Pi0.5 prefix layout (matches lerobot's embed_prefix): - # [image_tokens, language_tokens] - # The robot state is *not* a separate token stream — it has already - # been formatted as a decimal-string suffix on the language prompt - # by ``Pi05Model.process_prompt``, then tokenized by the PaliGemma - # tokenizer. So the LLM only consumes ``img_emb`` + ``text_inputs``. - # - # IMPORTANT: lerobot's ``embed_prefix`` scales BOTH the image features - # (after the multi_modal_projector) and the language token embeddings - # by ``sqrt(hidden_size)``. We mirror that here. Without the image - # scaling the SigLIP tokens come in ~sqrt(2048)≈45x too small relative - # to the language tokens and the action expert sees a wildly wrong - # prefix. (The standalone test_pi05_model_loaded_via_remapper_matches_ - # lerobot integration test missed this because it bypasses - # _preprocess_prefill and feeds in lerobot's pre-scaled embed_prefix - # output directly.) - - img_emb = inputs["img_emb"][0] * self._image_embed_scale - text_ids = inputs["text_inputs"][0] - text_emb = self._embed_tokens_scaled(text_ids) - prefix_emb = torch.cat([img_emb, text_emb], dim=0) - seq_len = prefix_emb.shape[0] - - return ARNodeInputs(input_embeds=prefix_emb, input_seq_len=seq_len) + return self._prepare_inputs_action_gen( + inputs=inputs, + fwd_info=fwd_info, + ) def _prepare_inputs_action_gen( self, @@ -512,23 +632,20 @@ def _prepare_inputs_action_gen( action_horizon = self.config.action_horizon action_dim = self.config.action_dim - if "noisy_actions" not in inputs or len(inputs["noisy_actions"]) == 0: - generator = torch.Generator(device=device).manual_seed(fwd_info.random_seed) - noisy = torch.randn( - action_horizon, action_dim, device=device, generator=generator - ) - ts = torch.zeros(1, device=device, dtype=torch.long) - else: - noisy = inputs["noisy_actions"][0] - ts = inputs["timestep_index"][0] + generator = torch.Generator(device=device).manual_seed(fwd_info.random_seed) + noisy = torch.randn( + action_horizon, action_dim, device=device, generator=generator + ) + ts = torch.zeros(1, device=device, dtype=torch.long) seq_len = action_horizon - return ARNodeInputs(input_seq_len=seq_len, - tensor_inputs={ - "noisy_actions": noisy, - "ts": ts - }) - + return ARNodeInputs( + input_seq_len=seq_len, + tensor_inputs={ + "noisy_actions": noisy, + "ts": ts + } + ) def preprocess( self, @@ -536,34 +653,10 @@ def preprocess( engine_inputs: ModelInputsFromEngine, inputs: list[ARNodeInputs], ) -> dict[str, torch.Tensor | Any]: - - if graph_walk == "prefill": - return self._preprocess_prefill( - inputs=inputs, - cache_manager=engine_inputs.cache_manager, - ) - if graph_walk == "action_gen": - return self._preprocess_action_gen( - inputs=inputs, - cache_manager=engine_inputs.cache_manager, - ) - - def _preprocess_prefill( - self, - inputs: list[ARNodeInputs], - cache_manager: BatchedCacheManager, - ) -> dict[str, torch.Tensor | Any]: - per_request_seqs = [inp.input_embeds for inp in inputs] - prefix_embs = torch.cat(per_request_seqs, dim=0) - seq_lens = [inp.input_seq_len for inp in inputs] - - # Bidirectional attention over the prefix; PaliGemma is a prefix-LM. - cache_manager.plan_attention( - seq_lens=seq_lens, is_causal=False, label="main", dtype=torch.bfloat16 + return self._preprocess_action_gen( + inputs=inputs, + cache_manager=engine_inputs.cache_manager, ) - cache_manager.plan_rope(seq_lens=seq_lens, pos_ids=None, label="main") - - return {"prefix_embs": prefix_embs} def _preprocess_action_gen( self, @@ -614,12 +707,7 @@ def forward( **kwargs # coming from preprocess output ) -> NameToTensorList: cache_handle=engine_inputs.cache_manager - - if graph_walk == "prefill": - return self._forward_prefill(cache_handle=cache_handle, **kwargs) - if graph_walk == "action_gen": - return self._forward_action_gen(cache_handle=cache_handle, **kwargs) - raise ValueError(f"Unknown Pi0.5 LLM graph walk: {graph_walk!r}") + return self._forward_action_gen(cache_handle=cache_handle, **kwargs) def forward_batched( self, @@ -627,45 +715,11 @@ def forward_batched( engine_inputs: ModelInputsFromEngine, **kwargs, # coming from preprocess output ) -> dict[str, NameToTensorList]: - """Batched forward: process all requests in a single transformer pass. - - Called by ``KVCacheEngine._execute_batched`` when ``can_batch()`` returns - True. ``packed_inputs`` comes from ``preprocess()`` which already - concatenated per-request tensors and planned attention/RoPE for the - full batch. - """ - - if graph_walk == "prefill": - return self._forward_prefill_batched( - cache_manager=engine_inputs.cache_manager, - request_ids=engine_inputs.request_ids, - **kwargs, - ) - if graph_walk == "action_gen": - return self._forward_action_gen_batched( - cache_manager=engine_inputs.cache_manager, - request_ids=engine_inputs.request_ids, - **kwargs, - ) - raise ValueError(f"Batched forward not supported for graph walk: {graph_walk!r}") - - - def _forward_prefill_batched( - self, - cache_manager: BatchedCacheManager, - request_ids: list[str], - prefix_embs: torch.Tensor, - **kwargs, - ) -> dict[str, NameToTensorList]: - """Batched prefill: single PaliGemma forward over concatenated prefixes.""" - cache_manager.set_active_label("main") - self.paligemma( - query_sequence=prefix_embs, - cache_handle=cache_manager, - write_cache=True, + return self._forward_action_gen_batched( + cache_manager=engine_inputs.cache_manager, + request_ids=engine_inputs.request_ids, + **kwargs, ) - # Prefill produces no graph-edge outputs. - return {rid: {} for rid in request_ids} def _forward_action_gen_batched( self, @@ -681,12 +735,13 @@ def _forward_action_gen_batched( horizon = self.config.action_horizon - next_actions, next_index = self._euler_step( - noisy_actions, timestep_index, - fraction=fraction, - time_emb_buffer=time_emb_buffer, - cache_handle=cache_manager - ) + for _ in range(self.config.num_flow_steps): + noisy_actions, timestep_index = self._euler_step( + noisy_actions, timestep_index, + fraction=fraction, + time_emb_buffer=time_emb_buffer, + cache_handle=cache_manager + ) # Split back per-request by horizon. result: dict[str, NameToTensorList] = {} @@ -694,26 +749,10 @@ def _forward_action_gen_batched( start = i * horizon end = start + horizon result[rid] = { - "noisy_actions": [next_actions[start:end]], - "timestep_index": [next_index[i:i+1]], + "actions": [noisy_actions[start:end]], } return result - def _forward_prefill( - self, - prefix_embs: torch.Tensor, - cache_handle: BatchedCacheManager, - **kwargs, - ) -> NameToTensorList: - if cache_handle is not None: - cache_handle.set_active_label("main") - self.paligemma( - query_sequence=prefix_embs, - cache_handle=cache_handle, - write_cache=True, - ) - return {} - def _forward_action_gen( self, noisy_actions, @@ -727,8 +766,9 @@ def _forward_action_gen( ``noisy_actions`` and ``timestep_index`` arrive as single-element lists from preprocess (to keep the data structure uniform with the - batched path). We unpack the first element, run one Euler step, and - return the loop-back edges. + batched path). We unpack the first element, run the full + ``num_flow_steps`` Euler denoising loop, and return the denoised action + tensor. """ # Unpack from list form (preprocess always returns lists now). if isinstance(noisy_actions, list): @@ -736,22 +776,15 @@ def _forward_action_gen( if isinstance(timestep_index, list): timestep_index = timestep_index[0] - next_actions, next_index = self._euler_step( - noisy_actions, timestep_index, - fraction=fraction, - time_emb_buffer=time_emb_buffer, - cache_handle=cache_handle - ) - # We ALWAYS return both loop-back edges, even on the final iteration. - # The Loop primitive (mstar/graph/base.py:Loop) handles the final-iter - # swap automatically: it matches the section's output ``noisy_actions`` - # to the Loop's terminal output (also named ``noisy_actions``, but - # routed to EMIT_TO_CLIENT with ``output_modality="action"``) and - # filters out the ``timestep_index`` loop-back edge. Same convention - # BAGEL's image_gen uses for ``latents`` / ``time_index``. + for _ in range(self.config.num_flow_steps): + noisy_actions, timestep_index = self._euler_step( + noisy_actions, timestep_index, + fraction=fraction, + time_emb_buffer=time_emb_buffer, + cache_handle=cache_handle + ) return { - "noisy_actions": [next_actions], - "timestep_index": [next_index], + "actions": [noisy_actions], } def _euler_step( diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 2e85ec53..5940787e 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -1511,8 +1511,11 @@ def _thread_outputs_to_speculative( rid_outputs = output_N.per_request_output_tensors.get(rid, {}) ok = True for input_name, _ in speculation.consumed_edges: - tensors = rid_outputs.get(input_name, []) - if not tensors: + # NOTE: this assumes that submodules may output a empty list as valid + # output, and will omit the key entirely from the output upon, e.g., + # an internal failure. Revisit if this contract ever changes. + tensors = rid_outputs.get(input_name, None) + if tensors is None: ok = False break speculation.node_batch.per_request_input_tensors[rid][input_name] \ diff --git a/test/integration/test_pi05_real_weights.py b/test/integration/test_pi05_real_weights.py index a698f265..39eb2f68 100644 --- a/test/integration/test_pi05_real_weights.py +++ b/test/integration/test_pi05_real_weights.py @@ -672,8 +672,9 @@ def test_pi05_model_loaded_via_remapper_matches_lerobot(): This is the strictest "real Pi05Model" check we can run without standing up a full mstar worker process: it exercises :class:`Pi05Model`'s lazy submodule construction, the lerobot→mstar - state-dict remap, and the actual ``Pi05ViTEncoderSubmodule`` and - ``Pi05LLMSubmodule`` forward methods. The only thing it bypasses is + state-dict remap, and the actual ``Pi05ViTEncoderSubmodule``, + ``Pi05PaligemmaSubmodule``, and ``Pi05ActionExpertSubmodule`` forward + methods. The only thing it bypasses is the FlashInfer/KVCacheEngine paged KV cache (replaced with the same ``MockCacheHandle`` used by the other integration tests, which has been validated against the real FlashInfer wrapper separately). diff --git a/test/modular/test_pi05_model.py b/test/modular/test_pi05_model.py index f373ff19..6b385e40 100644 --- a/test/modular/test_pi05_model.py +++ b/test/modular/test_pi05_model.py @@ -376,7 +376,8 @@ def t(*shape): pali = "paligemma_with_expert.paligemma.model" flat = { - # Vision tower -> vision_model.vision_model. + # Vision tower -> vision_model. (native port; our + # Pi05SiglipEncoder owns ``vision_model`` directly, no double nest). f"{pali}.vision_tower.vision_model.embeddings.patch_embedding.weight": t(1152, 3, 14, 14), f"{pali}.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight": t(1152), # multi_modal_projector.linear -> connector @@ -389,8 +390,8 @@ def t(*shape): } siglip = Pi05Model._extract_siglip_state_dict(flat) - assert "vision_model.vision_model.embeddings.patch_embedding.weight" in siglip - assert "vision_model.vision_model.encoder.layers.0.layer_norm1.weight" in siglip + assert "vision_model.embeddings.patch_embedding.weight" in siglip + assert "vision_model.encoder.layers.0.layer_norm1.weight" in siglip assert "connector.weight" in siglip assert "connector.bias" in siglip assert not any("multi_modal_projector" in k for k in siglip) diff --git a/test/modular/test_pi05_reference_equivalence.py b/test/modular/test_pi05_reference_equivalence.py index 89d73963..a48ff73f 100644 --- a/test/modular/test_pi05_reference_equivalence.py +++ b/test/modular/test_pi05_reference_equivalence.py @@ -26,8 +26,9 @@ ``BatchPrefillWithPagedKVCacheWrapper`` against vanilla SDPA, both for the bidirectional prefill and the suffix-attends-to-prefix flow used during the action_gen denoising loop - * Pi05SiglipEncoder produces bit-identical features to a freshly-built - HF SiglipVisionModel with matched weights + * Pi05SiglipEncoder (native port w/ fused QKV + SDPA) produces features + matching a freshly-built HF SiglipVisionModel, loaded via the same + stacked-param path the real checkpoint loader uses The attention used inside the action-expert tests is a small vanilla-SDPA implementation shared by the mock cache handle and the reference code; the @@ -725,17 +726,23 @@ def test_flashinfer_paged_prefill_attention_matches_sdpa(): def test_pi05_siglip_encoder_matches_hf_reference(): - """``Pi05SiglipEncoder`` produces bit-identical features to HF SiglipVisionModel. - - Both wrap the same HF class; the only difference is mstar adds a - ``nn.Linear`` connector to project to the LLM hidden size. The reference - PaliGemma uses an analogous ``multi_modal_projector``. We verify the - pre-connector features match exactly and the connector preserves the - expected output shape. + """``Pi05SiglipEncoder`` (native port) matches HF SiglipVisionModel. + + The port fuses q/k/v into one projection and runs SDPA, so it is no + longer the same class as the reference. We load the HF reference's + weights into our encoder through ``load_hf_weights`` with + ``SIGLIP_STACKED_PARAMS`` — the same stacked-param path the real + checkpoint loader uses — then check the pre-connector features match + (allclose, since fused-QKV + SDPA differ from HF only by fp32 rounding) + and the connector preserves the expected output shape. """ from transformers import SiglipVisionConfig, SiglipVisionModel - from mstar.model.pi05.components.siglip import Pi05SiglipEncoder + from mstar.model.loader import load_hf_weights + from mstar.model.pi05.components.siglip import ( + SIGLIP_STACKED_PARAMS, + Pi05SiglipEncoder, + ) torch.manual_seed(0) config = Pi05Config( @@ -748,8 +755,6 @@ def test_pi05_siglip_encoder_matches_hf_reference(): hidden_size=128, ) - ours = Pi05SiglipEncoder(config).to(DEVICE).eval() - siglip_cfg = SiglipVisionConfig( hidden_size=config.vit_hidden_size, intermediate_size=config.vit_intermediate_size, @@ -761,19 +766,34 @@ def test_pi05_siglip_encoder_matches_hf_reference(): # Match Pi05SiglipEncoder, which disables the pooling head to match # the production lerobot/pi05_base checkpoint key set. vision_use_head=False, + attn_implementation="sdpa", ) ref_vision = SiglipVisionModel(siglip_cfg).to(DEVICE).eval() - ref_vision.load_state_dict(ours.vision_model.state_dict()) - images = torch.randn(2, 3, config.vit_image_size, config.vit_image_size, device=DEVICE) - with torch.no_grad(): - ref_features = ref_vision(pixel_values=images).last_hidden_state - ours_inner = ours.vision_model(pixel_values=images).last_hidden_state - ours_full = ours(images) + ours = Pi05SiglipEncoder(config).to(DEVICE).eval() + # HF SiglipVisionModel state_dict keys (``vision_model.encoder.layers.N. + # self_attn.{q,k,v,out}_proj.*`` etc.) line up 1:1 with our encoder after + # the stacked-param rules fuse q/k/v into ``qkv_proj``. + load_hf_weights( + ours, ref_vision.state_dict().items(), stacked_params=SIGLIP_STACKED_PARAMS, + ) - # Pre-connector features should be exactly bit-identical (same HF class, - # same weights, same input). - assert torch.equal(ref_features, ours_inner) + images = torch.randn(2, 3, config.vit_image_size, config.vit_image_size, device=DEVICE) + # Disable TF32 for the comparison: the fused [3*H, H] QKV GEMM tiles + # differently from HF's three separate [H, H] GEMMs, so with TF32 tensor + # cores enabled the two paths round differently (~1e-3 abs — negligible + # for actions, but not bit-exact). In true fp32 the port is identical. + tf32_prev = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + with torch.no_grad(): + ref_features = ref_vision(pixel_values=images).last_hidden_state + ours_inner = ours.vision_model(images) + ours_full = ours(images) + finally: + torch.backends.cuda.matmul.allow_tf32 = tf32_prev + + torch.testing.assert_close(ours_inner, ref_features, atol=1e-5, rtol=1e-5) # Connector output shape: [batch, num_patches, llm_hidden_size] n_patches = (config.vit_image_size // config.vit_patch_size) ** 2 @@ -787,7 +807,7 @@ def test_pi05_siglip_encoder_matches_hf_reference(): def _ref_resize_with_pad(images: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor: """Reference port of ``image_tools.resize_with_pad_torch`` (channels-first - float32 path). Used to verify ``Pi05ViTEncoderSubmodule._preprocess_one``. + float32 path). Used to verify ``Pi05ViTEncoderSubmodule._prepare_one``. """ assert images.dim() == 4 and images.dtype == torch.float32 _, _, cur_h, cur_w = images.shape @@ -807,7 +827,7 @@ def _ref_resize_with_pad(images: torch.Tensor, target_h: int, target_w: int) -> def test_pi05_image_preprocessing_matches_resize_with_pad_letterbox(): - """``Pi05ViTEncoderSubmodule._preprocess_one`` vs openpi's resize_with_pad_torch. + """``Pi05ViTEncoderSubmodule._prepare_one`` vs openpi's resize_with_pad_torch. Tests three cases that exercise the letterbox path: * already-target square (no resize / no pad — identity-ish) @@ -840,7 +860,7 @@ def test_pi05_image_preprocessing_matches_resize_with_pad_letterbox(): for name, shape in cases: torch.manual_seed(hash(name) & 0xFFFF) images = torch.rand(*shape) * 2.0 - 1.0 # [-1, 1] float32 - ours = submodule._preprocess_one(images) + ours = submodule._prepare_one(images) ref = _ref_resize_with_pad(images, cfg.vit_image_size, cfg.vit_image_size) assert ours.shape == ref.shape == (1, 3, 224, 224), f"{name}: shape mismatch" # Padding regions are exactly -1, content region matches the resized @@ -861,7 +881,7 @@ def test_pi05_image_preprocessing_uint8_to_float(): submodule = Pi05ViTEncoderSubmodule(Pi05SiglipEncoder(cfg), cfg) images_u8 = torch.zeros(1, 3, 224, 224, dtype=torch.uint8) images_u8[..., 100:200, 100:200] = 255 - out = submodule._preprocess_one(images_u8) + out = submodule._prepare_one(images_u8) assert out.dtype == torch.float32 # Background pixels (0) -> -1, foreground pixels (255) -> +1. assert out[0, 0, 0, 0].item() == pytest.approx(-1.0, abs=1e-6) diff --git a/test/pi05/compare_with_lerobot.py b/test/pi05/compare_with_lerobot.py index 077dc1f4..670428da 100644 --- a/test/pi05/compare_with_lerobot.py +++ b/test/pi05/compare_with_lerobot.py @@ -68,7 +68,7 @@ def server_seed_for(request_id: str) -> int: def reproduce_server_noise(request_id: str, device: torch.device) -> torch.Tensor: - """Reproduce the noise tensor that ``Pi05LLMSubmodule._preprocess_action_gen`` + """Reproduce the noise tensor that ``Pi05ActionExpertSubmodule._preprocess_action_gen`` will sample on iteration 0 for this request. Server code (mstar/model/pi05/submodules.py):: diff --git a/test/pi05/launch_server_pi05.sh b/test/pi05/launch_server_pi05.sh index 9352cf43..e64656e5 100755 --- a/test/pi05/launch_server_pi05.sh +++ b/test/pi05/launch_server_pi05.sh @@ -45,7 +45,7 @@ mkdir -p "${PI05_CACHE_DIR}" # Pick the yaml: default to base pi05.yaml; override with PI05_CONFIG env var # to swap in a variant (e.g. configs/pi05_droid.yaml for the DROID benchmark). -PI05_CONFIG_PATH="${PI05_CONFIG:-configs/pi05.yaml}" +PI05_CONFIG_PATH="${PI05_CONFIG:-configs/pi05_droid.yaml}" echo "[pi05] launching server" echo " user: ${WHO}" diff --git a/test/pi05/probe_mstar_vs_lerobot.py b/test/pi05/probe_mstar_vs_lerobot.py index 307cc757..1a11e366 100644 --- a/test/pi05/probe_mstar_vs_lerobot.py +++ b/test/pi05/probe_mstar_vs_lerobot.py @@ -9,7 +9,7 @@ Stage 1: Pi05ViTEncoderSubmodule output (per-camera image embeddings) vs lerobot ``paligemma_with_expert.embed_image(image)``. - Stage 2: Pi05LLMSubmodule._preprocess_prefill output (prefix_embs) + Stage 2: Pi05PaligemmaSubmodule._preprocess_prefill output (prefix_embs) vs lerobot ``embed_prefix(images, masks, tokens, masks)``. Stage 3: Action expert first-step velocity vs lerobot ``denoise_step`` first iteration. @@ -333,7 +333,7 @@ def main(): } ] # We don't actually need a real cache_manager for this stage — just to - # call the helper that builds prefix_embs. Pi05LLMSubmodule._preprocess_prefill + # call the helper that builds prefix_embs. Pi05PaligemmaSubmodule._preprocess_prefill # also calls plan_attention/plan_rope which need a real cache manager. # Build a dummy that just no-ops the plan_* calls: class _NoopCache: