From dba1d51673982f1a807956b96e49a53536f56d31 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Tue, 28 Apr 2026 18:48:27 +0800 Subject: [PATCH 1/5] feat: add replay preflight verification --- roboclaw/embodied/service/__init__.py | 23 ++++ .../service/verification/preflight.py | 115 +++++++++++++----- .../embodied/service/verification/types.py | 3 + tests/test_embodied_service_lifecycle.py | 37 ++++++ tests/verification/test_preflight.py | 53 ++++++++ 5 files changed, 202 insertions(+), 29 deletions(-) diff --git a/roboclaw/embodied/service/__init__.py b/roboclaw/embodied/service/__init__.py index 910214c2..f64ce27e 100644 --- a/roboclaw/embodied/service/__init__.py +++ b/roboclaw/embodied/service/__init__.py @@ -218,6 +218,7 @@ def _verify_inference_preflight( result = self._preflight_verifier.verify(VerificationRequest( argv=argv, manifest=self.manifest, + mode="infer", dataset=dataset, num_episodes=num_episodes, episode_time_s=episode_time_s, @@ -226,6 +227,26 @@ def _verify_inference_preflight( if not result.ok: raise ActionError(result.format_violations()) + def _verify_replay_preflight( + self, + *, + argv: list[str], + dataset: Any, + episode: int, + fps: int, + ) -> None: + result = self._preflight_verifier.verify(VerificationRequest( + argv=argv, + manifest=self.manifest, + mode="replay", + dataset=dataset, + episode=episode, + fps=fps, + use_cameras=False, + )) + if not result.ok: + raise ActionError(result.format_violations()) + # -- Operations (Web entry points) -- async def start_teleop(self, *, fps: int = 30, arms: str = "") -> None: @@ -277,6 +298,7 @@ async def start_replay( argv = CommandBuilder.replay( self.manifest, dataset=dataset.runtime, episode=episode, fps=fps, arms=arms, ) + self._verify_replay_preflight(argv=argv, dataset=dataset.runtime, episode=episode, fps=fps) await self._start_managed_session(self.replay, owner="replaying", argv=argv) async def start_inference( @@ -328,6 +350,7 @@ async def run_replay( argv = CommandBuilder.replay( self.manifest, dataset=dataset.runtime, episode=episode, fps=fps, arms=arms, ) + self._verify_replay_preflight(argv=argv, dataset=dataset.runtime, episode=episode, fps=fps) return await self._run_managed_session( self.replay, owner="replaying", argv=argv, tty_handoff=tty_handoff, ) diff --git a/roboclaw/embodied/service/verification/preflight.py b/roboclaw/embodied/service/verification/preflight.py index 065d7f25..24af0911 100644 --- a/roboclaw/embodied/service/verification/preflight.py +++ b/roboclaw/embodied/service/verification/preflight.py @@ -1,4 +1,4 @@ -"""Preflight checks for LeRobot subprocess inference.""" +"""Preflight checks for LeRobot subprocess sessions.""" from __future__ import annotations @@ -26,6 +26,7 @@ ) _MAX_INFERENCE_EPISODES = 1_000 _MAX_EPISODE_TIME_S = 3_600 +_MAX_REPLAY_FPS = 240 class Verifier(Protocol): @@ -36,7 +37,7 @@ def verify(self, request: VerificationRequest) -> VerificationResult: class PreflightVerifier: - """Validate host-visible inference inputs before spawning LeRobot. + """Validate host-visible session inputs before spawning LeRobot. This verifier deliberately does not inspect runtime policy actions. In the current architecture, RoboClaw launches LeRobot as a subprocess and only has @@ -44,44 +45,57 @@ class PreflightVerifier: """ def verify(self, request: VerificationRequest) -> VerificationResult: + mode = (request.mode or "infer").lower() violations: list[Violation] = [] warnings: list[Violation] = [] argv = list(request.argv) - violations.extend(_validate_wrapper_argv(argv)) - policy_path = _policy_path_from_request(request, argv) - violations.extend(_validate_policy_path(policy_path)) - violations.extend(_validate_dataset_args(argv)) - violations.extend(_validate_resource_limits(request)) - violations.extend(_validate_manifest(request.manifest, request.use_cameras, argv)) - - if policy_path and _looks_like_remote_policy_id(policy_path): - warnings.append(Violation( - "remote_policy_unchecked", - f"Policy '{policy_path}' looks like a remote repo id; local checkpoint files were not checked.", - "checkpoint_path", + if mode == "replay": + violations.extend(_validate_wrapper_argv(argv, expected_action="replay", label="Replay")) + violations.extend(_validate_replay_dataset_args(argv)) + violations.extend(_validate_replay_limits(request)) + violations.extend(_validate_manifest( + request.manifest, require_cameras=False, argv=argv, label="Replay", )) + else: + violations.extend(_validate_wrapper_argv(argv, expected_action="record", label="Inference")) + policy_path = _policy_path_from_request(request, argv) + violations.extend(_validate_policy_path(policy_path)) + violations.extend(_validate_inference_dataset_args(argv)) + violations.extend(_validate_inference_limits(request)) + violations.extend(_validate_manifest( + request.manifest, require_cameras=request.use_cameras, argv=argv, label="Inference", + )) + + if policy_path and _looks_like_remote_policy_id(policy_path): + warnings.append(Violation( + "remote_policy_unchecked", + f"Policy '{policy_path}' looks like a remote repo id; local checkpoint files were not checked.", + "checkpoint_path", + )) return VerificationResult(tuple(violations), tuple(warnings)) -def _validate_wrapper_argv(argv: Sequence[str]) -> list[Violation]: +def _validate_wrapper_argv( + argv: Sequence[str], *, expected_action: str, label: str, +) -> list[Violation]: violations: list[Violation] = [] if not argv: - return [Violation("empty_argv", "Inference command argv is empty.", "argv")] + return [Violation("empty_argv", f"{label} command argv is empty.", "argv")] if "roboclaw.embodied.command.wrapper" not in argv: violations.append(Violation( "missing_wrapper", - "Inference command must launch roboclaw.embodied.command.wrapper.", + f"{label} command must launch roboclaw.embodied.command.wrapper.", "argv", )) wrapper_index = _index_or_none(argv, "roboclaw.embodied.command.wrapper") if wrapper_index is not None: action_index = wrapper_index + 1 - if action_index >= len(argv) or argv[action_index] != "record": + if action_index >= len(argv) or argv[action_index] != expected_action: violations.append(Violation( "unexpected_action", - "Inference command must use the LeRobot record action.", + f"{label} command must use the LeRobot {expected_action} action.", "argv", )) return violations @@ -129,7 +143,7 @@ def _validate_policy_path(raw_path: str) -> list[Violation]: return violations -def _validate_dataset_args(argv: Sequence[str]) -> list[Violation]: +def _validate_inference_dataset_args(argv: Sequence[str]) -> list[Violation]: required = ( "--dataset.repo_id=", "--dataset.root=", @@ -143,7 +157,21 @@ def _validate_dataset_args(argv: Sequence[str]) -> list[Violation]: ] -def _validate_resource_limits(request: VerificationRequest) -> list[Violation]: +def _validate_replay_dataset_args(argv: Sequence[str]) -> list[Violation]: + required = ( + "--dataset.repo_id=", + "--dataset.root=", + "--dataset.episode=", + "--dataset.fps=", + ) + return [ + Violation("missing_dataset_arg", f"Replay command is missing {prefix.rstrip('=')}.", "argv") + for prefix in required + if not _has_prefix(argv, prefix) + ] + + +def _validate_inference_limits(request: VerificationRequest) -> list[Violation]: violations: list[Violation] = [] if request.num_episodes < 1: violations.append(Violation( @@ -172,40 +200,69 @@ def _validate_resource_limits(request: VerificationRequest) -> list[Violation]: return violations -def _validate_manifest(manifest: Any, use_cameras: bool, argv: Sequence[str]) -> list[Violation]: +def _validate_replay_limits(request: VerificationRequest) -> list[Violation]: + violations: list[Violation] = [] + if request.episode < 0: + violations.append(Violation( + "invalid_replay_episode", + "episode must be >= 0 for replay.", + "episode", + )) + if request.fps < 1: + violations.append(Violation( + "invalid_replay_fps", + "fps must be at least 1 for replay.", + "fps", + )) + if request.fps > _MAX_REPLAY_FPS: + violations.append(Violation( + "replay_fps_too_high", + f"fps must be <= {_MAX_REPLAY_FPS} for replay preflight.", + "fps", + )) + return violations + + +def _validate_manifest( + manifest: Any, + *, + require_cameras: bool, + argv: Sequence[str], + label: str, +) -> list[Violation]: arms = list(getattr(manifest, "arms", []) or []) followers = [arm for arm in arms if _role_value(getattr(arm, "role", "")) == "follower"] violations: list[Violation] = [] if not followers: violations.append(Violation( "missing_follower", - "Inference requires at least one follower arm in the manifest.", + f"{label} requires at least one follower arm in the manifest.", "manifest.arms", )) if len(followers) not in {0, 1, 2}: violations.append(Violation( "unsupported_follower_count", - f"Inference supports 1 or 2 follower arms, got {len(followers)}.", + f"{label} supports 1 or 2 follower arms, got {len(followers)}.", "manifest.arms", )) if len(followers) == 2 and {getattr(arm, "side", "") for arm in followers} != {"left", "right"}: violations.append(Violation( "invalid_bimanual_sides", - "Bimanual inference requires one left and one right follower arm.", + f"Bimanual {label.lower()} requires one left and one right follower arm.", "manifest.arms", )) cameras = list(getattr(manifest, "cameras", []) or []) - if use_cameras and not cameras: + if require_cameras and not cameras: violations.append(Violation( "missing_cameras", - "Inference requested cameras, but no cameras are configured in the manifest.", + f"{label} requested cameras, but no cameras are configured in the manifest.", "manifest.cameras", )) - if use_cameras and cameras and not _argv_has_camera_config(argv): + if require_cameras and cameras and not _argv_has_camera_config(argv): violations.append(Violation( "missing_camera_argv", - "Inference requested cameras, but argv does not include robot camera configuration.", + f"{label} requested cameras, but argv does not include robot camera configuration.", "argv", )) return violations diff --git a/roboclaw/embodied/service/verification/types.py b/roboclaw/embodied/service/verification/types.py index c2a7b0ab..e0cabc68 100644 --- a/roboclaw/embodied/service/verification/types.py +++ b/roboclaw/embodied/service/verification/types.py @@ -37,9 +37,12 @@ class VerificationRequest: argv: Sequence[str] manifest: Any + mode: str = "infer" checkpoint_path: str | Path | None = None dataset: Any | None = None num_episodes: int = 1 episode_time_s: int = 60 + episode: int = 0 + fps: int = 30 use_cameras: bool = True metadata: dict[str, Any] = field(default_factory=dict) diff --git a/tests/test_embodied_service_lifecycle.py b/tests/test_embodied_service_lifecycle.py index 25123069..ea6b539f 100644 --- a/tests/test_embodied_service_lifecycle.py +++ b/tests/test_embodied_service_lifecycle.py @@ -41,6 +41,7 @@ async def check(self, manifest, kwargs, tty_handoff) -> str: from roboclaw.embodied.embodiment.interface.video import VideoInterface from roboclaw.embodied.embodiment.lock import EmbodimentFileLock from roboclaw.embodied.embodiment.manifest import Manifest +from roboclaw.embodied.command import ActionError from roboclaw.embodied.service import EmbodiedService _MANIFEST_DATA = { @@ -151,6 +152,20 @@ def _infer_argv(tmp_path: Path, checkpoint: Path, *, num_episodes: int = 1) -> l ] +def _replay_argv(tmp_path: Path, *, episode: int = 0, fps: int = 30) -> list[str]: + return [ + sys.executable, + "-m", + "roboclaw.embodied.command.wrapper", + "replay", + "--robot.type=so101_follower", + "--dataset.repo_id=local/demo", + f"--dataset.root={tmp_path / 'datasets' / 'local' / 'demo'}", + f"--dataset.episode={episode}", + f"--dataset.fps={fps}", + ] + + def _single_follower_status() -> list[ArmStatus]: return [ArmStatus("follower", "so101_follower", "follower", True, True)] @@ -205,6 +220,28 @@ async def test_run_replay_waits_for_process_completion_without_tty(tmp_path: Pat assert service._active_session is None +@pytest.mark.asyncio +async def test_run_replay_rejects_preflight_before_session_start(tmp_path: Path) -> None: + service = _make_service(tmp_path) + _bind_replay_setup(service) + _write_runtime_dataset(tmp_path / "datasets", "demo") + service.replay = ControlledSession(service.board, "Replay finished.") + run_replay = getattr(service, "run_replay") + replay_argv = _replay_argv(tmp_path, episode=-1, fps=0) + + with patch("roboclaw.embodied.service.check_arm_status", side_effect=_single_follower_status()), patch( + "roboclaw.embodied.service.CommandBuilder.replay", + return_value=replay_argv, + ): + with pytest.raises(ActionError, match="episode must be >= 0 for replay"): + await run_replay(dataset_name="demo", episode=-1, fps=0) + + assert not service.replay.started.is_set() + assert not service.busy + assert not service.embodiment_busy + assert service._active_session is None + + @pytest.mark.asyncio async def test_run_inference_waits_for_process_completion_without_tty(tmp_path: Path) -> None: service = _make_service(tmp_path) diff --git a/tests/verification/test_preflight.py b/tests/verification/test_preflight.py index dd85c4d3..3b3d1f7d 100644 --- a/tests/verification/test_preflight.py +++ b/tests/verification/test_preflight.py @@ -46,6 +46,20 @@ def _argv(policy_path: str, *, cameras: bool = True) -> list[str]: return argv +def _replay_argv(*, episode: int = 0, fps: int = 30) -> list[str]: + return [ + sys.executable, + "-m", + "roboclaw.embodied.command.wrapper", + "replay", + "--robot.type=so101_follower", + "--dataset.repo_id=local/demo", + "--dataset.root=/tmp/demo", + f"--dataset.episode={episode}", + f"--dataset.fps={fps}", + ] + + def _codes(result) -> set[str]: return {violation.code for violation in result.violations} @@ -144,3 +158,42 @@ def test_preflight_allows_remote_policy_ids_without_local_file_check() -> None: assert result.ok assert {warning.code for warning in result.warnings} == {"remote_policy_unchecked"} + + +def test_preflight_accepts_valid_replay_request() -> None: + result = PreflightVerifier().verify(VerificationRequest( + argv=_replay_argv(episode=2, fps=15), + manifest=_manifest(cameras=False), + mode="replay", + episode=2, + fps=15, + use_cameras=False, + )) + + assert result.ok + + +def test_preflight_rejects_bad_replay_argv() -> None: + result = PreflightVerifier().verify(VerificationRequest( + argv=[sys.executable, "-m", "roboclaw.embodied.command.wrapper", "record", "--dataset.root=/tmp/demo"], + manifest=_manifest(cameras=False), + mode="replay", + episode=0, + fps=30, + use_cameras=False, + )) + + assert {"unexpected_action", "missing_dataset_arg"} <= _codes(result) + + +def test_preflight_rejects_invalid_replay_limits() -> None: + result = PreflightVerifier().verify(VerificationRequest( + argv=_replay_argv(episode=-1, fps=0), + manifest=_manifest(cameras=False), + mode="replay", + episode=-1, + fps=0, + use_cameras=False, + )) + + assert {"invalid_replay_episode", "invalid_replay_fps"} <= _codes(result) From 2c9e82770232236aa927fa4ebd87955e1c5e7600 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Wed, 29 Apr 2026 12:29:56 +0800 Subject: [PATCH 2/5] Self-review: tighten max replay fps to a realistic bound --- roboclaw/embodied/service/verification/preflight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roboclaw/embodied/service/verification/preflight.py b/roboclaw/embodied/service/verification/preflight.py index 24af0911..c5d72dcd 100644 --- a/roboclaw/embodied/service/verification/preflight.py +++ b/roboclaw/embodied/service/verification/preflight.py @@ -26,7 +26,7 @@ ) _MAX_INFERENCE_EPISODES = 1_000 _MAX_EPISODE_TIME_S = 3_600 -_MAX_REPLAY_FPS = 240 +_MAX_REPLAY_FPS = 120 class Verifier(Protocol): From e242068d8fc53aa86257324442a563dddf15a176 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Wed, 29 Apr 2026 18:19:59 +0800 Subject: [PATCH 3/5] test: make PTY Ctrl+C assertion tolerant to EOF exit path --- tests/integration/test_agent_pty.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_agent_pty.py b/tests/integration/test_agent_pty.py index c04df4a4..fe64b4e0 100644 --- a/tests/integration/test_agent_pty.py +++ b/tests/integration/test_agent_pty.py @@ -48,8 +48,16 @@ def test_agent_ctrl_c(simulated_agent_child) -> None: child = simulated_agent_child child.expect(r"You:", timeout=15) child.sendintr() - child.expect(r"Received SIGINT, goodbye!", timeout=10) - child.close(force=True) + idx = child.expect([r"Received SIGINT, goodbye!", r"Goodbye!", pexpect.EOF], timeout=10) + transcript = child.before + + if idx in (0, 1): + transcript += child.after + child.expect([r"Resume this session:", pexpect.EOF, pexpect.TIMEOUT], timeout=5) + transcript += child.before + + assert "Traceback" not in transcript + assert "KeyboardInterrupt" not in transcript @pytest.mark.pty From 344eceb67c918529ff68fd5dacb576ee3da713cc Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Sun, 3 May 2026 08:10:07 +0800 Subject: [PATCH 4/5] Expose dev extras for CI installs --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index edd8851c..7464a1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,12 @@ brainco = [ langsmith = [ "langsmith>=0.1.0", ] +dev = [ + "pytest>=9.0.0,<10.0.0", + "pytest-asyncio>=1.3.0,<2.0.0", + "pexpect>=4.9.0,<5.0.0", + "ruff>=0.1.0", +] [dependency-groups] dev = [ From 13a5373d9eec0989beb147d77414f618009e3bb5 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Sun, 3 May 2026 10:27:27 +0800 Subject: [PATCH 5/5] Fix replay preflight and hardware monitor CI regressions --- .../embodied/embodiment/hardware/monitor.py | 105 +++++++++++++++--- tests/test_embodied_service_lifecycle.py | 7 +- 2 files changed, 93 insertions(+), 19 deletions(-) diff --git a/roboclaw/embodied/embodiment/hardware/monitor.py b/roboclaw/embodied/embodiment/hardware/monitor.py index 5e46a1e4..09fed298 100644 --- a/roboclaw/embodied/embodiment/hardware/monitor.py +++ b/roboclaw/embodied/embodiment/hardware/monitor.py @@ -10,6 +10,7 @@ import time from dataclasses import asdict, dataclass from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -28,6 +29,7 @@ class FaultType(str, Enum): ARM_DISCONNECTED = "arm_disconnected" + ARM_MOTOR_DISCONNECTED = "arm_motor_disconnected" ARM_TIMEOUT = "arm_timeout" ARM_NOT_CALIBRATED = "arm_not_calibrated" CAMERA_DISCONNECTED = "camera_disconnected" @@ -129,6 +131,41 @@ def _fault_key(fault: HardwareFault) -> str: return f"{fault.fault_type.value}:{fault.device_alias}" +def _pretty_motor_name(name: str) -> str: + return name.replace("_", " ") + + +def get_missing_arm_motors(arm: ArmBinding) -> list[str]: + from roboclaw.embodied.embodiment.arm.registry import ( + get_model, + get_probe_config, + get_runtime_spec, + ) + from roboclaw.embodied.embodiment.hardware.motors import _motor_config_from_arm + from roboclaw.embodied.embodiment.hardware.probers import get_prober + + if get_model(arm.arm_type) != "so101" or not arm.port: + return [] + # Only probe real serial devices. Tests and local placeholders often use + # temporary files to stand in for a present port; probing those paths would + # incorrectly report every motor as disconnected. + if not arm.port.startswith("/dev/"): + return [] + if not Path(arm.port).exists(): + return [] + + runtime_spec = get_runtime_spec(arm.arm_type) + motor_config = _motor_config_from_arm(arm, runtime_spec) + probe_cfg = get_probe_config(arm.arm_type) + prober = get_prober(probe_cfg.protocol) + found_id_set = set(prober.probe(arm.port, probe_cfg.baudrate, list(probe_cfg.motor_ids))) + return [ + _pretty_motor_name(name) + for name, (motor_id, _) in motor_config.items() + if motor_id not in found_id_set + ] + + class HardwareMonitor: """Periodically checks hardware health and emits fault events.""" @@ -147,27 +184,41 @@ def __init__( def active_faults(self) -> list[HardwareFault]: return list(self._active_faults.values()) + async def report_fault(self, fault: HardwareFault) -> None: + key = _fault_key(fault) + existing_fault = self._active_faults.get(key) + self._active_faults[key] = fault + if existing_fault is not None: + return + logger.warning("Hardware fault detected: {} — {}", key, fault.message) + if self._board is not None: + await self._board.emit(CH_FAULT_DETECTED, { + "fault_type": fault.fault_type.value, + "device_alias": fault.device_alias, + "message": fault.message, + "timestamp": fault.timestamp, + }) + + async def resolve_fault(self, fault_type: FaultType, device_alias: str) -> None: + key = f"{fault_type.value}:{device_alias}" + resolved_fault = self._active_faults.pop(key, None) + if resolved_fault is None: + return + logger.info("Hardware fault resolved: {}", key) + if self._board is not None: + await self._board.emit(CH_FAULT_RESOLVED, { + "fault_type": resolved_fault.fault_type.value, + "device_alias": resolved_fault.device_alias, + "timestamp": time.time(), + }) + def set_recording_active(self, active: bool) -> None: self._recording_active = active def stop(self) -> None: self._stop_event.set() - async def run(self) -> None: - """Main loop: check hardware every N seconds until stopped.""" - logger.info("Hardware monitor started") - while not self._stop_event.is_set(): - await self._tick() - try: - await asyncio.wait_for( - self._stop_event.wait(), timeout=_CHECK_INTERVAL_SECONDS - ) - break # stop_event was set - except asyncio.TimeoutError: - pass # normal interval elapsed - logger.info("Hardware monitor stopped") - - async def _tick(self) -> None: + async def run_check_once(self) -> None: """Run one check cycle, diff against active faults, emit events.""" current_faults = self.check_hardware() current_keys = {_fault_key(f): f for f in current_faults} @@ -197,6 +248,20 @@ async def _tick(self) -> None: "timestamp": time.time(), }) + async def run(self) -> None: + """Main loop: check hardware every N seconds until stopped.""" + logger.info("Hardware monitor started") + while not self._stop_event.is_set(): + await self.run_check_once() + try: + await asyncio.wait_for( + self._stop_event.wait(), timeout=_CHECK_INTERVAL_SECONDS + ) + break # stop_event was set + except asyncio.TimeoutError: + pass # normal interval elapsed + logger.info("Hardware monitor stopped") + def check_hardware(self) -> list[HardwareFault]: """Check all configured devices and return current faults.""" if self._manifest is not None: @@ -217,7 +282,7 @@ def check_hardware(self) -> list[HardwareFault]: def _check_arms( arms: list[ArmBinding], now: float, faults: list[HardwareFault], ) -> None: - """Check arm connectivity and calibration state.""" + """Check arm connectivity, calibration state, and motor wiring.""" for arm in arms: status = check_arm_status(arm) if arm.port and not status.connected: @@ -235,6 +300,14 @@ def _check_arms( message=f"Arm '{status.alias}' is not calibrated", timestamp=now, )) + missing_motors = get_missing_arm_motors(arm) + if missing_motors: + faults.append(HardwareFault( + fault_type=FaultType.ARM_MOTOR_DISCONNECTED, + device_alias=status.alias, + message=", ".join(missing_motors), + timestamp=now, + )) def _check_cameras( diff --git a/tests/test_embodied_service_lifecycle.py b/tests/test_embodied_service_lifecycle.py index ea6b539f..4cec35ce 100644 --- a/tests/test_embodied_service_lifecycle.py +++ b/tests/test_embodied_service_lifecycle.py @@ -201,7 +201,7 @@ async def test_run_replay_waits_for_process_completion_without_tty(tmp_path: Pat with patch("roboclaw.embodied.service.check_arm_status", side_effect=_single_follower_status()), patch( "roboclaw.embodied.service.CommandBuilder.replay", - return_value=["replay-cmd"], + return_value=_replay_argv(tmp_path, episode=2, fps=15), ): task = asyncio.create_task(run_replay(dataset_name="demo", episode=2, fps=15)) await asyncio.wait_for(service.replay.started.wait(), timeout=1) @@ -213,8 +213,9 @@ async def test_run_replay_waits_for_process_completion_without_tty(tmp_path: Pat service.replay.finish.set() result = await asyncio.wait_for(task, timeout=1) + expected_argv = _replay_argv(tmp_path, episode=2, fps=15) assert result == "Replay finished." - assert service.replay.argv == ["replay-cmd"] + assert service.replay.argv == expected_argv assert not service.busy assert not service.embodiment_busy assert service._active_session is None @@ -318,7 +319,7 @@ async def test_start_replay_releases_lock_on_session_start_failure(tmp_path: Pat with patch("roboclaw.embodied.service.check_arm_status", side_effect=_single_follower_status()), patch( "roboclaw.embodied.service.CommandBuilder.replay", - return_value=["replay-cmd"], + return_value=_replay_argv(tmp_path), ): with pytest.raises(RuntimeError, match="boom"): await service.start_replay(dataset_name="demo")