From 478b905d7d8031be2d4d7abd1e96cad000714de7 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 9 Jun 2026 20:39:42 +0000 Subject: [PATCH 01/17] benchmark parity for pi05 --- benchmark/dataset.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 9fc64248..a06144c5 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -777,7 +777,10 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, chunk_video = download_fn(self._chunk_video_path(ep_id, cam_key, chunks_size)) png_path = os.path.join(self.local_file_dir, f"ep{ep_id}_cam{len(image_paths)}.png") _decode_frames_to_png_and_video( - chunk_video, [first_local], png_path=png_path, mp4_path=None + video_path=chunk_video, + frame_indices=[first_local], + png_path=png_path, + mp4_path=None ) image_paths.append(png_path) @@ -793,7 +796,9 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, req_type=RequestType.VLA, prompt=language or "manipulate the object", image_path=image_paths[0], - extra_image_paths=image_paths[1:], + # openpi droid policy only uses the first extra image path! So, to be consistent + # we emit the remainder entirely from bechmarking + extra_image_paths=image_paths[1:2], model_kwargs={"robot_state": state}, ) @@ -824,7 +829,9 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, ) mp4_path = os.path.join(self.local_file_dir, f"ep{ep_id}.mp4") _decode_frames_to_png_and_video( - chunk_video, video_local_indices, png_path=None, mp4_path=mp4_path + video_path=chunk_video, + frame_indices=video_local_indices, + png_path=None, mp4_path=mp4_path ) actions = [_to_float_list(f.get(action_col), self.action_dim) @@ -841,7 +848,7 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, "rollout_horizon": self.rollout_horizon, }, ) - + @property def num_requests(self) -> int: return self._num_requests @@ -852,7 +859,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> RequestInput: return self.items[idx] - # ------------------------------------------------------------------ + class VideoMMEDataset(BaseDataset): """ Dataset loader for Video-MME (https://video-mme.github.io/). From 94ef0e58383b8c8f0c7c3bc123953e2f36d7d4ea Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 9 Jun 2026 22:22:47 +0000 Subject: [PATCH 02/17] use droid yaml by default --- configs/pi05_droid.yaml | 1 + mstar/model/pi05/submodules.py | 5 +---- test/pi05/launch_server_pi05.sh | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/configs/pi05_droid.yaml b/configs/pi05_droid.yaml index 74b504b4..a4ae115e 100644 --- a/configs/pi05_droid.yaml +++ b/configs/pi05_droid.yaml @@ -23,6 +23,7 @@ max_seq_len: 2048 # time, never per-request — that's what this yaml override is for. model_kwargs: action_horizon: 15 + num_cameras: 2 node_groups: - node_names: diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 9b55af89..134923a3 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -152,9 +152,6 @@ def get_cuda_graph_configs(self, device: torch.device, tp_world_size: int = 1) - (num_cameras, 3, H, W). preprocess() stacks them to (bs, num_cameras, 3, H, W) so shape[0] == bs, satisfying StatelessCudaGraphRunner's leading-dim == bs requirement. - - compile=False because warmup() already applies torch.compile to - forward_batched; _capture_one captures the compiled callable directly. """ from mstar.engine.cuda_graph_config import BasicBatchedCudaGraphConfig H = W = self.config.vit_image_size @@ -172,7 +169,7 @@ def get_cuda_graph_configs(self, device: torch.device, tp_world_size: int = 1) - }, ), capture_batch_sizes=[1], - compile=False, + compile=False, # empircally does better than compile=True for now ) ] diff --git a/test/pi05/launch_server_pi05.sh b/test/pi05/launch_server_pi05.sh index 9352cf43..e64656e5 100755 --- a/test/pi05/launch_server_pi05.sh +++ b/test/pi05/launch_server_pi05.sh @@ -45,7 +45,7 @@ mkdir -p "${PI05_CACHE_DIR}" # Pick the yaml: default to base pi05.yaml; override with PI05_CONFIG env var # to swap in a variant (e.g. configs/pi05_droid.yaml for the DROID benchmark). -PI05_CONFIG_PATH="${PI05_CONFIG:-configs/pi05.yaml}" +PI05_CONFIG_PATH="${PI05_CONFIG:-configs/pi05_droid.yaml}" echo "[pi05] launching server" echo " user: ${WHO}" From 73d5d29833526a438fd8b21582d80f63d3040a03 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 9 Jun 2026 22:46:12 +0000 Subject: [PATCH 03/17] add cache for droid --- benchmark/dataset.py | 95 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index a06144c5..7dcbdd02 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -663,6 +663,17 @@ def __init__( # producing 8 token-frames; only the first is used as rollout context. self.num_video_frames = num_video_frames + # Fast path: reuse a manifest (PNG paths + robot_state + prompt) built + # on a previous run so repeat benchmarks skip the full-parquet load and + # the per-frame video decode entirely. Only pi05 caches; vjepa2_ac + # streams the episode mp4 directly and is left uncached. + if task == "pi05": + cached = self._load_manifest() + if cached is not None: + print(f" [cache] reusing {len(cached)} pi05 items from manifest") + self.items = self._resize_data(cached) + return + def _dl(filename): return hf_hub_download( self.HF_REPO, filename, repo_type="dataset", cache_dir=cache_dir @@ -724,7 +735,13 @@ def _dl(filename): for frames in episodes.values(): frames.sort(key=lambda r: int(r[frame_col])) - ep_ids = sorted(episodes.keys())[:num_requests] + # pi05 caches a complete manifest, so build every episode once + # (_resize_data truncates to num_requests below) and the cache is reused + # for any num_requests. vjepa2_ac streams mp4s uncached, so keep the + # original [:num_requests] cap to bound its first-run decode cost. + ep_ids = sorted(episodes.keys()) + if task != "pi05": + ep_ids = ep_ids[:num_requests] print(f" using {len(ep_ids)} of {len(episodes)} episodes") self.items: list[RequestInput] = [] @@ -749,6 +766,9 @@ def _dl(filename): if item is not None: self.items.append(item) + if task == "pi05": + self._save_manifest(self.items) + self.items = self._resize_data(self.items) # ------------------------------------------------------------------ @@ -849,6 +869,79 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, }, ) + # ------------------------------------------------------------------ + # pi05 manifest cache + # ------------------------------------------------------------------ + + def _manifest_path(self) -> str: + """Manifest filename keyed by the params that change the built items.""" + return os.path.join( + self.local_file_dir, + f"manifest_pi05_nvf{self.num_video_frames}_ad{self.action_dim}.json", + ) + + def _load_manifest(self) -> list[RequestInput] | None: + """Return cached pi05 RequestInputs, or None to force a rebuild. + + Returns None if the manifest is absent, unreadable, or references a PNG + that no longer exists on disk. + """ + import json as _json + + path = self._manifest_path() + if not os.path.exists(path): + return None + try: + with open(path) as f: + data = _json.load(f) + items: list[RequestInput] = [] + for entry in data["items"]: + img = os.path.join(self.local_file_dir, entry["image_path"]) + extra = [os.path.join(self.local_file_dir, p) + for p in entry.get("extra_image_paths", [])] + for p in (img, *extra): + if not os.path.exists(p): + print(f" [cache] missing {p}; rebuilding") + return None + items.append(RequestInput( + req_type=RequestType.VLA, + prompt=entry["prompt"], + image_path=img, + extra_image_paths=extra, + model_kwargs=entry.get("model_kwargs", {}), + )) + return items or None + except Exception as e: + print(f" [cache] manifest unreadable ({e}); rebuilding") + return None + + def _save_manifest(self, items: list[RequestInput]) -> None: + """Persist built pi05 items so the next run can skip parquet + decode. + + PNG paths are stored as basenames (relative to local_file_dir) and the + write is atomic (tmp + os.replace) so an interrupted run never leaves a + half-written manifest that would later be reused. + """ + import json as _json + + entries = [{ + "prompt": it.prompt, + "image_path": os.path.basename(it.image_path), + "extra_image_paths": [os.path.basename(p) for p in it.extra_image_paths], + "model_kwargs": it.model_kwargs, + } for it in items] + payload = { + "version": 1, + "task": "pi05", + "num_video_frames": self.num_video_frames, + "action_dim": self.action_dim, + "items": entries, + } + tmp = self._manifest_path() + ".tmp" + with open(tmp, "w") as f: + _json.dump(payload, f) + os.replace(tmp, self._manifest_path()) + @property def num_requests(self) -> int: return self._num_requests From 26eca8fa25a4693f0a512c8af86c032e4a2edb06 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 9 Jun 2026 23:01:56 +0000 Subject: [PATCH 04/17] collapse full flow loop into one forward pass --- benchmark/dataset.py | 2 +- mstar/model/pi05/pi05_model.py | 33 ++++++------------- mstar/model/pi05/submodules.py | 58 ++++++++++++++++------------------ 3 files changed, 38 insertions(+), 55 deletions(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 7dcbdd02..46d1da73 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -868,7 +868,7 @@ def _make_vjepa2_ac(self, idx, ep_id, frames, camera_keys, action_col, "rollout_horizon": self.rollout_horizon, }, ) - + # ------------------------------------------------------------------ # pi05 manifest cache # ------------------------------------------------------------------ diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index f1846730..9f9d1f0a 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -40,7 +40,6 @@ GraphEdge, GraphNode, GraphSection, - Loop, Sequential, TensorPointerInfo, ) @@ -417,33 +416,19 @@ def get_graph_walk_graphs(self) -> dict[str, GraphSection]: ] ) - # NOTE: The Loop's terminal ``outputs`` are matched into the section's - # node outputs by **name** (see Loop._replace_outputs_for_final_iter - # in mstar/graph/base.py): on the final iteration, any section-output - # edge whose name matches a terminal output's name is replaced with - # the terminal version. This is the same convention BAGEL's image_gen - # uses (section returns ``latents`` looping back to LLM, terminal - # output is ``name="latents" → vae_decoder``). So our terminal output - # MUST be named ``noisy_actions`` to match the section's loop-back - # edge — the name is just a graph-internal key, while the actual - # client-facing modality bucket is determined by ``output_modality``. - action_gen = Loop( - section=GraphNode( - name="LLM", - input_names=["noisy_actions", "timestep_index"], - outputs=[ - GraphEdge(next_node="LLM", name="noisy_actions"), - GraphEdge(next_node="LLM", name="timestep_index"), - ], - ), - max_iters=self.config.num_flow_steps, + # NOTE: the full action generation flow loop is extremely short (total < 50ms), so + # we opt to have it as one node to reduce cuda graph startup, flashinfer planning, + # etc. overhead. Cache planning only needs to happen at the beginning of the flow + # loop, so this collapsed loop is valid. + action_gen = GraphNode( + name="LLM", + input_names=["noisy_actions"], outputs=[ GraphEdge( next_node=EMIT_TO_CLIENT, - name="noisy_actions", + name="actions", output_modality="action", - persist=True, - ), + ) ], ) diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 134923a3..a29ce11b 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -509,22 +509,20 @@ def _prepare_inputs_action_gen( action_horizon = self.config.action_horizon action_dim = self.config.action_dim - if "noisy_actions" not in inputs or len(inputs["noisy_actions"]) == 0: - generator = torch.Generator(device=device).manual_seed(fwd_info.random_seed) - noisy = torch.randn( - action_horizon, action_dim, device=device, generator=generator - ) - ts = torch.zeros(1, device=device, dtype=torch.long) - else: - noisy = inputs["noisy_actions"][0] - ts = inputs["timestep_index"][0] + generator = torch.Generator(device=device).manual_seed(fwd_info.random_seed) + noisy = torch.randn( + action_horizon, action_dim, device=device, generator=generator + ) + ts = torch.zeros(1, device=device, dtype=torch.long) seq_len = action_horizon - return ARNodeInputs(input_seq_len=seq_len, - tensor_inputs={ - "noisy_actions": noisy, - "ts": ts - }) + return ARNodeInputs( + input_seq_len=seq_len, + tensor_inputs={ + "noisy_actions": noisy, + "ts": ts + } + ) def preprocess( @@ -678,12 +676,13 @@ def _forward_action_gen_batched( horizon = self.config.action_horizon - next_actions, next_index = self._euler_step( - noisy_actions, timestep_index, - fraction=fraction, - time_emb_buffer=time_emb_buffer, - cache_handle=cache_manager - ) + for _ in range(self.config.num_flow_steps): + noisy_actions, timestep_index = self._euler_step( + noisy_actions, timestep_index, + fraction=fraction, + time_emb_buffer=time_emb_buffer, + cache_handle=cache_manager + ) # Split back per-request by horizon. result: dict[str, NameToTensorList] = {} @@ -691,8 +690,7 @@ def _forward_action_gen_batched( start = i * horizon end = start + horizon result[rid] = { - "noisy_actions": [next_actions[start:end]], - "timestep_index": [next_index[i:i+1]], + "actions": [noisy_actions[start:end]], } return result @@ -733,12 +731,13 @@ def _forward_action_gen( if isinstance(timestep_index, list): timestep_index = timestep_index[0] - next_actions, next_index = self._euler_step( - noisy_actions, timestep_index, - fraction=fraction, - time_emb_buffer=time_emb_buffer, - cache_handle=cache_handle - ) + for _ in range(self.config.num_flow_steps): + noisy_actions, timestep_index = self._euler_step( + noisy_actions, timestep_index, + fraction=fraction, + time_emb_buffer=time_emb_buffer, + cache_handle=cache_handle + ) # We ALWAYS return both loop-back edges, even on the final iteration. # The Loop primitive (mstar/graph/base.py:Loop) handles the final-iter # swap automatically: it matches the section's output ``noisy_actions`` @@ -747,8 +746,7 @@ def _forward_action_gen( # filters out the ``timestep_index`` loop-back edge. Same convention # BAGEL's image_gen uses for ``latents`` / ``time_index``. return { - "noisy_actions": [next_actions], - "timestep_index": [next_index], + "actions": [noisy_actions], } def _euler_step( From c81cd50a0916fa6054fd0cd368aa24bdd1935767 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Tue, 9 Jun 2026 23:56:18 +0000 Subject: [PATCH 05/17] port over pi05 siglip --- mstar/model/pi05/components/siglip.py | 208 +++++++++++++++--- mstar/model/pi05/pi05_model.py | 19 +- test/modular/test_pi05_model.py | 7 +- .../test_pi05_reference_equivalence.py | 70 +++--- 4 files changed, 239 insertions(+), 65 deletions(-) diff --git a/mstar/model/pi05/components/siglip.py b/mstar/model/pi05/components/siglip.py index d017558c..dbcdb21d 100644 --- a/mstar/model/pi05/components/siglip.py +++ b/mstar/model/pi05/components/siglip.py @@ -1,45 +1,197 @@ -"""SigLIP vision encoder for Pi0.5. +"""SigLIP vision encoder for Pi0.5 (native mminf port). -Thin wrapper around the HuggingFace SiglipVisionModel that produces a fixed -number of image tokens (default 256) per camera image at the resolution Pi0.5 -expects (224x224). A learned linear projection maps SigLIP's hidden dim to the -LLM hidden dim so the resulting tokens can be concatenated with PaliGemma -language token embeddings. +Ports the inference path of HuggingFace's ``SiglipVisionModel`` (So400m/14) +into mminf so we own the code and can fuse projections. Differences from the +transformers implementation: + + * **Fused QKV** — the three ``q/k/v_proj`` GEMMs are merged into one + ``QKVParallelLinear`` (loaded from the separate checkpoint keys via the + ``q/k/v`` stacked-param rules; see ``SIGLIP_STACKED_PARAMS``). + * **SDPA attention** — full bidirectional ``scaled_dot_product_attention``. + We do NOT use flash-attn or the Triton ``sliding_window_attn`` here: the + encoder runs in **fp32** (Pi05VitEncoderSubmodule forces it, since bf16 + rounding over 27 layers perturbs the actions) and flash-attn is fp16/bf16 + only, while the Triton kernel is causal-only and rejects head_dim=72. + * **Inference-only** — all weight-init, gradient-checkpointing, the text + tower, pooling head, and variable-resolution position interpolation are + dropped. Images are a fixed 224x224 → 256 patches. + +Only ``last_hidden_state`` is consumed downstream (``vision_use_head=False`` +in the original), so the pooling head is omitted entirely. """ +from __future__ import annotations import torch +import torch.nn.functional as F from torch import nn -from transformers import SiglipVisionConfig, SiglipVisionModel +from mstar.distributed.communication import TPCommGroup +from mstar.model.components.distributed.linear import QKVParallelLinear +from mstar.model.loader import StackedParamRule from mstar.model.pi05.config import Pi05Config +# SigLIP architectural constants not carried on Pi05Config. These match +# HF ``SiglipVisionConfig`` defaults for the So400m checkpoint. +_LAYER_NORM_EPS = 1e-6 -class Pi05SiglipEncoder(nn.Module): - """SigLIP image encoder + linear connector to the LLM hidden size.""" +# Route the checkpoint's separate q/k/v projection keys into the fused +# ``qkv_proj`` parameter. Consumed by ``load_hf_weights`` when loading the +# encoder (the SigLIP MLP is ungated, so there are no gate/up rules). +SIGLIP_STACKED_PARAMS: list[StackedParamRule] = [ + StackedParamRule(".qkv_proj", ".q_proj", "q"), + StackedParamRule(".qkv_proj", ".k_proj", "k"), + StackedParamRule(".qkv_proj", ".v_proj", "v"), +] + + +class _SiglipVisionEmbeddings(nn.Module): + """Conv patch embedding + learned position embedding. + + Fixed-resolution only: 224x224 input → a 16x16 grid of 14px patches → + 256 tokens. Position ids are computed inline (no buffer) so the module + has no non-persistent state to re-materialize after ``to_empty``. + """ def __init__(self, config: Pi05Config): super().__init__() - self.config = config + self.embed_dim = config.vit_hidden_size + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=config.vit_patch_size, + stride=config.vit_patch_size, + padding="valid", + ) + self.num_positions = (config.vit_image_size // config.vit_patch_size) ** 2 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + # pixel_values: (N, 3, H, W) -> patches (N, embed_dim, gh, gw). + patch_embeds = self.patch_embedding(pixel_values.to(self.patch_embedding.weight.dtype)) + embeddings = patch_embeds.flatten(2).transpose(1, 2) # (N, num_patches, embed_dim) + positions = torch.arange(self.num_positions, device=embeddings.device) + return embeddings + self.position_embedding(positions) + + +class _SiglipAttention(nn.Module): + """Bidirectional multi-head self-attention with a fused QKV projection. + + Full MHA (no GQA): num_kv_heads == num_heads. Attention is computed + per-image over its own 256 patches (the batch dim isolates images), so + no attention mask is needed. + """ - siglip_cfg = SiglipVisionConfig( - hidden_size=config.vit_hidden_size, - intermediate_size=config.vit_intermediate_size, - num_hidden_layers=config.vit_num_layers, - num_attention_heads=config.vit_num_heads, - num_channels=3, - image_size=config.vit_image_size, - patch_size=config.vit_patch_size, - # Pi0.5 / lerobot's PaliGemma SigLIP does NOT use the pooling - # head — only ``last_hidden_state`` is consumed downstream by the - # multi_modal_projector. Disabling the head matches the - # production checkpoint key set (no ``vision_model.head.*`` keys). - vision_use_head=False, + def __init__(self, config: Pi05Config): + super().__init__() + self.embed_dim = config.vit_hidden_size + self.num_heads = config.vit_num_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"vit_hidden_size {self.embed_dim} not divisible by " + f"vit_num_heads {self.num_heads}" + ) + self.scale = self.head_dim**-0.5 + + # Trivial (single-rank) comm group: reuses the TP-aware fused-QKV + # loader without any actual sharding. bias=True — SigLIP projects + # q/k/v with bias. + self.qkv_proj = QKVParallelLinear( + comm_group=TPCommGroup.trivial(), + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_heads, + bias=True, + ) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + n, seq_len, _ = hidden_states.shape + qkv = self.qkv_proj(hidden_states) # (N, seq, 3*embed_dim) + q, k, v = qkv.split([self.embed_dim, self.embed_dim, self.embed_dim], dim=-1) + + # (N, seq, embed) -> (N, heads, seq, head_dim) for SDPA. + def to_heads(x: torch.Tensor) -> torch.Tensor: + return x.view(n, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + out = F.scaled_dot_product_attention( + to_heads(q), to_heads(k), to_heads(v), scale=self.scale, ) - self.vision_model = SiglipVisionModel(siglip_cfg) - self.connector = nn.Linear( - config.vit_hidden_size, config.hidden_size, bias=True + out = out.transpose(1, 2).reshape(n, seq_len, self.embed_dim) + return self.out_proj(out) + + +class _SiglipMLP(nn.Module): + """Ungated 2-layer MLP with gelu-tanh activation.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.fc1 = nn.Linear(config.vit_hidden_size, config.vit_intermediate_size) + self.activation_fn = nn.GELU(approximate="tanh") # gelu_pytorch_tanh + self.fc2 = nn.Linear(config.vit_intermediate_size, config.vit_hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.fc2(self.activation_fn(self.fc1(hidden_states))) + + +class _SiglipEncoderLayer(nn.Module): + """Pre-norm transformer block: ln1→attn→res, ln2→mlp→res.""" + + def __init__(self, config: Pi05Config): + super().__init__() + embed_dim = config.vit_hidden_size + self.layer_norm1 = nn.LayerNorm(embed_dim, eps=_LAYER_NORM_EPS) + self.self_attn = _SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(embed_dim, eps=_LAYER_NORM_EPS) + self.mlp = _SiglipMLP(config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states + self.self_attn(self.layer_norm1(hidden_states)) + hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states)) + return hidden_states + + +class _SiglipEncoder(nn.Module): + """Stack of encoder layers. Named to match the ``encoder.layers.N`` + checkpoint key layout so weights load without per-layer remapping.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.layers = nn.ModuleList( + [_SiglipEncoderLayer(config) for _ in range(config.vit_num_layers)] ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class _SiglipVisionTransformer(nn.Module): + """Embeddings → encoder stack → final layer norm.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.embeddings = _SiglipVisionEmbeddings(config) + self.encoder = _SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(config.vit_hidden_size, eps=_LAYER_NORM_EPS) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(hidden_states) + return self.post_layernorm(hidden_states) + + +class Pi05SiglipEncoder(nn.Module): + """SigLIP image encoder + linear connector to the LLM hidden size.""" + + def __init__(self, config: Pi05Config): + super().__init__() + self.config = config + self.vision_model = _SiglipVisionTransformer(config) + self.connector = nn.Linear(config.vit_hidden_size, config.hidden_size, bias=True) + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """Encode a batch of images into LLM-space tokens. @@ -50,7 +202,5 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape ``(N, tokens_per_image, hidden_size)``. """ - outputs = self.vision_model(pixel_values=pixel_values) - # last_hidden_state: [N, num_patches, vit_hidden_size] - features = outputs.last_hidden_state + features = self.vision_model(pixel_values) # (N, num_patches, vit_hidden) return self.connector(features) diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index 9f9d1f0a..eab123be 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -48,7 +48,7 @@ from mstar.model.loader import LLAMA_STACKED_PARAMS, load_hf_weights from mstar.model.pi05.components.action_expert import Pi05ActionExpert, Pi05TimeMLP from mstar.model.pi05.components.paligemma import Pi05PaliGemmaExpert -from mstar.model.pi05.components.siglip import Pi05SiglipEncoder +from mstar.model.pi05.components.siglip import SIGLIP_STACKED_PARAMS, Pi05SiglipEncoder from mstar.model.pi05.components.tokenization import Pi05Tokenizer from mstar.model.pi05.config import Pi05Config, load_pi05_config from mstar.model.pi05.submodules import Pi05LLMSubmodule, Pi05ViTEncoderSubmodule @@ -364,12 +364,13 @@ def _extract_siglip_state_dict( if inner.startswith("vision_tower.vision_model."): # The lerobot key is # paligemma.model.vision_tower.vision_model. - # Pi05SiglipEncoder owns ``self.vision_model = SiglipVisionModel(...)``, - # and HF's SiglipVisionModel has its own inner ``.vision_model`` - # attribute, so the corresponding key is - # ``vision_model.vision_model.``. We replace - # ``vision_tower`` with ``vision_model`` to make that explicit. - out["vision_model." + inner.removeprefix("vision_tower.")] = tensor + # Pi05SiglipEncoder owns ``self.vision_model`` (our native + # _SiglipVisionTransformer) directly, so the matching key is + # ``vision_model.``. Stripping ``vision_tower.`` yields + # exactly that. The separate q/k/v_proj keys under + # `` = encoder.layers.N.self_attn.*`` are fused into + # ``qkv_proj`` by SIGLIP_STACKED_PARAMS at load time. + out[inner.removeprefix("vision_tower.")] = tensor elif inner.startswith("multi_modal_projector.linear."): sub = inner.removeprefix("multi_modal_projector.linear.") out[f"connector.{sub}"] = tensor @@ -641,7 +642,9 @@ def _init_vit_components(self, device: str): # parameter in the target module, so the leftover keys are dropped # without needing an explicit ``strict=False`` switch. siglip_sd = self._extract_siglip_state_dict(flat) - load_hf_weights(self.siglip, siglip_sd.items()) + load_hf_weights( + self.siglip, siglip_sd.items(), stacked_params=SIGLIP_STACKED_PARAMS, + ) def _init_llm_components(self, device: str): if self.embed_tokens is not None: diff --git a/test/modular/test_pi05_model.py b/test/modular/test_pi05_model.py index f373ff19..6b385e40 100644 --- a/test/modular/test_pi05_model.py +++ b/test/modular/test_pi05_model.py @@ -376,7 +376,8 @@ def t(*shape): pali = "paligemma_with_expert.paligemma.model" flat = { - # Vision tower -> vision_model.vision_model. + # Vision tower -> vision_model. (native port; our + # Pi05SiglipEncoder owns ``vision_model`` directly, no double nest). f"{pali}.vision_tower.vision_model.embeddings.patch_embedding.weight": t(1152, 3, 14, 14), f"{pali}.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight": t(1152), # multi_modal_projector.linear -> connector @@ -389,8 +390,8 @@ def t(*shape): } siglip = Pi05Model._extract_siglip_state_dict(flat) - assert "vision_model.vision_model.embeddings.patch_embedding.weight" in siglip - assert "vision_model.vision_model.encoder.layers.0.layer_norm1.weight" in siglip + assert "vision_model.embeddings.patch_embedding.weight" in siglip + assert "vision_model.encoder.layers.0.layer_norm1.weight" in siglip assert "connector.weight" in siglip assert "connector.bias" in siglip assert not any("multi_modal_projector" in k for k in siglip) diff --git a/test/modular/test_pi05_reference_equivalence.py b/test/modular/test_pi05_reference_equivalence.py index 89d73963..a48ff73f 100644 --- a/test/modular/test_pi05_reference_equivalence.py +++ b/test/modular/test_pi05_reference_equivalence.py @@ -26,8 +26,9 @@ ``BatchPrefillWithPagedKVCacheWrapper`` against vanilla SDPA, both for the bidirectional prefill and the suffix-attends-to-prefix flow used during the action_gen denoising loop - * Pi05SiglipEncoder produces bit-identical features to a freshly-built - HF SiglipVisionModel with matched weights + * Pi05SiglipEncoder (native port w/ fused QKV + SDPA) produces features + matching a freshly-built HF SiglipVisionModel, loaded via the same + stacked-param path the real checkpoint loader uses The attention used inside the action-expert tests is a small vanilla-SDPA implementation shared by the mock cache handle and the reference code; the @@ -725,17 +726,23 @@ def test_flashinfer_paged_prefill_attention_matches_sdpa(): def test_pi05_siglip_encoder_matches_hf_reference(): - """``Pi05SiglipEncoder`` produces bit-identical features to HF SiglipVisionModel. - - Both wrap the same HF class; the only difference is mstar adds a - ``nn.Linear`` connector to project to the LLM hidden size. The reference - PaliGemma uses an analogous ``multi_modal_projector``. We verify the - pre-connector features match exactly and the connector preserves the - expected output shape. + """``Pi05SiglipEncoder`` (native port) matches HF SiglipVisionModel. + + The port fuses q/k/v into one projection and runs SDPA, so it is no + longer the same class as the reference. We load the HF reference's + weights into our encoder through ``load_hf_weights`` with + ``SIGLIP_STACKED_PARAMS`` — the same stacked-param path the real + checkpoint loader uses — then check the pre-connector features match + (allclose, since fused-QKV + SDPA differ from HF only by fp32 rounding) + and the connector preserves the expected output shape. """ from transformers import SiglipVisionConfig, SiglipVisionModel - from mstar.model.pi05.components.siglip import Pi05SiglipEncoder + from mstar.model.loader import load_hf_weights + from mstar.model.pi05.components.siglip import ( + SIGLIP_STACKED_PARAMS, + Pi05SiglipEncoder, + ) torch.manual_seed(0) config = Pi05Config( @@ -748,8 +755,6 @@ def test_pi05_siglip_encoder_matches_hf_reference(): hidden_size=128, ) - ours = Pi05SiglipEncoder(config).to(DEVICE).eval() - siglip_cfg = SiglipVisionConfig( hidden_size=config.vit_hidden_size, intermediate_size=config.vit_intermediate_size, @@ -761,19 +766,34 @@ def test_pi05_siglip_encoder_matches_hf_reference(): # Match Pi05SiglipEncoder, which disables the pooling head to match # the production lerobot/pi05_base checkpoint key set. vision_use_head=False, + attn_implementation="sdpa", ) ref_vision = SiglipVisionModel(siglip_cfg).to(DEVICE).eval() - ref_vision.load_state_dict(ours.vision_model.state_dict()) - images = torch.randn(2, 3, config.vit_image_size, config.vit_image_size, device=DEVICE) - with torch.no_grad(): - ref_features = ref_vision(pixel_values=images).last_hidden_state - ours_inner = ours.vision_model(pixel_values=images).last_hidden_state - ours_full = ours(images) + ours = Pi05SiglipEncoder(config).to(DEVICE).eval() + # HF SiglipVisionModel state_dict keys (``vision_model.encoder.layers.N. + # self_attn.{q,k,v,out}_proj.*`` etc.) line up 1:1 with our encoder after + # the stacked-param rules fuse q/k/v into ``qkv_proj``. + load_hf_weights( + ours, ref_vision.state_dict().items(), stacked_params=SIGLIP_STACKED_PARAMS, + ) - # Pre-connector features should be exactly bit-identical (same HF class, - # same weights, same input). - assert torch.equal(ref_features, ours_inner) + images = torch.randn(2, 3, config.vit_image_size, config.vit_image_size, device=DEVICE) + # Disable TF32 for the comparison: the fused [3*H, H] QKV GEMM tiles + # differently from HF's three separate [H, H] GEMMs, so with TF32 tensor + # cores enabled the two paths round differently (~1e-3 abs — negligible + # for actions, but not bit-exact). In true fp32 the port is identical. + tf32_prev = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + with torch.no_grad(): + ref_features = ref_vision(pixel_values=images).last_hidden_state + ours_inner = ours.vision_model(images) + ours_full = ours(images) + finally: + torch.backends.cuda.matmul.allow_tf32 = tf32_prev + + torch.testing.assert_close(ours_inner, ref_features, atol=1e-5, rtol=1e-5) # Connector output shape: [batch, num_patches, llm_hidden_size] n_patches = (config.vit_image_size // config.vit_patch_size) ** 2 @@ -787,7 +807,7 @@ def test_pi05_siglip_encoder_matches_hf_reference(): def _ref_resize_with_pad(images: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor: """Reference port of ``image_tools.resize_with_pad_torch`` (channels-first - float32 path). Used to verify ``Pi05ViTEncoderSubmodule._preprocess_one``. + float32 path). Used to verify ``Pi05ViTEncoderSubmodule._prepare_one``. """ assert images.dim() == 4 and images.dtype == torch.float32 _, _, cur_h, cur_w = images.shape @@ -807,7 +827,7 @@ def _ref_resize_with_pad(images: torch.Tensor, target_h: int, target_w: int) -> def test_pi05_image_preprocessing_matches_resize_with_pad_letterbox(): - """``Pi05ViTEncoderSubmodule._preprocess_one`` vs openpi's resize_with_pad_torch. + """``Pi05ViTEncoderSubmodule._prepare_one`` vs openpi's resize_with_pad_torch. Tests three cases that exercise the letterbox path: * already-target square (no resize / no pad — identity-ish) @@ -840,7 +860,7 @@ def test_pi05_image_preprocessing_matches_resize_with_pad_letterbox(): for name, shape in cases: torch.manual_seed(hash(name) & 0xFFFF) images = torch.rand(*shape) * 2.0 - 1.0 # [-1, 1] float32 - ours = submodule._preprocess_one(images) + ours = submodule._prepare_one(images) ref = _ref_resize_with_pad(images, cfg.vit_image_size, cfg.vit_image_size) assert ours.shape == ref.shape == (1, 3, 224, 224), f"{name}: shape mismatch" # Padding regions are exactly -1, content region matches the resized @@ -861,7 +881,7 @@ def test_pi05_image_preprocessing_uint8_to_float(): submodule = Pi05ViTEncoderSubmodule(Pi05SiglipEncoder(cfg), cfg) images_u8 = torch.zeros(1, 3, 224, 224, dtype=torch.uint8) images_u8[..., 100:200, 100:200] = 255 - out = submodule._preprocess_one(images_u8) + out = submodule._prepare_one(images_u8) assert out.dtype == torch.float32 # Background pixels (0) -> -1, foreground pixels (255) -> +1. assert out[0, 0, 0, 0].item() == pytest.approx(-1.0, abs=1e-6) From 692d43249f7b06e55a671fd0d26c5f34e05d1d47 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Wed, 10 Jun 2026 01:09:07 +0000 Subject: [PATCH 06/17] single graph walk for pi05 (remove conductor overhead) --- configs/pi05.yaml | 3 +- configs/pi05_droid.yaml | 3 +- mstar/engine/kv_cache_engine.py | 1 + mstar/engine/kv_store.py | 4 + mstar/model/pi05/pi05_model.py | 97 +++---- mstar/model/pi05/submodules.py | 433 +++++++++++++++++--------------- mstar/worker/worker.py | 4 +- 7 files changed, 285 insertions(+), 260 deletions(-) diff --git a/configs/pi05.yaml b/configs/pi05.yaml index e7aaaf09..85202fe7 100644 --- a/configs/pi05.yaml +++ b/configs/pi05.yaml @@ -8,6 +8,7 @@ node_groups: - 0 - node_names: - - LLM + - paligemma_LLM + - action_expert_LLM ranks: - 0 diff --git a/configs/pi05_droid.yaml b/configs/pi05_droid.yaml index a4ae115e..4259823d 100644 --- a/configs/pi05_droid.yaml +++ b/configs/pi05_droid.yaml @@ -32,6 +32,7 @@ node_groups: - 0 - node_names: - - LLM + - paligemma_LLM + - action_expert_LLM ranks: - 0 diff --git a/mstar/engine/kv_cache_engine.py b/mstar/engine/kv_cache_engine.py index 983e85e3..7e3c0503 100644 --- a/mstar/engine/kv_cache_engine.py +++ b/mstar/engine/kv_cache_engine.py @@ -904,6 +904,7 @@ def check_ready( ).items(): if needed_labels is not None and label not in needed_labels: continue + print("ASYNC RETRIEVE HAPPENING") cache_mgmt.alloc_manager.start_async_retrieve( request_id, label, seq_info ) diff --git a/mstar/engine/kv_store.py b/mstar/engine/kv_store.py index 380d5755..b2975429 100644 --- a/mstar/engine/kv_store.py +++ b/mstar/engine/kv_store.py @@ -604,6 +604,10 @@ def add_request(self, request_id: str, labels: list[str]=None): } def remove_request(self, request_id: str): + if request_id not in self.request_states: + # This request has already been removed; e.g., if we have colocated + # nodes sharing a KV cache + return for label in self.request_states[request_id]: self.wait_for_retrieves(request_id, label) with self._lock: diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index eab123be..5475d012 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -51,7 +51,7 @@ from mstar.model.pi05.components.siglip import SIGLIP_STACKED_PARAMS, Pi05SiglipEncoder from mstar.model.pi05.components.tokenization import Pi05Tokenizer from mstar.model.pi05.config import Pi05Config, load_pi05_config -from mstar.model.pi05.submodules import Pi05LLMSubmodule, Pi05ViTEncoderSubmodule +from mstar.model.pi05.submodules import Pi05ActionExpertSubmodule, Pi05PaligemmaSubmodule, Pi05ViTEncoderSubmodule from mstar.model.submodule_base import NodeSubmodule logger = logging.getLogger(__name__) @@ -122,8 +122,6 @@ def _reset_non_persistent_buffers(module: nn.Module, device) -> None: class Pi05Model(Model): """Pi0.5 vision-language-action model implementation.""" - - PREFILL_WALK = "prefill" ACTION_GEN_WALK = "action_gen" def __init__( @@ -387,54 +385,50 @@ def get_kv_cache_config(self) -> KVCacheConfig: head_dim=self.config.head_dim, max_seq_len=self.config.max_position_embeddings, num_qo_heads=self.config.num_qo_heads, + nodes=["paligemma_LLM", "action_expert_LLM"] )] def get_node_engine_types(self) -> dict[str, EngineType]: return { "vit_encoder": EngineType.STATELESS, - "LLM": EngineType.KV_CACHE, + "paligemma_LLM": EngineType.KV_CACHE, + "action_expert_LLM": EngineType.KV_CACHE, } def get_graph_walk_graphs(self) -> dict[str, GraphSection]: - # Pi0.5 encodes the robot state as a decimal-string suffix on the - # language prompt (e.g. "Task: pick up the block, State: 12 87 ...; - # \nAction: ") and tokenizes the whole thing with the PaliGemma - # tokenizer. So the model only ever sees a single "text_inputs" - # stream — there are no separate state-bin tokens. This matches - # lerobot's processor_pi05.Pi05PrepareStateTokenizerProcessorStep. - prefill = Sequential( + # NOTE: the full action generation flow loop is extremely short (total < 50ms), so + # we opt to have it as one node to reduce cuda graph startup, flashinfer planning, + # etc. overhead. Cache planning only needs to happen at the beginning of the flow + # loop, so this collapsed loop is valid. + action_gen = Sequential( [ GraphNode( name="vit_encoder", input_names=["image_inputs"], - outputs=[GraphEdge(next_node="LLM", name="img_emb")], + outputs=[GraphEdge(next_node="paligemma_LLM", name="img_emb")], ), GraphNode( - name="LLM", + name="paligemma_LLM", input_names=["img_emb", "text_inputs"], - outputs=[], + outputs=[ + GraphEdge(next_node="action_expert_LLM", name="action_expert_trigger") + ], ), - ] - ) - - # NOTE: the full action generation flow loop is extremely short (total < 50ms), so - # we opt to have it as one node to reduce cuda graph startup, flashinfer planning, - # etc. overhead. Cache planning only needs to happen at the beginning of the flow - # loop, so this collapsed loop is valid. - action_gen = GraphNode( - name="LLM", - input_names=["noisy_actions"], - outputs=[ - GraphEdge( - next_node=EMIT_TO_CLIENT, - name="actions", - output_modality="action", + GraphNode( + name="action_expert_LLM", + input_names=["action_expert_trigger"], + outputs=[ + GraphEdge( + next_node=EMIT_TO_CLIENT, + name="actions", + output_modality="action", + ) + ], ) - ], + ] ) return { - self.PREFILL_WALK: prefill, self.ACTION_GEN_WALK: action_gen, } @@ -513,7 +507,7 @@ def get_initial_forward_pass_args( full_metadata = CurrentForwardConductorMetadata( input_modalities=input_modalities, output_modalities=output_modalities, - graph_walk=self.PREFILL_WALK, + graph_walk=self.ACTION_GEN_WALK, is_prefill=True, kwargs={}, ) @@ -524,7 +518,7 @@ def get_initial_forward_pass_args( edge.tensor_info = input_signals["image_inputs"] inputs.append(edge) if "text_inputs" in input_signals: - edge = GraphEdge(next_node="LLM", name="text_inputs") + edge = GraphEdge(next_node="paligemma_LLM", name="text_inputs") edge.tensor_info = input_signals["text_inputs"] inputs.append(edge) @@ -533,7 +527,6 @@ def get_initial_forward_pass_args( full_metadata=full_metadata, inputs=inputs, unpersist_tensors=unpersist_tensors, - step_metadata={"is_prefill": True}, ) def get_partition_forward_pass_args( @@ -543,29 +536,12 @@ def get_partition_forward_pass_args( persist_signals: dict[str, list[TensorPointerInfo]], incoming_connections: list[StreamingConnectionState] | None = None, ) -> ForwardPassArgs: - metadata = partition_metadata - request_done = False - inputs: list[GraphEdge] = [] - - if metadata.graph_walk == self.PREFILL_WALK: - metadata.is_prefill = False - metadata.graph_walk = self.ACTION_GEN_WALK - # Inputs for the first action_gen iteration are sampled inside the - # LLM submodule's preprocess (Gaussian noise + timestep_index=0). - inputs = [ - GraphEdge(next_node="LLM", name="noisy_actions"), - GraphEdge(next_node="LLM", name="timestep_index"), - ] - elif metadata.graph_walk == self.ACTION_GEN_WALK: - request_done = True - - unpersist_tensors = sum([inp.tensor_info for inp in inputs], start=[]) + # only one graph walk, so we're done return ForwardPassArgs( - full_metadata=metadata, - inputs=inputs, - unpersist_tensors=unpersist_tensors, - step_metadata={"is_prefill": metadata.is_prefill}, - request_done=request_done, + full_metadata=partition_metadata, + inputs=[], + unpersist_tensors=[], + request_done=True, ) # ------------------------------------------------------------------ @@ -591,11 +567,16 @@ def _create_submodule( return Pi05ViTEncoderSubmodule( encoder=self.siglip, config=self.config ) - if node_name == "LLM": + if node_name == "paligemma_LLM": self._init_llm_components(device) - return Pi05LLMSubmodule( + return Pi05PaligemmaSubmodule( embed_tokens=self.embed_tokens, paligemma=self.paligemma, + config=self.config, + ) + if node_name == "action_expert_LLM": + self._init_llm_components(device) + return Pi05ActionExpertSubmodule( action_expert=self.action_expert, action_in_proj=self.action_in_proj, action_out_proj=self.action_out_proj, diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index a29ce11b..12e16a84 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -158,7 +158,7 @@ def get_cuda_graph_configs(self, device: torch.device, tp_world_size: int = 1) - num_cameras = self.config.num_cameras return [ BasicBatchedCudaGraphConfig( - capture_graph_walk="prefill", + capture_graph_walk="action_gen", single_request_inputs=ARNodeInputs( input_seq_len=0, # not used by StatelessCudaGraphRunner tensor_inputs={ @@ -251,20 +251,10 @@ def forward_batched( } -class Pi05LLMSubmodule(ARNodeSubmodule): - """Combined PaliGemma prefix expert + action expert. - - Dispatches by graph_walk: - - ``prefill``: PaliGemma forwards over the prefix - ``[image_tokens, language_tokens, state_tokens]`` and - writes the KV cache. - - ``action_gen``: action expert runs one Euler step of flow-matching - denoising over the action suffix, attending to the - frozen prefix KV cache. The current ``noisy_actions`` - and ``timestep_index`` cycle through the loop via - loop-back graph edges; on the final iteration the - denoised action tensor is emitted as ``action_output``. - """ +class Pi05PaligemmaSubmodule(ARNodeSubmodule): + """PaliGemma prefix expert: forwards over the prefix + ``[image_tokens, language_tokens]`` and writes the KV cache that the + action expert later reads.""" # Parameter name fragments whose weights must stay in float32 even when # the rest of the model is bf16. Matches lerobot's @@ -280,12 +270,229 @@ class Pi05LLMSubmodule(ARNodeSubmodule): # For the default image size and a simple text prompt, one request is about 400 tokens PREFILL_TOKEN_BUCKETS = [512, 1024, 1800] # 2048 was giving OOM PREFILL_CAPTURE_BATCH_SIZES = [1, 2, 4] - ACTION_GEN_CAPTURE_BATCH_SIZES = [1, 2, 4] def __init__( self, embed_tokens: nn.Embedding, paligemma: Pi05PaliGemmaExpert, + config: Pi05Config, + ): + super().__init__() + self.embed_tokens = embed_tokens + self.paligemma = paligemma + self.config = config + # lerobot scales images by sqrt(H) but text by H: its + # embed_language_tokens routes through HF Gemma's + # GemmaTextScaledWordEmbedding, which already bakes in a sqrt(H) factor, + # so the effective text scale is sqrt(H)*sqrt(H) = H. Our plain + # nn.Embedding has no internal scale, so we apply the full H here. + # Mismatching makes the text prefix ~45x too small and corrupts context. + self._image_embed_scale = math.sqrt(config.hidden_size) + self._text_embed_scale = float(config.hidden_size) + + def to(self, *args, **kwargs): + """Apply standard ``to()`` then upcast norm parameters back to fp32. + + Matches lerobot's ``to_bfloat16_for_selected_params`` which keeps + ``input_layernorm``, ``post_attention_layernorm``, and ``model.norm`` + in float32 while the rest of the transformer runs in bfloat16. + """ + result = super().to(*args, **kwargs) + for name, param in result.named_parameters(): + if any(sel in name for sel in self._FLOAT32_PARAM_SELECTORS): + param.data = param.data.to(torch.float32) + return result + + def can_batch( + self, + batch: NodeBatch, + model_inputs: list[NodeInputs], + ) -> bool: + return True + + def get_needed_cache_labels( + self, + graph_walk: str, + per_request_info: dict[str, CurrentForwardPassInfo], + ) -> list[str] | None: + return ["main"] + + def _embed_tokens_scaled(self, ids: torch.Tensor) -> torch.Tensor: + emb = self.embed_tokens(ids) + return emb * self._text_embed_scale + + def get_cuda_graph_configs( + self, device: torch.device, tp_world_size: int = 1, + ) -> list[BasicBatchedCudaGraphConfig | FlashInferPackedCudaGraphConfig]: + prefill_packed = { + num_tokens: { + "prefix_embs": torch.zeros(num_tokens, self.config.hidden_size, device=device) + } + for num_tokens in self.PREFILL_TOKEN_BUCKETS + } + return [ + FlashInferPackedCudaGraphConfig( + capture_graph_walk="action_gen", + packed_seq_len_to_inputs=prefill_packed, + requires_cfg=False, + labels=["main"], + compile=True, + causal_attention=False, + capture_batch_sizes=self.PREFILL_CAPTURE_BATCH_SIZES, + ), + ] + + def prepare_inputs( + self, + graph_walk: str, + fwd_info: CurrentForwardPassInfo, + inputs: NameToTensorList, + **kwargs + ) -> ARNodeInputs: + return self._prepare_inputs_prefill( + inputs=inputs, + fwd_info=fwd_info, + ) + + def _prepare_inputs_prefill( + self, + inputs: NameToTensorList, + **kwargs + ) -> ARNodeInputs: + # Prefix layout [image_tokens, language_tokens]. Robot state is not a + # separate stream — process_prompt already appended it as a decimal + # suffix on the prompt. Image and text embeds use different scales (see + # __init__); applying them here is load-bearing. + img_emb = inputs["img_emb"][0] * self._image_embed_scale + text_ids = inputs["text_inputs"][0] + text_emb = self._embed_tokens_scaled(text_ids) + prefix_emb = torch.cat([img_emb, text_emb], dim=0) + seq_len = prefix_emb.shape[0] + + return ARNodeInputs(input_embeds=prefix_emb, input_seq_len=seq_len) + + + def preprocess( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + inputs: list[ARNodeInputs], + ) -> dict[str, torch.Tensor | Any]: + + return self._preprocess_prefill( + inputs=inputs, + cache_manager=engine_inputs.cache_manager, + ) + + def _preprocess_prefill( + self, + inputs: list[ARNodeInputs], + cache_manager: BatchedCacheManager, + ) -> dict[str, torch.Tensor | Any]: + per_request_seqs = [inp.input_embeds for inp in inputs] + prefix_embs = torch.cat(per_request_seqs, dim=0) + seq_lens = [inp.input_seq_len for inp in inputs] + + # Bidirectional attention over the prefix; PaliGemma is a prefix-LM. + cache_manager.plan_attention( + seq_lens=seq_lens, is_causal=False, label="main", dtype=torch.bfloat16 + ) + cache_manager.plan_rope(seq_lens=seq_lens, pos_ids=None, label="main") + return {"prefix_embs": prefix_embs} + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + def forward( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + **kwargs # coming from preprocess output + ) -> NameToTensorList: + cache_handle=engine_inputs.cache_manager + + return self._forward_prefill(cache_handle=cache_handle, **kwargs) + + def forward_batched( + self, + graph_walk: str, + engine_inputs: ModelInputsFromEngine, + **kwargs, # coming from preprocess output + ) -> dict[str, NameToTensorList]: + """Batched forward: process all requests in a single transformer pass. + + Called by ``KVCacheEngine._execute_batched`` when ``can_batch()`` returns + True. ``packed_inputs`` comes from ``preprocess()`` which already + concatenated per-request tensors and planned attention/RoPE for the + full batch. + """ + + return self._forward_prefill_batched( + cache_manager=engine_inputs.cache_manager, + request_ids=engine_inputs.request_ids, + **kwargs, + ) + + + def _forward_prefill_batched( + self, + cache_manager: BatchedCacheManager, + request_ids: list[str], + prefix_embs: torch.Tensor, + **kwargs, + ) -> dict[str, NameToTensorList]: + """Batched prefill: single PaliGemma forward over concatenated prefixes.""" + cache_manager.set_active_label("main") + self.paligemma( + query_sequence=prefix_embs, + cache_handle=cache_manager, + write_cache=True, + ) + # Prefill produces no graph-edge outputs. + return {rid: {} for rid in request_ids} + + def _forward_prefill( + self, + prefix_embs: torch.Tensor, + cache_handle: BatchedCacheManager, + **kwargs, + ) -> NameToTensorList: + if cache_handle is not None: + cache_handle.set_active_label("main") + self.paligemma( + query_sequence=prefix_embs, + cache_handle=cache_handle, + write_cache=True, + ) + return {} + + def postprocess(self, request_id, request_info, outputs, **kwargs): + outputs["action_expert_trigger"] = [] + + +class Pi05ActionExpertSubmodule(ARNodeSubmodule): + """Action expert flow-matching loop. + + Runs all ``num_flow_steps`` Euler denoising steps over the action suffix + in a single forward, attending to the frozen prefix KV cache that the + PaliGemma submodule wrote, then emits the denoised action tensor. + """ + + # Parameter name fragments whose weights must stay in float32 even when + # the rest of the model is bf16. Matches lerobot's + # ``to_bfloat16_for_selected_params`` — keeping norms in fp32 prevents + # the per-layer precision loss that otherwise compounds across 18 layers + # and causes ~0.2 abs delta on the final actions. + _FLOAT32_PARAM_SELECTORS = ( + "input_layernorm", + "post_attention_layernorm", + ".norm.", # final RMSNorm / adaRMS norm + ) + + ACTION_GEN_CAPTURE_BATCH_SIZES = [1, 2, 4] + + def __init__( + self, action_expert: Pi05ActionExpert, action_in_proj: nn.Linear, action_out_proj: nn.Linear, @@ -293,34 +500,11 @@ def __init__( config: Pi05Config, ): super().__init__() - self.embed_tokens = embed_tokens - self.paligemma = paligemma self.action_expert = action_expert self.action_in_proj = action_in_proj self.action_out_proj = action_out_proj self.time_mlp = time_mlp self.config = config - # Image features and language token embeddings use DIFFERENT scaling - # factors in lerobot's reference, even though both end up calling it - # ``sqrt(hidden_size)``: - # - # * Images: ``embed_image`` returns - # ``connector(siglip_features) * sqrt(hidden_size)`` -> scale = sqrt(H). - # - # * Text: lerobot's ``lang_embed_func`` does - # ``embed_language_tokens(tokens) * sqrt(hidden_size)``, but - # ``embed_language_tokens`` calls HF Gemma's - # ``GemmaTextScaledWordEmbedding`` whose ``forward`` already - # multiplies the raw lookup by an internal ``embed_scale = - # sqrt(hidden_size)``. So the EFFECTIVE text scale is - # ``sqrt(H) * sqrt(H) = H``, not ``sqrt(H)``. - # - # We use a plain ``nn.Embedding`` for ``embed_tokens`` (no internal - # scale), so we have to apply the full ``H`` factor manually here. - # Mismatching this produces a ~45x undersized text prefix and the - # action expert sees a wildly wrong context. - self._image_embed_scale = math.sqrt(config.hidden_size) - self._text_embed_scale = float(config.hidden_size) # Lazily allocated on first action Euler step, sized for the largest # captured batch. sincos_timestep_embedding fully overwrites this buffer @@ -346,16 +530,6 @@ def can_batch( batch: NodeBatch, model_inputs: list[NodeInputs], ) -> bool: - """Pi0.5 supports batched execution for both graph walks. - - - ``prefill``: prefix embeddings are concatenated across requests and - processed in a single PaliGemma forward with batched FlashInfer - attention. Each request can have a different prefix length (different - text prompt lengths). - - ``action_gen``: all requests in a batch are at the same Euler - iteration (guaranteed by the Loop primitive), so their suffix tokens - can be concatenated and processed in a single action expert forward. - """ return True def get_needed_cache_labels( @@ -418,12 +592,6 @@ def get_cuda_graph_configs( self.config.action_horizon, self.config.action_dim, self.config.num_flow_steps, ) - prefill_packed = { - num_tokens: { - "prefix_embs": torch.zeros(num_tokens, self.config.hidden_size, device=device) - } - for num_tokens in self.PREFILL_TOKEN_BUCKETS - } return [ # Action generation always has latents of the same size, so it is a similar # paradigm to AR decode and can use the batched cuda graphs @@ -440,15 +608,6 @@ def get_cuda_graph_configs( ), capture_batch_sizes=self.ACTION_GEN_CAPTURE_BATCH_SIZES ), - FlashInferPackedCudaGraphConfig( - capture_graph_walk="prefill", - packed_seq_len_to_inputs=prefill_packed, - requires_cfg=False, - labels=["main"], - compile=True, - causal_attention=False, - capture_batch_sizes=self.PREFILL_CAPTURE_BATCH_SIZES, - ), ] def prepare_inputs( @@ -458,46 +617,10 @@ def prepare_inputs( inputs: NameToTensorList, **kwargs ) -> ARNodeInputs: - if graph_walk == "prefill": - return self._prepare_inputs_prefill( - inputs=inputs, - ) - if graph_walk == "action_gen": - return self._prepare_inputs_action_gen( - inputs=inputs, - fwd_info=fwd_info, - ) - raise ValueError(f"Unknown Pi0.5 LLM graph walk: {graph_walk!r}") - - def _prepare_inputs_prefill( - self, - inputs: NameToTensorList, - **kwargs - ) -> ARNodeInputs: - # Pi0.5 prefix layout (matches lerobot's embed_prefix): - # [image_tokens, language_tokens] - # The robot state is *not* a separate token stream — it has already - # been formatted as a decimal-string suffix on the language prompt - # by ``Pi05Model.process_prompt``, then tokenized by the PaliGemma - # tokenizer. So the LLM only consumes ``img_emb`` + ``text_inputs``. - # - # IMPORTANT: lerobot's ``embed_prefix`` scales BOTH the image features - # (after the multi_modal_projector) and the language token embeddings - # by ``sqrt(hidden_size)``. We mirror that here. Without the image - # scaling the SigLIP tokens come in ~sqrt(2048)≈45x too small relative - # to the language tokens and the action expert sees a wildly wrong - # prefix. (The standalone test_pi05_model_loaded_via_remapper_matches_ - # lerobot integration test missed this because it bypasses - # _preprocess_prefill and feeds in lerobot's pre-scaled embed_prefix - # output directly.) - - img_emb = inputs["img_emb"][0] * self._image_embed_scale - text_ids = inputs["text_inputs"][0] - text_emb = self._embed_tokens_scaled(text_ids) - prefix_emb = torch.cat([img_emb, text_emb], dim=0) - seq_len = prefix_emb.shape[0] - - return ARNodeInputs(input_embeds=prefix_emb, input_seq_len=seq_len) + return self._prepare_inputs_action_gen( + inputs=inputs, + fwd_info=fwd_info, + ) def _prepare_inputs_action_gen( self, @@ -524,41 +647,16 @@ def _prepare_inputs_action_gen( } ) - def preprocess( self, graph_walk: str, engine_inputs: ModelInputsFromEngine, inputs: list[ARNodeInputs], ) -> dict[str, torch.Tensor | Any]: - - if graph_walk == "prefill": - return self._preprocess_prefill( - inputs=inputs, - cache_manager=engine_inputs.cache_manager, - ) - if graph_walk == "action_gen": - return self._preprocess_action_gen( - inputs=inputs, - cache_manager=engine_inputs.cache_manager, - ) - - def _preprocess_prefill( - self, - inputs: list[ARNodeInputs], - cache_manager: BatchedCacheManager, - ) -> dict[str, torch.Tensor | Any]: - per_request_seqs = [inp.input_embeds for inp in inputs] - prefix_embs = torch.cat(per_request_seqs, dim=0) - seq_lens = [inp.input_seq_len for inp in inputs] - - # Bidirectional attention over the prefix; PaliGemma is a prefix-LM. - cache_manager.plan_attention( - seq_lens=seq_lens, is_causal=False, label="main", dtype=torch.bfloat16 + return self._preprocess_action_gen( + inputs=inputs, + cache_manager=engine_inputs.cache_manager, ) - cache_manager.plan_rope(seq_lens=seq_lens, pos_ids=None, label="main") - - return {"prefix_embs": prefix_embs} def _preprocess_action_gen( self, @@ -609,12 +707,7 @@ def forward( **kwargs # coming from preprocess output ) -> NameToTensorList: cache_handle=engine_inputs.cache_manager - - if graph_walk == "prefill": - return self._forward_prefill(cache_handle=cache_handle, **kwargs) - if graph_walk == "action_gen": - return self._forward_action_gen(cache_handle=cache_handle, **kwargs) - raise ValueError(f"Unknown Pi0.5 LLM graph walk: {graph_walk!r}") + return self._forward_action_gen(cache_handle=cache_handle, **kwargs) def forward_batched( self, @@ -622,45 +715,11 @@ def forward_batched( engine_inputs: ModelInputsFromEngine, **kwargs, # coming from preprocess output ) -> dict[str, NameToTensorList]: - """Batched forward: process all requests in a single transformer pass. - - Called by ``KVCacheEngine._execute_batched`` when ``can_batch()`` returns - True. ``packed_inputs`` comes from ``preprocess()`` which already - concatenated per-request tensors and planned attention/RoPE for the - full batch. - """ - - if graph_walk == "prefill": - return self._forward_prefill_batched( - cache_manager=engine_inputs.cache_manager, - request_ids=engine_inputs.request_ids, - **kwargs, - ) - if graph_walk == "action_gen": - return self._forward_action_gen_batched( - cache_manager=engine_inputs.cache_manager, - request_ids=engine_inputs.request_ids, - **kwargs, - ) - raise ValueError(f"Batched forward not supported for graph walk: {graph_walk!r}") - - - def _forward_prefill_batched( - self, - cache_manager: BatchedCacheManager, - request_ids: list[str], - prefix_embs: torch.Tensor, - **kwargs, - ) -> dict[str, NameToTensorList]: - """Batched prefill: single PaliGemma forward over concatenated prefixes.""" - cache_manager.set_active_label("main") - self.paligemma( - query_sequence=prefix_embs, - cache_handle=cache_manager, - write_cache=True, + return self._forward_action_gen_batched( + cache_manager=engine_inputs.cache_manager, + request_ids=engine_inputs.request_ids, + **kwargs, ) - # Prefill produces no graph-edge outputs. - return {rid: {} for rid in request_ids} def _forward_action_gen_batched( self, @@ -694,21 +753,6 @@ def _forward_action_gen_batched( } return result - def _forward_prefill( - self, - prefix_embs: torch.Tensor, - cache_handle: BatchedCacheManager, - **kwargs, - ) -> NameToTensorList: - if cache_handle is not None: - cache_handle.set_active_label("main") - self.paligemma( - query_sequence=prefix_embs, - cache_handle=cache_handle, - write_cache=True, - ) - return {} - def _forward_action_gen( self, noisy_actions, @@ -738,13 +782,6 @@ def _forward_action_gen( time_emb_buffer=time_emb_buffer, cache_handle=cache_handle ) - # We ALWAYS return both loop-back edges, even on the final iteration. - # The Loop primitive (mstar/graph/base.py:Loop) handles the final-iter - # swap automatically: it matches the section's output ``noisy_actions`` - # to the Loop's terminal output (also named ``noisy_actions``, but - # routed to EMIT_TO_CLIENT with ``output_modality="action"``) and - # filters out the ``timestep_index`` loop-back edge. Same convention - # BAGEL's image_gen uses for ``latents`` / ``time_index``. return { "actions": [noisy_actions], } diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index 2e85ec53..a3d5c0f2 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -1511,8 +1511,8 @@ def _thread_outputs_to_speculative( rid_outputs = output_N.per_request_output_tensors.get(rid, {}) ok = True for input_name, _ in speculation.consumed_edges: - tensors = rid_outputs.get(input_name, []) - if not tensors: + tensors = rid_outputs.get(input_name, None) + if tensors is None: ok = False break speculation.node_batch.per_request_input_tensors[rid][input_name] \ From 99118e958c31d21ada27db4abf7f98781d518fb7 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Wed, 10 Jun 2026 02:22:00 +0000 Subject: [PATCH 07/17] IN PROGRESS remove PNG loading overhead --- benchmark/request.py | 29 +++++++++++++- mstar/api_server/data_worker.py | 57 ++++++++++++++++++++++++++- mstar/api_server/entrypoint.py | 69 ++++++++++++++++++++++++++++++++- mstar/model/base.py | 16 +++++++- 4 files changed, 165 insertions(+), 6 deletions(-) diff --git a/benchmark/request.py b/benchmark/request.py index 009646c4..ecb12911 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -17,6 +17,31 @@ from benchmark.base import Bagel, Model, Orpheus, RequestType, Status from benchmark.utils import _write_wav +# Optional: send images as a raw uint8 CxHxW .npy instead of PNG (MMINF_BENCH_RAW=1). +# PNG decode in the data worker is ~2ms/image (zlib inflate + unfiltering), +# independent of resolution; a raw array is np.load'd for ~0. Lossless — the +# bytes are the decoded pixels (bit-identical to the server's torchvision +# decode), self-described by the .npy header. Trades a larger upload (~173KB vs +# ~95KB PNG) for ~zero server-side decode. The server's load_image sniffs the +# .npy magic bytes, so only the client bytes change. +_BENCH_RAW = os.environ.get("MMINF_BENCH_RAW", "") not in ("", "0", "false") + + +def _maybe_raw(data: bytes) -> bytes: + """Re-encode PNG/JPEG bytes as a raw CxHxW uint8 .npy when MMINF_BENCH_RAW set.""" + if not _BENCH_RAW: + return data + import numpy as np + import torch + import torchvision + + chw = torchvision.io.decode_image( + torch.frombuffer(bytearray(data), dtype=torch.uint8) + ).numpy() # CxHxW uint8 — matches the server's decode exactly + buf = io.BytesIO() + np.save(buf, chw) + return buf.getvalue() + @dataclass class LatencyStats: @@ -623,7 +648,7 @@ class RequestInput: def __post_init__(self): if self.image_path and self._image_bytes is None: - self._image_bytes = Path(self.image_path).read_bytes() + self._image_bytes = _maybe_raw(Path(self.image_path).read_bytes()) self._image_b64 = base64.b64encode(self._image_bytes).decode() if self.audio_path and self._audio_bytes is None: self._audio_bytes = Path(self.audio_path).read_bytes() @@ -632,7 +657,7 @@ def __post_init__(self): self._video_bytes = Path(self.video_path).read_bytes() self._video_b64 = base64.b64encode(self._video_bytes).decode() if self.extra_image_paths and not self._extra_image_bytes: - self._extra_image_bytes = [Path(p).read_bytes() for p in self.extra_image_paths] + self._extra_image_bytes = [_maybe_raw(Path(p).read_bytes()) for p in self.extra_image_paths] def get_all_filepaths(self) -> dict[str, str]: res = {} diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index c7dabb7b..8225cbc6 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -1,6 +1,7 @@ import logging +import os import queue import threading import time @@ -30,6 +31,17 @@ logger = logging.getLogger(__name__) +# Lightweight, env-gated timing prints (MMINF_TIMING=1). perf_counter is +# process-wide monotonic, so timestamps stamped in the API-server handler +# thread and read in this data-worker thread are directly comparable — that's +# how queue-wait (polling) latency is separated from actual work below. +_TIMING = os.environ.get("MMINF_TIMING", "") not in ("", "0", "false") + + +def _tlog(msg: str) -> None: + if _TIMING: + print(f"[DW-TIMING] {msg}", flush=True) + def _preprocess_loop(**kwargs): worker = PreprocessWorkerThread(**kwargs) @@ -75,6 +87,7 @@ def __init__( self.thread.start() def new_request(self, input: PreprocessInput): + input._t_enqueue = time.perf_counter() # for queue-wait timing self.output_loop_idxs[input.request_id] = {} self.per_request_reading_tensors[input.request_id] = 0 self.request_input_queue.put(input) @@ -157,6 +170,7 @@ def __init__( self.model = model self.tensor_uuid_to_metadata_per_request = {} + self._t_read_start: dict[str, float] = {} # request_id -> read-start time self.communicator = ZMQCommunicator( my_id="api_server_preprocess_worker", @@ -176,6 +190,8 @@ def __init__( def _process_input( self, input: PreprocessInput ): + _t0 = time.perf_counter() + _enq = getattr(input, "_t_enqueue", None) tensors: NameToTensorList = {} input_metadata = {} @@ -208,6 +224,8 @@ def _process_input( input_metadata[key].append(out.metadata) + _t_load = time.perf_counter() # media decode (load_image/audio/video) done + # Then, tokenize the prompt and let the model augment/transform the # tensors dict (e.g., Qwen3-Omni needs to compute pixel_values, # image_grid_thw, audio_features, audio_seqlens from the raw tensors @@ -231,6 +249,8 @@ def _process_input( list(byte_data), dtype=torch.uint8, device=self.device )] + _t_prompt = time.perf_counter() # tokenization / process_prompt done + initial_signals = self.tensor_manager.store_and_return_tensor_info( request_id=input.request_id, tensors=tensors # dict(modality_input: list[tensors]) @@ -248,6 +268,8 @@ def _process_input( input.request_id, uuid, persist=True ) + _t_store = time.perf_counter() # tensor store/register/persist done + msg = ConductorMessage( message_type=ConductorMessageType.NEW_REQUEST, body=NewRequestConductor( @@ -260,10 +282,26 @@ def _process_input( ), ) self.communicator.send("conductor", msg) + if _TIMING: + _t_send = time.perf_counter() + _qwait = (_t0 - _enq) * 1e3 if _enq is not None else -1.0 + _imgs = tensors.get("image_inputs") or [] + _img_shape = tuple(_imgs[0].shape) if _imgs else None + _tlog( + f"{input.request_id[:8]} INPUT " + f"img={_img_shape}x{len(_imgs)} " # decoded shape x count (decode cost driver) + f"qwait={_qwait:.2f} " # enqueue->dequeue (polling) + f"load={(_t_load - _t0) * 1e3:.2f} " # media decode + f"prompt={(_t_prompt - _t_load) * 1e3:.2f} " # tokenize + f"store={(_t_store - _t_prompt) * 1e3:.2f} " # tensor store/register + f"send={(_t_send - _t_store) * 1e3:.2f} " # zmq send to conductor + f"total={(_t_send - _t0) * 1e3:.2f}ms" + ) def _read_result_tensor( self, result: ResultTensors ): + self._t_read_start[result.request_id] = time.perf_counter() result.graph_edge.name = f"{result.modality}_output" self.tensor_manager.start_read_tensors( request_id=result.request_id, @@ -279,18 +317,31 @@ def _process_read_tensors(self): did_work = False for request_id, graph_edges in self.tensor_manager.get_ready_tensors().items(): did_work = True + _t_ready = time.perf_counter() # tensor became ready (RDMA read done) + _read_start = self._t_read_start.pop(request_id, None) for graph_edge in graph_edges: modality = graph_edge.name.replace("_output", "") for tensor_info in graph_edge.tensor_info: logger.debug("Reading in OUTPUT tensor %s with uuid %s", graph_edge.name, tensor_info.uuid) + _t_a = time.perf_counter() tensor = self.tensor_manager.get_tensor( request_id=request_id, uuid=tensor_info.uuid ) + _t_get = time.perf_counter() postprocessed = self.model.postprocess( tensor, modality ) + _t_post = time.perf_counter() + if _TIMING: + _rw = (_t_ready - _read_start) * 1e3 if _read_start else -1.0 + _tlog( + f"{request_id[:8]} OUTPUT " + f"read_wait={_rw:.2f} " # start_read -> ready (RDMA + polling) + f"get={(_t_get - _t_a) * 1e3:.2f} " # fetch tensor handle + f"post={(_t_post - _t_get) * 1e3:.2f}ms" # model.postprocess + ) chunk_metadata = self.tensor_uuid_to_metadata_per_request[request_id][ tensor_info.uuid] or {} @@ -302,12 +353,14 @@ def _process_read_tensors(self): "sample_rate": self.model.get_output_sample_rate("audio"), } - self.out_queue.put(ResultChunk( + _chunk = ResultChunk( request_id=request_id, modality=modality, data=postprocessed, metadata=chunk_metadata, - )) + ) + _chunk._t_outqueue = time.perf_counter() + self.out_queue.put(_chunk) del self.tensor_uuid_to_metadata_per_request[request_id][ tensor_info.uuid] self.tensor_manager.dereference( diff --git a/mstar/api_server/entrypoint.py b/mstar/api_server/entrypoint.py index d71611b2..40f20410 100644 --- a/mstar/api_server/entrypoint.py +++ b/mstar/api_server/entrypoint.py @@ -15,7 +15,7 @@ from typing import Optional import uvicorn -from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from starlette.concurrency import run_in_threadpool @@ -28,6 +28,16 @@ logger = logging.getLogger(__name__) +# Env-gated timing prints (MMINF_TIMING=1); pairs with the [DW-TIMING] prints +# in data_worker.py to split HTTP/handler overhead from data-worker work. +_TIMING = os.environ.get("MMINF_TIMING", "") not in ("", "0", "false") + + +def _tlog(msg: str) -> None: + if _TIMING: + print(f"[API-TIMING] {msg}", flush=True) + + SUPPORTED_MODALITIES = frozenset({"text", "image", "audio", "video", "action", "scalar", "tensor"}) # Extension-based modality detection for uploaded files. @@ -254,6 +264,7 @@ def submit_request( "input_modalities": input_modalities, "output_modalities": output_modalities, "final_outputs": {}, + "_t_submit": time.perf_counter(), # for end-to-end wait timing } self.preprocess_worker.new_request(PreprocessInput( @@ -335,6 +346,14 @@ def _process_messages(self) -> None: result_chunk.modality, result_chunk.request_id ) rid = result_chunk.request_id + if _TIMING: + _oq = getattr(result_chunk, "_t_outqueue", None) + if _oq is not None: + _tlog( + f"{rid[:8]} CHUNK " + # out_queue.put -> picked up here (output polling hop) + f"outq_wait={(time.perf_counter() - _oq) * 1e3:.2f}ms" + ) with self.request_lock: self.pending_requests[rid]["chunks"].append( result_chunk @@ -360,6 +379,10 @@ async def iter_result_chunks(self, request_id: str): pre-serialized line). """ start = time.time() + with self.request_lock: + _req0 = self.pending_requests.get(request_id) + _t_submit = _req0["_t_submit"] if _req0 else None + _t_first = None while True: if time.time() - start > self.timeout_seconds: with self.request_lock: @@ -380,9 +403,19 @@ async def iter_result_chunks(self, request_id: str): done = True for chunk in new_chunks: + if _t_first is None: + _t_first = time.perf_counter() yield chunk if done: + if _TIMING and _t_submit is not None: + _now = time.perf_counter() + _tlog( + f"{request_id[:8]} STREAM " + # submit -> first chunk delivered (full worker round-trip) + f"ttfc={(_t_first - _t_submit) * 1e3 if _t_first else -1:.2f} " + f"total={(_now - _t_submit) * 1e3:.2f}ms" # submit -> done + ) logger.info("Async stream results received finish for %s", request_id) # flush remaining remaining: list[ResultChunk] = [] @@ -460,6 +493,18 @@ def cleanup(self) -> None: allow_headers=["*"], ) + +@app.middleware("http") +async def _stamp_recv_time(request: Request, call_next): + # Stamp ASGI request arrival. The gap to the handler body (_t_in) covers + # routing + multipart form parsing (FastAPI reads the upload bodies while + # resolving the File()/Form() params, before the handler runs) — that's the + # HTTP-side overhead not visible in the [DW-TIMING]/STREAM brackets. + if _TIMING: + request.state._t_recv = time.perf_counter() + return await call_next(request) + + api_server: APIServer | None = None # Mount the OpenAI-compatible routes (/v1/*) alongside the native /generate. @@ -472,6 +517,7 @@ def cleanup(self) -> None: @app.post("/generate") async def generate( + request: Request, text: Optional[str] = Form(None), files: Optional[list[UploadFile]] = File(None), input_modalities: Optional[str] = Form(None), @@ -500,6 +546,12 @@ async def generate( if api_server is None: raise HTTPException(status_code=503, detail="Server not ready") + _t_in = time.perf_counter() + if _TIMING: + _recv = getattr(request.state, "_t_recv", None) + if _recv is not None: + # ASGI receive -> handler body = routing + multipart parse (HTTP-side) + _tlog(f"PREHDLR parse={(_t_in - _recv) * 1e3:.2f}ms") out_mods = [m.strip() for m in output_modalities.split(",") if m.strip()] # --- save uploaded files, grouped by modality ---------------- @@ -528,6 +580,7 @@ async def generate( in_mods.append("text") parsed_kwargs = json.loads(model_kwargs) if model_kwargs else None + _t_files = time.perf_counter() # multipart read + disk save done try: request_id = api_server.submit_request( @@ -539,6 +592,12 @@ async def generate( streaming=streaming, request_id=request_id, ) + if _TIMING: + _tlog( + f"{request_id[:8]} HANDLER " + f"files={(_t_files - _t_in) * 1e3:.2f} " # multipart read + disk write + f"submit={(time.perf_counter() - _t_files) * 1e3:.2f}ms" # submit_request + ) if streaming: return StreamingResponse( @@ -550,12 +609,20 @@ async def generate( chunks = await run_in_threadpool( api_server.collect_results, request_id ) + _t_results = time.perf_counter() outputs: dict[str, list[dict]] = {} for chunk in chunks: outputs.setdefault(chunk.modality, []).append({ "data": base64.b64encode(chunk.data).decode("ascii"), "metadata": chunk.metadata, }) + if _TIMING: + _tlog( + f"{request_id[:8]} BLOCKING " + f"wait={(_t_results - _t_files) * 1e3:.2f} " # submit -> all results in + f"serialize={(time.perf_counter() - _t_results) * 1e3:.2f} " # b64 + json + f"total={(time.perf_counter() - _t_in) * 1e3:.2f}ms" + ) return JSONResponse({ "request_id": request_id, "outputs": outputs, diff --git a/mstar/model/base.py b/mstar/model/base.py index 54a7e90d..1ad6d2ab 100644 --- a/mstar/model/base.py +++ b/mstar/model/base.py @@ -5,6 +5,7 @@ import torch import yaml +import numpy as np from mstar.communication.tensors import NameToTensorList from mstar.conductor.request_info import ( @@ -377,9 +378,22 @@ def process_prompt( pass def load_image(self, filepath: str, device: str) -> TensorAndMetadata: + import io + import torchvision - img = torchvision.io.decode_image(filepath).to(device) # uint8 CxHxW + # Read the file once, then dispatch on content: a raw uint8 CxHxW array + # uploaded as .npy (np.save magic = b"\x93NUMPY") skips PNG/JPEG decode + # entirely (np.load is ~a memcpy); anything else goes through torchvision. + # Sniffing the magic (not the extension) keeps the upload filename free. + with open(filepath, "rb") as f: + raw = f.read() + if raw[:6] == b"\x93NUMPY": + img = torch.from_numpy(np.load(io.BytesIO(raw))).to(device) # uint8 CxHxW + else: + img = torchvision.io.decode_image( + torch.frombuffer(bytearray(raw), dtype=torch.uint8) + ).to(device) # uint8 CxHxW img = img.float() / 255.0 return TensorAndMetadata(img) From 03da3ad46a767d4491c091a3934bd6976b6e283c Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Wed, 10 Jun 2026 04:54:20 +0000 Subject: [PATCH 08/17] cleanup --- benchmark/dataset.py | 86 +++++++++++++++++++++---------- benchmark/request.py | 48 +++++++---------- mstar/api_server/data_worker.py | 14 +++++ mstar/api_server/entrypoint.py | 13 ++++- mstar/api_server/request_types.py | 5 ++ mstar/model/base.py | 12 ++--- mstar/model/pi05/pi05_model.py | 9 ++++ 7 files changed, 121 insertions(+), 66 deletions(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 46d1da73..039bfdd2 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -607,6 +607,39 @@ def _decode_frames_to_png_and_video( VideoEncoder(frames=tensors, frame_rate=fps).to_file(mp4_path) +def _resize_with_pad(chw, size: int): + """Aspect-preserving letterbox of a (C, H, W) uint8 tensor to size x size. + + Scales the longer side to ``size`` and pads the shorter with black (0). + Mirrors the server's ``Pi05ViTEncoderSubmodule._prepare_one`` geometry, so + sending the pre-resized frame produces the same model input as decoding at + native resolution and resizing on the worker. + """ + import torch + import torch.nn.functional as F + + _, h, w = chw.shape + if (h, w) == (size, size): + return chw + ratio = max(w / size, h / size) + rh, rw = int(h / ratio), int(w / ratio) + x = F.interpolate(chw[None].float(), size=(rh, rw), mode="bilinear", align_corners=False) + ph0, remh = divmod(size - rh, 2) + pw0, remw = divmod(size - rw, 2) + x = F.pad(x, (pw0, pw0 + remw, ph0, ph0 + remh), value=0.0) + return x[0].round().clamp(0, 255).to(torch.uint8) + + +def _decode_frame_to_npy(video_path: str, frame_index: int, npy_path: str, size: int) -> None: + """Decode one frame, letterbox-resize to ``size`` x ``size``, save as a + (C, H, W) uint8 ``.npy`` (the "numpy" upload the server np.loads in memory).""" + import numpy as np + from torchcodec.decoders import VideoDecoder + + frame = VideoDecoder(video_path).get_frames_at(indices=[frame_index]).data[0] # (C,H,W) uint8 + np.save(npy_path, _resize_with_pad(frame, size).numpy()) + + class DROIDDataset(BaseDataset): """DROID robotics dataset for evaluating pi0.5 and V-JEPA 2-AC. @@ -631,6 +664,9 @@ class DROIDDataset(BaseDataset): """ HF_REPO = "lerobot/droid_100" + # pi05 camera frames are letterboxed to this size client-side (matches the + # server's vit_image_size) so both mminf and openpi get identical input. + IMAGE_SIZE = 224 def __init__( self, @@ -792,22 +828,21 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, local_indices = self._local_frame_indices(frames) first_local = local_indices[0] - image_paths: list[str] = [] + # Decode + letterbox-resize each camera frame to 224x224 uint8 and save + # as a ".npy" (the "numpy" modality). Sending pre-resized arrays lets the + # server skip both image decode and the resize, and lets us hand mminf + # and openpi identical input. + npy_paths: list[str] = [] for cam_key in camera_keys[:3]: chunk_video = download_fn(self._chunk_video_path(ep_id, cam_key, chunks_size)) - png_path = os.path.join(self.local_file_dir, f"ep{ep_id}_cam{len(image_paths)}.png") - _decode_frames_to_png_and_video( - video_path=chunk_video, - frame_indices=[first_local], - png_path=png_path, - mp4_path=None - ) - image_paths.append(png_path) + npy_path = os.path.join(self.local_file_dir, f"ep{ep_id}_cam{len(npy_paths)}.npy") + _decode_frame_to_npy(chunk_video, first_local, npy_path, self.IMAGE_SIZE) + npy_paths.append(npy_path) - if not image_paths: + if not npy_paths: raise ValueError("no camera videos found") - while len(image_paths) < 3: - image_paths.append(image_paths[0]) + while len(npy_paths) < 3: + npy_paths.append(npy_paths[0]) state = _to_float_list( frames[0].get(state_col) if state_col else None, self.action_dim @@ -815,10 +850,8 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, return RequestInput( req_type=RequestType.VLA, prompt=language or "manipulate the object", - image_path=image_paths[0], - # openpi droid policy only uses the first extra image path! So, to be consistent - # we emit the remainder entirely from bechmarking - extra_image_paths=image_paths[1:2], + # openpi droid policy only uses the first extra image, so send 2 cameras. + _numpy_paths=npy_paths[:2], model_kwargs={"robot_state": state}, ) @@ -877,13 +910,13 @@ def _manifest_path(self) -> str: """Manifest filename keyed by the params that change the built items.""" return os.path.join( self.local_file_dir, - f"manifest_pi05_nvf{self.num_video_frames}_ad{self.action_dim}.json", + f"manifest_pi05_npy{self.IMAGE_SIZE}_nvf{self.num_video_frames}_ad{self.action_dim}.json", ) def _load_manifest(self) -> list[RequestInput] | None: """Return cached pi05 RequestInputs, or None to force a rebuild. - Returns None if the manifest is absent, unreadable, or references a PNG + Returns None if the manifest is absent, unreadable, or references a .npy that no longer exists on disk. """ import json as _json @@ -896,18 +929,16 @@ def _load_manifest(self) -> list[RequestInput] | None: data = _json.load(f) items: list[RequestInput] = [] for entry in data["items"]: - img = os.path.join(self.local_file_dir, entry["image_path"]) - extra = [os.path.join(self.local_file_dir, p) - for p in entry.get("extra_image_paths", [])] - for p in (img, *extra): + npy_paths = [os.path.join(self.local_file_dir, p) + for p in entry.get("numpy_paths", [])] + for p in npy_paths: if not os.path.exists(p): print(f" [cache] missing {p}; rebuilding") return None items.append(RequestInput( req_type=RequestType.VLA, prompt=entry["prompt"], - image_path=img, - extra_image_paths=extra, + _numpy_paths=npy_paths, model_kwargs=entry.get("model_kwargs", {}), )) return items or None @@ -918,7 +949,7 @@ def _load_manifest(self) -> list[RequestInput] | None: def _save_manifest(self, items: list[RequestInput]) -> None: """Persist built pi05 items so the next run can skip parquet + decode. - PNG paths are stored as basenames (relative to local_file_dir) and the + .npy paths are stored as basenames (relative to local_file_dir) and the write is atomic (tmp + os.replace) so an interrupted run never leaves a half-written manifest that would later be reused. """ @@ -926,12 +957,11 @@ def _save_manifest(self, items: list[RequestInput]) -> None: entries = [{ "prompt": it.prompt, - "image_path": os.path.basename(it.image_path), - "extra_image_paths": [os.path.basename(p) for p in it.extra_image_paths], + "numpy_paths": [os.path.basename(p) for p in it._numpy_paths], "model_kwargs": it.model_kwargs, } for it in items] payload = { - "version": 1, + "version": 2, "task": "pi05", "num_video_frames": self.num_video_frames, "action_dim": self.action_dim, diff --git a/benchmark/request.py b/benchmark/request.py index ecb12911..5da3a6e0 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -17,31 +17,6 @@ from benchmark.base import Bagel, Model, Orpheus, RequestType, Status from benchmark.utils import _write_wav -# Optional: send images as a raw uint8 CxHxW .npy instead of PNG (MMINF_BENCH_RAW=1). -# PNG decode in the data worker is ~2ms/image (zlib inflate + unfiltering), -# independent of resolution; a raw array is np.load'd for ~0. Lossless — the -# bytes are the decoded pixels (bit-identical to the server's torchvision -# decode), self-described by the .npy header. Trades a larger upload (~173KB vs -# ~95KB PNG) for ~zero server-side decode. The server's load_image sniffs the -# .npy magic bytes, so only the client bytes change. -_BENCH_RAW = os.environ.get("MMINF_BENCH_RAW", "") not in ("", "0", "false") - - -def _maybe_raw(data: bytes) -> bytes: - """Re-encode PNG/JPEG bytes as a raw CxHxW uint8 .npy when MMINF_BENCH_RAW set.""" - if not _BENCH_RAW: - return data - import numpy as np - import torch - import torchvision - - chw = torchvision.io.decode_image( - torch.frombuffer(bytearray(data), dtype=torch.uint8) - ).numpy() # CxHxW uint8 — matches the server's decode exactly - buf = io.BytesIO() - np.save(buf, chw) - return buf.getvalue() - @dataclass class LatencyStats: @@ -629,6 +604,11 @@ class RequestInput: # All paths are uploaded as separate "files" form fields alongside image_path. extra_image_paths: list[str] = field(default_factory=list) + # Pre-decoded ".npy" uploads (the "numpy" modality): paths to raw uint8 + # arrays the server np.loads in memory (no disk, no decode). Used by pi0.5 + # (resized 224x224 camera frames); each path is one camera. + _numpy_paths: list[str] = field(default_factory=list) + # Per-request model_kwargs merged into the JSON payload at send time. # Use this for robotics-specific fields: robot_state, actions, states, # rollout_horizon, etc. @@ -645,10 +625,11 @@ class RequestInput: _audio_b64: Optional[str] = field(default=None, repr=False) _video_b64: Optional[str] = field(default=None, repr=False) _extra_image_bytes: list[bytes] = field(default_factory=list, repr=False) + _numpy_bytes: list[bytes] = field(default_factory=list, repr=False) def __post_init__(self): if self.image_path and self._image_bytes is None: - self._image_bytes = _maybe_raw(Path(self.image_path).read_bytes()) + self._image_bytes = Path(self.image_path).read_bytes() self._image_b64 = base64.b64encode(self._image_bytes).decode() if self.audio_path and self._audio_bytes is None: self._audio_bytes = Path(self.audio_path).read_bytes() @@ -657,7 +638,9 @@ def __post_init__(self): self._video_bytes = Path(self.video_path).read_bytes() self._video_b64 = base64.b64encode(self._video_bytes).decode() if self.extra_image_paths and not self._extra_image_bytes: - self._extra_image_bytes = [_maybe_raw(Path(p).read_bytes()) for p in self.extra_image_paths] + self._extra_image_bytes = [Path(p).read_bytes() for p in self.extra_image_paths] + if self._numpy_paths and not self._numpy_bytes: + self._numpy_bytes = [Path(p).read_bytes() for p in self._numpy_paths] def get_all_filepaths(self) -> dict[str, str]: res = {} @@ -764,7 +747,16 @@ async def send_request( "files", content, filename=os.path.basename(path), - content_type="image/png", + content_type="application/octet-stream", + ) + # Pre-decoded ".npy" uploads (numpy modality): the server keeps these + # in memory and np.loads them — no disk, no decode (pi0.5 cameras). + for path, content in zip(req_input._numpy_paths, req_input._numpy_bytes): + form.add_field( + "files", + content, + filename=os.path.basename(path), + content_type="application/octet-stream", ) metrics.start_time = time.monotonic() diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index 8225cbc6..c6b6d65d 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -223,6 +223,20 @@ def _process_input( tensors[key].append(out.data) input_metadata[key].append(out.metadata) + # ".npy" uploads (modality "numpy") are kept in memory and np.load'd + # here as "raw_inputs"; the model maps them in process_prompt. + if input.numpy_bytes: + import io as _io + + import numpy as np + + tensors["raw_inputs"] = [] + input_metadata["raw_inputs"] = [] + for blob in input.numpy_bytes: + tensors["raw_inputs"].append( + torch.from_numpy(np.load(_io.BytesIO(blob))).to(self.device) + ) + input_metadata["raw_inputs"].append({}) _t_load = time.perf_counter() # media decode (load_image/audio/video) done diff --git a/mstar/api_server/entrypoint.py b/mstar/api_server/entrypoint.py index 40f20410..72794aab 100644 --- a/mstar/api_server/entrypoint.py +++ b/mstar/api_server/entrypoint.py @@ -46,6 +46,7 @@ def _tlog(msg: str) -> None: "image": (".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".gif"), "audio": (".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac"), "video": (".mp4", ".avi", ".mov", ".mkv", ".webm"), + "numpy": (".npy",) }.items(): for _ext in _exts: _EXT_TO_MODALITY[_ext] = _mod @@ -234,6 +235,7 @@ def submit_request( *, text: str | None = None, file_paths: dict[str, list[str]] | None = None, + numpy_bytes: list[bytes] | None = None, input_modalities: list[str], output_modalities: list[str], model_kwargs: dict | None = None, @@ -271,6 +273,7 @@ def submit_request( request_id=request_id, text=text, file_paths=file_paths, + numpy_bytes=numpy_bytes, input_modalities=input_modalities, output_modalities=output_modalities, model_kwargs=model_kwargs @@ -555,7 +558,11 @@ async def generate( out_mods = [m.strip() for m in output_modalities.split(",") if m.strip()] # --- save uploaded files, grouped by modality ---------------- + # The "numpy" modality (.npy) is kept in memory and np.load'd by the data + # worker; image/audio/video are written to disk so their decoders work from + # a file (PNG/mp4 decode prefers a path). file_paths: dict[str, list[str]] = {} + numpy_bytes: list[bytes] = [] if files: for f in files: modality = _detect_modality(f.filename or "") @@ -564,9 +571,12 @@ async def generate( status_code=400, detail=f"Cannot determine modality for file: {f.filename}", ) + content = await f.read() + if modality == "numpy": + numpy_bytes.append(content) + continue save_name = f"{uuid.uuid4()}_{f.filename}" save_path = api_server.upload_dir / save_name - content = await f.read() await run_in_threadpool(save_path.write_bytes, content) file_paths.setdefault(modality, []).append(str(save_path)) @@ -586,6 +596,7 @@ async def generate( request_id = api_server.submit_request( text=text, file_paths=file_paths or None, + numpy_bytes=numpy_bytes or None, input_modalities=in_mods, output_modalities=out_mods, model_kwargs=parsed_kwargs, diff --git a/mstar/api_server/request_types.py b/mstar/api_server/request_types.py index 23c5c899..e7bfa415 100644 --- a/mstar/api_server/request_types.py +++ b/mstar/api_server/request_types.py @@ -49,3 +49,8 @@ class PreprocessInput: input_modalities: list[str] output_modalities: list[str] model_kwargs: dict + + # In-memory uploads for the "numpy" modality (.npy): the bytes are NOT + # written to disk (unlike images/audio/video), so the data worker np.loads + # them directly. Each entry is one .npy blob (e.g. one camera frame). + numpy_bytes: list[bytes] | None = None diff --git a/mstar/model/base.py b/mstar/model/base.py index 1ad6d2ab..731ce6d2 100644 --- a/mstar/model/base.py +++ b/mstar/model/base.py @@ -5,7 +5,6 @@ import torch import yaml -import numpy as np from mstar.communication.tensors import NameToTensorList from mstar.conductor.request_info import ( @@ -378,8 +377,6 @@ def process_prompt( pass def load_image(self, filepath: str, device: str) -> TensorAndMetadata: - import io - import torchvision # Read the file once, then dispatch on content: a raw uint8 CxHxW array @@ -388,12 +385,9 @@ def load_image(self, filepath: str, device: str) -> TensorAndMetadata: # Sniffing the magic (not the extension) keeps the upload filename free. with open(filepath, "rb") as f: raw = f.read() - if raw[:6] == b"\x93NUMPY": - img = torch.from_numpy(np.load(io.BytesIO(raw))).to(device) # uint8 CxHxW - else: - img = torchvision.io.decode_image( - torch.frombuffer(bytearray(raw), dtype=torch.uint8) - ).to(device) # uint8 CxHxW + img = torchvision.io.decode_image( + torch.frombuffer(bytearray(raw), dtype=torch.uint8) + ).to(device) # uint8 CxHxW img = img.float() / 255.0 return TensorAndMetadata(img) diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index 5475d012..15625aa4 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -457,6 +457,15 @@ def process_prompt( here so the resulting ``text_inputs`` stream matches the production format. """ + # A "numpy" upload arrives as "raw_inputs"; Pi0.5 treats it as the image. + tensors = kwargs.get("tensors") + if tensors is not None and "raw_inputs" in tensors: + assert "image_inputs" not in tensors, "got both raw_inputs and image_inputs" + tensors["image_inputs"] = tensors.pop("raw_inputs") + input_metadata = kwargs.get("input_metadata") + if input_metadata is not None and "raw_inputs" in input_metadata: + input_metadata["image_inputs"] = input_metadata.pop("raw_inputs") + if self.tokenizer is None: # Tokenizer-less fallback used by structural unit tests. if prompt is not None: From c36c00df211c4a095c565a657b2d8aefed442e31 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Wed, 10 Jun 2026 05:04:45 +0000 Subject: [PATCH 09/17] port over openpi benchmarking --- benchmark/download_pi05_ckpt.py | 38 +++++++++ benchmark/openpi_instructions.md | 17 ++++ benchmark/request.py | 140 ++++++++++++++++++++++++++++++- benchmark/runner.py | 6 ++ 4 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 benchmark/download_pi05_ckpt.py create mode 100644 benchmark/openpi_instructions.md diff --git a/benchmark/download_pi05_ckpt.py b/benchmark/download_pi05_ckpt.py new file mode 100644 index 00000000..91cbe7c2 --- /dev/null +++ b/benchmark/download_pi05_ckpt.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +"""Download the pi0.5 checkpoint and print the local path. + conda activate openpi + python benchmark/download_pi05_ckpt.py +""" + +from __future__ import annotations + +import argparse +import sys + +DEFAULT_CONFIG = "pi05_droid" +DEFAULT_CHECKPOINT = "gs://openpi-assets/checkpoints/pi05_droid" + + +def main(): + p = argparse.ArgumentParser(description="Download pi0.5 checkpoint") + p.add_argument("--config", default=DEFAULT_CONFIG) + p.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) + args = p.parse_args() + + try: + from openpi.shared import download + from openpi.training import config as _config + except ImportError as e: + sys.exit( + f"\n[ERROR] openpi is not importable ({e}).\n" + "Run inside the openpi conda env:\n" + " conda activate openpi\n" + ) + + _config.get_config(args.config) + ckpt_dir = download.maybe_download(args.checkpoint) + print(ckpt_dir) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/openpi_instructions.md b/benchmark/openpi_instructions.md new file mode 100644 index 00000000..0a08e692 --- /dev/null +++ b/benchmark/openpi_instructions.md @@ -0,0 +1,17 @@ +1. Make a python3.12 environment +2. Clone the `openpi` reop +3. Run the following in your ennvironment: +``` +git submodule update --init --recursive +GIT_LFS_SKIP_SMUDGE=1 uv sync +GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . +``` +4. From the mminf repo, run, e.g. for coriander,: +``` +python benchmark/download_pi05_ckpt.py[3:21 PM]mkdir /m-coriander/coriander/naomi/openpi-cache +mv /home/$USER/.cache/openpi/* /m-coriander/coriander/$USER/openpi-cache/ +``` +5. Start the server with: +``` +CUDA_VISIBLE_DEVICES=4 uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=/m-coriander/coriander/$USER/openpi-cache/openpi-assets/checkpoints/pi05_droid +``` \ No newline at end of file diff --git a/benchmark/request.py b/benchmark/request.py index 5da3a6e0..e4412aba 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -14,7 +14,7 @@ import aiohttp import numpy as np -from benchmark.base import Bagel, Model, Orpheus, RequestType, Status +from benchmark.base import Bagel, Model, Orpheus, Pi05, RequestType, Status from benchmark.utils import _write_wav @@ -729,7 +729,9 @@ async def send_request( input_mod = req_type.get_input_modalities() if "," in input_mod or input_mod not in ("text",): # TODO: if a request does not have text as an input modality, this must be revisited - form.add_field("input_modalities", ",".join([input_mod, "text"])) + if "text" not in input_mod: + input_mod = ",".join([input_mod, "text"]) + form.add_field("input_modalities", input_mod) for modality in req_input.get_all_filepaths(): file_content = req_input.get_bytes(modality) @@ -1669,3 +1671,137 @@ async def _send_request_audio_speech( else: metrics.record_completion() return metrics + + +# --------------------------------------------------------------------------- +# openpi: call their own api server +# --------------------------------------------------------------------------- +def _build_obs(req_input: RequestInput) -> dict: + """Map our DROIDDataset RequestInput → openpi DroidInputs dict. + + DROIDDataset emits the camera frames as ``_numpy_bytes`` — already-decoded, + letterboxed 224x224 uint8 arrays (CxHxW), the same bytes mminf receives, so + both systems get identical input. openpi wants (H,W,3) uint8, so we just + transpose; no decode/resize needed here. cam0 → exterior_image_1_left, + cam1 → wrist_image_left. The 32-dim DROID state holds joint positions in + [:7] and gripper in [7]; the rest is padding we ignore. + + openpi DroidInputs (droid_policy.py:make_droid_example) expects: + observation/exterior_image_1_left : (H,W,3) uint8 + observation/wrist_image_left : (H,W,3) uint8 + observation/joint_position : (7,) float + observation/gripper_position : (1,) float + prompt : str + """ + import io + + import numpy as np + + state = np.asarray(req_input.model_kwargs.get("robot_state", []), dtype=np.float32) + if state.size < 8: + state = np.pad(state, (0, 8 - state.size)) + + imgs = [np.load(io.BytesIO(b)).transpose(1, 2, 0) # CxHxW uint8 -> HxWxC + for b in req_input._numpy_bytes] + base_img = imgs[0] + wrist_img = imgs[1] if len(imgs) > 1 else imgs[0] + + return { + "observation/exterior_image_1_left": base_img, + "observation/wrist_image_left": wrist_img, + "observation/joint_position": state[:7], + "observation/gripper_position": state[7:8], + "prompt": req_input.prompt or "manipulate the object", + } + +class OpenPi(InferenceSystem): + async def send_request( + self, + session: aiohttp.ClientSession, + req_input: RequestInput, + base_url: str, + request_id: int, + model: Model, + additional_model_kwargs: dict = {}, + ) -> RequestMetrics: + assert isinstance(model, Pi05), "openpi only supports Pi05 models" + assert req_input.req_type == RequestType.VLA, "openpi only supports VLA requests" + + import numpy as np + from openpi_client import msgpack_numpy + metrics = RequestMetrics( + request_id=request_id, + type=req_input.req_type, + expected_output_modalities=["action"], + ) + + # base_url is expected to be an http(s) URL for consistency with the rest + # of the harness; convert to ws(s) for the websocket handshake. + ws_url = base_url + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[len("http://"):] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[len("https://"):] + + # Build the observation. _build_obs gives us full-res images + a + # 32-dim DROID state; openpi expects 224x224 uint8 images and + # separate joint/gripper vectors, which _build_obs already provides. + obs = _build_obs(req_input) + packer = msgpack_numpy.Packer() + payload = packer.pack(obs) + + try: + metrics.start_time = time.monotonic() + async with session.ws_connect( + ws_url, + max_msg_size=0, # no limit; action chunks are small but obs is large + compress=0, # match the openpi client (compression=None) + timeout=aiohttp.ClientWSTimeout(ws_close=30), + ) as ws: + # Server sends metadata as the first message on connect. + # Drain it; we don't need it for benchmarking, but we MUST read + # it before sending or the server's send buffer can stall. + metadata_msg = await ws.receive() + if metadata_msg.type != aiohttp.WSMsgType.BINARY: + raise RuntimeError( + f"Expected binary metadata frame, got {metadata_msg.type}: " + f"{metadata_msg.data!r}" + ) + _ = msgpack_numpy.unpackb(metadata_msg.data) + + # Send observation, await action chunk. + await ws.send_bytes(payload) + response_msg = await ws.receive() + + if response_msg.type == aiohttp.WSMsgType.TEXT: + # The openpi server signals errors by sending a string. + raise RuntimeError(f"Error in inference server:\n{response_msg.data}") + if response_msg.type != aiohttp.WSMsgType.BINARY: + raise RuntimeError( + f"Unexpected ws frame type {response_msg.type}: " + f"{response_msg.data!r}" + ) + + arrival_time = time.monotonic() + response = msgpack_numpy.unpackb(response_msg.data) + action_chunk = response["actions"] # (action_horizon, action_dim) + + # One-shot output: the entire action chunk arrives at once, + # so TTFT == E2E. Encode the chunk as a single output unit. + # n_tokens = action_horizon so throughput numbers are + # in "actions/sec" if the metrics layer divides by n_tokens. + action_bytes = np.asarray(action_chunk, dtype=np.float32).tobytes() + metrics.record_output_chunk( + modality="action", + data_b64=base64.b64encode(action_bytes), + arrival_time=arrival_time, + n_tokens=int(action_chunk.shape[0]), + ) + + except Exception as e: + metrics.record_error(str(e)) + else: + metrics.record_completion() + + return metrics + diff --git a/benchmark/runner.py b/benchmark/runner.py index c7544c49..fb4e9883 100644 --- a/benchmark/runner.py +++ b/benchmark/runner.py @@ -24,6 +24,7 @@ from benchmark.request import ( AggregateMetrics, InferenceSystem, + OpenPi, OursOpenAI, OurSystem, RequestInput, @@ -52,6 +53,7 @@ class InferenceSystemType(Enum): VLLM_OMNI = "vllm_omni" VOX_SERVE = "vox_serve" SGLANG_OMNI = "sglang_omni" + OPENPI = "openpi" def instantiate(self) -> InferenceSystem: if self == InferenceSystemType.OURS: @@ -64,6 +66,10 @@ def instantiate(self) -> InferenceSystem: return VoxServe() elif self == InferenceSystemType.SGLANG_OMNI: return SGLangOmni() + elif self == InferenceSystemType.OPENPI: + return OpenPi() + else: + raise NotImplementedError("Unknown inference system", self) class ProfilingType(Enum): From 8d97e94caba887ef3f066f2832be83a0cdf4cf5a Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 22:05:30 +0000 Subject: [PATCH 10/17] ruff check fix --- benchmark/download_pi05_ckpt.py | 2 +- benchmark/request.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/download_pi05_ckpt.py b/benchmark/download_pi05_ckpt.py index 91cbe7c2..088cceb8 100644 --- a/benchmark/download_pi05_ckpt.py +++ b/benchmark/download_pi05_ckpt.py @@ -35,4 +35,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/benchmark/request.py b/benchmark/request.py index e4412aba..27f98681 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -753,7 +753,7 @@ async def send_request( ) # Pre-decoded ".npy" uploads (numpy modality): the server keeps these # in memory and np.loads them — no disk, no decode (pi0.5 cameras). - for path, content in zip(req_input._numpy_paths, req_input._numpy_bytes): + for path, content in zip(req_input._numpy_paths, req_input._numpy_bytes, strict=True): form.add_field( "files", content, From 3880646410c27b552dc242489353de1f1dc3a278 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 22:30:17 +0000 Subject: [PATCH 11/17] update openpi instructions --- benchmark/openpi_instructions.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmark/openpi_instructions.md b/benchmark/openpi_instructions.md index 0a08e692..e2a31637 100644 --- a/benchmark/openpi_instructions.md +++ b/benchmark/openpi_instructions.md @@ -6,12 +6,14 @@ git submodule update --init --recursive GIT_LFS_SKIP_SMUDGE=1 uv sync GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . ``` -4. From the mminf repo, run, e.g. for coriander,: +4. From the mstar repo, run: ``` -python benchmark/download_pi05_ckpt.py[3:21 PM]mkdir /m-coriander/coriander/naomi/openpi-cache -mv /home/$USER/.cache/openpi/* /m-coriander/coriander/$USER/openpi-cache/ +pip install gsutil +python benchmark/download_pi05_ckpt.py +mkdir +mv /home/$USER/.cache/openpi/* ``` 5. Start the server with: ``` -CUDA_VISIBLE_DEVICES=4 uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=/m-coriander/coriander/$USER/openpi-cache/openpi-assets/checkpoints/pi05_droid +uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=/openpi-assets/checkpoints/pi05_droid ``` \ No newline at end of file From d437bb999320765e5e25b6088508ffb65d7f6ce2 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 22:39:56 +0000 Subject: [PATCH 12/17] refactor mminf -> mstar in a few places --- benchmark/dataset.py | 4 ++-- benchmark/request.py | 2 +- mstar/api_server/data_worker.py | 4 ++-- mstar/api_server/entrypoint.py | 4 ++-- mstar/model/pi05/components/siglip.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 039bfdd2..28c7890e 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -665,7 +665,7 @@ class DROIDDataset(BaseDataset): HF_REPO = "lerobot/droid_100" # pi05 camera frames are letterboxed to this size client-side (matches the - # server's vit_image_size) so both mminf and openpi get identical input. + # server's vit_image_size) so both mstar and openpi get identical input. IMAGE_SIZE = 224 def __init__( @@ -830,7 +830,7 @@ def _make_pi05(self, idx, ep_id, frames, camera_keys, state_col, # Decode + letterbox-resize each camera frame to 224x224 uint8 and save # as a ".npy" (the "numpy" modality). Sending pre-resized arrays lets the - # server skip both image decode and the resize, and lets us hand mminf + # server skip both image decode and the resize, and lets us hand mstar # and openpi identical input. npy_paths: list[str] = [] for cam_key in camera_keys[:3]: diff --git a/benchmark/request.py b/benchmark/request.py index 27f98681..90a97e96 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -1680,7 +1680,7 @@ def _build_obs(req_input: RequestInput) -> dict: """Map our DROIDDataset RequestInput → openpi DroidInputs dict. DROIDDataset emits the camera frames as ``_numpy_bytes`` — already-decoded, - letterboxed 224x224 uint8 arrays (CxHxW), the same bytes mminf receives, so + letterboxed 224x224 uint8 arrays (CxHxW), the same bytes mstar receives, so both systems get identical input. openpi wants (H,W,3) uint8, so we just transpose; no decode/resize needed here. cam0 → exterior_image_1_left, cam1 → wrist_image_left. The 32-dim DROID state holds joint positions in diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index c6b6d65d..2715c3b1 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -31,11 +31,11 @@ logger = logging.getLogger(__name__) -# Lightweight, env-gated timing prints (MMINF_TIMING=1). perf_counter is +# Lightweight, env-gated timing prints (MSTAR_TIMING=1). perf_counter is # process-wide monotonic, so timestamps stamped in the API-server handler # thread and read in this data-worker thread are directly comparable — that's # how queue-wait (polling) latency is separated from actual work below. -_TIMING = os.environ.get("MMINF_TIMING", "") not in ("", "0", "false") +_TIMING = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") def _tlog(msg: str) -> None: diff --git a/mstar/api_server/entrypoint.py b/mstar/api_server/entrypoint.py index 72794aab..01c38906 100644 --- a/mstar/api_server/entrypoint.py +++ b/mstar/api_server/entrypoint.py @@ -28,9 +28,9 @@ logger = logging.getLogger(__name__) -# Env-gated timing prints (MMINF_TIMING=1); pairs with the [DW-TIMING] prints +# Env-gated timing prints (MSTAR_TIMING=1); pairs with the [DW-TIMING] prints # in data_worker.py to split HTTP/handler overhead from data-worker work. -_TIMING = os.environ.get("MMINF_TIMING", "") not in ("", "0", "false") +_TIMING = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") def _tlog(msg: str) -> None: diff --git a/mstar/model/pi05/components/siglip.py b/mstar/model/pi05/components/siglip.py index dbcdb21d..8c512441 100644 --- a/mstar/model/pi05/components/siglip.py +++ b/mstar/model/pi05/components/siglip.py @@ -1,7 +1,7 @@ -"""SigLIP vision encoder for Pi0.5 (native mminf port). +"""SigLIP vision encoder for Pi0.5 (native mstar port). Ports the inference path of HuggingFace's ``SiglipVisionModel`` (So400m/14) -into mminf so we own the code and can fuse projections. Differences from the +into mstar so we own the code and can fuse projections. Differences from the transformers implementation: * **Fused QKV** — the three ``q/k/v_proj`` GEMMs are merged into one From 491ff5b5bcee2565660f03eacc436af970a9dc97 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 22:51:28 +0000 Subject: [PATCH 13/17] some cleanup --- mstar/engine/kv_cache_engine.py | 1 - mstar/model/base.py | 4 ---- mstar/model/pi05/submodules.py | 4 ---- 3 files changed, 9 deletions(-) diff --git a/mstar/engine/kv_cache_engine.py b/mstar/engine/kv_cache_engine.py index 7e3c0503..983e85e3 100644 --- a/mstar/engine/kv_cache_engine.py +++ b/mstar/engine/kv_cache_engine.py @@ -904,7 +904,6 @@ def check_ready( ).items(): if needed_labels is not None and label not in needed_labels: continue - print("ASYNC RETRIEVE HAPPENING") cache_mgmt.alloc_manager.start_async_retrieve( request_id, label, seq_info ) diff --git a/mstar/model/base.py b/mstar/model/base.py index 731ce6d2..892b22e1 100644 --- a/mstar/model/base.py +++ b/mstar/model/base.py @@ -379,10 +379,6 @@ def process_prompt( def load_image(self, filepath: str, device: str) -> TensorAndMetadata: import torchvision - # Read the file once, then dispatch on content: a raw uint8 CxHxW array - # uploaded as .npy (np.save magic = b"\x93NUMPY") skips PNG/JPEG decode - # entirely (np.load is ~a memcpy); anything else goes through torchvision. - # Sniffing the magic (not the extension) keeps the upload filename free. with open(filepath, "rb") as f: raw = f.read() img = torchvision.io.decode_image( diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 12e16a84..472105f9 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -571,10 +571,6 @@ def _get_time_emb_buffer(self, bs: int) -> torch.Tensor: ) return self._time_emb_buffer[:bs] - def _embed_tokens_scaled(self, ids: torch.Tensor) -> torch.Tensor: - emb = self.embed_tokens(ids) - return emb * self._text_embed_scale - def get_cuda_graph_configs( self, device: torch.device, tp_world_size: int = 1, ) -> list[BasicBatchedCudaGraphConfig | FlashInferPackedCudaGraphConfig]: From a471ae8fc6c1f2b88e190272d4213fdccacda967 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 23:05:28 +0000 Subject: [PATCH 14/17] Cleanup stale comments + abstract away timing prints --- benchmark/openpi_instructions.md | 8 ++++---- mstar/api_server/_timing.py | 20 ++++++++++++++++++++ mstar/api_server/data_worker.py | 24 ++++++++---------------- mstar/api_server/entrypoint.py | 22 +++++++++++++--------- mstar/model/pi05/submodules.py | 20 ++++++++++---------- 5 files changed, 55 insertions(+), 39 deletions(-) create mode 100644 mstar/api_server/_timing.py diff --git a/benchmark/openpi_instructions.md b/benchmark/openpi_instructions.md index e2a31637..4e1502fa 100644 --- a/benchmark/openpi_instructions.md +++ b/benchmark/openpi_instructions.md @@ -1,6 +1,6 @@ -1. Make a python3.12 environment -2. Clone the `openpi` reop -3. Run the following in your ennvironment: +1. Make a python3.12 environment +2. Clone the `openpi` repo +3. Run the following in your environment: ``` git submodule update --init --recursive GIT_LFS_SKIP_SMUDGE=1 uv sync @@ -16,4 +16,4 @@ mv /home/$USER/.cache/openpi/* 5. Start the server with: ``` uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=/openpi-assets/checkpoints/pi05_droid -``` \ No newline at end of file +``` diff --git a/mstar/api_server/_timing.py b/mstar/api_server/_timing.py new file mode 100644 index 00000000..811ee0ba --- /dev/null +++ b/mstar/api_server/_timing.py @@ -0,0 +1,20 @@ +"""Lightweight, env-gated timing prints shared by the API server and data worker. + +Enabled with ``MSTAR_TIMING=1`` (anything other than unset/``0``/``false``). +``perf_counter`` is process-wide monotonic, so timestamps stamped in the +API-server handler thread and read in the data-worker thread are directly +comparable — that's how queue-wait (polling) latency is separated from actual +work in the [API-TIMING]/[DW-TIMING] brackets. +""" +import os + +TIMING_ENABLED = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") + + +def make_tlog(prefix: str): + """Return a ``tlog(msg)`` that prints ``[] `` when enabled.""" + def _tlog(msg: str) -> None: + if TIMING_ENABLED: + print(f"[{prefix}] {msg}", flush=True) + + return _tlog diff --git a/mstar/api_server/data_worker.py b/mstar/api_server/data_worker.py index 2715c3b1..67731b78 100644 --- a/mstar/api_server/data_worker.py +++ b/mstar/api_server/data_worker.py @@ -1,11 +1,12 @@ +import io import logging -import os import queue import threading import time +import numpy as np import torch from mstar.graph.loop_indices import NestedLoopIndices @@ -16,6 +17,8 @@ except (ImportError, RuntimeError, OSError): VideoDecoder = None +from mstar.api_server._timing import TIMING_ENABLED as _TIMING +from mstar.api_server._timing import make_tlog from mstar.api_server.request_types import PreprocessInput, ResultChunk, ResultTensors from mstar.communication.communicator import CommProtocol, ZMQCommunicator from mstar.communication.tensors import NameToTensorList, create_tensor_communication_manager @@ -31,16 +34,9 @@ logger = logging.getLogger(__name__) -# Lightweight, env-gated timing prints (MSTAR_TIMING=1). perf_counter is -# process-wide monotonic, so timestamps stamped in the API-server handler -# thread and read in this data-worker thread are directly comparable — that's -# how queue-wait (polling) latency is separated from actual work below. -_TIMING = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") - - -def _tlog(msg: str) -> None: - if _TIMING: - print(f"[DW-TIMING] {msg}", flush=True) +# See mstar.api_server._timing; env-gated [DW-TIMING] prints (MSTAR_TIMING=1) +# that pair with the [API-TIMING] prints to split queue-wait from actual work. +_tlog = make_tlog("DW-TIMING") def _preprocess_loop(**kwargs): @@ -226,15 +222,11 @@ def _process_input( # ".npy" uploads (modality "numpy") are kept in memory and np.load'd # here as "raw_inputs"; the model maps them in process_prompt. if input.numpy_bytes: - import io as _io - - import numpy as np - tensors["raw_inputs"] = [] input_metadata["raw_inputs"] = [] for blob in input.numpy_bytes: tensors["raw_inputs"].append( - torch.from_numpy(np.load(_io.BytesIO(blob))).to(self.device) + torch.from_numpy(np.load(io.BytesIO(blob))).to(self.device) ) input_metadata["raw_inputs"].append({}) diff --git a/mstar/api_server/entrypoint.py b/mstar/api_server/entrypoint.py index 01c38906..4ee164d0 100644 --- a/mstar/api_server/entrypoint.py +++ b/mstar/api_server/entrypoint.py @@ -20,6 +20,8 @@ from fastapi.responses import JSONResponse, StreamingResponse from starlette.concurrency import run_in_threadpool +from mstar.api_server._timing import TIMING_ENABLED as _TIMING +from mstar.api_server._timing import make_tlog from mstar.api_server.data_worker import PreprocessWorker from mstar.api_server.request_types import APIServerMessage, PreprocessInput, ResultChunk from mstar.communication.communicator import CommProtocol, ZMQCommunicator @@ -28,17 +30,15 @@ logger = logging.getLogger(__name__) -# Env-gated timing prints (MSTAR_TIMING=1); pairs with the [DW-TIMING] prints -# in data_worker.py to split HTTP/handler overhead from data-worker work. -_TIMING = os.environ.get("MSTAR_TIMING", "") not in ("", "0", "false") +# See mstar.api_server._timing; env-gated [API-TIMING] prints (MSTAR_TIMING=1) +# that pair with the [DW-TIMING] prints to split HTTP/handler overhead from +# data-worker work. +_tlog = make_tlog("API-TIMING") -def _tlog(msg: str) -> None: - if _TIMING: - print(f"[API-TIMING] {msg}", flush=True) - - -SUPPORTED_MODALITIES = frozenset({"text", "image", "audio", "video", "action", "scalar", "tensor"}) +SUPPORTED_MODALITIES = frozenset( + {"text", "image", "audio", "video", "action", "scalar", "tensor", "numpy"} +) # Extension-based modality detection for uploaded files. _EXT_TO_MODALITY: dict[str, str] = {} @@ -586,6 +586,10 @@ async def generate( else: in_mods: list[str] = [] in_mods.extend(file_paths.keys()) + # ".npy" uploads bypass file_paths (kept in memory as numpy_bytes), so + # add their "numpy" modality explicitly or auto-detect would drop it. + if numpy_bytes: + in_mods.append("numpy") if text: in_mods.append("text") diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 472105f9..120d89a5 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -1,12 +1,11 @@ """NodeSubmodule wrappers for the Pi0.5 model nodes. -Two submodules: - Pi05ViTEncoderSubmodule -- SigLIP vision encoder for camera images. - Pi05LLMSubmodule -- combined PaliGemma + action expert. Dispatches by - graph_walk between prefill (PaliGemma writes the - prefix KV cache) and action_gen (action expert - reads the frozen prefix KV cache and runs one - Euler flow-matching step). +Three submodules: + Pi05ViTEncoderSubmodule -- SigLIP vision encoder for camera images. + Pi05PaligemmaSubmodule -- PaliGemma prefix expert; prefills and writes the + prefix KV cache. + Pi05ActionExpertSubmodule -- action expert; reads the frozen prefix KV cache + and runs the Euler flow-matching denoising loop. """ import logging @@ -580,7 +579,7 @@ def get_cuda_graph_configs( # are read directly from self.config — same source as the nn.Linear # weight shapes — so they're guaranteed consistent. logger.info( - "Pi05LLMSubmodule.get_cuda_graph_configs: capturing 'action_gen' " + "Pi05ActionExpertSubmodule.get_cuda_graph_configs: capturing 'action_gen' " "graph with input_seq_len=%d, noisy_actions=(%d, %d), batch_sizes=[1,2,4] " "(num_flow_steps=%d denoising iters runs INSIDE this captured graph; " "denoising count is independent of horizon)", @@ -762,8 +761,9 @@ def _forward_action_gen( ``noisy_actions`` and ``timestep_index`` arrive as single-element lists from preprocess (to keep the data structure uniform with the - batched path). We unpack the first element, run one Euler step, and - return the loop-back edges. + batched path). We unpack the first element, run the full + ``num_flow_steps`` Euler denoising loop, and return the denoised action + tensor. """ # Unpack from list form (preprocess always returns lists now). if isinstance(noisy_actions, list): From 8954bf05ec6c59257ace19984d0466906e062799 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Thu, 11 Jun 2026 23:11:33 +0000 Subject: [PATCH 15/17] Remove stale Pi05LLMSubmodule references --- test/integration/test_pi05_real_weights.py | 5 +++-- test/pi05/compare_with_lerobot.py | 2 +- test/pi05/probe_mstar_vs_lerobot.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/integration/test_pi05_real_weights.py b/test/integration/test_pi05_real_weights.py index a698f265..39eb2f68 100644 --- a/test/integration/test_pi05_real_weights.py +++ b/test/integration/test_pi05_real_weights.py @@ -672,8 +672,9 @@ def test_pi05_model_loaded_via_remapper_matches_lerobot(): This is the strictest "real Pi05Model" check we can run without standing up a full mstar worker process: it exercises :class:`Pi05Model`'s lazy submodule construction, the lerobot→mstar - state-dict remap, and the actual ``Pi05ViTEncoderSubmodule`` and - ``Pi05LLMSubmodule`` forward methods. The only thing it bypasses is + state-dict remap, and the actual ``Pi05ViTEncoderSubmodule``, + ``Pi05PaligemmaSubmodule``, and ``Pi05ActionExpertSubmodule`` forward + methods. The only thing it bypasses is the FlashInfer/KVCacheEngine paged KV cache (replaced with the same ``MockCacheHandle`` used by the other integration tests, which has been validated against the real FlashInfer wrapper separately). diff --git a/test/pi05/compare_with_lerobot.py b/test/pi05/compare_with_lerobot.py index 077dc1f4..670428da 100644 --- a/test/pi05/compare_with_lerobot.py +++ b/test/pi05/compare_with_lerobot.py @@ -68,7 +68,7 @@ def server_seed_for(request_id: str) -> int: def reproduce_server_noise(request_id: str, device: torch.device) -> torch.Tensor: - """Reproduce the noise tensor that ``Pi05LLMSubmodule._preprocess_action_gen`` + """Reproduce the noise tensor that ``Pi05ActionExpertSubmodule._preprocess_action_gen`` will sample on iteration 0 for this request. Server code (mstar/model/pi05/submodules.py):: diff --git a/test/pi05/probe_mstar_vs_lerobot.py b/test/pi05/probe_mstar_vs_lerobot.py index 307cc757..1a11e366 100644 --- a/test/pi05/probe_mstar_vs_lerobot.py +++ b/test/pi05/probe_mstar_vs_lerobot.py @@ -9,7 +9,7 @@ Stage 1: Pi05ViTEncoderSubmodule output (per-camera image embeddings) vs lerobot ``paligemma_with_expert.embed_image(image)``. - Stage 2: Pi05LLMSubmodule._preprocess_prefill output (prefix_embs) + Stage 2: Pi05PaligemmaSubmodule._preprocess_prefill output (prefix_embs) vs lerobot ``embed_prefix(images, masks, tokens, masks)``. Stage 3: Action expert first-step velocity vs lerobot ``denoise_step`` first iteration. @@ -333,7 +333,7 @@ def main(): } ] # We don't actually need a real cache_manager for this stage — just to - # call the helper that builds prefix_embs. Pi05LLMSubmodule._preprocess_prefill + # call the helper that builds prefix_embs. Pi05PaligemmaSubmodule._preprocess_prefill # also calls plan_attention/plan_rope which need a real cache manager. # Build a dummy that just no-ops the plan_* calls: class _NoopCache: From 0c83059480f647fe4b259e12744be06908ccf9b6 Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 14 Jun 2026 04:51:52 +0000 Subject: [PATCH 16/17] respond to PR comments --- benchmark/dataset.py | 11 +++++++ benchmark/request.py | 5 ++++ configs/pi05_droid.yaml | 7 +++-- mstar/model/pi05/pi05_model.py | 53 +++++----------------------------- mstar/model/pi05/submodules.py | 13 ++++++--- mstar/worker/worker.py | 3 ++ 6 files changed, 39 insertions(+), 53 deletions(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 28c7890e..27999a15 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -668,6 +668,11 @@ class DROIDDataset(BaseDataset): # server's vit_image_size) so both mstar and openpi get identical input. IMAGE_SIZE = 224 + PI05_KEYS = [ + "observation.images.exterior_image_1_left", + "observation.images.wrist_image_left" + ] + def __init__( self, local_file_dir: str, @@ -730,6 +735,12 @@ def _dl(filename): f"No video keys found in {self.HF_REPO}/meta/info.json. " f"Top-level keys: {list(info.keys())}" ) + + if task == "pi05": + assert all((key in camera_keys for key in self.PI05_KEYS)), \ + f"Expected camera keys {self.PI05_KEYS} not all found in {camera_keys}" + camera_keys = self.PI05_KEYS + chunks_size: int = info.get("chunks_size", info.get("chunk_size", 1000)) print(f" camera keys : {camera_keys}") print(f" chunks_size : {chunks_size}") diff --git a/benchmark/request.py b/benchmark/request.py index 90a97e96..91219f23 100644 --- a/benchmark/request.py +++ b/benchmark/request.py @@ -1686,6 +1686,11 @@ def _build_obs(req_input: RequestInput) -> dict: cam1 → wrist_image_left. The 32-dim DROID state holds joint positions in [:7] and gripper in [7]; the rest is padding we ignore. + NOTE: lerobot/droid_100 ships no gripper signal, so state[7:8] is always + padding 0.0 — actions are not semantically valid for either system. Fine + here: the benchmark measures latency (identical tensor shapes => identical + compute), not action quality. + openpi DroidInputs (droid_policy.py:make_droid_example) expects: observation/exterior_image_1_left : (H,W,3) uint8 observation/wrist_image_left : (H,W,3) uint8 diff --git a/configs/pi05_droid.yaml b/configs/pi05_droid.yaml index 4259823d..0d750a06 100644 --- a/configs/pi05_droid.yaml +++ b/configs/pi05_droid.yaml @@ -18,9 +18,10 @@ max_seq_len: 2048 # post-compute Python, so both systems do identical 32-dim work. # # CUDA-graph note: action_horizon and action_dim are both baked into the -# graph captures (see Pi05LLMSubmodule.get_cuda_graph_configs in -# mstar/model/pi05/submodules.py:325-329). They MUST be set at server-init -# time, never per-request — that's what this yaml override is for. +# graph captures (see Pi05ActionExpertSubmodule.get_cuda_graph_configs +# They MUST be set at server-init time, never per-request — that's what +# this yaml override is for. + model_kwargs: action_horizon: 15 num_cameras: 2 diff --git a/mstar/model/pi05/pi05_model.py b/mstar/model/pi05/pi05_model.py index 15625aa4..56d9f248 100644 --- a/mstar/model/pi05/pi05_model.py +++ b/mstar/model/pi05/pi05_model.py @@ -90,36 +90,6 @@ def __init__( self.time_mlp = time_mlp -def _reset_non_persistent_buffers(module: nn.Module, device) -> None: - """Re-initialize non-persistent buffers like ``position_ids`` after a - ``meta + to_empty`` materialization. - - Modules constructed on the meta device skip ``post_init``, and - ``to_empty`` only allocates uninitialized storage for parameters and - buffers. Non-persistent buffers (registered with ``persistent=False``) - are not in the state_dict, so ``load_state_dict`` will not restore them - either — leaving them as garbage. The most common offender is HuggingFace - SigLIP's ``position_ids`` buffer (``register_buffer("position_ids", - arange(num_positions), persistent=False)``), which feeds the position - embedding lookup. If left as garbage int64 it produces wildly incorrect - image embeddings (off by the full norm of the position table). - - This walks the module tree and resets any sub-module that has a - ``position_ids`` buffer to the canonical ``arange(num_positions)``. - """ - with torch.no_grad(): - for sub in module.modules(): - pos = getattr(sub, "position_ids", None) - if isinstance(pos, torch.Tensor): - shape = pos.shape - num_positions = shape[-1] - pos.copy_( - torch.arange( - num_positions, device=pos.device, dtype=pos.dtype - ).expand(shape) - ) - - class Pi05Model(Model): """Pi0.5 vision-language-action model implementation.""" ACTION_GEN_WALK = "action_gen" @@ -457,14 +427,17 @@ def process_prompt( here so the resulting ``text_inputs`` stream matches the production format. """ - # A "numpy" upload arrives as "raw_inputs"; Pi0.5 treats it as the image. + # A "numpy" upload arrives as "raw_inputs"; Pi0.5 treats it as an image input + # We append the raw_inputs onto the image_inputs, so the user can pass in both + # images and numpy arrays tensors = kwargs.get("tensors") if tensors is not None and "raw_inputs" in tensors: - assert "image_inputs" not in tensors, "got both raw_inputs and image_inputs" - tensors["image_inputs"] = tensors.pop("raw_inputs") + tensors.setdefault("image_inputs", []).extend(tensors.pop("raw_inputs")) input_metadata = kwargs.get("input_metadata") if input_metadata is not None and "raw_inputs" in input_metadata: - input_metadata["image_inputs"] = input_metadata.pop("raw_inputs") + input_metadata.setdefault("image_inputs", []).extend( + input_metadata.pop("raw_inputs") + ) if self.tokenizer is None: # Tokenizer-less fallback used by structural unit tests. @@ -610,22 +583,10 @@ def _init_vit_components(self, device: str): self.siglip = Pi05SiglipEncoder(self.config) if self.skip_weight_loading: self.siglip = self.siglip.to_empty(device=device) - _reset_non_persistent_buffers(self.siglip, device) return flat = self._ensure_lerobot_flat() self.siglip.to_empty(device=device) - # CRITICAL: HF's SiglipVisionEmbeddings registers ``position_ids`` as - # a NON-persistent buffer (persistent=False), so it's not in any - # state_dict. ``to_empty`` materializes it as uninitialized GPU - # memory, ``_init_weights`` is never called (we never go through - # post_init), and ``load_state_dict(strict=False)`` does not restore - # it. The result is garbage int64 indices feeding into - # ``position_embedding``, which corrupts every image embedding by - # ~the full norm of the position table. We must manually reset any - # non-persistent ``position_ids`` buffer with the canonical - # ``arange`` values before running the forward. - _reset_non_persistent_buffers(self.siglip, device) # The extracted bucket may contain stray pooling-head keys that # Pi05SiglipEncoder doesn't model (``vision_use_head=False``); # ``load_hf_weights`` silently ignores any key that has no matching diff --git a/mstar/model/pi05/submodules.py b/mstar/model/pi05/submodules.py index 120d89a5..6a3f7574 100644 --- a/mstar/model/pi05/submodules.py +++ b/mstar/model/pi05/submodules.py @@ -68,7 +68,7 @@ def to(self, *args, **kwargs): return result def _prepare_one(self, images: torch.Tensor) -> torch.Tensor: - """Resize one request's stack of camera images with aspect-preserving + """Resize one request's camera image(s) with aspect-preserving letterbox padding. Matches openpi's ``image_tools.resize_with_pad_torch`` exactly: @@ -179,9 +179,14 @@ def prepare_inputs( inputs: NameToTensorList, **kwargs ) -> NodeInputs: - return NodeInputs(tensor_inputs={"pixel_values": self._prepare_one( - inputs["image_inputs"][0] - )}) + images = torch.cat([ + self._prepare_one(img) for img in inputs["image_inputs"] + ]) + # TODO: assert images.shape == (num_cameras, 3, H, W) once worker errors + # are surfaced. A wrong count silently broadcasts in the static CUDA + # graph; today a raised error is swallowed and the client hangs, so this + # needs prepare_inputs errors threaded engine -> conductor -> API server. + return NodeInputs(tensor_inputs={"pixel_values": images}) def preprocess( self, diff --git a/mstar/worker/worker.py b/mstar/worker/worker.py index a3d5c0f2..5940787e 100644 --- a/mstar/worker/worker.py +++ b/mstar/worker/worker.py @@ -1511,6 +1511,9 @@ def _thread_outputs_to_speculative( rid_outputs = output_N.per_request_output_tensors.get(rid, {}) ok = True for input_name, _ in speculation.consumed_edges: + # NOTE: this assumes that submodules may output a empty list as valid + # output, and will omit the key entirely from the output upon, e.g., + # an internal failure. Revisit if this contract ever changes. tensors = rid_outputs.get(input_name, None) if tensors is None: ok = False From 41e5b206a0472aa6075d22d7e301eea01520616f Mon Sep 17 00:00:00 2001 From: NSagan271 Date: Sun, 14 Jun 2026 04:52:17 +0000 Subject: [PATCH 17/17] ruff --- benchmark/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 27999a15..4447e4f6 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -735,7 +735,7 @@ def _dl(filename): f"No video keys found in {self.HF_REPO}/meta/info.json. " f"Top-level keys: {list(info.keys())}" ) - + if task == "pi05": assert all((key in camera_keys for key in self.PI05_KEYS)), \ f"Expected camera keys {self.PI05_KEYS} not all found in {camera_keys}"