From cffbe99d3b46698063de1299ae8e2bdd1df66d35 Mon Sep 17 00:00:00 2001 From: RomirJ Date: Fri, 12 Jun 2026 12:52:15 -0700 Subject: [PATCH] feat(runtime): add AAC config canary mode --- src/tether/cli.py | 66 +++++++++++++++++++++- src/tether/observability/prometheus.py | 11 ++++ src/tether/runtime/adaptive_calibration.py | 2 + src/tether/runtime/rtc_adapter.py | 61 ++++++++++++++++++-- src/tether/runtime/server.py | 10 ++++ tests/test_adaptive_calibration.py | 21 +++++-- tests/test_observability_prometheus.py | 2 + tests/test_rtc_adapter_day1.py | 35 ++++++++++++ tests/test_rtc_adapter_integration.py | 53 +++++++++++++++++ 9 files changed, 247 insertions(+), 14 deletions(-) diff --git a/src/tether/cli.py b/src/tether/cli.py index 74bd61b..f3fa472 100644 --- a/src/tether/cli.py +++ b/src/tether/cli.py @@ -1889,6 +1889,51 @@ def serve( "Stable/high-latency chunks execute longer before replanning; " "uncertain or discontinuous chunks replan sooner.", ), + adaptive_action_chunking_canary: bool = typer.Option( + False, + "--adaptive-action-chunking-canary", + help="With --rtc: compute AAC decisions and telemetry but keep applying " + "the base --rtc-execution-horizon. Use before enabling AAC control.", + ), + aac_min_horizon: int = typer.Option( + 1, + "--aac-min-horizon", + help="With --adaptive-action-chunking: minimum execution horizon in actions.", + ), + aac_low_uncertainty: float = typer.Option( + 0.20, + "--aac-low-uncertainty", + help="With --adaptive-action-chunking: uncertainty below this is low risk.", + ), + aac_high_uncertainty: float = typer.Option( + 0.65, + "--aac-high-uncertainty", + help="With --adaptive-action-chunking: uncertainty at/above this is high risk.", + ), + aac_low_guard_margin: float = typer.Option( + 0.05, + "--aac-low-guard-margin", + help="With --adaptive-action-chunking: guard margin at/below this shortens " + "the horizon.", + ), + aac_high_correction_magnitude: float = typer.Option( + 0.20, + "--aac-high-correction-magnitude", + help="With --adaptive-action-chunking: A2C2 correction magnitude at/above " + "this shortens the horizon.", + ), + aac_high_action_delta: float = typer.Option( + 0.25, + "--aac-high-action-delta", + help="With --adaptive-action-chunking: chunk-boundary action delta at/above " + "this shortens the horizon.", + ), + aac_high_latency_ms: float = typer.Option( + 120.0, + "--aac-high-latency-ms", + help="With --adaptive-action-chunking: stable scenes at/above this latency " + "can lengthen the horizon.", + ), record: str = typer.Option( "", help="If set, write every /act request+response to a JSONL trace in " @@ -2370,7 +2415,19 @@ def serve( prefix_attention_schedule=rtc_schedule, max_guidance_weight=rtc_max_guidance_weight, debug=rtc_debug, - adaptive_chunking_enabled=adaptive_action_chunking, + adaptive_chunking_enabled=( + adaptive_action_chunking or adaptive_action_chunking_canary + ), + adaptive_chunking_canary=adaptive_action_chunking_canary, + adaptive_min_horizon=aac_min_horizon, + adaptive_low_uncertainty=aac_low_uncertainty, + adaptive_high_uncertainty=aac_high_uncertainty, + adaptive_low_guard_margin=aac_low_guard_margin, + adaptive_high_correction_magnitude=( + aac_high_correction_magnitude + ), + adaptive_high_action_delta=aac_high_action_delta, + adaptive_high_latency_ms=aac_high_latency_ms, ) except ValueError as exc: err_console.print(f"[red]Invalid RTC config: {exc}[/red]") @@ -2463,9 +2520,14 @@ def serve( f"{', no-gzip' if record_no_gzip else ''})" ) if rtc: + aac_suffix = "" + if adaptive_action_chunking_canary: + aac_suffix = "/aac-canary" + elif adaptive_action_chunking: + aac_suffix = "/aac" composed.append( f"[cyan]rtc[/cyan]=horizon{rtc_execution_horizon}/{rtc_schedule}" - f"{'/aac' if adaptive_action_chunking else ''}" + f"{aac_suffix}" ) if composed: console.print(f" Wedges: {' · '.join(composed)}") diff --git a/src/tether/observability/prometheus.py b/src/tether/observability/prometheus.py index 20e7e59..07283de 100644 --- a/src/tether/observability/prometheus.py +++ b/src/tether/observability/prometheus.py @@ -212,6 +212,12 @@ labelnames=("embodiment", "model_id", "policy_slot"), registry=REGISTRY, ) +tether_rtc_adaptive_applied_horizon = Gauge( + "tether_rtc_adaptive_applied_horizon", + "Latest RTC execution horizon actually applied after canary gating", + labelnames=("embodiment", "model_id", "policy_slot"), + registry=REGISTRY, +) tether_rtc_adaptive_risk_score = Gauge( "tether_rtc_adaptive_risk_score", "Latest adaptive RTC risk score used for horizon selection", @@ -359,6 +365,11 @@ def observe_rtc_adaptive_chunking( labels, decision.get("horizon"), ) + _set_gauge_if_float( + tether_rtc_adaptive_applied_horizon, + labels, + decision.get("applied_horizon"), + ) _set_gauge_if_float( tether_rtc_adaptive_risk_score, labels, diff --git a/src/tether/runtime/adaptive_calibration.py b/src/tether/runtime/adaptive_calibration.py index 9fefdea..eb4af1e 100644 --- a/src/tether/runtime/adaptive_calibration.py +++ b/src/tether/runtime/adaptive_calibration.py @@ -24,6 +24,7 @@ "action_delta", "latency_ms", "horizon", + "applied_horizon", "risk_score", ) @@ -97,6 +98,7 @@ def summarize_adaptive_records( reason = str(decision.get("reason") or "unknown") reasons[reason] += 1 _append_float(values["horizon"], decision.get("horizon")) + _append_float(values["applied_horizon"], decision.get("applied_horizon")) _append_float(values["risk_score"], decision.get("risk_score")) latency = record.get("latency") diff --git a/src/tether/runtime/rtc_adapter.py b/src/tether/runtime/rtc_adapter.py index dda3346..bb3ba44 100644 --- a/src/tether/runtime/rtc_adapter.py +++ b/src/tether/runtime/rtc_adapter.py @@ -119,12 +119,20 @@ class RtcAdapterConfig: gripper_dim_indices: list[int] = field(default_factory=list) skip_gripper_smoothing: bool = True adaptive_chunking_enabled: bool = False + adaptive_chunking_canary: bool = False adaptive_min_horizon: int = 1 + adaptive_low_uncertainty: float = 0.20 + adaptive_high_uncertainty: float = 0.65 + adaptive_low_guard_margin: float = 0.05 + adaptive_high_correction_magnitude: float = 0.20 + adaptive_high_action_delta: float = 0.25 adaptive_high_latency_ms: float = 120.0 def __post_init__(self) -> None: """Validate the Tether-side extras. lerobot's RTCConfig validates its own fields when constructed via _build_lerobot_rtc_config().""" + if self.adaptive_chunking_canary: + self.adaptive_chunking_enabled = True if self.prefix_attention_schedule not in _VALID_SCHEDULES: raise ValueError( f"prefix_attention_schedule must be one of {_VALID_SCHEDULES}, " @@ -146,6 +154,31 @@ def __post_init__(self) -> None: raise ValueError( f"adaptive_min_horizon must be >= 1, got {self.adaptive_min_horizon}" ) + if self.adaptive_low_uncertainty < 0: + raise ValueError( + "adaptive_low_uncertainty must be >= 0, " + f"got {self.adaptive_low_uncertainty}" + ) + if self.adaptive_high_uncertainty <= self.adaptive_low_uncertainty: + raise ValueError( + "adaptive_high_uncertainty must be greater than " + "adaptive_low_uncertainty" + ) + if self.adaptive_low_guard_margin < 0: + raise ValueError( + "adaptive_low_guard_margin must be >= 0, " + f"got {self.adaptive_low_guard_margin}" + ) + if self.adaptive_high_correction_magnitude < 0: + raise ValueError( + "adaptive_high_correction_magnitude must be >= 0, " + f"got {self.adaptive_high_correction_magnitude}" + ) + if self.adaptive_high_action_delta < 0: + raise ValueError( + "adaptive_high_action_delta must be >= 0, " + f"got {self.adaptive_high_action_delta}" + ) if self.adaptive_high_latency_ms <= 0: raise ValueError( "adaptive_high_latency_ms must be positive, " @@ -298,6 +331,7 @@ def __init__( self._last_action_delta: float | None = None self._last_adaptive_signal = AdaptiveChunkSignal() self._last_adaptive_decision: AdaptiveChunkDecision | None = None + self._last_execution_horizon: int | None = None self._adaptive_chunker: AdaptiveChunkController | None = None if config.adaptive_chunking_enabled: self._adaptive_chunker = AdaptiveChunkController( @@ -306,6 +340,13 @@ def __init__( min_horizon=config.adaptive_min_horizon, base_horizon=config.rtc_execution_horizon, max_horizon=action_buffer.capacity, + low_uncertainty=config.adaptive_low_uncertainty, + high_uncertainty=config.adaptive_high_uncertainty, + low_guard_margin=config.adaptive_low_guard_margin, + high_correction_magnitude=( + config.adaptive_high_correction_magnitude + ), + high_action_delta=config.adaptive_high_action_delta, high_latency_ms=config.adaptive_high_latency_ms, ) ) @@ -351,12 +392,15 @@ def predict_chunk_with_rtc(self, batch: dict[str, Any]) -> np.ndarray: "inference_delay": actions_consumed, "prev_chunk_left_over": self._prev_chunk_left_over, } + execution_horizon = self.config.rtc_execution_horizon + if ( + adaptive_decision is not None + and not self.config.adaptive_chunking_canary + ): + execution_horizon = adaptive_decision.horizon + self._last_execution_horizon = execution_horizon if self.config.enabled: - rtc_kwargs["execution_horizon"] = ( - adaptive_decision.horizon - if adaptive_decision is not None - else self.config.rtc_execution_horizon - ) + rtc_kwargs["execution_horizon"] = execution_horizon t0 = time.monotonic() try: @@ -469,6 +513,7 @@ def reset(self, episode_id: str | None = None) -> None: self._last_action_delta = None self._last_adaptive_signal = AdaptiveChunkSignal() self._last_adaptive_decision = None + self._last_execution_horizon = None # Clear the latency window — old samples are stale on a fresh episode self.latency = LatencyTracker( percentile=self.config.latency_percentile, @@ -495,7 +540,11 @@ def get_stats(self) -> dict[str, Any]: if self._last_adaptive_signal.has_values(): stats["adaptive_signal"] = self._last_adaptive_signal.as_dict() if self._last_adaptive_decision is not None: - stats["adaptive_chunking"] = self._last_adaptive_decision.as_dict() + adaptive_chunking = self._last_adaptive_decision.as_dict() + adaptive_chunking["canary"] = self.config.adaptive_chunking_canary + if self._last_execution_horizon is not None: + adaptive_chunking["applied_horizon"] = self._last_execution_horizon + stats["adaptive_chunking"] = adaptive_chunking return stats diff --git a/src/tether/runtime/server.py b/src/tether/runtime/server.py index 57a5d42..db21e6e 100644 --- a/src/tether/runtime/server.py +++ b/src/tether/runtime/server.py @@ -157,6 +157,11 @@ def _set_rtc_adaptive_span_attrs(span: Any, rtc_record: dict[str, Any]) -> None: _set_span_optional_float( span, "tether.rtc.adaptive.horizon", decision.get("horizon") ) + _set_span_optional_float( + span, + "tether.rtc.adaptive.applied_horizon", + decision.get("applied_horizon"), + ) _set_span_optional_float( span, "tether.rtc.adaptive.risk_score", decision.get("risk_score") ) @@ -165,6 +170,11 @@ def _set_rtc_adaptive_span_attrs(span: Any, rtc_record: dict[str, Any]) -> None: "tether.rtc.adaptive.replan_threshold_ratio", decision.get("replan_threshold_ratio"), ) + if "canary" in decision: + span.set_attribute( + "tether.rtc.adaptive.canary", + bool(decision.get("canary")), + ) signal = rtc_record.get("adaptive_signal") if isinstance(signal, dict): diff --git a/tests/test_adaptive_calibration.py b/tests/test_adaptive_calibration.py index d15ba21..1f1cb3a 100644 --- a/tests/test_adaptive_calibration.py +++ b/tests/test_adaptive_calibration.py @@ -29,7 +29,16 @@ def _record( action_delta: float, horizon: int, reason: str, + applied_horizon: int | None = None, ) -> dict: + decision = { + "horizon": horizon, + "reason": reason, + "risk_score": 0.5, + "replan_threshold_ratio": 0.4, + } + if applied_horizon is not None: + decision["applied_horizon"] = applied_horizon return { "kind": "request", "latency": {"total_ms": latency_ms}, @@ -39,12 +48,7 @@ def _record( "correction_magnitude": correction_magnitude, "uncertainty": uncertainty, }, - "adaptive_chunking": { - "horizon": horizon, - "reason": reason, - "risk_score": 0.5, - "replan_threshold_ratio": 0.4, - }, + "adaptive_chunking": decision, "last_action_delta": action_delta, }, } @@ -60,6 +64,7 @@ def test_iter_adaptive_records_reads_plain_and_gzip_jsonl(tmp_path): uncertainty=0.2, action_delta=0.03, horizon=8, + applied_horizon=5, reason="stable", ), {"kind": "request", "latency": {"total_ms": 60}}, @@ -86,6 +91,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles(): uncertainty=0.2, action_delta=0.03, horizon=8, + applied_horizon=8, reason="stable", ), _record( @@ -95,6 +101,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles(): uncertainty=0.5, action_delta=0.08, horizon=5, + applied_horizon=5, reason="correction", ), _record( @@ -104,6 +111,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles(): uncertainty=0.9, action_delta=0.16, horizon=2, + applied_horizon=5, reason="correction", ), ] @@ -115,6 +123,7 @@ def test_summarize_adaptive_records_counts_reasons_and_percentiles(): assert summary["reasons"] == {"correction": 2, "stable": 1} assert summary["observed"]["latency_ms"]["p50"] == pytest.approx(100) assert summary["observed"]["guard_margin"]["p10"] == pytest.approx(0.024) + assert summary["observed"]["applied_horizon"]["p50"] == pytest.approx(5) def test_recommend_adaptive_chunk_thresholds_uses_recorded_distribution(): diff --git a/tests/test_observability_prometheus.py b/tests/test_observability_prometheus.py index b382d9c..cab6d5a 100644 --- a/tests/test_observability_prometheus.py +++ b/tests/test_observability_prometheus.py @@ -149,6 +149,7 @@ def test_rtc_adaptive_chunking_metrics(self): policy_slot="prod", decision={ "horizon": 4, + "applied_horizon": 10, "reason": "guard_margin", "risk_score": 0.8, "replan_threshold_ratio": 0.6, @@ -164,6 +165,7 @@ def test_rtc_adaptive_chunking_metrics(self): assert "tether_rtc_adaptive_decisions_total" in out assert 'reason="guard_margin"' in out assert "tether_rtc_adaptive_horizon" in out + assert "tether_rtc_adaptive_applied_horizon" in out assert 'model_id="pi05"' in out assert "tether_rtc_adaptive_guard_margin" in out assert "tether_rtc_adaptive_action_delta" in out diff --git a/tests/test_rtc_adapter_day1.py b/tests/test_rtc_adapter_day1.py index 775e292..1989a34 100644 --- a/tests/test_rtc_adapter_day1.py +++ b/tests/test_rtc_adapter_day1.py @@ -73,12 +73,47 @@ def test_latency_percentile_p99_accepted(self): def test_adaptive_chunking_config_defaults_off(self): cfg = RtcAdapterConfig() assert cfg.adaptive_chunking_enabled is False + assert cfg.adaptive_chunking_canary is False assert cfg.adaptive_min_horizon == 1 + assert cfg.adaptive_low_uncertainty == pytest.approx(0.20) + assert cfg.adaptive_high_uncertainty == pytest.approx(0.65) + assert cfg.adaptive_low_guard_margin == pytest.approx(0.05) + assert cfg.adaptive_high_correction_magnitude == pytest.approx(0.20) + assert cfg.adaptive_high_action_delta == pytest.approx(0.25) def test_invalid_adaptive_min_horizon_rejected(self): with pytest.raises(ValueError, match="adaptive_min_horizon"): RtcAdapterConfig(adaptive_min_horizon=0) + def test_adaptive_canary_implies_adaptive_chunking_enabled(self): + cfg = RtcAdapterConfig(adaptive_chunking_canary=True) + assert cfg.adaptive_chunking_canary is True + assert cfg.adaptive_chunking_enabled is True + + def test_adaptive_uncertainty_thresholds_must_be_ordered(self): + with pytest.raises(ValueError, match="adaptive_high_uncertainty"): + RtcAdapterConfig( + adaptive_low_uncertainty=0.7, + adaptive_high_uncertainty=0.6, + ) + + @pytest.mark.parametrize( + "field", + [ + "adaptive_low_uncertainty", + "adaptive_low_guard_margin", + "adaptive_high_correction_magnitude", + "adaptive_high_action_delta", + ], + ) + def test_negative_adaptive_threshold_rejected(self, field): + with pytest.raises(ValueError, match=field): + RtcAdapterConfig(**{field: -0.01}) + + def test_adaptive_high_latency_must_be_positive(self): + with pytest.raises(ValueError, match="adaptive_high_latency_ms"): + RtcAdapterConfig(adaptive_high_latency_ms=0.0) + # --------------------------------------------------------------------------- # LatencyTracker (already shipped in skeleton) diff --git a/tests/test_rtc_adapter_integration.py b/tests/test_rtc_adapter_integration.py index 79bb08d..dc27f3c 100644 --- a/tests/test_rtc_adapter_integration.py +++ b/tests/test_rtc_adapter_integration.py @@ -173,6 +173,37 @@ def test_adaptive_chunking_overrides_execution_horizon(self): stats = adapter.get_stats() assert stats["adaptive_chunking"]["reason"] == "stable_high_latency" + def test_adaptive_chunking_canary_keeps_base_execution_horizon(self): + from tether.runtime.rtc_adapter import _RTC_AVAILABLE + if not _RTC_AVAILABLE: + pytest.skip("lerobot not installed") + policy = _SyntheticPolicy() + cfg = RtcAdapterConfig( + enabled=True, + execute_hz=100.0, + cold_start_discard=0, + rtc_execution_horizon=5, + adaptive_chunking_canary=True, + adaptive_high_latency_ms=50.0, + ) + adapter = RtcAdapter( + policy=policy, + action_buffer=ActionChunkBuffer(capacity=10), + config=cfg, + ) + adapter.latency.record(0.10) + adapter.latency.record(0.10) + adapter.latency.record(0.10) + + adapter.predict_chunk_with_rtc({"image": "x"}) + + assert policy.calls[-1]["execution_horizon"] == 5 + stats = adapter.get_stats() + assert stats["adaptive_chunking"]["horizon"] == 10 + assert stats["adaptive_chunking"]["applied_horizon"] == 5 + assert stats["adaptive_chunking"]["canary"] is True + assert stats["adaptive_chunking"]["reason"] == "stable_high_latency" + def test_adaptive_chunking_uses_guard_margin_signal(self): policy = _SyntheticPolicy() cfg = RtcAdapterConfig( @@ -196,6 +227,28 @@ def test_adaptive_chunking_uses_guard_margin_signal(self): stats = adapter.get_stats() assert stats["adaptive_signal"]["guard_margin"] == pytest.approx(0.01) + def test_adaptive_chunking_uses_configured_correction_threshold(self): + policy = _SyntheticPolicy() + cfg = RtcAdapterConfig( + enabled=False, + cold_start_discard=0, + rtc_execution_horizon=5, + adaptive_chunking_enabled=True, + adaptive_high_correction_magnitude=0.50, + ) + adapter = RtcAdapter( + policy=policy, + action_buffer=ActionChunkBuffer(capacity=10), + config=cfg, + ) + + adapter.record_adaptive_signal(correction_magnitude=0.30) + decision = adapter._decide_adaptive_horizon(0.01) + + assert decision is not None + assert decision.reason == "correction" + assert 1 < decision.horizon < 10 + def test_adaptive_chunking_uses_a2c2_correction_signal(self): policy = _SyntheticPolicy() cfg = RtcAdapterConfig(