Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/tether/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,20 @@ def serve(
"fires (whichever first). Lower = lower per-request latency; "
"higher = better batching efficiency under bursty load.",
),
inference_executor_workers: int = typer.Option(
1,
"--inference-executor-workers",
help="Dedicated worker threads for offloading synchronous inference from "
"the async server. Keep 1 for static-shape GPU exports; increase "
"only when the backend supports parallel inference safely.",
),
inference_executor_queue: int = typer.Option(
8,
"--inference-executor-queue",
help="Accepted-but-not-yet-running inference submissions before the "
"server returns inference_executor_full. Set 0 to reject instead "
"of queueing behind a busy worker.",
),
max_batch_cost_ms: float = typer.Option(
100.0,
"--max-batch-cost-ms",
Expand Down Expand Up @@ -2405,6 +2419,11 @@ def serve(
composed.append(f"[cyan]deadline[/cyan]={deadline_ms:.0f}ms")
if max_batch > 1:
composed.append(f"[cyan]batch[/cyan]={max_batch}@{batch_timeout_ms:.0f}ms")
if inference_executor_workers != 1 or inference_executor_queue != 8:
composed.append(
f"[cyan]inference-executor[/cyan]="
f"{inference_executor_workers}w/{inference_executor_queue}q"
)
if embodiment_cfg is not None:
composed.append(f"[cyan]embodiment[/cyan]={embodiment_cfg.embodiment}")
if so_arm100_adapter is not None:
Expand Down Expand Up @@ -2522,6 +2541,8 @@ def serve(
deadline_ms=deadline_ms if deadline_ms > 0 else None,
max_batch=max_batch,
batch_timeout_ms=batch_timeout_ms,
inference_executor_workers=inference_executor_workers,
inference_executor_queue=inference_executor_queue,
api_key=api_key or None,
replan_hz=replan_hz if replan_hz > 0 else None,
execute_hz=execute_hz if execute_hz > 0 else None,
Expand Down
4 changes: 4 additions & 0 deletions src/tether/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
inc_cache_miss,
inc_denoise_steps,
inc_fallback_invocation,
inc_inference_executor_rejected,
inc_model_swap,
inc_safety_violation,
inc_slo_violation,
Expand All @@ -25,6 +26,7 @@
record_act_latency,
render_metrics,
set_episodes_active,
set_inference_executor_state,
set_robot_info,
set_server_up,
track_in_flight,
Expand All @@ -42,10 +44,12 @@
"inc_safety_violation",
"inc_slo_violation",
"inc_fallback_invocation",
"inc_inference_executor_rejected",
"inc_model_swap",
"observe_batch_flush",
"track_in_flight",
"set_server_up",
"set_robot_info",
"set_episodes_active",
"set_inference_executor_state",
]
64 changes: 64 additions & 0 deletions src/tether/observability/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@
registry=REGISTRY,
)

tether_inference_executor_rejected_total = Counter(
"tether_inference_executor_rejected_total",
"Inference executor submissions rejected because the bounded queue was full",
labelnames=("embodiment", "model_id", "policy_slot"),
registry=REGISTRY,
)

# Action-similarity fast-path skip counter (action-similarity-fast-path
# Phase 1.5 — FlashVLA). Increments when the inference path returns a
# cached action chunk instead of running the expert. Operator visibility
Expand Down Expand Up @@ -167,6 +174,27 @@
registry=REGISTRY,
)

tether_inference_executor_in_flight = Gauge(
"tether_inference_executor_in_flight",
"Synchronous inference calls currently running in executor worker threads",
labelnames=("embodiment", "model_id", "policy_slot"),
registry=REGISTRY,
)

tether_inference_executor_queue_depth = Gauge(
"tether_inference_executor_queue_depth",
"Synchronous inference calls accepted but not yet running in executor workers",
labelnames=("embodiment", "model_id", "policy_slot"),
registry=REGISTRY,
)

tether_inference_executor_capacity = Gauge(
"tether_inference_executor_capacity",
"Configured inference executor capacity by kind",
labelnames=("embodiment", "model_id", "policy_slot", "kind"),
registry=REGISTRY,
)


# ---------------------------------------------------------------------------
# Helpers — typed call-sites keep the surface searchable
Expand Down Expand Up @@ -228,6 +256,18 @@ def inc_fallback_invocation(embodiment: str, target: str) -> None:
).inc()


def inc_inference_executor_rejected(
embodiment: str,
model_id: str,
policy_slot: str = "prod",
) -> None:
tether_inference_executor_rejected_total.labels(
embodiment=embodiment,
model_id=model_id,
policy_slot=policy_slot,
).inc()


def inc_action_skip() -> None:
tether_action_skip_total.inc()

Expand Down Expand Up @@ -258,6 +298,30 @@ def set_episodes_active(embodiment: str, value: int) -> None:
tether_episodes_active.labels(embodiment=embodiment).set(value)


