Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ brainco = [
langsmith = [
"langsmith>=0.1.0",
]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"pexpect>=4.9.0,<5.0.0",
"ruff>=0.1.0",
]

[dependency-groups]
dev = [
Expand Down
22 changes: 22 additions & 0 deletions roboclaw/embodied/embodiment/hardware/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from dataclasses import asdict, dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +146,13 @@ def get_missing_arm_motors(arm: ArmBinding) -> list[str]:

if get_model(arm.arm_type) != "so101" or not arm.port:
return []
# Only probe real serial devices. Tests and local placeholders often use
# temporary files to stand in for a present port; probing those paths would
# incorrectly report every motor as disconnected.
if not arm.port.startswith("/dev/"):
return []
if not Path(arm.port).exists():
return []

runtime_spec = get_runtime_spec(arm.arm_type)
motor_config = _motor_config_from_arm(arm, runtime_spec)
Expand Down Expand Up @@ -240,6 +248,20 @@ async def run_check_once(self) -> None:
"timestamp": time.time(),
})

async def run(self) -> None:
"""Main loop: check hardware every N seconds until stopped."""
logger.info("Hardware monitor started")
while not self._stop_event.is_set():
await self.run_check_once()
try:
await asyncio.wait_for(
self._stop_event.wait(), timeout=_CHECK_INTERVAL_SECONDS
)
break # stop_event was set
except asyncio.TimeoutError:
pass # normal interval elapsed
logger.info("Hardware monitor stopped")

