diff --git a/roboclaw/agent/__init__.py b/roboclaw/agent/__init__.py index c4c3bacb..d46e966d 100644 --- a/roboclaw/agent/__init__.py +++ b/roboclaw/agent/__init__.py @@ -1,8 +1,9 @@ """Agent core module.""" from roboclaw.agent.context import ContextBuilder +from roboclaw.agent.experience import ExperienceStore from roboclaw.agent.loop import AgentLoop from roboclaw.agent.memory import MemoryStore from roboclaw.agent.skills import SkillsLoader -__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"] +__all__ = ["AgentLoop", "ContextBuilder", "ExperienceStore", "MemoryStore", "SkillsLoader"] diff --git a/roboclaw/agent/context.py b/roboclaw/agent/context.py index 6ad2d0f6..1fef62d4 100644 --- a/roboclaw/agent/context.py +++ b/roboclaw/agent/context.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +from roboclaw.agent.experience import ExperienceStore from roboclaw.utils.helpers import current_time_str from roboclaw.agent.memory import MemoryStore @@ -22,9 +23,15 @@ class ContextBuilder: def __init__(self, workspace: Path): self.workspace = workspace self.memory = MemoryStore(workspace) + self.experiences = ExperienceStore(workspace) self.skills = SkillsLoader(workspace) - def build_system_prompt(self, skill_names: list[str] | None = None) -> str: + def build_system_prompt( + self, + skill_names: list[str] | None = None, + *, + current_message: str = "", + ) -> str: """Build the system prompt from identity, bootstrap files, memory, and skills.""" parts = [self._get_identity()] @@ -36,6 +43,10 @@ def build_system_prompt(self, skill_names: list[str] | None = None) -> str: if memory: parts.append(f"# Memory\n\n{memory}") + experience_context = self.experiences.build_context(query=current_message) + if experience_context: + parts.append(f"# Relevant Experience\n\n{experience_context}") + always_skills = self.skills.get_always_skills() if always_skills: always_content = self.skills.load_skills_for_context(always_skills) @@ -138,7 +149,7 @@ def build_messages( merged = [{"type": "text", "text": runtime_ctx}] + user_content return [ - {"role": "system", "content": self.build_system_prompt(skill_names)}, + {"role": "system", "content": self.build_system_prompt(skill_names, current_message=current_message)}, *history, {"role": "user", "content": merged}, ] diff --git a/roboclaw/agent/experience.py b/roboclaw/agent/experience.py new file mode 100644 index 00000000..bbdad58b --- /dev/null +++ b/roboclaw/agent/experience.py @@ -0,0 +1,240 @@ +"""Structured experience storage and retrieval for agent adaptation.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +import json +from pathlib import Path +import re +from typing import Any + +from roboclaw.utils.helpers import ensure_dir + +_TOKEN_RE = re.compile(r"[a-z0-9][a-z0-9_./:-]*") + + +def _normalize_text(value: str | None) -> str: + return (value or "").strip() + + +def _normalize_key(value: str | None) -> str: + return _normalize_text(value).lower() + + +def _tokenize(*values: str) -> set[str]: + tokens: set[str] = set() + for value in values: + for match in _TOKEN_RE.finditer(value.lower()): + token = match.group(0) + if len(token) >= 2: + tokens.add(token) + return tokens + + +@dataclass(frozen=True) +class ExperienceRecord: + timestamp: str + task_type: str + summary: str + outcome: str + lesson: str = "" + dataset: str = "" + replay_datasets: str = "" + policy: str = "" + provider: str = "" + job_id: str = "" + source: str = "" + error: str = "" + checkpoint_path: str = "" + dataset_path: str = "" + task_name: str = "" + + @classmethod + def create( + cls, + *, + task_type: str, + summary: str, + outcome: str, + lesson: str = "", + dataset: str = "", + replay_datasets: str = "", + policy: str = "", + provider: str = "", + job_id: str = "", + source: str = "", + error: str = "", + checkpoint_path: str = "", + dataset_path: str = "", + task_name: str = "", + ) -> "ExperienceRecord": + return cls( + timestamp=datetime.now(UTC).isoformat(), + task_type=_normalize_text(task_type), + summary=_normalize_text(summary), + outcome=_normalize_text(outcome), + lesson=_normalize_text(lesson), + dataset=_normalize_text(dataset), + replay_datasets=_normalize_text(replay_datasets), + policy=_normalize_text(policy), + provider=_normalize_text(provider), + job_id=_normalize_text(job_id), + source=_normalize_text(source), + error=_normalize_text(error), + checkpoint_path=_normalize_text(checkpoint_path), + dataset_path=_normalize_text(dataset_path), + task_name=_normalize_text(task_name), + ) + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "ExperienceRecord": + return cls( + timestamp=_normalize_text(str(payload.get("timestamp") or "")), + task_type=_normalize_text(str(payload.get("task_type") or "")), + summary=_normalize_text(str(payload.get("summary") or "")), + outcome=_normalize_text(str(payload.get("outcome") or "")), + lesson=_normalize_text(str(payload.get("lesson") or "")), + dataset=_normalize_text(str(payload.get("dataset") or "")), + replay_datasets=_normalize_text(str(payload.get("replay_datasets") or "")), + policy=_normalize_text(str(payload.get("policy") or "")), + provider=_normalize_text(str(payload.get("provider") or "")), + job_id=_normalize_text(str(payload.get("job_id") or "")), + source=_normalize_text(str(payload.get("source") or "")), + error=_normalize_text(str(payload.get("error") or "")), + checkpoint_path=_normalize_text(str(payload.get("checkpoint_path") or "")), + dataset_path=_normalize_text(str(payload.get("dataset_path") or "")), + task_name=_normalize_text(str(payload.get("task_name") or "")), + ) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def fingerprint(self) -> str: + parts = ( + _normalize_key(self.task_type), + _normalize_key(self.summary), + _normalize_key(self.outcome), + _normalize_key(self.dataset), + _normalize_key(self.replay_datasets), + _normalize_key(self.policy), + _normalize_key(self.provider), + _normalize_key(self.job_id), + _normalize_key(self.error), + _normalize_key(self.lesson), + ) + return "|".join(parts) + + +class ExperienceStore: + """Append-only JSONL store plus lightweight experience retrieval.""" + + def __init__(self, workspace: Path): + self.memory_dir = ensure_dir(workspace / "memory") + self.experience_file = self.memory_dir / "EXPERIENCES.jsonl" + + def read_all(self) -> list[ExperienceRecord]: + if not self.experience_file.exists(): + return [] + records: list[ExperienceRecord] = [] + for raw_line in self.experience_file.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line: + continue + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + records.append(ExperienceRecord.from_dict(payload)) + return records + + def append(self, record: ExperienceRecord) -> bool: + records = self.read_all() + fingerprint = record.fingerprint() + if any(existing.fingerprint() == fingerprint for existing in records): + return False + with self.experience_file.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record.to_dict(), ensure_ascii=False) + "\n") + return True + + def search( + self, + *, + query: str = "", + task_type: str = "", + dataset: str = "", + policy: str = "", + provider: str = "", + outcomes: frozenset[str] | None = None, + limit: int = 3, + ) -> list[ExperienceRecord]: + query_tokens = _tokenize(query, dataset, policy, provider, task_type) + scored: list[tuple[int, ExperienceRecord]] = [] + for record in self.read_all(): + if outcomes is not None and record.outcome not in outcomes: + continue + score = 0 + if task_type and _normalize_key(record.task_type) == _normalize_key(task_type): + score += 6 + if dataset and _normalize_key(record.dataset) == _normalize_key(dataset): + score += 8 + if policy and _normalize_key(record.policy) == _normalize_key(policy): + score += 6 + if provider and _normalize_key(record.provider) == _normalize_key(provider): + score += 4 + + record_tokens = _tokenize( + record.summary, + record.lesson, + record.dataset, + record.policy, + record.provider, + record.task_name, + record.error, + ) + score += len(query_tokens & record_tokens) + if score <= 0: + continue + scored.append((score, record)) + + scored.sort(key=lambda item: (item[0], item[1].timestamp), reverse=True) + return [record for _, record in scored[:limit]] + + def build_context( + self, + *, + query: str = "", + task_type: str = "", + dataset: str = "", + policy: str = "", + provider: str = "", + limit: int = 3, + ) -> str: + records = self.search( + query=query, + task_type=task_type, + dataset=dataset, + policy=policy, + provider=provider, + limit=limit, + ) + if not records: + return "" + + lines = [ + "Use these past outcomes as hints. Reuse what worked and avoid repeating failures." + ] + for record in records: + fields = [record.outcome] + if record.dataset: + fields.append(f"dataset={record.dataset}") + if record.replay_datasets: + fields.append(f"replay={record.replay_datasets}") + if record.policy: + fields.append(f"policy={record.policy}") + if record.provider: + fields.append(f"provider={record.provider}") + summary = record.lesson or record.summary + lines.append(f"- [{record.timestamp[:19]}] {'; '.join(fields)} -> {summary}") + return "\n".join(lines) diff --git a/roboclaw/embodied/command/builder.py b/roboclaw/embodied/command/builder.py index a749e73b..885b6e22 100644 --- a/roboclaw/embodied/command/builder.py +++ b/roboclaw/embodied/command/builder.py @@ -18,6 +18,7 @@ ) from roboclaw.embodied.embodiment.arm.registry import get_model from roboclaw.embodied.embodiment.manifest.binding import ArmBinding, ArmRole, CameraBinding +from roboclaw.embodied.policy import policy_registry _BIMANUAL: dict[str, tuple[str, str]] = { "so101": ("bi_so_follower", "bi_so_leader"), @@ -28,23 +29,7 @@ _DEFAULT_REPLAY_ROOT = Path("~/.cache/huggingface/lerobot").expanduser() -TRAIN_POLICY_TYPES = { - "act", - "diffusion", - "groot", - "multi_task_dit", - "pi0", - "pi0_fast", - "pi05", - "reward_classifier", - "sac", - "sarm", - "smolvla", - "tdmpc", - "vqbet", - "wall_x", - "xvla", -} +TRAIN_POLICY_TYPES = policy_registry.supported_types() # ── Private helper functions ───────────────────────────────────────────── @@ -288,12 +273,17 @@ def train( device: str = "cuda", ) -> list[str]: """Build training argv (standalone lerobot-train, not through wrapper).""" - if policy_type not in TRAIN_POLICY_TYPES: - allowed = ", ".join(sorted(TRAIN_POLICY_TYPES)) - raise ActionError(f"Unsupported policy_type '{policy_type}'. Expected one of: {allowed}.") + try: + policy_config = policy_registry.get(policy_type) + except ValueError as exc: + raise ActionError(str(exc)) from exc policies_root = manifest.snapshot.get("policies", {}).get("root", "") - output_dir_name = dataset.name if policy_type == "act" else f"{dataset.name}_{policy_type}" + output_dir_name = ( + dataset.name + if policy_config.policy_type == "act" + else f"{dataset.name}_{policy_config.policy_type}" + ) output_dir = Path(policies_root).expanduser() / output_dir_name argv = [ @@ -301,13 +291,14 @@ def train( f"--dataset.repo_id={dataset.repo_id}", f"--dataset.root={dataset.local_path}", "--dataset.video_backend=pyav", - f"--policy.type={policy_type}", + f"--policy.type={policy_config.policy_type}", "--policy.push_to_hub=false", f"--policy.repo_id={dataset.repo_id}", f"--output_dir={output_dir}", f"--steps={steps}", f"--policy.device={device}", ] + argv.extend(policy_config.extra_train_args()) # Resume if a previous checkpoint exists if output_dir.is_dir(): diff --git a/roboclaw/embodied/policy/__init__.py b/roboclaw/embodied/policy/__init__.py new file mode 100644 index 00000000..6e81644f --- /dev/null +++ b/roboclaw/embodied/policy/__init__.py @@ -0,0 +1,17 @@ +"""Policy config registry for embodied training.""" + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import PolicyRegistry, policy_registry + +from . import act as _act +from . import diffusion as _diffusion +from . import gr00t as _gr00t +from . import pi0 as _pi0 +from . import smolvla as _smolvla + +__all__ = [ + "BasePolicyConfig", + "PolicyRegistry", + "policy_registry", +] + diff --git a/roboclaw/embodied/policy/act.py b/roboclaw/embodied/policy/act.py new file mode 100644 index 00000000..c81314d8 --- /dev/null +++ b/roboclaw/embodied/policy/act.py @@ -0,0 +1,36 @@ +"""ACT-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class ActPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="act") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class MultiTaskDiTPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="multi_task_dit") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class VQBeTPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="vqbet") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/base.py b/roboclaw/embodied/policy/base.py new file mode 100644 index 00000000..cdcdd0dc --- /dev/null +++ b/roboclaw/embodied/policy/base.py @@ -0,0 +1,18 @@ +"""Base config types for trainable LeRobot policies.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class BasePolicyConfig(ABC): + """Train-time policy config registered against a LeRobot policy type.""" + + policy_type: str = field(init=False, default="") + + @abstractmethod + def extra_train_args(self) -> list[str]: + """Return policy-specific ``lerobot-train`` CLI args.""" + diff --git a/roboclaw/embodied/policy/diffusion.py b/roboclaw/embodied/policy/diffusion.py new file mode 100644 index 00000000..2a4bb7b8 --- /dev/null +++ b/roboclaw/embodied/policy/diffusion.py @@ -0,0 +1,63 @@ +"""Diffusion, RL, and auxiliary policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class DiffusionPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="diffusion") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class TDMPCPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="tdmpc") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class SACPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="sac") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class RewardClassifierPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="reward_classifier") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class SARMPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="sarm") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class WallXPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="wall_x") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/gr00t.py b/roboclaw/embodied/policy/gr00t.py new file mode 100644 index 00000000..fc41d1ac --- /dev/null +++ b/roboclaw/embodied/policy/gr00t.py @@ -0,0 +1,20 @@ +"""GR00T-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class GR00TPolicyConfig(BasePolicyConfig): + """NVIDIA GR00T N1.""" + + policy_type: str = field(init=False, default="groot") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/pi0.py b/roboclaw/embodied/policy/pi0.py new file mode 100644 index 00000000..de9a883f --- /dev/null +++ b/roboclaw/embodied/policy/pi0.py @@ -0,0 +1,36 @@ +"""PI-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class Pi0PolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi0") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class Pi0FastPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi0_fast") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class Pi05PolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi05") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/registry.py b/roboclaw/embodied/policy/registry.py new file mode 100644 index 00000000..506373ac --- /dev/null +++ b/roboclaw/embodied/policy/registry.py @@ -0,0 +1,49 @@ +"""Registry for trainable LeRobot policy configs.""" + +from __future__ import annotations + +from dataclasses import is_dataclass + +from roboclaw.embodied.policy.base import BasePolicyConfig + + +class PolicyRegistry: + """Register and resolve train-time policy configs by policy type.""" + + def __init__(self) -> None: + self._config_types: dict[str, type[BasePolicyConfig]] = {} + + def register(self, config_cls: type[BasePolicyConfig]) -> type[BasePolicyConfig]: + """Register a config dataclass and return it for decorator usage.""" + if not issubclass(config_cls, BasePolicyConfig): + raise TypeError("Policy config must inherit from BasePolicyConfig.") + if not is_dataclass(config_cls): + raise TypeError("Policy config must be a dataclass.") + + config = config_cls() + policy_type = config.policy_type.strip() + if not policy_type: + raise ValueError("Policy config must declare a non-empty policy_type.") + if policy_type in self._config_types: + raise ValueError(f"Policy '{policy_type}' is already registered.") + + self._config_types[policy_type] = config_cls + return config_cls + + def get(self, policy_type: str) -> BasePolicyConfig: + """Instantiate the config registered for ``policy_type``.""" + config_cls = self._config_types.get(policy_type) + if config_cls is None: + allowed = ", ".join(sorted(self.supported_types())) + raise ValueError( + f"Unsupported policy_type '{policy_type}'. Expected one of: {allowed}." + ) + return config_cls() + + def supported_types(self) -> set[str]: + """Return a copy of all registered policy types.""" + return set(self._config_types) + + +policy_registry = PolicyRegistry() + diff --git a/roboclaw/embodied/policy/smolvla.py b/roboclaw/embodied/policy/smolvla.py new file mode 100644 index 00000000..c6b6a89d --- /dev/null +++ b/roboclaw/embodied/policy/smolvla.py @@ -0,0 +1,27 @@ +"""VLA policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class SmolVLAConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="smolvla") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class XVLAPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="xvla") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/service/__init__.py b/roboclaw/embodied/service/__init__.py index 910214c2..b9b775ec 100644 --- a/roboclaw/embodied/service/__init__.py +++ b/roboclaw/embodied/service/__init__.py @@ -33,6 +33,7 @@ from roboclaw.embodied.service.session.calibrate import CalibrationSession from roboclaw.embodied.service.session.setup import SetupSession from roboclaw.embodied.service.verification import ( + InferenceConfigVerifier, PreflightVerifier, VerificationRequest, Verifier, @@ -66,6 +67,7 @@ def __init__( self._active_session: Session | None = None self._recording_started = False self._preflight_verifier = preflight_verifier or PreflightVerifier() + self._inference_config_verifier = InferenceConfigVerifier() # Sub-services self.calibration = CalibrationSession(self) @@ -226,6 +228,18 @@ def _verify_inference_preflight( if not result.ok: raise ActionError(result.format_violations()) + def _verify_inference_config( + self, + *, + checkpoint_path: str, + dataset_local_path: str, + ) -> None: + self._inference_config_verifier.verify( + checkpoint_path=checkpoint_path, + manifest_snapshot=self.manifest.snapshot, + dataset_local_path=dataset_local_path, + ) + # -- Operations (Web entry points) -- async def start_teleop(self, *, fps: int = 30, arms: str = "") -> None: @@ -312,6 +326,10 @@ async def start_inference( episode_time_s=episode_time_s, use_cameras=use_cameras, ) + self._verify_inference_config( + checkpoint_path=checkpoint_path or _arg_value(argv, "--policy.path="), + dataset_local_path=str(source.runtime.local_path if source else output_dataset.runtime.local_path), + ) await self._start_managed_session(self.infer, owner="inferring", argv=argv) async def run_replay( @@ -543,3 +561,10 @@ async def shutdown(self) -> None: if self._monitor is not None: self._monitor.set_recording_active(False) self.release_embodiment() + + +def _arg_value(argv: list[str], prefix: str) -> str: + for arg in argv: + if arg.startswith(prefix): + return arg.split("=", 1)[1] + return "" diff --git a/roboclaw/embodied/service/session/train.py b/roboclaw/embodied/service/session/train.py index 3c1cffe3..ff713766 100644 --- a/roboclaw/embodied/service/session/train.py +++ b/roboclaw/embodied/service/session/train.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore from roboclaw.embodied.command import CommandBuilder, logs_dir if TYPE_CHECKING: @@ -23,60 +24,120 @@ class TrainSession: def __init__(self, parent: EmbodiedService) -> None: self._parent = parent + self._experiences = ExperienceStore(parent.manifest._path.parent.parent) + self._job_specs: dict[str, dict[str, str]] = {} - async def train( + async def start_job_state( self, manifest: Manifest, kwargs: dict[str, Any], - tty_handoff: Any, - ) -> str: + ) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - dataset_name = kwargs.get("dataset_name", "default") + dataset_name = str(kwargs.get("dataset_name", "default") or "default") + policy_type = str(kwargs.get("policy_type", "act") or "act") + steps = int(kwargs.get("steps", 100_000) or 100_000) + device = str(kwargs.get("device", "cuda") or "cuda") dataset = self._parent.datasets.resolve_runtime_dataset(dataset_name) + experience_hint = self._build_experience_hint( + dataset_name=dataset_name, + policy_type=policy_type, + ) argv = CommandBuilder.train( manifest, dataset=dataset.runtime, - policy_type=kwargs.get("policy_type", "act"), - steps=kwargs.get("steps", 100_000), - device=kwargs.get("device", "cuda"), + policy_type=policy_type, + steps=steps, + device=device, ) job_id = await SubprocessExecutor().run_detached(argv=argv, log_dir=logs_dir()) - return f"Training started. Job ID: {job_id}" + self._job_specs[job_id] = { + "dataset_name": dataset_name, + "policy_type": policy_type, + "dataset_path": str(dataset.runtime.local_path), + "provider": "local", + } + state: dict[str, str | int | bool | None] = { + "job_id": job_id, + "status": "running", + "running": True, + "pid": None, + "log_path": str(SubprocessExecutor()._job_log_path(job_id, logs_dir())), + "log_tail": "", + "dataset_name": dataset_name, + "policy_type": policy_type, + "dataset_path": str(dataset.runtime.local_path), + "provider": "local", + "experience_hint": experience_hint, + } + state["message"] = self._format_status_message(state) + self._record_experience( + job_id=job_id, + status="submitted", + log_tail="", + log_path=str(state.get("log_path") or ""), + ) + return state - async def stop_job( + async def train( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, ) -> str: + state = await self.start_job_state(manifest, kwargs) + return str(state["message"]) + + async def stop_job_state(self, job_id: str) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - job_id = kwargs.get("job_id", "") status = await SubprocessExecutor().stop_job(job_id=job_id, log_dir=logs_dir()) - return "\n".join(f"{key}: {value}" for key, value in status.items()) + return self._enrich_state(job_id, status) - async def job_status( + async def stop_job( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, ) -> str: + job_id = kwargs.get("job_id", "") + status = await self.stop_job_state(str(job_id)) + return str(status["message"]) + + async def job_status_state(self, job_id: str) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - job_id = kwargs.get("job_id", "") status = await SubprocessExecutor().job_status(job_id=job_id, log_dir=logs_dir()) - return "\n".join(f"{key}: {value}" for key, value in status.items()) + return self._enrich_state(job_id, status) - async def current_job( + async def job_status( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, - ) -> dict[str, str | int | bool | None]: + ) -> str: + job_id = kwargs.get("job_id", "") + status = await self.job_status_state(str(job_id)) + return str(status["message"]) + + async def current_job_state(self) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - return await SubprocessExecutor().latest_running_job(log_dir=logs_dir()) + status = await SubprocessExecutor().latest_running_job(log_dir=logs_dir()) + job_id = str(status.get("job_id") or "") + if not job_id: + enriched = self._enrich_state("", status) + enriched["message"] = self._format_status_message(enriched) + return enriched + return self._enrich_state(job_id, status) + + async def current_job( + self, + manifest: Manifest, + kwargs: dict[str, Any], + tty_handoff: Any, + ) -> dict[str, str | int | bool | None]: + return await self.current_job_state() def curve_data(self, job_id: str) -> dict[str, Any]: job_id = job_id.strip() @@ -128,6 +189,114 @@ def list_policies(self, manifest: Manifest | None = None) -> str: return "No policies found." return json.dumps(policies, indent=2, ensure_ascii=False) + def _build_experience_hint( + self, + *, + dataset_name: str, + policy_type: str, + ) -> str: + records = self._experiences.search( + task_type="train", + dataset=dataset_name, + policy=policy_type, + outcomes=frozenset({"success", "failed", "error", "stopped"}), + provider="local", + limit=2, + ) + lines = [ + f"{record.outcome}: {record.lesson or record.summary}" + for record in records + ] + return "\n".join(lines) + + def _enrich_state( + self, + job_id: str, + state: dict[str, str | int | bool | None], + ) -> dict[str, str | int | bool | None]: + enriched = dict(state) + metadata = self._job_specs.get(job_id, {}) + enriched.setdefault("job_id", job_id) + enriched.setdefault("provider", metadata.get("provider", "local")) + enriched.setdefault("dataset_name", metadata.get("dataset_name", "")) + enriched.setdefault("policy_type", metadata.get("policy_type", "")) + enriched.setdefault("dataset_path", metadata.get("dataset_path", "")) + enriched.setdefault("experience_hint", "") + enriched["message"] = self._format_status_message(enriched) + + status_text = str(enriched.get("status") or "").lower() + if status_text in _TERMINAL_TRAIN_STATUSES and job_id: + self._record_experience( + job_id=job_id, + status=status_text, + log_tail=str(enriched.get("log_tail") or ""), + log_path=str(enriched.get("log_path") or ""), + ) + return enriched + + def _format_status_message(self, state: dict[str, str | int | bool | None]) -> str: + order = ( + "job_id", + "status", + "running", + "pid", + "dataset_name", + "policy_type", + "provider", + "dataset_path", + "log_path", + "experience_hint", + ) + lines: list[str] = [] + seen: set[str] = set() + for key in order: + value = state.get(key) + if value in {None, ""}: + continue + seen.add(key) + lines.append(f"{key}: {value}") + for key, value in state.items(): + if key in seen or value in {None, ""}: + continue + lines.append(f"{key}: {value}") + return "\n".join(lines) + + def _record_experience( + self, + *, + job_id: str, + status: str, + log_tail: str, + log_path: str, + ) -> None: + metadata = self._job_specs.get(job_id, {}) + dataset_name = metadata.get("dataset_name", "") + policy_type = metadata.get("policy_type", "") + provider = metadata.get("provider", "local") + if not dataset_name and not policy_type and not job_id: + return + outcome = _experience_outcome(status) + lesson = _status_lesson(status, log_tail) + summary = ( + f"Local training for dataset '{dataset_name or ''}' " + f"with policy '{policy_type or ''}' ended as {status}" + ) + self._experiences.append(ExperienceRecord.create( + task_type="train", + summary=summary, + outcome=outcome, + lesson=lesson, + dataset=dataset_name, + policy=policy_type, + provider=provider, + job_id=job_id, + source="train_session", + error=log_tail if status in {"failed", "error"} else "", + dataset_path=metadata.get("dataset_path", ""), + task_name=job_id, + checkpoint_path=log_path, + )) + def _scan_policies(root: Path) -> list[dict[str, Any]]: """Scan policy directories under *root* and return summary dicts.""" policies: list[dict[str, Any]] = [] @@ -157,6 +326,7 @@ def _enrich_policy_entry(entry: dict[str, Any], checkpoint_dir: Path) -> None: _JOB_ID_RE = re.compile(r"^[A-Za-z0-9-]+$") +_TERMINAL_TRAIN_STATUSES = {"finished", "stopped", "failed", "error", "missing", "idle"} _TRAIN_LOG_RE = re.compile( r"step:(?P\S+).*?" r"ep:(?P\d+).*?" @@ -177,6 +347,12 @@ def _update_best( return best +def _experience_outcome(status: str) -> str: + if status == "finished": + return "success" + return status + + def _parse_training_curve(job_id: str, log_path: Path) -> tuple[dict[str, float | int] | None, list[dict[str, Any]]]: if not log_path.exists(): return _BEST_LOSS_BY_JOB.get(job_id), [] @@ -246,3 +422,21 @@ def _parse_training_curve_line(line: str) -> dict[str, Any] | None: "epoch": epoch, "loss": loss, } + + +def _status_lesson(status: str, log_tail: str) -> str: + normalized = status.strip().lower() + if normalized == "submitted": + return "A similar run was submitted successfully." + if normalized == "finished": + return "A similar run finished previously." + if normalized == "stopped": + return "A similar run had to be stopped manually." + if normalized in {"failed", "error"}: + tail = log_tail.strip() + if tail: + return f"Recent failure signal: {tail.splitlines()[-1]}" + return "A similar run failed previously." + if normalized == "missing": + return "A similar run lost its local process metadata before completion." + return f"A similar run reached status '{status}'." diff --git a/roboclaw/embodied/service/verification/__init__.py b/roboclaw/embodied/service/verification/__init__.py index a828227a..10a21f1e 100644 --- a/roboclaw/embodied/service/verification/__init__.py +++ b/roboclaw/embodied/service/verification/__init__.py @@ -1,6 +1,10 @@ """Embodied verification interfaces and preflight checks.""" -from roboclaw.embodied.service.verification.preflight import PreflightVerifier, Verifier +from roboclaw.embodied.service.verification.preflight import ( + InferenceConfigVerifier, + PreflightVerifier, + Verifier, +) from roboclaw.embodied.service.verification.types import ( VerificationRequest, VerificationResult, @@ -9,6 +13,7 @@ __all__ = [ "PreflightVerifier", + "InferenceConfigVerifier", "VerificationRequest", "VerificationResult", "Verifier", diff --git a/roboclaw/embodied/service/verification/preflight.py b/roboclaw/embodied/service/verification/preflight.py index 065d7f25..3b7c914b 100644 --- a/roboclaw/embodied/service/verification/preflight.py +++ b/roboclaw/embodied/service/verification/preflight.py @@ -2,9 +2,13 @@ from __future__ import annotations +import json +import logging from pathlib import Path +import re from typing import Any, Iterable, Protocol, Sequence +from roboclaw.embodied.embodiment.arm.registry import get_runtime_spec from roboclaw.embodied.service.verification.types import ( VerificationRequest, VerificationResult, @@ -26,6 +30,9 @@ ) _MAX_INFERENCE_EPISODES = 1_000 _MAX_EPISODE_TIME_S = 3_600 +_MIN_DATASET_VERSION = (2, 1) +_VERSION_TOKEN_RE = re.compile(r"\d+") +logger = logging.getLogger(__name__) class Verifier(Protocol): @@ -65,6 +72,53 @@ def verify(self, request: VerificationRequest) -> VerificationResult: return VerificationResult(tuple(violations), tuple(warnings)) +class InferenceConfigVerifier: + """Validate checkpoint/dataset consistency before inference starts.""" + + def verify( + self, + checkpoint_path: str, + manifest_snapshot: dict[str, Any], + dataset_local_path: str, + ) -> None: + errors: list[str] = [] + checkpoint = Path(checkpoint_path).expanduser() + if not checkpoint.exists(): + errors.append(f"Checkpoint path does not exist: {checkpoint}") + self._raise_if_errors(errors) + return + + pretrained_dir = _pretrained_model_dir(checkpoint) + config_path = pretrained_dir / "config.json" + if not config_path.is_file(): + errors.append(f"Checkpoint is missing pretrained_model/config.json: {config_path}") + + train_config_path = pretrained_dir / "train_config.json" + train_config = _load_json_if_exists(train_config_path) + if train_config is not None: + _warn_device_mismatch(train_config, manifest_snapshot) + _warn_dataset_mismatch(train_config, dataset_local_path) + + if config_path.is_file(): + config = json.loads(config_path.read_text(encoding="utf-8")) + expected_action_dim = _manifest_action_dim(manifest_snapshot) + actual_action_dim = _config_action_dim(config) + if actual_action_dim is None: + errors.append(f"Checkpoint config is missing action_dim: {config_path}") + elif actual_action_dim != expected_action_dim: + errors.append( + "Checkpoint action_dim does not match manifest follower motors: " + f"{actual_action_dim} != {expected_action_dim}" + ) + + _warn_if_dataset_version_is_old(dataset_local_path) + self._raise_if_errors(errors) + + def _raise_if_errors(self, errors: list[str]) -> None: + if errors: + raise ValueError(" · ".join(errors)) + + def _validate_wrapper_argv(argv: Sequence[str]) -> list[Violation]: violations: list[Violation] = [] if not argv: @@ -254,3 +308,126 @@ def _index_or_none(argv: Sequence[str], value: str) -> int | None: def _role_value(role: Any) -> str: value = getattr(role, "value", role) return str(value) + + +def _pretrained_model_dir(checkpoint: Path) -> Path: + if checkpoint.name == "pretrained_model": + return checkpoint + return checkpoint / "pretrained_model" + + +def _load_json_if_exists(path: Path) -> dict[str, Any] | None: + if not path.is_file(): + return None + return json.loads(path.read_text(encoding="utf-8")) + + +def _warn_device_mismatch(train_config: dict[str, Any], manifest_snapshot: dict[str, Any]) -> None: + checkpoint_device = str(train_config.get("policy", {}).get("device") or "").strip() + manifest_device = _manifest_device(manifest_snapshot) + if checkpoint_device and manifest_device and checkpoint_device != manifest_device: + logger.warning( + "Inference config warning: checkpoint policy.device=%s but manifest device=%s", + checkpoint_device, + manifest_device, + ) + + +def _warn_dataset_mismatch(train_config: dict[str, Any], dataset_local_path: str) -> None: + checkpoint_dataset = str(train_config.get("dataset", {}).get("repo_id") or "").strip() + current_dataset = _dataset_repo_id_from_local_path(Path(dataset_local_path).expanduser()) + if checkpoint_dataset and current_dataset and checkpoint_dataset != current_dataset: + logger.warning( + "Inference config warning: checkpoint dataset.repo_id=%s but current dataset=%s", + checkpoint_dataset, + current_dataset, + ) + + +def _manifest_device(manifest_snapshot: dict[str, Any]) -> str: + direct = str(manifest_snapshot.get("device") or "").strip() + if direct: + return direct + policies = manifest_snapshot.get("policies", {}) + if isinstance(policies, dict): + return str(policies.get("device") or "").strip() + return "" + + +def _dataset_repo_id_from_local_path(dataset_local_path: Path) -> str: + info_path = dataset_local_path / "meta" / "info.json" + if not info_path.is_file(): + return "" + info = json.loads(info_path.read_text(encoding="utf-8")) + for key in ("source_dataset", "repo_id", "dataset_id"): + value = str(info.get(key) or "").strip() + if value: + return value + return "" + + +def _manifest_action_dim(manifest_snapshot: dict[str, Any]) -> int: + arms = list(manifest_snapshot.get("arms", []) or []) + followers = [arm for arm in arms if _arm_role(arm) == "follower"] + return sum(_arm_motor_count(arm) for arm in followers) + + +def _arm_role(arm: dict[str, Any]) -> str: + role = str(arm.get("role") or "").strip() + if role: + return role + arm_type = str(arm.get("type") or "").strip() + if arm_type.endswith("_follower"): + return "follower" + if arm_type.endswith("_leader"): + return "leader" + return "" + + +def _arm_motor_count(arm: dict[str, Any]) -> int: + arm_type = str(arm.get("type") or "").strip() + return len(get_runtime_spec(arm_type).default_joint_names) + + +def _config_action_dim(config: dict[str, Any]) -> int | None: + for key in ("action_dim", "max_action_dim"): + value = config.get(key) + if isinstance(value, int): + return value + output_features = config.get("output_features", {}) + if not isinstance(output_features, dict): + return None + action = output_features.get("action", {}) + if not isinstance(action, dict): + return None + shape = action.get("shape") + if isinstance(shape, list) and shape and isinstance(shape[0], int): + return shape[0] + return None + + +def _warn_if_dataset_version_is_old(dataset_local_path: str) -> None: + dataset_path = Path(dataset_local_path).expanduser() + if not dataset_path.exists(): + return + info_path = dataset_path / "meta" / "info.json" + if not info_path.is_file(): + return + info = json.loads(info_path.read_text(encoding="utf-8")) + version = str(info.get("codebase_version") or "").strip() + if _version_at_least(version, _MIN_DATASET_VERSION): + return + logger.warning( + "Inference config warning: dataset codebase_version=%s is older than v2.1: %s", + version or "", + dataset_path, + ) + + +def _version_at_least(version: str, minimum: tuple[int, int]) -> bool: + numbers = [int(token) for token in _VERSION_TOKEN_RE.findall(version)] + if not numbers: + return False + major = numbers[0] + minor = numbers[1] if len(numbers) > 1 else 0 + return (major, minor) >= minimum diff --git a/roboclaw/http/routes/train.py b/roboclaw/http/routes/train.py index ed4c8085..995549d5 100644 --- a/roboclaw/http/routes/train.py +++ b/roboclaw/http/routes/train.py @@ -16,6 +16,7 @@ class TrainStartRequest(BaseModel): policy_type: str = "act" steps: int = 100_000 device: str = "cuda" + continual_learning: bool = False class TrainStopRequest(BaseModel): @@ -26,44 +27,29 @@ def register_train_routes(app: FastAPI, service: EmbodiedService) -> None: @app.post("/api/train/start") async def train_start(body: TrainStartRequest) -> dict[str, Any]: - result = await service.train.train( + result = await service.train.start_job_state( manifest=service.manifest, kwargs={ "dataset_name": body.dataset_name, "policy_type": body.policy_type, "steps": body.steps, "device": body.device, + "continual_learning": body.continual_learning, }, - tty_handoff=None, ) - job_id = result.rsplit("Job ID:", 1)[-1].strip() if "Job ID:" in result else "" - return {"message": result, "job_id": job_id} + return result @app.post("/api/train/stop") async def train_stop(body: TrainStopRequest) -> dict[str, Any]: - result = await service.train.stop_job( - manifest=service.manifest, - kwargs={"job_id": body.job_id}, - tty_handoff=None, - ) - return {"message": result} + return await service.train.stop_job_state(body.job_id) @app.get("/api/train/current") async def train_current() -> dict[str, Any]: - return await service.train.current_job( - manifest=service.manifest, - kwargs={}, - tty_handoff=None, - ) + return await service.train.current_job_state() @app.get("/api/train/status/{job_id}") async def train_status(job_id: str) -> dict[str, Any]: - result = await service.train.job_status( - manifest=service.manifest, - kwargs={"job_id": job_id}, - tty_handoff=None, - ) - return {"message": result} + return await service.train.job_status_state(job_id) @app.get("/api/train/curve/{job_id}") async def train_curve(job_id: str) -> dict[str, Any]: diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py index cc238e83..8d677f59 100644 --- a/tests/test_context_prompt_cache.py +++ b/tests/test_context_prompt_cache.py @@ -8,6 +8,7 @@ import datetime as datetime_module from roboclaw.agent.context import ContextBuilder +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore class _FakeDatetime(real_datetime): @@ -71,3 +72,31 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: assert "Channel: cli" in user_content assert "Chat ID: direct" in user_content assert "Return exactly: OK" in user_content + + +def test_system_prompt_includes_relevant_experience_for_similar_request(tmp_path) -> None: + workspace = _make_workspace(tmp_path) + store = ExperienceStore(workspace) + store.append(ExperienceRecord.create( + task_type="train", + summary="Local training for dataset 'demo' with policy 'act' ended as finished", + outcome="finished", + lesson="This dataset/policy pair finished successfully before.", + dataset="demo", + policy="act", + provider="local", + job_id="job-1", + source="test", + )) + + builder = ContextBuilder(workspace) + messages = builder.build_messages( + history=[], + current_message="Please train demo with act again", + ) + + system_prompt = messages[0]["content"] + assert isinstance(system_prompt, str) + assert "# Relevant Experience" in system_prompt + assert "dataset=demo" in system_prompt + assert "finished" in system_prompt diff --git a/tests/test_experience_replay.py b/tests/test_experience_replay.py new file mode 100644 index 00000000..f1c9e513 --- /dev/null +++ b/tests/test_experience_replay.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore + + +def test_experience_store_records_and_deduplicates(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + record = ExperienceRecord( + timestamp="2026-05-07T10:00:00+00:00", + task_type="train", + summary="demo -> success", + outcome="success", + dataset="demo", + policy="act", + provider="local", + ) + assert store.append(record) is True + assert store.append(record) is False + assert len(store.read_all()) == 1 + + +def test_experience_store_search_filters_by_outcome(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + store.append(ExperienceRecord( + timestamp="2026-05-07T10:00:00+00:00", + task_type="train", + summary="demo -> submitted", + outcome="submitted", + dataset="demo", + policy="act", + provider="local", + )) + store.append(ExperienceRecord( + timestamp="2026-05-07T11:00:00+00:00", + task_type="train", + summary="demo -> success", + outcome="success", + dataset="demo", + policy="act", + provider="local", + )) + + results = store.search( + task_type="train", + dataset="demo", + policy="act", + outcomes=frozenset({"success", "failed", "error", "stopped"}), + provider="local", + ) + assert all(r.outcome != "submitted" for r in results) + assert any(r.outcome == "success" for r in results) diff --git a/tests/test_policy_registry.py b/tests/test_policy_registry.py new file mode 100644 index 00000000..257eebdd --- /dev/null +++ b/tests/test_policy_registry.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from roboclaw.embodied.policy import BasePolicyConfig, PolicyRegistry, policy_registry + + +def test_policy_registry_registers_custom_config() -> None: + registry = PolicyRegistry() + + @registry.register + @dataclass(frozen=True) + class ExamplePolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="example") + + def extra_train_args(self) -> list[str]: + return ["--policy.example=true"] + + config = registry.get("example") + + assert isinstance(config, ExamplePolicyConfig) + assert config.extra_train_args() == ["--policy.example=true"] + assert registry.supported_types() == {"example"} + + +def test_policy_registry_returns_registered_builtin_policy() -> None: + config = policy_registry.get("act") + + assert config.policy_type == "act" + assert config.extra_train_args() == [] + assert "groot" in policy_registry.supported_types() + + +def test_policy_registry_raises_for_unknown_policy() -> None: + with pytest.raises(ValueError, match="Unsupported policy_type 'unknown'"): + policy_registry.get("unknown") diff --git a/tests/test_train_experience.py b/tests/test_train_experience.py new file mode 100644 index 00000000..e56f71e2 --- /dev/null +++ b/tests/test_train_experience.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from roboclaw.agent.experience import ExperienceStore +from roboclaw.embodied.board import Board +from roboclaw.embodied.embodiment.hardware.monitor import HardwareMonitor +from roboclaw.embodied.embodiment.manifest import Manifest +from roboclaw.embodied.service import EmbodiedService + + +@pytest.fixture(autouse=True) +def isolated_roboclaw_home(tmp_path): + with patch( + "roboclaw.embodied.embodiment.lock.get_roboclaw_home", + return_value=tmp_path, + ), patch( + "roboclaw.embodied.embodiment.manifest.helpers.get_roboclaw_home", + return_value=tmp_path, + ): + yield + + +@pytest.fixture() +def embodied_service(tmp_path: Path) -> EmbodiedService: + board = Board() + manifest = Manifest(path=tmp_path / "workspace" / "embodied" / "manifest.json", board=board) + monitor = HardwareMonitor(board=board, manifest=manifest) + return EmbodiedService(hardware_monitor=monitor, board=board, manifest=manifest) + + +def test_train_session_records_experience_and_reuses_it_as_hint(embodied_service: EmbodiedService) -> None: + dataset = SimpleNamespace( + name="demo", + runtime=SimpleNamespace(name="demo", repo_id="local/demo", local_path=Path("/tmp/demo")), + ) + embodied_service.datasets.resolve_runtime_dataset = lambda name: dataset # type: ignore[method-assign] + + with patch( + "roboclaw.embodied.service.session.train.CommandBuilder.train", + return_value=["python3", "-m", "lerobot.scripts.lerobot_train"], + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.run_detached", + new=AsyncMock(side_effect=["job-1", "job-2"]), + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.job_status", + new=AsyncMock(return_value={ + "job_id": "job-1", + "status": "finished", + "running": False, + "pid": 123, + "log_path": "/tmp/job-1.log", + "log_tail": "training complete", + }), + ): + first = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + {"dataset_name": "demo", "policy_type": "act", "steps": 1000, "device": "cuda"}, + )) + assert first["experience_hint"] == "" + + finished = asyncio.run(embodied_service.train.job_status_state("job-1")) + assert finished["status"] == "finished" + + second = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + {"dataset_name": "demo", "policy_type": "act", "steps": 1000, "device": "cuda"}, + )) + + assert "success" in str(second["experience_hint"]) + assert "demo" in str(second["message"]) + + store = ExperienceStore(embodied_service.manifest._path.parent.parent) + records = store.read_all() + outcomes = {record.outcome for record in records} + assert "submitted" in outcomes + assert "success" in outcomes diff --git a/tests/verification/test_inference_config_verifier.py b/tests/verification/test_inference_config_verifier.py new file mode 100644 index 00000000..4f5d3169 --- /dev/null +++ b/tests/verification/test_inference_config_verifier.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path + +import pytest + +from roboclaw.embodied.service.verification import InferenceConfigVerifier + + +def test_inference_config_verifier_rejects_missing_checkpoint(tmp_path: Path) -> None: + verifier = InferenceConfigVerifier() + + with pytest.raises(ValueError, match="Checkpoint path does not exist"): + verifier.verify( + checkpoint_path=str(tmp_path / "missing"), + manifest_snapshot=_manifest_snapshot(), + dataset_local_path=str(tmp_path / "dataset"), + ) + + +def test_inference_config_verifier_rejects_action_dim_mismatch(tmp_path: Path) -> None: + checkpoint = _checkpoint(tmp_path / "policy", action_dim=7) + verifier = InferenceConfigVerifier() + + with pytest.raises(ValueError, match="action_dim"): + verifier.verify( + checkpoint_path=str(checkpoint), + manifest_snapshot=_manifest_snapshot(), + dataset_local_path=str(_dataset(tmp_path / "dataset", version="v2.1", repo_id="local/demo")), + ) + + +def test_inference_config_verifier_warns_on_old_dataset_version(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + checkpoint = _checkpoint(tmp_path / "policy", action_dim=6) + dataset = _dataset(tmp_path / "dataset", version="v2.0", repo_id="local/demo") + verifier = InferenceConfigVerifier() + + with caplog.at_level(logging.WARNING): + verifier.verify( + checkpoint_path=str(checkpoint), + manifest_snapshot=_manifest_snapshot(), + dataset_local_path=str(dataset), + ) + + assert "codebase_version=v2.0 is older than v2.1" in caplog.text + + +def test_inference_config_verifier_accepts_consistent_config(tmp_path: Path) -> None: + checkpoint = _checkpoint(tmp_path / "policy", action_dim=6) + dataset = _dataset(tmp_path / "dataset", version="v2.1", repo_id="local/demo") + verifier = InferenceConfigVerifier() + + verifier.verify( + checkpoint_path=str(checkpoint), + manifest_snapshot=_manifest_snapshot(), + dataset_local_path=str(dataset), + ) + + +def _checkpoint(path: Path, *, action_dim: int) -> Path: + pretrained = path / "pretrained_model" + pretrained.mkdir(parents=True) + (pretrained / "config.json").write_text(json.dumps({ + "action_dim": action_dim, + }), encoding="utf-8") + (pretrained / "train_config.json").write_text(json.dumps({ + "policy": {"device": "cuda"}, + "dataset": {"repo_id": "local/demo"}, + }), encoding="utf-8") + return path + + +def _dataset(path: Path, *, version: str, repo_id: str) -> Path: + info_dir = path / "meta" + info_dir.mkdir(parents=True) + (info_dir / "info.json").write_text(json.dumps({ + "codebase_version": version, + "source_dataset": repo_id, + }), encoding="utf-8") + return path + + +def _manifest_snapshot() -> dict[str, object]: + return { + "device": "cuda", + "arms": [ + {"alias": "follower", "type": "so101_follower"}, + ], + "cameras": [], + "hands": [], + "datasets": {}, + "policies": {}, + } diff --git a/ui/src/domains/training/pages/TrainingCenterPage.tsx b/ui/src/domains/training/pages/TrainingCenterPage.tsx index d9930fb2..18af5e18 100644 --- a/ui/src/domains/training/pages/TrainingCenterPage.tsx +++ b/ui/src/domains/training/pages/TrainingCenterPage.tsx @@ -48,6 +48,7 @@ export default function TrainingCenterPage() { const [policyType, setPolicyType] = useState('act') const [trainSteps, setTrainSteps] = useState(100000) const [trainDevice, setTrainDevice] = useState('cuda') + const [continualLearning, setContinualLearning] = useState(false) const [pullPolicyRepo, setPullPolicyRepo] = useState('') useEffect(() => { @@ -109,6 +110,15 @@ export default function TrainingCenterPage() { +