Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions src/tether/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,51 @@ def serve(
"Stable/high-latency chunks execute longer before replanning; "
"uncertain or discontinuous chunks replan sooner.",
),
adaptive_action_chunking_canary: bool = typer.Option(
False,
"--adaptive-action-chunking-canary",
help="With --rtc: compute AAC decisions and telemetry but keep applying "
"the base --rtc-execution-horizon. Use before enabling AAC control.",
),
aac_min_horizon: int = typer.Option(
1,
"--aac-min-horizon",
help="With --adaptive-action-chunking: minimum execution horizon in actions.",
),
aac_low_uncertainty: float = typer.Option(
0.20,
"--aac-low-uncertainty",
help="With --adaptive-action-chunking: uncertainty below this is low risk.",
),
aac_high_uncertainty: float = typer.Option(
0.65,
"--aac-high-uncertainty",
help="With --adaptive-action-chunking: uncertainty at/above this is high risk.",
),
aac_low_guard_margin: float = typer.Option(
0.05,
"--aac-low-guard-margin",
help="With --adaptive-action-chunking: guard margin at/below this shortens "
"the horizon.",
),
aac_high_correction_magnitude: float = typer.Option(
0.20,
"--aac-high-correction-magnitude",
help="With --adaptive-action-chunking: A2C2 correction magnitude at/above "
"this shortens the horizon.",
),
aac_high_action_delta: float = typer.Option(
0.25,
"--aac-high-action-delta",
help="With --adaptive-action-chunking: chunk-boundary action delta at/above "
"this shortens the horizon.",
),
aac_high_latency_ms: float = typer.Option(
120.0,
"--aac-high-latency-ms",
help="With --adaptive-action-chunking: stable scenes at/above this latency "
"can lengthen the horizon.",
),
record: str = typer.Option(
"",
help="If set, write every /act request+response to a JSONL trace in "
Expand Down Expand Up @@ -2370,7 +2415,19 @@ def serve(
prefix_attention_schedule=rtc_schedule,
max_guidance_weight=rtc_max_guidance_weight,
debug=rtc_debug,
adaptive_chunking_enabled=adaptive_action_chunking,
adaptive_chunking_enabled=(
adaptive_action_chunking or adaptive_action_chunking_canary
),
adaptive_chunking_canary=adaptive_action_chunking_canary,
adaptive_min_horizon=aac_min_horizon,
adaptive_low_uncertainty=aac_low_uncertainty,
adaptive_high_uncertainty=aac_high_uncertainty,
adaptive_low_guard_margin=aac_low_guard_margin,
adaptive_high_correction_magnitude=(
aac_high_correction_magnitude
),
adaptive_high_action_delta=aac_high_action_delta,
adaptive_high_latency_ms=aac_high_latency_ms,
)
except ValueError as exc:
err_console.print(f"[red]Invalid RTC config: {exc}[/red]")
Expand Down Expand Up @@ -2463,9 +2520,14 @@ def serve(
f"{', no-gzip' if record_no_gzip else ''})"
)
if rtc:
aac_suffix = ""
if adaptive_action_chunking_canary:
aac_suffix = "/aac-canary"
elif adaptive_action_chunking:
aac_suffix = "/aac"
composed.append(
f"[cyan]rtc[/cyan]=horizon{rtc_execution_horizon}/{rtc_schedule}"
f"{'/aac' if adaptive_action_chunking else ''}"
f"{aac_suffix}"
)
if composed:
console.print(f" Wedges: {' · '.join(composed)}")
Expand Down
11 changes: 11 additions & 0 deletions src/tether/observability/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@
labelnames=("embodiment", "model_id", "policy_slot"),
registry=REGISTRY,
)
tether_rtc_adaptive_applied_horizon = Gauge(
"tether_rtc_adaptive_applied_horizon",
"Latest RTC execution horizon actually applied after canary gating",
labelnames=("embodiment", "model_id", "policy_slot"),
registry=REGISTRY,
)
tether_rtc_adaptive_risk_score = Gauge(
"tether_rtc_adaptive_risk_score",
"Latest adaptive RTC risk score used for horizon selection",
Expand Down Expand Up @@ -359,6 +365,11 @@ def observe_rtc_adaptive_chunking(
labels,
decision.get("horizon"),
)
_set_gauge_if_float(
tether_rtc_adaptive_applied_horizon,
labels,
decision.get("applied_horizon"),
)
_set_gauge_if_float(
tether_rtc_adaptive_risk_score,
labels,
Expand Down
2 changes: 2 additions & 0 deletions src/tether/runtime/adaptive_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"action_delta",
"latency_ms",
"horizon",
"applied_horizon",
"risk_score",
)

Expand Down Expand Up @@ -97,6 +98,7 @@ def summarize_adaptive_records(
reason = str(decision.get("reason") or "unknown")
reasons[reason] += 1
_append_float(values["horizon"], decision.get("horizon"))
_append_float(values["applied_horizon"], decision.get("applied_horizon"))
_append_float(values["risk_score"], decision.get("risk_score"))

latency = record.get("latency")
Expand Down
61 changes: 55 additions & 6 deletions src/tether/runtime/rtc_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,20 @@ class RtcAdapterConfig:
gripper_dim_indices: list[int] = field(default_factory=list)
skip_gripper_smoothing: bool = True
adaptive_chunking_enabled: bool = False
adaptive_chunking_canary: bool = False
adaptive_min_horizon: int = 1
adaptive_low_uncertainty: float = 0.20
adaptive_high_uncertainty: float = 0.65
adaptive_low_guard_margin: float = 0.05
adaptive_high_correction_magnitude: float = 0.20
adaptive_high_action_delta: float = 0.25
adaptive_high_latency_ms: float = 120.0

def __post_init__(self) -> None:
"""Validate the Tether-side extras. lerobot's RTCConfig validates
its own fields when constructed via _build_lerobot_rtc_config()."""
if self.adaptive_chunking_canary:
self.adaptive_chunking_enabled = True
if self.prefix_attention_schedule not in _VALID_SCHEDULES:
raise ValueError(
f"prefix_attention_schedule must be one of {_VALID_SCHEDULES}, "
Expand All @@ -146,6 +154,31 @@ def __post_init__(self) -> None:
raise ValueError(
f"adaptive_min_horizon must be >= 1, got {self.adaptive_min_horizon}"
)
if self.adaptive_low_uncertainty < 0:
raise ValueError(
"adaptive_low_uncertainty must be >= 0, "
f"got {self.adaptive_low_uncertainty}"
)
if self.adaptive_high_uncertainty <= self.adaptive_low_uncertainty:
raise ValueError(
"adaptive_high_uncertainty must be greater than "
"adaptive_low_uncertainty"
)
if self.adaptive_low_guard_margin < 0:
raise ValueError(
"adaptive_low_guard_margin must be >= 0, "
f"got {self.adaptive_low_guard_margin}"
)
if self.adaptive_high_correction_magnitude < 0:
raise ValueError(
"adaptive_high_correction_magnitude must be >= 0, "
f"got {self.adaptive_high_correction_magnitude}"
)
if self.adaptive_high_action_delta < 0:
raise ValueError(
"adaptive_high_action_delta must be >= 0, "
f"got {self.adaptive_high_action_delta}"
)
if self.adaptive_high_latency_ms <= 0:
raise ValueError(
"adaptive_high_latency_ms must be positive, "
Expand Down Expand Up @@ -298,6 +331,7 @@ def __init__(
self._last_action_delta: float | None = None
self._last_adaptive_signal = AdaptiveChunkSignal()
self._last_adaptive_decision: AdaptiveChunkDecision | None = None
self._last_execution_horizon: int | None = None
self._adaptive_chunker: AdaptiveChunkController | None = None
if config.adaptive_chunking_enabled:
self._adaptive_chunker = AdaptiveChunkController(
Expand All @@ -306,6 +340,13 @@ def __init__(
min_horizon=config.adaptive_min_horizon,
base_horizon=config.rtc_execution_horizon,
max_horizon=action_buffer.capacity,
low_uncertainty=config.adaptive_low_uncertainty,
high_uncertainty=config.adaptive_high_uncertainty,
low_guard_margin=config.adaptive_low_guard_margin,
high_correction_magnitude=(
config.adaptive_high_correction_magnitude
),
high_action_delta=config.adaptive_high_action_delta,
high_latency_ms=config.adaptive_high_latency_ms,
)
)
Expand Down Expand Up @@ -351,12 +392,15 @@ def predict_chunk_with_rtc(self, batch: dict[str, Any]) -> np.ndarray:
"inference_delay": actions_consumed,
"prev_chunk_left_over": self._prev_chunk_left_over,
}
execution_horizon = self.config.rtc_execution_horizon
if (
adaptive_decision is not None
and not self.config.adaptive_chunking_canary
):
execution_horizon = adaptive_decision.horizon
self._last_execution_horizon = execution_horizon
if self.config.enabled:
rtc_kwargs["execution_horizon"] = (
adaptive_decision.horizon
if adaptive_decision is not None
else self.config.rtc_execution_horizon
)
rtc_kwargs["execution_horizon"] = execution_horizon

t0 = time.monotonic()
try:
Expand Down Expand Up @@ -469,6 +513,7 @@ def reset(self, episode_id: str | None = None) -> None:
self._last_action_delta = None
self._last_adaptive_signal = AdaptiveChunkSignal()
self._last_adaptive_decision = None
self._last_execution_horizon = None
# Clear the latency window — old samples are stale on a fresh episode
self.latency = LatencyTracker(
percentile=self.config.latency_percentile,
Expand All @@ -495,7 +540,11 @@ def get_stats(self) -> dict[str, Any]:
if self._last_adaptive_signal.has_values():
stats["adaptive_signal"] = self._last_adaptive_signal.as_dict()
if self._last_adaptive_decision is not None:
stats["adaptive_chunking"] = self._last_adaptive_decision.as_dict()
adaptive_chunking = self._last_adaptive_decision.as_dict()
adaptive_chunking["canary"] = self.config.adaptive_chunking_canary
if self._last_execution_horizon is not None:
adaptive_chunking["applied_horizon"] = self._last_execution_horizon
stats["adaptive_chunking"] = adaptive_chunking
return stats


Expand Down
10 changes: 10 additions & 0 deletions src/tether/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def _set_rtc_adaptive_span_attrs(span: Any, rtc_record: dict[str, Any]) -> None:
_set_span_optional_float(
span, "tether.rtc.adaptive.horizon", decision.get("horizon")
)
_set_span_optional_float(
span,
"tether.rtc.adaptive.applied_horizon",
decision.get("applied_horizon"),
)
_set_span_optional_float(
span, "tether.rtc.adaptive.risk_score", decision.get("risk_score")
)
Expand All @@ -165,6 +170,11 @@ def _set_rtc_adaptive_span_attrs(span: Any, rtc_record: dict[str, Any]) -> None:
"tether.rtc.adaptive.replan_threshold_ratio",
decision.get("replan_threshold_ratio"),
)
if "canary" in decision:
span.set_attribute(
"tether.rtc.adaptive.canary",
bool(decision.get("canary")),
)

signal = rtc_record.get("adaptive_signal")
if isinstance(signal, dict):
Expand Down
21 changes: 15 additions & 6 deletions tests/test_adaptive_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ def _record(
action_delta: float,
horizon: int,
reason: str,
applied_horizon: int | None = None,
) -> dict:
decision = {
"horizon": horizon,
"reason": reason,
"risk_score": 0.5,
"replan_threshold_ratio": 0.4,
}
if applied_horizon is not None:
decision["applied_horizon"] = applied_horizon
return {
"kind": "request",
"latency": {"total_ms": latency_ms},
Expand All @@ -39,12 +48,7 @@ def _record(
"correction_magnitude": correction_magnitude,
"uncertainty": uncertainty,
},
"adaptive_chunking": {
"horizon": horizon,
"reason": reason,
"risk_score": 0.5,
"replan_threshold_ratio": 0.4,
},
"adaptive_chunking": decision,
"last_action_delta": action_delta,
},
}
Expand All @@ -60,6 +64,7 @@ def test_iter_adaptive_records_reads_plain_and_gzip_jsonl(tmp_path):
uncertainty=0.2,
action_delta=0.03,
horizon=8,
applied_horizon=5,
reason="stable",
),
{"kind": "request", "latency": {"total_ms": 60}},
Expand All @@ -86,6 +91,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles():
uncertainty=0.2,
action_delta=0.03,
horizon=8,
applied_horizon=8,
reason="stable",
),
_record(
Expand All @@ -95,6 +101,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles():
uncertainty=0.5,
action_delta=0.08,
horizon=5,
applied_horizon=5,
reason="correction",
),
_record(
Expand All @@ -104,6 +111,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles():
uncertainty=0.9,
action_delta=0.16,
horizon=2,
applied_horizon=5,
reason="correction",
),
]
Expand All @@ -115,6 +123,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles():
assert summary["reasons"] == {"correction": 2, "stable": 1}
assert summary["observed"]["latency_ms"]["p50"] == pytest.approx(100)
assert summary["observed"]["guard_margin"]["p10"] == pytest.approx(0.024)
assert summary["observed"]["applied_horizon"]["p50"] == pytest.approx(5)


def test_recommend_adaptive_chunk_thresholds_uses_recorded_distribution():
Expand Down
2 changes: 2 additions & 0 deletions tests/test_observability_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_rtc_adaptive_chunking_metrics(self):
policy_slot="prod",
decision={
"horizon": 4,
"applied_horizon": 10,
"reason": "guard_margin",
"risk_score": 0.8,
"replan_threshold_ratio": 0.6,
Expand All @@ -164,6 +165,7 @@ def test_rtc_adaptive_chunking_metrics(self):
assert "tether_rtc_adaptive_decisions_total" in out
assert 'reason="guard_margin"' in out
assert "tether_rtc_adaptive_horizon" in out
assert "tether_rtc_adaptive_applied_horizon" in out
assert 'model_id="pi05"' in out
assert "tether_rtc_adaptive_guard_margin" in out
assert "tether_rtc_adaptive_action_delta" in out
Expand Down
Loading
Loading