def check_hardware(self) -> list[HardwareFault]:
"""Check all configured devices and return current faults."""
if self._manifest is not None:
Expand Down
23 changes: 23 additions & 0 deletions roboclaw/embodied/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _verify_inference_preflight(
result = self._preflight_verifier.verify(VerificationRequest(
argv=argv,
manifest=self.manifest,
mode="infer",
dataset=dataset,
num_episodes=num_episodes,
episode_time_s=episode_time_s,
Expand All @@ -226,6 +227,26 @@ def _verify_inference_preflight(
if not result.ok:
raise ActionError(result.format_violations())

def _verify_replay_preflight(
self,
*,
argv: list[str],
dataset: Any,
episode: int,
fps: int,
) -> None:
result = self._preflight_verifier.verify(VerificationRequest(
argv=argv,
manifest=self.manifest,
mode="replay",
dataset=dataset,
episode=episode,
fps=fps,
use_cameras=False,
))
if not result.ok:
raise ActionError(result.format_violations())

# -- Operations (Web entry points) --

async def start_teleop(self, *, fps: int = 30, arms: str = "") -> None:
Expand Down Expand Up @@ -277,6 +298,7 @@ async def start_replay(
argv = CommandBuilder.replay(
self.manifest, dataset=dataset.runtime, episode=episode, fps=fps, arms=arms,
)
self._verify_replay_preflight(argv=argv, dataset=dataset.runtime, episode=episode, fps=fps)
await self._start_managed_session(self.replay, owner="replaying", argv=argv)

async def start_inference(
Expand Down Expand Up @@ -328,6 +350,7 @@ async def run_replay(
argv = CommandBuilder.replay(
self.manifest, dataset=dataset.runtime, episode=episode, fps=fps, arms=arms,
)
self._verify_replay_preflight(argv=argv, dataset=dataset.runtime, episode=episode, fps=fps)
return await self._run_managed_session(
self.replay, owner="replaying", argv=argv, tty_handoff=tty_handoff,
)
Expand Down
115 changes: 86 additions & 29 deletions roboclaw/embodied/service/verification/preflight.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Preflight checks for LeRobot subprocess inference."""
"""Preflight checks for LeRobot subprocess sessions."""

from __future__ import annotations

Expand Down Expand Up @@ -26,6 +26,7 @@
)
_MAX_INFERENCE_EPISODES = 1_000
_MAX_EPISODE_TIME_S = 3_600
_MAX_REPLAY_FPS = 120


class Verifier(Protocol):
Expand All @@ -36,52 +37,65 @@ def verify(self, request: VerificationRequest) -> VerificationResult:


class PreflightVerifier:
"""Validate host-visible inference inputs before spawning LeRobot.
"""Validate host-visible session inputs before spawning LeRobot.

This verifier deliberately does not inspect runtime policy actions. In the
current architecture, RoboClaw launches LeRobot as a subprocess and only has
access to argv, manifest state, and local checkpoint files before launch.
"""

def verify(self, request: VerificationRequest) -> VerificationResult:
mode = (request.mode or "infer").lower()
violations: list[Violation] = []
warnings: list[Violation] = []

argv = list(request.argv)
violations.extend(_validate_wrapper_argv(argv))
policy_path = _policy_path_from_request(request, argv)
violations.extend(_validate_policy_path(policy_path))
violations.extend(_validate_dataset_args(argv))
violations.extend(_validate_resource_limits(request))
violations.extend(_validate_manifest(request.manifest, request.use_cameras, argv))

if policy_path and _looks_like_remote_policy_id(policy_path):
warnings.append(Violation(
"remote_policy_unchecked",
f"Policy '{policy_path}' looks like a remote repo id; local checkpoint files were not checked.",
"checkpoint_path",
if mode == "replay":
violations.extend(_validate_wrapper_argv(argv, expected_action="replay", label="Replay"))
violations.extend(_validate_replay_dataset_args(argv))
violations.extend(_validate_replay_limits(request))
violations.extend(_validate_manifest(
request.manifest, require_cameras=False, argv=argv, label="Replay",
))
else:
violations.extend(_validate_wrapper_argv(argv, expected_action="record", label="Inference"))
policy_path = _policy_path_from_request(request, argv)
violations.extend(_validate_policy_path(policy_path))
violations.extend(_validate_inference_dataset_args(argv))
violations.extend(_validate_inference_limits(request))
violations.extend(_validate_manifest(
request.manifest, require_cameras=request.use_cameras, argv=argv, label="Inference",
))

if policy_path and _looks_like_remote_policy_id(policy_path):
warnings.append(Violation(
"remote_policy_unchecked",
f"Policy '{policy_path}' looks like a remote repo id; local checkpoint files were not checked.",
"checkpoint_path",
))

return VerificationResult(tuple(violations), tuple(warnings))


def _validate_wrapper_argv(argv: Sequence[str]) -> list[Violation]:
def _validate_wrapper_argv(
argv: Sequence[str], *, expected_action: str, label: str,
) -> list[Violation]:
violations: list[Violation] = []
if not argv:
return [Violation("empty_argv", "Inference command argv is empty.", "argv")]
return [Violation("empty_argv", f"{label} command argv is empty.", "argv")]
if "roboclaw.embodied.command.wrapper" not in argv:
violations.append(Violation(
"missing_wrapper",
"Inference command must launch roboclaw.embodied.command.wrapper.",
f"{label} command must launch roboclaw.embodied.command.wrapper.",
"argv",
))
wrapper_index = _index_or_none(argv, "roboclaw.embodied.command.wrapper")
if wrapper_index is not None:
action_index = wrapper_index + 1
if action_index >= len(argv) or argv[action_index] != "record":
if action_index >= len(argv) or argv[action_index] != expected_action:
violations.append(Violation(
"unexpected_action",
"Inference command must use the LeRobot record action.",
f"{label} command must use the LeRobot {expected_action} action.",
"argv",
))
return violations
Expand Down Expand Up @@ -129,7 +143,7 @@ def _validate_policy_path(raw_path: str) -> list[Violation]:
return violations


def _validate_dataset_args(argv: Sequence[str]) -> list[Violation]:
def _validate_inference_dataset_args(argv: Sequence[str]) -> list[Violation]:
required = (
"--dataset.repo_id=",
"--dataset.root=",
Expand All @@ -143,7 +157,21 @@ def _validate_dataset_args(argv: Sequence[str]) -> list[Violation]:
]


def _validate_resource_limits(request: VerificationRequest) -> list[Violation]:
def _validate_replay_dataset_args(argv: Sequence[str]) -> list[Violation]:
required = (
"--dataset.repo_id=",
"--dataset.root=",
"--dataset.episode=",
"--dataset.fps=",
)
return [
Violation("missing_dataset_arg", f"Replay command is missing {prefix.rstrip('=')}.", "argv")
for prefix in required
if not _has_prefix(argv, prefix)
]


def _validate_inference_limits(request: VerificationRequest) -> list[Violation]:
violations: list[Violation] = []
if request.num_episodes < 1:
violations.append(Violation(
Expand Down Expand Up @@ -172,40 +200,69 @@ def _validate_resource_limits(request: VerificationRequest) -> list[Violation]:
return violations


def _validate_manifest(manifest: Any, use_cameras: bool, argv: Sequence[str]) -> list[Violation]:
def _validate_replay_limits(request: VerificationRequest) -> list[Violation]:
violations: list[Violation] = []
if request.episode < 0:
violations.append(Violation(
"invalid_replay_episode",
"episode must be >= 0 for replay.",
"episode",
))
if request.fps < 1:
violations.append(Violation(
"invalid_replay_fps",
"fps must be at least 1 for replay.",
"fps",
))
if request.fps > _MAX_REPLAY_FPS:
violations.append(Violation(
"replay_fps_too_high",
f"fps must be <= {_MAX_REPLAY_FPS} for replay preflight.",
"fps",
))
return violations


def _validate_manifest(
manifest: Any,
*,
require_cameras: bool,
argv: Sequence[str],
label: str,
) -> list[Violation]:
arms = list(getattr(manifest, "arms", []) or [])
followers = [arm for arm in arms if _role_value(getattr(arm, "role", "")) == "follower"]
violations: list[Violation] = []
if not followers:
violations.append(Violation(
"missing_follower",
"Inference requires at least one follower arm in the manifest.",
f"{label} requires at least one follower arm in the manifest.",
"manifest.arms",
))
if len(followers) not in {0, 1, 2}:
violations.append(Violation(
"unsupported_follower_count",
f"Inference supports 1 or 2 follower arms, got {len(followers)}.",
f"{label} supports 1 or 2 follower arms, got {len(followers)}.",
"manifest.arms",
))
if len(followers) == 2 and {getattr(arm, "side", "") for arm in followers} != {"left", "right"}:
violations.append(Violation(
"invalid_bimanual_sides",
"Bimanual inference requires one left and one right follower arm.",
f"Bimanual {label.lower()} requires one left and one right follower arm.",
"manifest.arms",
))

cameras = list(getattr(manifest, "cameras", []) or [])
if use_cameras and not cameras:
if require_cameras and not cameras:
violations.append(Violation(
"missing_cameras",
"Inference requested cameras, but no cameras are configured in the manifest.",
f"{label} requested cameras, but no cameras are configured in the manifest.",
"manifest.cameras",
))
if use_cameras and cameras and not _argv_has_camera_config(argv):
if require_cameras and cameras and not _argv_has_camera_config(argv):
violations.append(Violation(
"missing_camera_argv",
"Inference requested cameras, but argv does not include robot camera configuration.",
f"{label} requested cameras, but argv does not include robot camera configuration.",
"argv",
))
return violations
Expand Down
3 changes: 3 additions & 0 deletions roboclaw/embodied/service/verification/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ class VerificationRequest:

argv: Sequence[str]
manifest: Any
mode: str = "infer"
checkpoint_path: str | Path | None = None
dataset: Any | None = None
num_episodes: int = 1
episode_time_s: int = 60
episode: int = 0
fps: int = 30
use_cameras: bool = True
metadata: dict[str, Any] = field(default_factory=dict)
10 changes: 10 additions & 0 deletions roboclaw/http/recovery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""Recovery helpers for the dashboard."""

from __future__ import annotations

from typing import Any


def get_recovery_guides_json() -> dict[str, Any]:
"""Return recovery-guide payload for dashboard clients.

The detailed guide catalog can grow later without changing the route shape.
"""
return {"guides": []}
16 changes: 16 additions & 0 deletions roboclaw/http/routes/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from loguru import logger

from roboclaw.embodied.embodiment.hardware.monitor import HardwareMonitor
from roboclaw.http.recovery import get_recovery_guides_json


def schedule_dashboard_restart(app: FastAPI, delay_s: float = 0.5) -> None:
Expand All @@ -29,12 +30,27 @@ async def _restart() -> None:


def register_recovery_routes(app: FastAPI) -> None:
@app.get("/api/recovery/guides")
async def recovery_guides() -> dict[str, Any]:
return get_recovery_guides_json()

@app.get("/api/recovery/faults")
async def recovery_faults() -> dict[str, Any]:
monitor: HardwareMonitor = app.state.hardware_monitor
return {"faults": [fault.to_dict() for fault in monitor.active_faults]}

@app.post("/api/recovery/check-hardware")
async def recovery_check_hardware() -> dict[str, Any]:
monitor: HardwareMonitor = app.state.hardware_monitor
await monitor.run_check_once()
return {"faults": [fault.to_dict() for fault in monitor.active_faults]}

@app.post("/api/recovery/recheck")
async def recovery_recheck() -> dict[str, Any]:
monitor: HardwareMonitor = app.state.hardware_monitor
await monitor.run_check_once()
return {"faults": [fault.to_dict() for fault in monitor.active_faults]}

@app.post("/api/recovery/restart-dashboard")
async def recovery_restart_dashboard() -> dict[str, str]:
schedule_dashboard_restart(app)
Expand Down
12 changes: 10 additions & 2 deletions tests/integration/test_agent_pty.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,16 @@ def test_agent_ctrl_c(simulated_agent_child) -> None:
child = simulated_agent_child
child.expect(r"You:", timeout=15)
child.sendintr()
child.expect(r"Received SIGINT, goodbye!", timeout=10)
child.close(force=True)
idx = child.expect([r"Received SIGINT, goodbye!", r"Goodbye!", pexpect.EOF], timeout=10)
transcript = child.before

if idx in (0, 1):
transcript += child.after
child.expect([r"Resume this session:", pexpect.EOF, pexpect.TIMEOUT], timeout=5)
transcript += child.before

assert "Traceback" not in transcript
assert "KeyboardInterrupt" not in transcript


@pytest.mark.pty
Expand Down
Loading
Loading