def set_inference_executor_state(
embodiment: str,
model_id: str,
policy_slot: str = "prod",
*,
in_flight: int,
queue_depth: int,
max_workers: int,
max_queue: int,
) -> None:
labels = {
"embodiment": embodiment,
"model_id": model_id,
"policy_slot": policy_slot,
}
tether_inference_executor_in_flight.labels(**labels).set(in_flight)
tether_inference_executor_queue_depth.labels(**labels).set(queue_depth)
tether_inference_executor_capacity.labels(**labels, kind="workers").set(max_workers)
tether_inference_executor_capacity.labels(**labels, kind="queue").set(max_queue)
tether_inference_executor_capacity.labels(**labels, kind="total").set(
max_workers + max_queue
)


@contextmanager
def track_in_flight(embodiment: str, policy_slot: str = "prod") -> Iterator[None]:
"""Context manager increments/decrements in-flight gauge for safe
Expand Down
157 changes: 157 additions & 0 deletions src/tether/runtime/inference_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Bounded async offload executor for synchronous inference work."""

from __future__ import annotations

import asyncio
import functools
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Callable, TypeVar

T = TypeVar("T")
logger = logging.getLogger(__name__)


class InferenceExecutorFull(RuntimeError):
"""Raised when all inference worker and queue slots are occupied."""


@dataclass(frozen=True)
class InferenceExecutorSnapshot:
"""Point-in-time executor state for health and metrics surfaces."""

max_workers: int
max_queue: int
capacity: int
pending: int
running: int
queue_depth: int
rejected: int


class BoundedInferenceExecutor:
"""Run sync inference functions in a bounded dedicated thread pool.

The default asyncio executor has an unbounded submission queue. For robot
serving, that hides overload and lets latency tails grow silently. This
wrapper rejects fast when all worker + queue slots are occupied.
"""

def __init__(
self,
*,
max_workers: int = 1,
max_queue: int = 8,
thread_name_prefix: str = "tether-inference",
on_state_change: Callable[[InferenceExecutorSnapshot], None] | None = None,
) -> None:
if max_workers <= 0:
raise ValueError(f"max_workers must be > 0, got {max_workers}")
if max_queue < 0:
raise ValueError(f"max_queue must be >= 0, got {max_queue}")
self._max_workers = int(max_workers)
self._max_queue = int(max_queue)
self._capacity = self._max_workers + self._max_queue
self._executor = ThreadPoolExecutor(
max_workers=self._max_workers,
thread_name_prefix=thread_name_prefix,
)
self._lock = threading.Lock()
self._pending = 0
self._running = 0
self._rejected = 0
self._closed = False
self._on_state_change = on_state_change

@property
def max_workers(self) -> int:
return self._max_workers

@property
def max_queue(self) -> int:
return self._max_queue

@property
def capacity(self) -> int:
return self._capacity

async def submit(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> T:
"""Submit sync work or raise InferenceExecutorFull without waiting."""

full_message = ""
with self._lock:
if self._closed:
raise RuntimeError("inference executor is shut down")
if self._pending >= self._capacity:
self._rejected += 1
snapshot = self._snapshot_locked()
full_message = (
"inference executor is full "
f"(pending={self._pending}, capacity={self._capacity})"
)
else:
self._pending += 1
snapshot = self._snapshot_locked()
self._notify_state(snapshot)
if full_message:
raise InferenceExecutorFull(full_message)

loop = asyncio.get_running_loop()
callback = functools.partial(self._invoke, fn, args, kwargs)
try:
return await loop.run_in_executor(self._executor, callback)
finally:
with self._lock:
self._pending -= 1
snapshot = self._snapshot_locked()
self._notify_state(snapshot)

def snapshot(self) -> InferenceExecutorSnapshot:
with self._lock:
return self._snapshot_locked()

def shutdown(self, *, wait: bool = False) -> None:
with self._lock:
if self._closed:
return
self._closed = True
self._executor.shutdown(wait=wait, cancel_futures=True)

def _invoke(
self,
fn: Callable[..., T],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> T:
with self._lock:
self._running += 1
snapshot = self._snapshot_locked()
self._notify_state(snapshot)
try:
return fn(*args, **kwargs)
finally:
with self._lock:
self._running -= 1
snapshot = self._snapshot_locked()
self._notify_state(snapshot)

def _snapshot_locked(self) -> InferenceExecutorSnapshot:
return InferenceExecutorSnapshot(
max_workers=self._max_workers,
max_queue=self._max_queue,
capacity=self._capacity,
pending=self._pending,
running=self._running,
queue_depth=max(0, self._pending - self._running),
rejected=self._rejected,
)

def _notify_state(self, snapshot: InferenceExecutorSnapshot) -> None:
if self._on_state_change is None:
return
try:
self._on_state_change(snapshot)
except Exception as exc: # noqa: BLE001
logger.debug("inference executor state callback failed: %s", exc)
Loading
Loading