diff --git a/src/tether/cli.py b/src/tether/cli.py index dca60bd..b731c44 100644 --- a/src/tether/cli.py +++ b/src/tether/cli.py @@ -1785,6 +1785,20 @@ def serve( "fires (whichever first). Lower = lower per-request latency; " "higher = better batching efficiency under bursty load.", ), + inference_executor_workers: int = typer.Option( + 1, + "--inference-executor-workers", + help="Dedicated worker threads for offloading synchronous inference from " + "the async server. Keep 1 for static-shape GPU exports; increase " + "only when the backend supports parallel inference safely.", + ), + inference_executor_queue: int = typer.Option( + 8, + "--inference-executor-queue", + help="Accepted-but-not-yet-running inference submissions before the " + "server returns inference_executor_full. Set 0 to reject instead " + "of queueing behind a busy worker.", + ), max_batch_cost_ms: float = typer.Option( 100.0, "--max-batch-cost-ms", @@ -2405,6 +2419,11 @@ def serve( composed.append(f"[cyan]deadline[/cyan]={deadline_ms:.0f}ms") if max_batch > 1: composed.append(f"[cyan]batch[/cyan]={max_batch}@{batch_timeout_ms:.0f}ms") + if inference_executor_workers != 1 or inference_executor_queue != 8: + composed.append( + f"[cyan]inference-executor[/cyan]=" + f"{inference_executor_workers}w/{inference_executor_queue}q" + ) if embodiment_cfg is not None: composed.append(f"[cyan]embodiment[/cyan]={embodiment_cfg.embodiment}") if so_arm100_adapter is not None: @@ -2522,6 +2541,8 @@ def serve( deadline_ms=deadline_ms if deadline_ms > 0 else None, max_batch=max_batch, batch_timeout_ms=batch_timeout_ms, + inference_executor_workers=inference_executor_workers, + inference_executor_queue=inference_executor_queue, api_key=api_key or None, replan_hz=replan_hz if replan_hz > 0 else None, execute_hz=execute_hz if execute_hz > 0 else None, diff --git a/src/tether/observability/__init__.py b/src/tether/observability/__init__.py index 04202e2..19be2ea 100644 --- a/src/tether/observability/__init__.py +++ b/src/tether/observability/__init__.py @@ -17,6 +17,7 @@ inc_cache_miss, inc_denoise_steps, inc_fallback_invocation, + inc_inference_executor_rejected, inc_model_swap, inc_safety_violation, inc_slo_violation, @@ -25,6 +26,7 @@ record_act_latency, render_metrics, set_episodes_active, + set_inference_executor_state, set_robot_info, set_server_up, track_in_flight, @@ -42,10 +44,12 @@ "inc_safety_violation", "inc_slo_violation", "inc_fallback_invocation", + "inc_inference_executor_rejected", "inc_model_swap", "observe_batch_flush", "track_in_flight", "set_server_up", "set_robot_info", "set_episodes_active", + "set_inference_executor_state", ] diff --git a/src/tether/observability/prometheus.py b/src/tether/observability/prometheus.py index ad8c221..36d4f47 100644 --- a/src/tether/observability/prometheus.py +++ b/src/tether/observability/prometheus.py @@ -113,6 +113,13 @@ registry=REGISTRY, ) +tether_inference_executor_rejected_total = Counter( + "tether_inference_executor_rejected_total", + "Inference executor submissions rejected because the bounded queue was full", + labelnames=("embodiment", "model_id", "policy_slot"), + registry=REGISTRY, +) + # Action-similarity fast-path skip counter (action-similarity-fast-path # Phase 1.5 — FlashVLA). Increments when the inference path returns a # cached action chunk instead of running the expert. Operator visibility @@ -167,6 +174,27 @@ registry=REGISTRY, ) +tether_inference_executor_in_flight = Gauge( + "tether_inference_executor_in_flight", + "Synchronous inference calls currently running in executor worker threads", + labelnames=("embodiment", "model_id", "policy_slot"), + registry=REGISTRY, +) + +tether_inference_executor_queue_depth = Gauge( + "tether_inference_executor_queue_depth", + "Synchronous inference calls accepted but not yet running in executor workers", + labelnames=("embodiment", "model_id", "policy_slot"), + registry=REGISTRY, +) + +tether_inference_executor_capacity = Gauge( + "tether_inference_executor_capacity", + "Configured inference executor capacity by kind", + labelnames=("embodiment", "model_id", "policy_slot", "kind"), + registry=REGISTRY, +) + # --------------------------------------------------------------------------- # Helpers — typed call-sites keep the surface searchable @@ -228,6 +256,18 @@ def inc_fallback_invocation(embodiment: str, target: str) -> None: ).inc() +def inc_inference_executor_rejected( + embodiment: str, + model_id: str, + policy_slot: str = "prod", +) -> None: + tether_inference_executor_rejected_total.labels( + embodiment=embodiment, + model_id=model_id, + policy_slot=policy_slot, + ).inc() + + def inc_action_skip() -> None: tether_action_skip_total.inc() @@ -258,6 +298,30 @@ def set_episodes_active(embodiment: str, value: int) -> None: tether_episodes_active.labels(embodiment=embodiment).set(value) +def set_inference_executor_state( + embodiment: str, + model_id: str, + policy_slot: str = "prod", + *, + in_flight: int, + queue_depth: int, + max_workers: int, + max_queue: int, +) -> None: + labels = { + "embodiment": embodiment, + "model_id": model_id, + "policy_slot": policy_slot, + } + tether_inference_executor_in_flight.labels(**labels).set(in_flight) + tether_inference_executor_queue_depth.labels(**labels).set(queue_depth) + tether_inference_executor_capacity.labels(**labels, kind="workers").set(max_workers) + tether_inference_executor_capacity.labels(**labels, kind="queue").set(max_queue) + tether_inference_executor_capacity.labels(**labels, kind="total").set( + max_workers + max_queue + ) + + @contextmanager def track_in_flight(embodiment: str, policy_slot: str = "prod") -> Iterator[None]: """Context manager increments/decrements in-flight gauge for safe diff --git a/src/tether/runtime/inference_executor.py b/src/tether/runtime/inference_executor.py new file mode 100644 index 0000000..5a8254a --- /dev/null +++ b/src/tether/runtime/inference_executor.py @@ -0,0 +1,157 @@ +"""Bounded async offload executor for synchronous inference work.""" + +from __future__ import annotations + +import asyncio +import functools +import logging +import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Callable, TypeVar + +T = TypeVar("T") +logger = logging.getLogger(__name__) + + +class InferenceExecutorFull(RuntimeError): + """Raised when all inference worker and queue slots are occupied.""" + + +@dataclass(frozen=True) +class InferenceExecutorSnapshot: + """Point-in-time executor state for health and metrics surfaces.""" + + max_workers: int + max_queue: int + capacity: int + pending: int + running: int + queue_depth: int + rejected: int + + +class BoundedInferenceExecutor: + """Run sync inference functions in a bounded dedicated thread pool. + + The default asyncio executor has an unbounded submission queue. For robot + serving, that hides overload and lets latency tails grow silently. This + wrapper rejects fast when all worker + queue slots are occupied. + """ + + def __init__( + self, + *, + max_workers: int = 1, + max_queue: int = 8, + thread_name_prefix: str = "tether-inference", + on_state_change: Callable[[InferenceExecutorSnapshot], None] | None = None, + ) -> None: + if max_workers <= 0: + raise ValueError(f"max_workers must be > 0, got {max_workers}") + if max_queue < 0: + raise ValueError(f"max_queue must be >= 0, got {max_queue}") + self._max_workers = int(max_workers) + self._max_queue = int(max_queue) + self._capacity = self._max_workers + self._max_queue + self._executor = ThreadPoolExecutor( + max_workers=self._max_workers, + thread_name_prefix=thread_name_prefix, + ) + self._lock = threading.Lock() + self._pending = 0 + self._running = 0 + self._rejected = 0 + self._closed = False + self._on_state_change = on_state_change + + @property + def max_workers(self) -> int: + return self._max_workers + + @property + def max_queue(self) -> int: + return self._max_queue + + @property + def capacity(self) -> int: + return self._capacity + + async def submit(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> T: + """Submit sync work or raise InferenceExecutorFull without waiting.""" + + full_message = "" + with self._lock: + if self._closed: + raise RuntimeError("inference executor is shut down") + if self._pending >= self._capacity: + self._rejected += 1 + snapshot = self._snapshot_locked() + full_message = ( + "inference executor is full " + f"(pending={self._pending}, capacity={self._capacity})" + ) + else: + self._pending += 1 + snapshot = self._snapshot_locked() + self._notify_state(snapshot) + if full_message: + raise InferenceExecutorFull(full_message) + + loop = asyncio.get_running_loop() + callback = functools.partial(self._invoke, fn, args, kwargs) + try: + return await loop.run_in_executor(self._executor, callback) + finally: + with self._lock: + self._pending -= 1 + snapshot = self._snapshot_locked() + self._notify_state(snapshot) + + def snapshot(self) -> InferenceExecutorSnapshot: + with self._lock: + return self._snapshot_locked() + + def shutdown(self, *, wait: bool = False) -> None: + with self._lock: + if self._closed: + return + self._closed = True + self._executor.shutdown(wait=wait, cancel_futures=True) + + def _invoke( + self, + fn: Callable[..., T], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> T: + with self._lock: + self._running += 1 + snapshot = self._snapshot_locked() + self._notify_state(snapshot) + try: + return fn(*args, **kwargs) + finally: + with self._lock: + self._running -= 1 + snapshot = self._snapshot_locked() + self._notify_state(snapshot) + + def _snapshot_locked(self) -> InferenceExecutorSnapshot: + return InferenceExecutorSnapshot( + max_workers=self._max_workers, + max_queue=self._max_queue, + capacity=self._capacity, + pending=self._pending, + running=self._running, + queue_depth=max(0, self._pending - self._running), + rejected=self._rejected, + ) + + def _notify_state(self, snapshot: InferenceExecutorSnapshot) -> None: + if self._on_state_change is None: + return + try: + self._on_state_change(snapshot) + except Exception as exc: # noqa: BLE001 + logger.debug("inference executor state callback failed: %s", exc) diff --git a/src/tether/runtime/server.py b/src/tether/runtime/server.py index 63e9ea3..872e64d 100644 --- a/src/tether/runtime/server.py +++ b/src/tether/runtime/server.py @@ -22,15 +22,19 @@ import io import json import logging +import os import time from pathlib import Path from typing import Any import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F +from .inference_executor import ( + BoundedInferenceExecutor, + InferenceExecutorFull, + InferenceExecutorSnapshot, +) from .record import ( RecordWriter, compute_config_hash, @@ -43,8 +47,10 @@ try: from tether.observability import ( METRICS_CONTENT_TYPE, + inc_inference_executor_rejected, record_act_latency, render_metrics, + set_inference_executor_state, set_robot_info, set_server_up, track_in_flight, @@ -53,8 +59,10 @@ except ImportError: # pragma: no cover _METRICS_AVAILABLE = False METRICS_CONTENT_TYPE = "text/plain" + def inc_inference_executor_rejected(*args, **kwargs): pass def record_act_latency(*args, **kwargs): pass def render_metrics() -> bytes: return b"# prometheus_client not installed\n" + def set_inference_executor_state(*args, **kwargs): pass def set_server_up(*args): pass def set_robot_info(*args, **kwargs): pass @@ -161,6 +169,8 @@ def __init__( deadline_ms: float | None = None, max_batch: int = 1, batch_timeout_ms: float = 5.0, + inference_executor_workers: int = 1, + inference_executor_queue: int = 8, ): """Create the server. @@ -188,6 +198,10 @@ def __init__( deadline_ms: soft deadline per `predict()` call. If the denoise loop + safety check exceeds this, the server returns the last known good action (or a zero vector) and logs a deadline miss. + inference_executor_workers: worker threads used by async entrypoints + to offload synchronous predict calls. + inference_executor_queue: accepted-but-not-yet-running inference + submissions before async entrypoints reject overload. """ self.export_dir = Path(export_dir) self.device = torch.device(device if torch.cuda.is_available() else "cpu") @@ -201,6 +215,7 @@ def __init__( self._vlm = None self._vlm_loaded = False self._expert_input_names: list[str] = [] + self._ort_iobinding_enabled = self.device.type == "cuda" # Composed wedges (Phase I.2) self._safety_config_path = Path(safety_config) if safety_config else None @@ -221,6 +236,15 @@ def __init__( self._batches_run = 0 self._batched_requests = 0 + # Async inference offload: bounded, dedicated executor so sync predict() + # work cannot build an unbounded queue behind the event loop. + self._inference_policy_slot = "prod" + self._inference_executor = BoundedInferenceExecutor( + max_workers=inference_executor_workers, + max_queue=inference_executor_queue, + on_state_change=self._record_inference_executor_state, + ) + # Rolling latency history for p50/p95/p99 reporting (goal: # latency-histograms). Capped at 1024 samples — a ring buffer in # all but name. @@ -594,6 +618,75 @@ def _load_vlm_orchestrator(self) -> None: def ready(self) -> bool: return self._ready + def _prepare_ort_iobinding( + self, + constant_inputs: dict[str, np.ndarray], + ) -> tuple[Any, str, list[Any]] | None: + """Bind denoise-loop constant inputs once for ORT I/O Binding.""" + if ( + not self._ort_iobinding_enabled + or getattr(self, "_ort_session", None) is None + or not hasattr(self._ort_session, "io_binding") + ): + return None + + try: + output_name = self._ort_session.get_outputs()[0].name + binding = self._ort_session.io_binding() + kept_alive: list[Any] = [] + for name, array in constant_inputs.items(): + self._bind_ort_input(binding, name, array, kept_alive) + return binding, output_name, kept_alive + except Exception as e: + logger.debug("ORT I/O Binding unavailable; falling back to session.run: %s", e) + return None + + def _bind_ort_input( + self, + binding: Any, + name: str, + array: np.ndarray, + kept_alive: list[Any], + ) -> None: + if self.device.type == "cuda": + import onnxruntime as ort + + ort_value = ort.OrtValue.ortvalue_from_numpy(array, "cuda", 0) + binding.bind_ortvalue_input(name, ort_value) + kept_alive.append(ort_value) + else: + binding.bind_cpu_input(name, array) + + def _run_ort_velocity( + self, + dynamic_inputs: dict[str, np.ndarray], + constant_inputs: dict[str, np.ndarray], + iobinding: tuple[Any, str, list[Any]] | None, + ) -> np.ndarray: + if iobinding is None: + return self._ort_session.run( + None, + {**dynamic_inputs, **constant_inputs}, + )[0] + + binding, output_name, kept_alive = iobinding + dynamic_kept_alive: list[Any] = [] + try: + if hasattr(binding, "clear_binding_outputs"): + binding.clear_binding_outputs() + for name, array in dynamic_inputs.items(): + self._bind_ort_input(binding, name, array, dynamic_kept_alive) + if self.device.type == "cuda": + binding.bind_output(output_name, "cuda", 0) + else: + binding.bind_output(output_name, "cpu") + self._ort_session.run_with_iobinding(binding) + return binding.get_outputs()[0].numpy() + finally: + # Keep OrtValues alive until ORT has finished the call. + kept_alive.extend(dynamic_kept_alive) + del kept_alive[len(kept_alive) - len(dynamic_kept_alive):] + def _run_denoise( self, noisy_actions: np.ndarray, @@ -648,35 +741,35 @@ def _run_denoise( # v0.3/v0.4 single-tensor fallback vlm_kv_single = zeros_4d + constant_feed: dict[str, np.ndarray] = {"position_ids": position_ids} + if expert_has_split_kv and vlm_k is not None and vlm_v is not None: + constant_feed["vlm_k"] = vlm_k + constant_feed["vlm_v"] = vlm_v + prefix_len = int(vlm_k.shape[2]) # [L, B, seq, kv] + batch = noisy_actions.shape[0] + if "prefix_offset" in self._expert_input_names: + constant_feed["prefix_offset"] = np.full( + (batch, 1), prefix_len, dtype=np.int64 + ) + if "kv_mask" in self._expert_input_names: + # All-valid mask when we don't have the prefix pad mask handy. + # TODO: plumb the real padded-token mask through from the + # VLM orchestrator. + constant_feed["kv_mask"] = np.ones((batch, prefix_len), dtype=bool) + elif expert_has_single_kv and vlm_kv_single is not None: + constant_feed["vlm_kv"] = vlm_kv_single + + iobinding = self._prepare_ort_iobinding(constant_feed) + for step in range(self.num_denoising_steps): t = 1.0 + step * dt timestep = np.array([t], dtype=np.float32) - feed_dict = { + dynamic_feed = { "noisy_actions": noisy_actions, "timestep": timestep, - "position_ids": position_ids, } - if expert_has_split_kv and vlm_k is not None and vlm_v is not None: - feed_dict["vlm_k"] = vlm_k - feed_dict["vlm_v"] = vlm_v - prefix_len = int(vlm_k.shape[2]) # [L, B, seq, kv] - batch = noisy_actions.shape[0] - if "prefix_offset" in self._expert_input_names: - feed_dict["prefix_offset"] = np.full( - (batch, 1), prefix_len, dtype=np.int64 - ) - if "kv_mask" in self._expert_input_names: - # All-valid mask when we don't have the prefix pad mask handy. - # TODO: plumb the real padded-token mask through from the - # VLM orchestrator. - feed_dict["kv_mask"] = np.ones( - (batch, prefix_len), dtype=bool - ) - elif expert_has_single_kv and vlm_kv_single is not None: - feed_dict["vlm_kv"] = vlm_kv_single - - velocity = self._ort_session.run(None, feed_dict)[0] + velocity = self._run_ort_velocity(dynamic_feed, constant_feed, iobinding) noisy_actions = noisy_actions + velocity * dt @@ -916,6 +1009,57 @@ async def stop_batch_worker(self) -> None: self._batch_worker_task = None self._batch_queue = None + def shutdown_inference_executor(self) -> None: + """Stop the dedicated inference offload pool.""" + self._inference_executor.shutdown(wait=False) + + def _inference_executor_metric_labels(self) -> tuple[str, str, str]: + ec = getattr(self, "embodiment_config", None) + embodiment = getattr(ec, "embodiment", None) or "custom" + model_id = Path(self.export_dir).name or "unknown" + policy_slot = getattr(self, "_inference_policy_slot", "prod") or "prod" + return embodiment, model_id, policy_slot + + def _record_inference_executor_state( + self, + snapshot: InferenceExecutorSnapshot | None = None, + ) -> None: + snapshot = snapshot or self._inference_executor.snapshot() + embodiment, model_id, policy_slot = self._inference_executor_metric_labels() + set_inference_executor_state( + embodiment=embodiment, + model_id=model_id, + policy_slot=policy_slot, + in_flight=snapshot.running, + queue_depth=snapshot.queue_depth, + max_workers=snapshot.max_workers, + max_queue=snapshot.max_queue, + ) + + def _record_inference_executor_rejected(self) -> None: + embodiment, model_id, policy_slot = self._inference_executor_metric_labels() + inc_inference_executor_rejected( + embodiment=embodiment, + model_id=model_id, + policy_slot=policy_slot, + ) + + def _inference_executor_full_result( + self, + exc: InferenceExecutorFull, + ) -> dict[str, Any]: + snapshot = self._inference_executor.snapshot() + return { + "error": "inference_executor_full", + "message": str(exc), + "max_workers": snapshot.max_workers, + "max_queue": snapshot.max_queue, + "queue_depth": snapshot.queue_depth, + "in_flight": snapshot.running, + "pending": snapshot.pending, + "rejected_total": snapshot.rejected, + } + async def predict_async( self, image: np.ndarray | None = None, @@ -924,14 +1068,24 @@ async def predict_async( ) -> dict[str, Any]: """Async front-door used by the HTTP /act handler. - - If max_batch <= 1: runs `self.predict()` synchronously in this task. + - If max_batch <= 1: runs `self.predict()` in the bounded inference + executor. - If max_batch > 1: enqueues the request onto a batch queue. A worker coroutine drains the queue every `batch_timeout_ms` ms (or when the queue hits max_batch) and runs ONE batched ONNX inference, then splits the results back to each waiter. """ if self._max_batch <= 1 or self._batch_queue is None: - return self.predict(image=image, instruction=instruction, state=state) + try: + return await self._inference_executor.submit( + self.predict, + image=image, + instruction=instruction, + state=state, + ) + except InferenceExecutorFull as exc: + self._record_inference_executor_rejected() + return self._inference_executor_full_result(exc) import asyncio loop = asyncio.get_event_loop() @@ -969,13 +1123,19 @@ async def _batch_worker_loop(self) -> None: fut.set_exception(asyncio.CancelledError()) return - # Run batched inference (sync — we're holding the event loop, but - # the actual ORT call is the bottleneck and yields the GIL). try: - results = self._predict_batch_sync(batch) + results = await self._inference_executor.submit( + self._predict_batch_sync, batch, + ) for (fut, *_), result in zip(batch, results): if not fut.done(): fut.set_result(result) + except InferenceExecutorFull as exc: + self._record_inference_executor_rejected() + result = self._inference_executor_full_result(exc) + for fut, *_ in batch: + if not fut.done(): + fut.set_result(dict(result)) except Exception as e: for fut, *_ in batch: if not fut.done(): @@ -1005,17 +1165,19 @@ def _predict_batch_sync(self, batch: list[tuple]) -> list[dict[str, Any]]: ) dt = -1.0 / self.num_denoising_steps + constant_feed = {"position_ids": position_ids_batched} + iobinding = self._prepare_ort_iobinding(constant_feed) for step in range(self.num_denoising_steps): t = 1.0 + step * dt timestep = np.full((b,), t, dtype=np.float32) - velocity = self._ort_session.run( - None, + velocity = self._run_ort_velocity( { "noisy_actions": noisy_batched, "timestep": timestep, - "position_ids": position_ids_batched, }, - )[0] + constant_feed, + iobinding, + ) noisy_batched = noisy_batched + velocity * dt elapsed_ms = (time.perf_counter() - start) * 1000 @@ -1136,6 +1298,8 @@ def create_app( deadline_ms: float | None = None, max_batch: int = 1, batch_timeout_ms: float = 5.0, + inference_executor_workers: int = 1, + inference_executor_queue: int = 8, max_batch_cost_ms: float = 100.0, # PolicyRuntime budget per chunk-budget-batching ADR api_key: str | None = None, replan_hz: float | None = None, @@ -1221,6 +1385,11 @@ def create_app( server.health_state flips to "degraded" — /health returns 503 and /act returns 503 with Retry-After: 60. Successful /act resets the counter. Default 5. Set to 0 to disable. + + inference_executor_workers / inference_executor_queue: bounded async + offload capacity for synchronous predict work. Saturation returns an + `inference_executor_full` error result instead of growing an unbounded + default-executor queue. """ try: from contextlib import asynccontextmanager @@ -1326,6 +1495,8 @@ def create_app( deadline_ms=deadline_ms, max_batch=max_batch, batch_timeout_ms=batch_timeout_ms, + inference_executor_workers=inference_executor_workers, + inference_executor_queue=inference_executor_queue, ) else: raise ValueError( @@ -1344,6 +1515,8 @@ def create_app( deadline_ms=deadline_ms, max_batch=max_batch, batch_timeout_ms=batch_timeout_ms, + inference_executor_workers=inference_executor_workers, + inference_executor_queue=inference_executor_queue, ) # Attach embodiment config (B.1) — optional, downstream consumers @@ -1871,6 +2044,7 @@ async def lifespan(app): # Build server B via setup_two_policy_serving's server_factory, # which mirrors the same load() path. We pass server A in via a # closure to avoid re-loading it. + server._inference_policy_slot = "a" # type: ignore[attr-defined] servers_pair = {"a": server} def _two_policy_server_factory(*, export_dir, **kwargs): @@ -1895,7 +2069,10 @@ def _two_policy_server_factory(*, export_dir, **kwargs): deadline_ms=deadline_ms, max_batch=max_batch, batch_timeout_ms=batch_timeout_ms, + inference_executor_workers=inference_executor_workers, + inference_executor_queue=inference_executor_queue, ) + srv_b._inference_policy_slot = "b" srv_b.load() servers_pair["b"] = srv_b return srv_b @@ -1996,6 +2173,7 @@ def _shape_key(_req): # documented. Operator sees the error in logs + the # banner the CLI prints. server.two_policy_state = None # type: ignore[attr-defined] + server._inference_policy_slot = "prod" # type: ignore[attr-defined] # PolicyRuntime — per-policy queue + cost-weighted scheduler (Phase 1 # chunk-budget-batching). Single-policy default key "prod"; multi-policy @@ -2171,6 +2349,25 @@ async def _heartbeat_loop(): _rec.write_footer({"total_requests": _rec.seq}) finally: _rec.close() + _servers_to_shutdown = [server] + _two_state_shutdown = getattr(server, "two_policy_state", None) + if _two_state_shutdown is not None: + _servers_to_shutdown.extend( + [ + getattr(_two_state_shutdown, "server_a", None), + getattr(_two_state_shutdown, "server_b", None), + ] + ) + for _srv_shutdown in { + id(_srv): _srv for _srv in _servers_to_shutdown if _srv is not None + }.values(): + _shutdown = getattr(_srv_shutdown, "shutdown_inference_executor", None) + if _shutdown is None: + continue + try: + _shutdown() + except Exception as exc: # noqa: BLE001 + logger.warning("inference_executor.shutdown failed: %s", exc) shutdown_tracing() app = FastAPI( diff --git a/tests/test_inference_executor.py b/tests/test_inference_executor.py new file mode 100644 index 0000000..9ca791e --- /dev/null +++ b/tests/test_inference_executor.py @@ -0,0 +1,91 @@ +"""Tests for the bounded inference offload executor.""" + +from __future__ import annotations + +import asyncio +import threading +import time + +import pytest + +from tether.runtime.inference_executor import ( + BoundedInferenceExecutor, + InferenceExecutorFull, +) + + +def test_rejects_invalid_capacity(): + with pytest.raises(ValueError, match="max_workers"): + BoundedInferenceExecutor(max_workers=0) + with pytest.raises(ValueError, match="max_queue"): + BoundedInferenceExecutor(max_queue=-1) + + +@pytest.mark.asyncio +async def test_rejects_when_worker_and_queue_are_full(): + executor = BoundedInferenceExecutor(max_workers=1, max_queue=0) + started = threading.Event() + release = threading.Event() + + def blocking_work(): + started.set() + release.wait(timeout=2.0) + return "done" + + first = asyncio.create_task(executor.submit(blocking_work)) + try: + assert await asyncio.to_thread(started.wait, 1.0) + + with pytest.raises(InferenceExecutorFull): + await executor.submit(lambda: "second") + + snapshot = executor.snapshot() + assert snapshot.pending == 1 + assert snapshot.running == 1 + assert snapshot.queue_depth == 0 + assert snapshot.rejected == 1 + finally: + release.set() + assert await first == "done" + executor.shutdown() + + +@pytest.mark.asyncio +async def test_reports_queued_and_running_work(): + states = [] + executor = BoundedInferenceExecutor( + max_workers=1, + max_queue=1, + on_state_change=states.append, + ) + first_started = threading.Event() + release_first = threading.Event() + + def blocking_work(): + first_started.set() + release_first.wait(timeout=2.0) + return "first" + + first = asyncio.create_task(executor.submit(blocking_work)) + second = None + try: + assert await asyncio.to_thread(first_started.wait, 1.0) + + second = asyncio.create_task(executor.submit(lambda: "second")) + deadline = time.monotonic() + 1.0 + while time.monotonic() < deadline: + snapshot = executor.snapshot() + if snapshot.pending == 2 and snapshot.queue_depth == 1: + break + await asyncio.sleep(0.01) + else: + pytest.fail(f"queued work was not observed: {executor.snapshot()}") + + assert any(state.running == 1 for state in states) + assert any(state.queue_depth == 1 for state in states) + finally: + release_first.set() + assert await first == "first" + if second is not None: + assert await second == "second" + executor.shutdown() diff --git a/tests/test_observability_prometheus.py b/tests/test_observability_prometheus.py index 9e5e4d0..5537c73 100644 --- a/tests/test_observability_prometheus.py +++ b/tests/test_observability_prometheus.py @@ -15,9 +15,9 @@ from tether.observability import ( METRICS_CONTENT_TYPE, inc_cache_hit, - inc_cache_miss, inc_denoise_steps, inc_fallback_invocation, + inc_inference_executor_rejected, inc_model_swap, inc_safety_violation, inc_slo_violation, @@ -25,6 +25,7 @@ record_act_latency, render_metrics, set_episodes_active, + set_inference_executor_state, set_server_up, track_in_flight, ) @@ -68,8 +69,6 @@ def test_render_includes_help_and_type_lines(self): class TestRecordActLatency: def test_observation_increments_count(self): - # Snapshot + record + diff - before = list(text_string_to_metric_families(render_metrics().decode())) for _ in range(3): record_act_latency(0.020, embodiment="ur5", model_id="pi05") out = render_metrics().decode() @@ -132,6 +131,16 @@ def test_fallback_invocation(self): inc_fallback_invocation(embodiment="so100", target="hold_position") assert "tether_fallback_invocations_total" in render_metrics().decode() + def test_inference_executor_rejected(self): + inc_inference_executor_rejected( + embodiment="franka", + model_id="pi05", + policy_slot="prod", + ) + out = render_metrics().decode() + assert "tether_inference_executor_rejected_total" in out + assert 'model_id="pi05"' in out + def test_model_swap(self): inc_model_swap(embodiment="franka", from_model="pi0", to_model="pi05") out = render_metrics().decode() @@ -193,6 +202,24 @@ def test_set_episodes_active(self): out = render_metrics().decode() assert 'tether_episodes_active{embodiment="franka"} 5' in out + def test_set_inference_executor_state(self): + set_inference_executor_state( + embodiment="franka", + model_id="pi05", + policy_slot="prod", + in_flight=1, + queue_depth=2, + max_workers=1, + max_queue=8, + ) + out = render_metrics().decode() + assert "tether_inference_executor_in_flight" in out + assert "tether_inference_executor_queue_depth" in out + assert "tether_inference_executor_capacity" in out + assert 'kind="workers"' in out + assert 'kind="queue"' in out + assert 'kind="total"' in out + # --------------------------------------------------------------------------- # Cardinality + anti-patterns diff --git a/tests/test_server.py b/tests/test_server.py index 2c480b8..c9979fd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,8 +1,7 @@ """Tests for the VLA inference server.""" import json -import tempfile -from pathlib import Path +import time from unittest.mock import patch, MagicMock import numpy as np @@ -51,6 +50,183 @@ def test_loads_with_missing_onnx(self, mock_export_dir): class TestTetherServerWithMockORT: + def test_predict_async_offloads_non_batched_predict(self, mock_export_dir): + import asyncio + + server = TetherServer(mock_export_dir, device="cpu", max_batch=1) + + def slow_predict(**kwargs): + time.sleep(0.15) + return {"ok": True, "kwargs": kwargs} + + server.predict = slow_predict + + async def run_check(): + start = time.perf_counter() + task = asyncio.create_task( + server.predict_async(instruction="pick", state=[0.0]) + ) + await asyncio.sleep(0.01) + elapsed = time.perf_counter() - start + result = await task + return elapsed, result + + elapsed, result = asyncio.run(run_check()) + server.shutdown_inference_executor() + + assert elapsed < 0.08 + assert result["ok"] is True + assert result["kwargs"]["instruction"] == "pick" + + def test_predict_async_rejects_when_executor_is_full(self, mock_export_dir): + import asyncio + import threading + + server = TetherServer( + mock_export_dir, + device="cpu", + max_batch=1, + inference_executor_workers=1, + inference_executor_queue=0, + ) + started = threading.Event() + release = threading.Event() + + def slow_predict(**_kwargs): + started.set() + release.wait(timeout=2.0) + return {"ok": True} + + server.predict = slow_predict + + async def run_check(): + first = asyncio.create_task(server.predict_async(instruction="first")) + try: + assert await asyncio.to_thread(started.wait, 1.0) + rejected = await server.predict_async(instruction="second") + finally: + release.set() + accepted = await first + return accepted, rejected + + accepted, rejected = asyncio.run(run_check()) + server.shutdown_inference_executor() + + assert rejected["error"] == "inference_executor_full" + assert rejected["max_workers"] == 1 + assert rejected["max_queue"] == 0 + assert rejected["rejected_total"] >= 1 + assert accepted["ok"] is True + + def test_batch_worker_offloads_sync_batch_predict(self, mock_export_dir): + import asyncio + + server = TetherServer( + mock_export_dir, + device="cpu", + max_batch=2, + batch_timeout_ms=1, + ) + + def slow_batch(batch): + time.sleep(0.15) + return [{"ok": True, "batch_size": len(batch)} for _ in batch] + + server._predict_batch_sync = slow_batch + + async def run_check(): + await server.start_batch_worker() + try: + start = time.perf_counter() + task = asyncio.create_task( + server.predict_async(instruction="pick", state=[0.0]) + ) + await asyncio.sleep(0.01) + elapsed = time.perf_counter() - start + result = await task + return elapsed, result + finally: + await server.stop_batch_worker() + + elapsed, result = asyncio.run(run_check()) + server.shutdown_inference_executor() + + assert elapsed < 0.08 + assert result["ok"] is True + assert result["batch_size"] == 1 + + def test_denoise_uses_iobinding_when_enabled(self, mock_export_dir): + class _FakeOutput: + name = "velocity" + + class _FakeOrtValue: + def __init__(self, array): + self._array = array + + def numpy(self): + return self._array + + class _FakeBinding: + def __init__(self, velocity): + self.velocity = velocity + self.bound_inputs = [] + self.bound_outputs = [] + self.clear_outputs_calls = 0 + + def bind_cpu_input(self, name, array): + self.bound_inputs.append((name, array.shape)) + + def bind_output(self, name, *args): + self.bound_outputs.append((name, args)) + + def clear_binding_outputs(self): + self.clear_outputs_calls += 1 + + def get_outputs(self): + return [_FakeOrtValue(self.velocity)] + + class _FakeSession: + def __init__(self, velocity): + self.binding = _FakeBinding(velocity) + self.run_calls = 0 + self.run_with_iobinding_calls = 0 + + def get_outputs(self): + return [_FakeOutput()] + + def io_binding(self): + return self.binding + + def run(self, *_args, **_kwargs): + self.run_calls += 1 + raise AssertionError("session.run should not be used") + + def run_with_iobinding(self, _binding): + self.run_with_iobinding_calls += 1 + + server = TetherServer(mock_export_dir, device="cpu") + server.num_denoising_steps = 1 + server._expert_input_names = [] + server._ort_iobinding_enabled = True + + noisy = np.zeros((1, 2, 3), dtype=np.float32) + velocity = np.ones_like(noisy) + fake_session = _FakeSession(velocity) + server._ort_session = fake_session + + actions, steps = server._run_denoise( + noisy_actions=noisy, + position_ids=np.arange(2, dtype=np.int64)[None, :], + ) + + np.testing.assert_allclose(actions, -np.ones_like(noisy)) + assert steps == 1 + assert fake_session.run_calls == 0 + assert fake_session.run_with_iobinding_calls == 1 + bound_names = [name for name, _shape in fake_session.binding.bound_inputs] + assert bound_names == ["position_ids", "noisy_actions", "timestep"] + assert fake_session.binding.bound_outputs == [("velocity", ("cpu",))] + def test_predict_returns_actions(self, mock_export_dir): server = TetherServer(mock_export_dir, device="cpu") server.action_dim = 32