From d032d24d87f20892850623b2dad3f18f46275108 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Wed, 6 May 2026 11:17:02 +0800 Subject: [PATCH 1/6] Refactor training policy registry --- roboclaw/embodied/command/builder.py | 35 ++++++--------- roboclaw/embodied/policy/__init__.py | 17 ++++++++ roboclaw/embodied/policy/act.py | 36 +++++++++++++++ roboclaw/embodied/policy/base.py | 18 ++++++++ roboclaw/embodied/policy/diffusion.py | 63 +++++++++++++++++++++++++++ roboclaw/embodied/policy/gr00t.py | 20 +++++++++ roboclaw/embodied/policy/pi0.py | 36 +++++++++++++++ roboclaw/embodied/policy/registry.py | 49 +++++++++++++++++++++ roboclaw/embodied/policy/smolvla.py | 27 ++++++++++++ tests/test_policy_registry.py | 38 ++++++++++++++++ 10 files changed, 317 insertions(+), 22 deletions(-) create mode 100644 roboclaw/embodied/policy/__init__.py create mode 100644 roboclaw/embodied/policy/act.py create mode 100644 roboclaw/embodied/policy/base.py create mode 100644 roboclaw/embodied/policy/diffusion.py create mode 100644 roboclaw/embodied/policy/gr00t.py create mode 100644 roboclaw/embodied/policy/pi0.py create mode 100644 roboclaw/embodied/policy/registry.py create mode 100644 roboclaw/embodied/policy/smolvla.py create mode 100644 tests/test_policy_registry.py diff --git a/roboclaw/embodied/command/builder.py b/roboclaw/embodied/command/builder.py index a749e73b..885b6e22 100644 --- a/roboclaw/embodied/command/builder.py +++ b/roboclaw/embodied/command/builder.py @@ -18,6 +18,7 @@ ) from roboclaw.embodied.embodiment.arm.registry import get_model from roboclaw.embodied.embodiment.manifest.binding import ArmBinding, ArmRole, CameraBinding +from roboclaw.embodied.policy import policy_registry _BIMANUAL: dict[str, tuple[str, str]] = { "so101": ("bi_so_follower", "bi_so_leader"), @@ -28,23 +29,7 @@ _DEFAULT_REPLAY_ROOT = Path("~/.cache/huggingface/lerobot").expanduser() -TRAIN_POLICY_TYPES = { - "act", - "diffusion", - "groot", - "multi_task_dit", - "pi0", - "pi0_fast", - "pi05", - "reward_classifier", - "sac", - "sarm", - "smolvla", - "tdmpc", - "vqbet", - "wall_x", - "xvla", -} +TRAIN_POLICY_TYPES = policy_registry.supported_types() # ── Private helper functions ───────────────────────────────────────────── @@ -288,12 +273,17 @@ def train( device: str = "cuda", ) -> list[str]: """Build training argv (standalone lerobot-train, not through wrapper).""" - if policy_type not in TRAIN_POLICY_TYPES: - allowed = ", ".join(sorted(TRAIN_POLICY_TYPES)) - raise ActionError(f"Unsupported policy_type '{policy_type}'. Expected one of: {allowed}.") + try: + policy_config = policy_registry.get(policy_type) + except ValueError as exc: + raise ActionError(str(exc)) from exc policies_root = manifest.snapshot.get("policies", {}).get("root", "") - output_dir_name = dataset.name if policy_type == "act" else f"{dataset.name}_{policy_type}" + output_dir_name = ( + dataset.name + if policy_config.policy_type == "act" + else f"{dataset.name}_{policy_config.policy_type}" + ) output_dir = Path(policies_root).expanduser() / output_dir_name argv = [ @@ -301,13 +291,14 @@ def train( f"--dataset.repo_id={dataset.repo_id}", f"--dataset.root={dataset.local_path}", "--dataset.video_backend=pyav", - f"--policy.type={policy_type}", + f"--policy.type={policy_config.policy_type}", "--policy.push_to_hub=false", f"--policy.repo_id={dataset.repo_id}", f"--output_dir={output_dir}", f"--steps={steps}", f"--policy.device={device}", ] + argv.extend(policy_config.extra_train_args()) # Resume if a previous checkpoint exists if output_dir.is_dir(): diff --git a/roboclaw/embodied/policy/__init__.py b/roboclaw/embodied/policy/__init__.py new file mode 100644 index 00000000..6e81644f --- /dev/null +++ b/roboclaw/embodied/policy/__init__.py @@ -0,0 +1,17 @@ +"""Policy config registry for embodied training.""" + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import PolicyRegistry, policy_registry + +from . import act as _act +from . import diffusion as _diffusion +from . import gr00t as _gr00t +from . import pi0 as _pi0 +from . import smolvla as _smolvla + +__all__ = [ + "BasePolicyConfig", + "PolicyRegistry", + "policy_registry", +] + diff --git a/roboclaw/embodied/policy/act.py b/roboclaw/embodied/policy/act.py new file mode 100644 index 00000000..c81314d8 --- /dev/null +++ b/roboclaw/embodied/policy/act.py @@ -0,0 +1,36 @@ +"""ACT-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class ActPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="act") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class MultiTaskDiTPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="multi_task_dit") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class VQBeTPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="vqbet") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/base.py b/roboclaw/embodied/policy/base.py new file mode 100644 index 00000000..cdcdd0dc --- /dev/null +++ b/roboclaw/embodied/policy/base.py @@ -0,0 +1,18 @@ +"""Base config types for trainable LeRobot policies.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class BasePolicyConfig(ABC): + """Train-time policy config registered against a LeRobot policy type.""" + + policy_type: str = field(init=False, default="") + + @abstractmethod + def extra_train_args(self) -> list[str]: + """Return policy-specific ``lerobot-train`` CLI args.""" + diff --git a/roboclaw/embodied/policy/diffusion.py b/roboclaw/embodied/policy/diffusion.py new file mode 100644 index 00000000..2a4bb7b8 --- /dev/null +++ b/roboclaw/embodied/policy/diffusion.py @@ -0,0 +1,63 @@ +"""Diffusion, RL, and auxiliary policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class DiffusionPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="diffusion") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class TDMPCPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="tdmpc") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class SACPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="sac") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class RewardClassifierPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="reward_classifier") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class SARMPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="sarm") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class WallXPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="wall_x") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/gr00t.py b/roboclaw/embodied/policy/gr00t.py new file mode 100644 index 00000000..fc41d1ac --- /dev/null +++ b/roboclaw/embodied/policy/gr00t.py @@ -0,0 +1,20 @@ +"""GR00T-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class GR00TPolicyConfig(BasePolicyConfig): + """NVIDIA GR00T N1.""" + + policy_type: str = field(init=False, default="groot") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/pi0.py b/roboclaw/embodied/policy/pi0.py new file mode 100644 index 00000000..de9a883f --- /dev/null +++ b/roboclaw/embodied/policy/pi0.py @@ -0,0 +1,36 @@ +"""PI-family policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class Pi0PolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi0") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class Pi0FastPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi0_fast") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class Pi05PolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="pi05") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/roboclaw/embodied/policy/registry.py b/roboclaw/embodied/policy/registry.py new file mode 100644 index 00000000..506373ac --- /dev/null +++ b/roboclaw/embodied/policy/registry.py @@ -0,0 +1,49 @@ +"""Registry for trainable LeRobot policy configs.""" + +from __future__ import annotations + +from dataclasses import is_dataclass + +from roboclaw.embodied.policy.base import BasePolicyConfig + + +class PolicyRegistry: + """Register and resolve train-time policy configs by policy type.""" + + def __init__(self) -> None: + self._config_types: dict[str, type[BasePolicyConfig]] = {} + + def register(self, config_cls: type[BasePolicyConfig]) -> type[BasePolicyConfig]: + """Register a config dataclass and return it for decorator usage.""" + if not issubclass(config_cls, BasePolicyConfig): + raise TypeError("Policy config must inherit from BasePolicyConfig.") + if not is_dataclass(config_cls): + raise TypeError("Policy config must be a dataclass.") + + config = config_cls() + policy_type = config.policy_type.strip() + if not policy_type: + raise ValueError("Policy config must declare a non-empty policy_type.") + if policy_type in self._config_types: + raise ValueError(f"Policy '{policy_type}' is already registered.") + + self._config_types[policy_type] = config_cls + return config_cls + + def get(self, policy_type: str) -> BasePolicyConfig: + """Instantiate the config registered for ``policy_type``.""" + config_cls = self._config_types.get(policy_type) + if config_cls is None: + allowed = ", ".join(sorted(self.supported_types())) + raise ValueError( + f"Unsupported policy_type '{policy_type}'. Expected one of: {allowed}." + ) + return config_cls() + + def supported_types(self) -> set[str]: + """Return a copy of all registered policy types.""" + return set(self._config_types) + + +policy_registry = PolicyRegistry() + diff --git a/roboclaw/embodied/policy/smolvla.py b/roboclaw/embodied/policy/smolvla.py new file mode 100644 index 00000000..c6b6a89d --- /dev/null +++ b/roboclaw/embodied/policy/smolvla.py @@ -0,0 +1,27 @@ +"""VLA policy training configs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from roboclaw.embodied.policy.base import BasePolicyConfig +from roboclaw.embodied.policy.registry import policy_registry + + +@policy_registry.register +@dataclass(frozen=True) +class SmolVLAConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="smolvla") + + def extra_train_args(self) -> list[str]: + return [] + + +@policy_registry.register +@dataclass(frozen=True) +class XVLAPolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="xvla") + + def extra_train_args(self) -> list[str]: + return [] + diff --git a/tests/test_policy_registry.py b/tests/test_policy_registry.py new file mode 100644 index 00000000..257eebdd --- /dev/null +++ b/tests/test_policy_registry.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import pytest + +from roboclaw.embodied.policy import BasePolicyConfig, PolicyRegistry, policy_registry + + +def test_policy_registry_registers_custom_config() -> None: + registry = PolicyRegistry() + + @registry.register + @dataclass(frozen=True) + class ExamplePolicyConfig(BasePolicyConfig): + policy_type: str = field(init=False, default="example") + + def extra_train_args(self) -> list[str]: + return ["--policy.example=true"] + + config = registry.get("example") + + assert isinstance(config, ExamplePolicyConfig) + assert config.extra_train_args() == ["--policy.example=true"] + assert registry.supported_types() == {"example"} + + +def test_policy_registry_returns_registered_builtin_policy() -> None: + config = policy_registry.get("act") + + assert config.policy_type == "act" + assert config.extra_train_args() == [] + assert "groot" in policy_registry.supported_types() + + +def test_policy_registry_raises_for_unknown_policy() -> None: + with pytest.raises(ValueError, match="Unsupported policy_type 'unknown'"): + policy_registry.get("unknown") From 1ec1b006b3a8333db11363306a79635e99175aaa Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Thu, 7 May 2026 16:09:06 +0800 Subject: [PATCH 2/6] feat: add experience-guided training hints --- roboclaw/agent/__init__.py | 3 +- roboclaw/agent/context.py | 15 +- roboclaw/agent/experience.py | 230 +++++++++++++++++++++ roboclaw/embodied/service/session/train.py | 214 +++++++++++++++++-- roboclaw/http/routes/train.py | 26 +-- tests/test_context_prompt_cache.py | 29 +++ tests/test_train_experience.py | 82 ++++++++ 7 files changed, 558 insertions(+), 41 deletions(-) create mode 100644 roboclaw/agent/experience.py create mode 100644 tests/test_train_experience.py diff --git a/roboclaw/agent/__init__.py b/roboclaw/agent/__init__.py index c4c3bacb..d46e966d 100644 --- a/roboclaw/agent/__init__.py +++ b/roboclaw/agent/__init__.py @@ -1,8 +1,9 @@ """Agent core module.""" from roboclaw.agent.context import ContextBuilder +from roboclaw.agent.experience import ExperienceStore from roboclaw.agent.loop import AgentLoop from roboclaw.agent.memory import MemoryStore from roboclaw.agent.skills import SkillsLoader -__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"] +__all__ = ["AgentLoop", "ContextBuilder", "ExperienceStore", "MemoryStore", "SkillsLoader"] diff --git a/roboclaw/agent/context.py b/roboclaw/agent/context.py index 6ad2d0f6..1fef62d4 100644 --- a/roboclaw/agent/context.py +++ b/roboclaw/agent/context.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +from roboclaw.agent.experience import ExperienceStore from roboclaw.utils.helpers import current_time_str from roboclaw.agent.memory import MemoryStore @@ -22,9 +23,15 @@ class ContextBuilder: def __init__(self, workspace: Path): self.workspace = workspace self.memory = MemoryStore(workspace) + self.experiences = ExperienceStore(workspace) self.skills = SkillsLoader(workspace) - def build_system_prompt(self, skill_names: list[str] | None = None) -> str: + def build_system_prompt( + self, + skill_names: list[str] | None = None, + *, + current_message: str = "", + ) -> str: """Build the system prompt from identity, bootstrap files, memory, and skills.""" parts = [self._get_identity()] @@ -36,6 +43,10 @@ def build_system_prompt(self, skill_names: list[str] | None = None) -> str: if memory: parts.append(f"# Memory\n\n{memory}") + experience_context = self.experiences.build_context(query=current_message) + if experience_context: + parts.append(f"# Relevant Experience\n\n{experience_context}") + always_skills = self.skills.get_always_skills() if always_skills: always_content = self.skills.load_skills_for_context(always_skills) @@ -138,7 +149,7 @@ def build_messages( merged = [{"type": "text", "text": runtime_ctx}] + user_content return [ - {"role": "system", "content": self.build_system_prompt(skill_names)}, + {"role": "system", "content": self.build_system_prompt(skill_names, current_message=current_message)}, *history, {"role": "user", "content": merged}, ] diff --git a/roboclaw/agent/experience.py b/roboclaw/agent/experience.py new file mode 100644 index 00000000..c03f4e48 --- /dev/null +++ b/roboclaw/agent/experience.py @@ -0,0 +1,230 @@ +"""Structured experience storage and retrieval for agent adaptation.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +import json +from pathlib import Path +import re +from typing import Any + +from roboclaw.utils.helpers import ensure_dir + +_TOKEN_RE = re.compile(r"[a-z0-9][a-z0-9_./:-]*") + + +def _normalize_text(value: str | None) -> str: + return (value or "").strip() + + +def _normalize_key(value: str | None) -> str: + return _normalize_text(value).lower() + + +def _tokenize(*values: str) -> set[str]: + tokens: set[str] = set() + for value in values: + for match in _TOKEN_RE.finditer(value.lower()): + token = match.group(0) + if len(token) >= 2: + tokens.add(token) + return tokens + + +@dataclass(frozen=True) +class ExperienceRecord: + timestamp: str + task_type: str + summary: str + outcome: str + lesson: str = "" + dataset: str = "" + policy: str = "" + provider: str = "" + job_id: str = "" + source: str = "" + error: str = "" + checkpoint_path: str = "" + dataset_path: str = "" + task_name: str = "" + + @classmethod + def create( + cls, + *, + task_type: str, + summary: str, + outcome: str, + lesson: str = "", + dataset: str = "", + policy: str = "", + provider: str = "", + job_id: str = "", + source: str = "", + error: str = "", + checkpoint_path: str = "", + dataset_path: str = "", + task_name: str = "", + ) -> "ExperienceRecord": + return cls( + timestamp=datetime.now(UTC).isoformat(), + task_type=_normalize_text(task_type), + summary=_normalize_text(summary), + outcome=_normalize_text(outcome), + lesson=_normalize_text(lesson), + dataset=_normalize_text(dataset), + policy=_normalize_text(policy), + provider=_normalize_text(provider), + job_id=_normalize_text(job_id), + source=_normalize_text(source), + error=_normalize_text(error), + checkpoint_path=_normalize_text(checkpoint_path), + dataset_path=_normalize_text(dataset_path), + task_name=_normalize_text(task_name), + ) + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "ExperienceRecord": + return cls( + timestamp=_normalize_text(str(payload.get("timestamp") or "")), + task_type=_normalize_text(str(payload.get("task_type") or "")), + summary=_normalize_text(str(payload.get("summary") or "")), + outcome=_normalize_text(str(payload.get("outcome") or "")), + lesson=_normalize_text(str(payload.get("lesson") or "")), + dataset=_normalize_text(str(payload.get("dataset") or "")), + policy=_normalize_text(str(payload.get("policy") or "")), + provider=_normalize_text(str(payload.get("provider") or "")), + job_id=_normalize_text(str(payload.get("job_id") or "")), + source=_normalize_text(str(payload.get("source") or "")), + error=_normalize_text(str(payload.get("error") or "")), + checkpoint_path=_normalize_text(str(payload.get("checkpoint_path") or "")), + dataset_path=_normalize_text(str(payload.get("dataset_path") or "")), + task_name=_normalize_text(str(payload.get("task_name") or "")), + ) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def fingerprint(self) -> str: + parts = ( + _normalize_key(self.task_type), + _normalize_key(self.summary), + _normalize_key(self.outcome), + _normalize_key(self.dataset), + _normalize_key(self.policy), + _normalize_key(self.provider), + _normalize_key(self.job_id), + _normalize_key(self.error), + _normalize_key(self.lesson), + ) + return "|".join(parts) + + +class ExperienceStore: + """Append-only JSONL store plus lightweight experience retrieval.""" + + def __init__(self, workspace: Path): + self.memory_dir = ensure_dir(workspace / "memory") + self.experience_file = self.memory_dir / "EXPERIENCES.jsonl" + + def read_all(self) -> list[ExperienceRecord]: + if not self.experience_file.exists(): + return [] + records: list[ExperienceRecord] = [] + for raw_line in self.experience_file.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line: + continue + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(payload, dict): + records.append(ExperienceRecord.from_dict(payload)) + return records + + def append(self, record: ExperienceRecord) -> bool: + records = self.read_all() + fingerprint = record.fingerprint() + if any(existing.fingerprint() == fingerprint for existing in records): + return False + with self.experience_file.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record.to_dict(), ensure_ascii=False) + "\n") + return True + + def search( + self, + *, + query: str = "", + task_type: str = "", + dataset: str = "", + policy: str = "", + provider: str = "", + limit: int = 3, + ) -> list[ExperienceRecord]: + query_tokens = _tokenize(query, dataset, policy, provider, task_type) + scored: list[tuple[int, ExperienceRecord]] = [] + for record in self.read_all(): + score = 0 + if task_type and _normalize_key(record.task_type) == _normalize_key(task_type): + score += 6 + if dataset and _normalize_key(record.dataset) == _normalize_key(dataset): + score += 8 + if policy and _normalize_key(record.policy) == _normalize_key(policy): + score += 6 + if provider and _normalize_key(record.provider) == _normalize_key(provider): + score += 4 + + record_tokens = _tokenize( + record.summary, + record.lesson, + record.dataset, + record.policy, + record.provider, + record.task_name, + record.error, + ) + score += len(query_tokens & record_tokens) + if score <= 0: + continue + scored.append((score, record)) + + scored.sort(key=lambda item: (item[0], item[1].timestamp), reverse=True) + return [record for _, record in scored[:limit]] + + def build_context( + self, + *, + query: str = "", + task_type: str = "", + dataset: str = "", + policy: str = "", + provider: str = "", + limit: int = 3, + ) -> str: + records = self.search( + query=query, + task_type=task_type, + dataset=dataset, + policy=policy, + provider=provider, + limit=limit, + ) + if not records: + return "" + + lines = [ + "Use these past outcomes as hints. Reuse what worked and avoid repeating failures." + ] + for record in records: + fields = [record.outcome] + if record.dataset: + fields.append(f"dataset={record.dataset}") + if record.policy: + fields.append(f"policy={record.policy}") + if record.provider: + fields.append(f"provider={record.provider}") + summary = record.lesson or record.summary + lines.append(f"- [{record.timestamp[:19]}] {'; '.join(fields)} -> {summary}") + return "\n".join(lines) diff --git a/roboclaw/embodied/service/session/train.py b/roboclaw/embodied/service/session/train.py index 3c1cffe3..7ceb3ecb 100644 --- a/roboclaw/embodied/service/session/train.py +++ b/roboclaw/embodied/service/session/train.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore from roboclaw.embodied.command import CommandBuilder, logs_dir if TYPE_CHECKING: @@ -23,60 +24,117 @@ class TrainSession: def __init__(self, parent: EmbodiedService) -> None: self._parent = parent + self._experiences = ExperienceStore(parent.manifest._path.parent) + self._job_specs: dict[str, dict[str, str]] = {} - async def train( + async def start_job_state( self, manifest: Manifest, kwargs: dict[str, Any], - tty_handoff: Any, - ) -> str: + ) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - dataset_name = kwargs.get("dataset_name", "default") + dataset_name = str(kwargs.get("dataset_name", "default") or "default") + policy_type = str(kwargs.get("policy_type", "act") or "act") + steps = int(kwargs.get("steps", 100_000) or 100_000) + device = str(kwargs.get("device", "cuda") or "cuda") dataset = self._parent.datasets.resolve_runtime_dataset(dataset_name) + experience_hint = self._build_experience_hint(dataset_name=dataset_name, policy_type=policy_type) argv = CommandBuilder.train( manifest, dataset=dataset.runtime, - policy_type=kwargs.get("policy_type", "act"), - steps=kwargs.get("steps", 100_000), - device=kwargs.get("device", "cuda"), + policy_type=policy_type, + steps=steps, + device=device, ) job_id = await SubprocessExecutor().run_detached(argv=argv, log_dir=logs_dir()) - return f"Training started. Job ID: {job_id}" + self._job_specs[job_id] = { + "dataset_name": dataset_name, + "policy_type": policy_type, + "dataset_path": str(dataset.runtime.local_path), + "provider": "local", + } + state: dict[str, str | int | bool | None] = { + "job_id": job_id, + "status": "running", + "running": True, + "pid": None, + "log_path": str(SubprocessExecutor()._job_log_path(job_id, logs_dir())), + "log_tail": "", + "dataset_name": dataset_name, + "policy_type": policy_type, + "dataset_path": str(dataset.runtime.local_path), + "provider": "local", + "experience_hint": experience_hint, + } + state["message"] = self._format_status_message(state) + self._record_experience( + job_id=job_id, + status="submitted", + log_tail="", + log_path=str(state.get("log_path") or ""), + ) + return state - async def stop_job( + async def train( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, ) -> str: + state = await self.start_job_state(manifest, kwargs) + return str(state["message"]) + + async def stop_job_state(self, job_id: str) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - job_id = kwargs.get("job_id", "") status = await SubprocessExecutor().stop_job(job_id=job_id, log_dir=logs_dir()) - return "\n".join(f"{key}: {value}" for key, value in status.items()) + return self._enrich_state(job_id, status) - async def job_status( + async def stop_job( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, ) -> str: + job_id = kwargs.get("job_id", "") + status = await self.stop_job_state(str(job_id)) + return str(status["message"]) + + async def job_status_state(self, job_id: str) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - job_id = kwargs.get("job_id", "") status = await SubprocessExecutor().job_status(job_id=job_id, log_dir=logs_dir()) - return "\n".join(f"{key}: {value}" for key, value in status.items()) + return self._enrich_state(job_id, status) - async def current_job( + async def job_status( self, manifest: Manifest, kwargs: dict[str, Any], tty_handoff: Any, - ) -> dict[str, str | int | bool | None]: + ) -> str: + job_id = kwargs.get("job_id", "") + status = await self.job_status_state(str(job_id)) + return str(status["message"]) + + async def current_job_state(self) -> dict[str, str | int | bool | None]: from roboclaw.embodied.executor import SubprocessExecutor - return await SubprocessExecutor().latest_running_job(log_dir=logs_dir()) + status = await SubprocessExecutor().latest_running_job(log_dir=logs_dir()) + job_id = str(status.get("job_id") or "") + if not job_id: + enriched = self._enrich_state("", status) + enriched["message"] = self._format_status_message(enriched) + return enriched + return self._enrich_state(job_id, status) + + async def current_job( + self, + manifest: Manifest, + kwargs: dict[str, Any], + tty_handoff: Any, + ) -> dict[str, str | int | bool | None]: + return await self.current_job_state() def curve_data(self, job_id: str) -> dict[str, Any]: job_id = job_id.strip() @@ -128,6 +186,109 @@ def list_policies(self, manifest: Manifest | None = None) -> str: return "No policies found." return json.dumps(policies, indent=2, ensure_ascii=False) + def _build_experience_hint(self, *, dataset_name: str, policy_type: str) -> str: + records = self._experiences.search( + task_type="train", + dataset=dataset_name, + policy=policy_type, + provider="local", + limit=2, + ) + if not records: + return "" + lines = [] + for record in records: + lesson = record.lesson or record.summary + lines.append(f"{record.outcome}: {lesson}") + return "\n".join(lines) + + def _enrich_state( + self, + job_id: str, + state: dict[str, str | int | bool | None], + ) -> dict[str, str | int | bool | None]: + enriched = dict(state) + metadata = self._job_specs.get(job_id, {}) + enriched.setdefault("job_id", job_id) + enriched.setdefault("provider", metadata.get("provider", "local")) + enriched.setdefault("dataset_name", metadata.get("dataset_name", "")) + enriched.setdefault("policy_type", metadata.get("policy_type", "")) + enriched.setdefault("dataset_path", metadata.get("dataset_path", "")) + enriched.setdefault("experience_hint", "") + enriched["message"] = self._format_status_message(enriched) + + status_text = str(enriched.get("status") or "").lower() + if status_text in _TERMINAL_TRAIN_STATUSES and job_id: + self._record_experience( + job_id=job_id, + status=status_text, + log_tail=str(enriched.get("log_tail") or ""), + log_path=str(enriched.get("log_path") or ""), + ) + return enriched + + def _format_status_message(self, state: dict[str, str | int | bool | None]) -> str: + order = ( + "job_id", + "status", + "running", + "pid", + "dataset_name", + "policy_type", + "provider", + "dataset_path", + "log_path", + "experience_hint", + ) + lines: list[str] = [] + seen: set[str] = set() + for key in order: + value = state.get(key) + if value in {None, ""}: + continue + seen.add(key) + lines.append(f"{key}: {value}") + for key, value in state.items(): + if key in seen or value in {None, ""}: + continue + lines.append(f"{key}: {value}") + return "\n".join(lines) + + def _record_experience( + self, + *, + job_id: str, + status: str, + log_tail: str, + log_path: str, + ) -> None: + metadata = self._job_specs.get(job_id, {}) + dataset_name = metadata.get("dataset_name", "") + policy_type = metadata.get("policy_type", "") + provider = metadata.get("provider", "local") + if not dataset_name and not policy_type and not job_id: + return + lesson = _status_lesson(status, log_tail) + summary = ( + f"Local training for dataset '{dataset_name or ''}' " + f"with policy '{policy_type or ''}' ended as {status}" + ) + self._experiences.append(ExperienceRecord.create( + task_type="train", + summary=summary, + outcome=status, + lesson=lesson, + dataset=dataset_name, + policy=policy_type, + provider=provider, + job_id=job_id, + source="train_session", + error=log_tail if status in {"failed", "error"} else "", + dataset_path=metadata.get("dataset_path", ""), + task_name=job_id, + checkpoint_path=log_path, + )) + def _scan_policies(root: Path) -> list[dict[str, Any]]: """Scan policy directories under *root* and return summary dicts.""" policies: list[dict[str, Any]] = [] @@ -157,6 +318,7 @@ def _enrich_policy_entry(entry: dict[str, Any], checkpoint_dir: Path) -> None: _JOB_ID_RE = re.compile(r"^[A-Za-z0-9-]+$") +_TERMINAL_TRAIN_STATUSES = {"finished", "stopped", "failed", "error", "missing", "idle"} _TRAIN_LOG_RE = re.compile( r"step:(?P\S+).*?" r"ep:(?P\d+).*?" @@ -246,3 +408,21 @@ def _parse_training_curve_line(line: str) -> dict[str, Any] | None: "epoch": epoch, "loss": loss, } + + +def _status_lesson(status: str, log_tail: str) -> str: + normalized = status.strip().lower() + if normalized == "submitted": + return "A similar run was submitted successfully." + if normalized == "finished": + return "A similar run finished previously." + if normalized == "stopped": + return "A similar run had to be stopped manually." + if normalized in {"failed", "error"}: + tail = log_tail.strip() + if tail: + return f"Recent failure signal: {tail.splitlines()[-1]}" + return "A similar run failed previously." + if normalized == "missing": + return "A similar run lost its local process metadata before completion." + return f"A similar run reached status '{status}'." diff --git a/roboclaw/http/routes/train.py b/roboclaw/http/routes/train.py index ed4c8085..5a946f17 100644 --- a/roboclaw/http/routes/train.py +++ b/roboclaw/http/routes/train.py @@ -26,7 +26,7 @@ def register_train_routes(app: FastAPI, service: EmbodiedService) -> None: @app.post("/api/train/start") async def train_start(body: TrainStartRequest) -> dict[str, Any]: - result = await service.train.train( + result = await service.train.start_job_state( manifest=service.manifest, kwargs={ "dataset_name": body.dataset_name, @@ -34,36 +34,20 @@ async def train_start(body: TrainStartRequest) -> dict[str, Any]: "steps": body.steps, "device": body.device, }, - tty_handoff=None, ) - job_id = result.rsplit("Job ID:", 1)[-1].strip() if "Job ID:" in result else "" - return {"message": result, "job_id": job_id} + return result @app.post("/api/train/stop") async def train_stop(body: TrainStopRequest) -> dict[str, Any]: - result = await service.train.stop_job( - manifest=service.manifest, - kwargs={"job_id": body.job_id}, - tty_handoff=None, - ) - return {"message": result} + return await service.train.stop_job_state(body.job_id) @app.get("/api/train/current") async def train_current() -> dict[str, Any]: - return await service.train.current_job( - manifest=service.manifest, - kwargs={}, - tty_handoff=None, - ) + return await service.train.current_job_state() @app.get("/api/train/status/{job_id}") async def train_status(job_id: str) -> dict[str, Any]: - result = await service.train.job_status( - manifest=service.manifest, - kwargs={"job_id": job_id}, - tty_handoff=None, - ) - return {"message": result} + return await service.train.job_status_state(job_id) @app.get("/api/train/curve/{job_id}") async def train_curve(job_id: str) -> dict[str, Any]: diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py index cc238e83..8d677f59 100644 --- a/tests/test_context_prompt_cache.py +++ b/tests/test_context_prompt_cache.py @@ -8,6 +8,7 @@ import datetime as datetime_module from roboclaw.agent.context import ContextBuilder +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore class _FakeDatetime(real_datetime): @@ -71,3 +72,31 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: assert "Channel: cli" in user_content assert "Chat ID: direct" in user_content assert "Return exactly: OK" in user_content + + +def test_system_prompt_includes_relevant_experience_for_similar_request(tmp_path) -> None: + workspace = _make_workspace(tmp_path) + store = ExperienceStore(workspace) + store.append(ExperienceRecord.create( + task_type="train", + summary="Local training for dataset 'demo' with policy 'act' ended as finished", + outcome="finished", + lesson="This dataset/policy pair finished successfully before.", + dataset="demo", + policy="act", + provider="local", + job_id="job-1", + source="test", + )) + + builder = ContextBuilder(workspace) + messages = builder.build_messages( + history=[], + current_message="Please train demo with act again", + ) + + system_prompt = messages[0]["content"] + assert isinstance(system_prompt, str) + assert "# Relevant Experience" in system_prompt + assert "dataset=demo" in system_prompt + assert "finished" in system_prompt diff --git a/tests/test_train_experience.py b/tests/test_train_experience.py new file mode 100644 index 00000000..b2ccd74c --- /dev/null +++ b/tests/test_train_experience.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from roboclaw.agent.experience import ExperienceStore +from roboclaw.embodied.board import Board +from roboclaw.embodied.embodiment.hardware.monitor import HardwareMonitor +from roboclaw.embodied.embodiment.manifest import Manifest +from roboclaw.embodied.service import EmbodiedService + + +@pytest.fixture(autouse=True) +def isolated_roboclaw_home(tmp_path): + with patch( + "roboclaw.embodied.embodiment.lock.get_roboclaw_home", + return_value=tmp_path, + ), patch( + "roboclaw.embodied.embodiment.manifest.helpers.get_roboclaw_home", + return_value=tmp_path, + ): + yield + + +@pytest.fixture() +def embodied_service(tmp_path: Path) -> EmbodiedService: + board = Board() + manifest = Manifest(path=tmp_path / "manifest.json", board=board) + monitor = HardwareMonitor(board=board, manifest=manifest) + return EmbodiedService(hardware_monitor=monitor, board=board, manifest=manifest) + + +def test_train_session_records_experience_and_reuses_it_as_hint(embodied_service: EmbodiedService) -> None: + dataset = SimpleNamespace( + name="demo", + runtime=SimpleNamespace(name="demo", repo_id="local/demo", local_path=Path("/tmp/demo")), + ) + embodied_service.datasets.resolve_runtime_dataset = lambda name: dataset # type: ignore[method-assign] + + with patch( + "roboclaw.embodied.service.session.train.CommandBuilder.train", + return_value=["python3", "-m", "lerobot.scripts.lerobot_train"], + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.run_detached", + new=AsyncMock(side_effect=["job-1", "job-2"]), + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.job_status", + new=AsyncMock(return_value={ + "job_id": "job-1", + "status": "finished", + "running": False, + "pid": 123, + "log_path": "/tmp/job-1.log", + "log_tail": "training complete", + }), + ): + first = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + {"dataset_name": "demo", "policy_type": "act", "steps": 1000, "device": "cuda"}, + )) + assert first["experience_hint"] == "" + + finished = asyncio.run(embodied_service.train.job_status_state("job-1")) + assert finished["status"] == "finished" + + second = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + {"dataset_name": "demo", "policy_type": "act", "steps": 1000, "device": "cuda"}, + )) + + assert "finished" in str(second["experience_hint"]) + assert "demo" in str(second["message"]) + + store = ExperienceStore(embodied_service.manifest._path.parent) + records = store.read_all() + outcomes = {record.outcome for record in records} + assert "submitted" in outcomes + assert "finished" in outcomes From 547ab7062e3d3c4b769614c42d7a054a14646521 Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Thu, 7 May 2026 16:45:18 +0800 Subject: [PATCH 3/6] fix: align ExperienceStore path with agent workspace root --- roboclaw/embodied/service/session/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roboclaw/embodied/service/session/train.py b/roboclaw/embodied/service/session/train.py index 7ceb3ecb..1439b0c1 100644 --- a/roboclaw/embodied/service/session/train.py +++ b/roboclaw/embodied/service/session/train.py @@ -24,7 +24,7 @@ class TrainSession: def __init__(self, parent: EmbodiedService) -> None: self._parent = parent - self._experiences = ExperienceStore(parent.manifest._path.parent) + self._experiences = ExperienceStore(parent.manifest._path.parent.parent) self._job_specs: dict[str, dict[str, str]] = {} async def start_job_state( From 8c698523f3873a038cfcec62bca81f61ed83eb6b Mon Sep 17 00:00:00 2001 From: Xiaofang Wu <3642115339@qq.com> Date: Thu, 7 May 2026 18:32:01 +0800 Subject: [PATCH 4/6] fix: isolate ExperienceStore path in tests, filter submitted outcomes from hints --- roboclaw/agent/experience.py | 36 +++ roboclaw/embodied/service/session/train.py | 140 ++++++++++- roboclaw/http/routes/train.py | 2 + tests/test_experience_replay.py | 233 ++++++++++++++++++ tests/test_train_experience.py | 8 +- .../training/pages/TrainingCenterPage.tsx | 11 + .../training/store/useTrainingStore.ts | 8 +- 7 files changed, 421 insertions(+), 17 deletions(-) create mode 100644 tests/test_experience_replay.py diff --git a/roboclaw/agent/experience.py b/roboclaw/agent/experience.py index c03f4e48..e03f28c5 100644 --- a/roboclaw/agent/experience.py +++ b/roboclaw/agent/experience.py @@ -40,6 +40,7 @@ class ExperienceRecord: outcome: str lesson: str = "" dataset: str = "" + replay_datasets: str = "" policy: str = "" provider: str = "" job_id: str = "" @@ -58,6 +59,7 @@ def create( outcome: str, lesson: str = "", dataset: str = "", + replay_datasets: str = "", policy: str = "", provider: str = "", job_id: str = "", @@ -74,6 +76,7 @@ def create( outcome=_normalize_text(outcome), lesson=_normalize_text(lesson), dataset=_normalize_text(dataset), + replay_datasets=_normalize_text(replay_datasets), policy=_normalize_text(policy), provider=_normalize_text(provider), job_id=_normalize_text(job_id), @@ -93,6 +96,7 @@ def from_dict(cls, payload: dict[str, Any]) -> "ExperienceRecord": outcome=_normalize_text(str(payload.get("outcome") or "")), lesson=_normalize_text(str(payload.get("lesson") or "")), dataset=_normalize_text(str(payload.get("dataset") or "")), + replay_datasets=_normalize_text(str(payload.get("replay_datasets") or "")), policy=_normalize_text(str(payload.get("policy") or "")), provider=_normalize_text(str(payload.get("provider") or "")), job_id=_normalize_text(str(payload.get("job_id") or "")), @@ -112,6 +116,7 @@ def fingerprint(self) -> str: _normalize_key(self.summary), _normalize_key(self.outcome), _normalize_key(self.dataset), + _normalize_key(self.replay_datasets), _normalize_key(self.policy), _normalize_key(self.provider), _normalize_key(self.job_id), @@ -161,11 +166,14 @@ def search( dataset: str = "", policy: str = "", provider: str = "", + outcomes: frozenset[str] | None = None, limit: int = 3, ) -> list[ExperienceRecord]: query_tokens = _tokenize(query, dataset, policy, provider, task_type) scored: list[tuple[int, ExperienceRecord]] = [] for record in self.read_all(): + if outcomes is not None and record.outcome not in outcomes: + continue score = 0 if task_type and _normalize_key(record.task_type) == _normalize_key(task_type): score += 6 @@ -221,6 +229,8 @@ def build_context( fields = [record.outcome] if record.dataset: fields.append(f"dataset={record.dataset}") + if record.replay_datasets: + fields.append(f"replay={record.replay_datasets}") if record.policy: fields.append(f"policy={record.policy}") if record.provider: @@ -228,3 +238,29 @@ def build_context( summary = record.lesson or record.summary lines.append(f"- [{record.timestamp[:19]}] {'; '.join(fields)} -> {summary}") return "\n".join(lines) + + def get_replay_datasets( + self, + current_dataset: str, + policy: str, + max_datasets: int = 3, + ) -> list[str]: + current_key = _normalize_key(current_dataset) + policy_key = _normalize_key(policy) + unique: list[str] = [] + seen: set[str] = set() + records = sorted(self.read_all(), key=lambda record: record.timestamp, reverse=True) + for record in records: + dataset_name = _normalize_text(record.dataset) + dataset_key = _normalize_key(dataset_name) + if record.task_type != "train" or record.outcome != "success": + continue + if not dataset_name or dataset_key == current_key or dataset_key in seen: + continue + if policy_key and _normalize_key(record.policy) != policy_key: + continue + seen.add(dataset_key) + unique.append(dataset_name) + if len(unique) >= max_datasets: + break + return unique diff --git a/roboclaw/embodied/service/session/train.py b/roboclaw/embodied/service/session/train.py index 1439b0c1..0cf6b1b2 100644 --- a/roboclaw/embodied/service/session/train.py +++ b/roboclaw/embodied/service/session/train.py @@ -2,13 +2,16 @@ from __future__ import annotations +import asyncio import json import re from collections import deque +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any from roboclaw.agent.experience import ExperienceRecord, ExperienceStore +from roboclaw.data.datasets import DatasetRuntimeRef from roboclaw.embodied.command import CommandBuilder, logs_dir if TYPE_CHECKING: @@ -39,10 +42,21 @@ async def start_job_state( steps = int(kwargs.get("steps", 100_000) or 100_000) device = str(kwargs.get("device", "cuda") or "cuda") dataset = self._parent.datasets.resolve_runtime_dataset(dataset_name) - experience_hint = self._build_experience_hint(dataset_name=dataset_name, policy_type=policy_type) + continual_learning = bool(kwargs.get("continual_learning", False)) + training_runtime, replay_datasets = await self._resolve_training_runtime( + dataset_name=dataset_name, + policy_type=policy_type, + dataset=dataset.runtime, + continual_learning=continual_learning, + ) + experience_hint = self._build_experience_hint( + dataset_name=dataset_name, + policy_type=policy_type, + replay_datasets=replay_datasets, + ) argv = CommandBuilder.train( manifest, - dataset=dataset.runtime, + dataset=training_runtime, policy_type=policy_type, steps=steps, device=device, @@ -51,8 +65,9 @@ async def start_job_state( self._job_specs[job_id] = { "dataset_name": dataset_name, "policy_type": policy_type, - "dataset_path": str(dataset.runtime.local_path), + "dataset_path": str(training_runtime.local_path), "provider": "local", + "replay_datasets": ", ".join(replay_datasets), } state: dict[str, str | int | bool | None] = { "job_id": job_id, @@ -63,9 +78,10 @@ async def start_job_state( "log_tail": "", "dataset_name": dataset_name, "policy_type": policy_type, - "dataset_path": str(dataset.runtime.local_path), + "dataset_path": str(training_runtime.local_path), "provider": "local", "experience_hint": experience_hint, + "replay_datasets": ", ".join(replay_datasets), } state["message"] = self._format_status_message(state) self._record_experience( @@ -186,22 +202,100 @@ def list_policies(self, manifest: Manifest | None = None) -> str: return "No policies found." return json.dumps(policies, indent=2, ensure_ascii=False) - def _build_experience_hint(self, *, dataset_name: str, policy_type: str) -> str: + def _build_experience_hint( + self, + *, + dataset_name: str, + policy_type: str, + replay_datasets: list[str], + ) -> str: records = self._experiences.search( task_type="train", dataset=dataset_name, policy=policy_type, + outcomes=frozenset({"success", "failed", "error", "stopped"}), provider="local", limit=2, ) - if not records: - return "" - lines = [] - for record in records: - lesson = record.lesson or record.summary - lines.append(f"{record.outcome}: {lesson}") + lines = [ + f"{record.outcome}: {record.lesson or record.summary}" + for record in records + ] + if replay_datasets: + lines.append(f"continual replay mixed datasets: {', '.join(replay_datasets)}") return "\n".join(lines) + async def _resolve_training_runtime( + self, + *, + dataset_name: str, + policy_type: str, + dataset: DatasetRuntimeRef, + continual_learning: bool, + ) -> tuple[DatasetRuntimeRef, list[str]]: + if not continual_learning: + return dataset, [] + replay_datasets = self._available_replay_datasets( + current_dataset=dataset_name, + policy=policy_type, + ) + if not replay_datasets: + return dataset, [] + replay_runtime = await asyncio.to_thread( + self._prepare_replay_runtime_dataset, + dataset, + replay_datasets, + ) + return replay_runtime, replay_datasets + + def _available_replay_datasets(self, *, current_dataset: str, policy: str) -> list[str]: + candidates = self._experiences.get_replay_datasets( + current_dataset=current_dataset, + policy=policy, + ) + available: list[str] = [] + for dataset_name in candidates: + ref = self._parent.datasets.get_local_dataset(f"local/{dataset_name}") + if ref is not None and ref.runtime is not None: + available.append(dataset_name) + return available + + def _prepare_replay_runtime_dataset( + self, + dataset: DatasetRuntimeRef, + replay_datasets: list[str], + ) -> DatasetRuntimeRef: + from lerobot.datasets.dataset_tools import merge_datasets + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + replay_slug = _build_replay_slug(dataset.name, replay_datasets) + replay_root = self._parent.datasets.root / "replay" / replay_slug + replay_root.parent.mkdir(parents=True, exist_ok=True) + source_datasets = [dataset, *self._resolve_replay_runtime_refs(replay_datasets)] + merged = [ + LeRobotDataset(repo_id=ref.repo_id, root=ref.local_path) + for ref in source_datasets + ] + merge_datasets( + merged, + output_repo_id=f"local/{replay_slug}", + output_dir=replay_root, + ) + return DatasetRuntimeRef( + name=dataset.name, + repo_id=dataset.repo_id, + local_path=replay_root, + ) + + def _resolve_replay_runtime_refs(self, replay_datasets: list[str]) -> list[DatasetRuntimeRef]: + refs: list[DatasetRuntimeRef] = [] + for dataset_name in replay_datasets: + runtime = self._parent.datasets.resolve_runtime_dataset(dataset_name).runtime + if runtime is None: + continue + refs.append(runtime) + return refs + def _enrich_state( self, job_id: str, @@ -214,6 +308,7 @@ def _enrich_state( enriched.setdefault("dataset_name", metadata.get("dataset_name", "")) enriched.setdefault("policy_type", metadata.get("policy_type", "")) enriched.setdefault("dataset_path", metadata.get("dataset_path", "")) + enriched.setdefault("replay_datasets", metadata.get("replay_datasets", "")) enriched.setdefault("experience_hint", "") enriched["message"] = self._format_status_message(enriched) @@ -237,6 +332,7 @@ def _format_status_message(self, state: dict[str, str | int | bool | None]) -> s "policy_type", "provider", "dataset_path", + "replay_datasets", "log_path", "experience_hint", ) @@ -264,10 +360,12 @@ def _record_experience( ) -> None: metadata = self._job_specs.get(job_id, {}) dataset_name = metadata.get("dataset_name", "") + replay_datasets = metadata.get("replay_datasets", "") policy_type = metadata.get("policy_type", "") provider = metadata.get("provider", "local") if not dataset_name and not policy_type and not job_id: return + outcome = _experience_outcome(status) lesson = _status_lesson(status, log_tail) summary = ( f"Local training for dataset '{dataset_name or ''}' " @@ -276,9 +374,10 @@ def _record_experience( self._experiences.append(ExperienceRecord.create( task_type="train", summary=summary, - outcome=status, + outcome=outcome, lesson=lesson, dataset=dataset_name, + replay_datasets=replay_datasets, policy=policy_type, provider=provider, job_id=job_id, @@ -339,6 +438,23 @@ def _update_best( return best +def _experience_outcome(status: str) -> str: + if status == "finished": + return "success" + return status + + +def _build_replay_slug(current_dataset: str, replay_datasets: list[str]) -> str: + tokens = [_slug_token(current_dataset), *[_slug_token(name) for name in replay_datasets]] + suffix = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f") + return "__".join(["replay", *tokens, suffix]) + + +def _slug_token(value: str) -> str: + token = re.sub(r"[^a-zA-Z0-9_-]+", "_", value.strip()).strip("_") + return token or "dataset" + + def _parse_training_curve(job_id: str, log_path: Path) -> tuple[dict[str, float | int] | None, list[dict[str, Any]]]: if not log_path.exists(): return _BEST_LOSS_BY_JOB.get(job_id), [] diff --git a/roboclaw/http/routes/train.py b/roboclaw/http/routes/train.py index 5a946f17..995549d5 100644 --- a/roboclaw/http/routes/train.py +++ b/roboclaw/http/routes/train.py @@ -16,6 +16,7 @@ class TrainStartRequest(BaseModel): policy_type: str = "act" steps: int = 100_000 device: str = "cuda" + continual_learning: bool = False class TrainStopRequest(BaseModel): @@ -33,6 +34,7 @@ async def train_start(body: TrainStartRequest) -> dict[str, Any]: "policy_type": body.policy_type, "steps": body.steps, "device": body.device, + "continual_learning": body.continual_learning, }, ) return result diff --git a/tests/test_experience_replay.py b/tests/test_experience_replay.py new file mode 100644 index 00000000..eed31632 --- /dev/null +++ b/tests/test_experience_replay.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from roboclaw.agent.experience import ExperienceRecord, ExperienceStore +from roboclaw.data.datasets import DatasetRuntimeRef +from roboclaw.embodied.board import Board +from roboclaw.embodied.embodiment.hardware.monitor import HardwareMonitor +from roboclaw.embodied.embodiment.manifest import Manifest +from roboclaw.embodied.service import EmbodiedService + + +@pytest.fixture(autouse=True) +def isolated_roboclaw_home(tmp_path): + with patch( + "roboclaw.embodied.embodiment.lock.get_roboclaw_home", + return_value=tmp_path, + ), patch( + "roboclaw.embodied.embodiment.manifest.helpers.get_roboclaw_home", + return_value=tmp_path, + ): + yield + + +@pytest.fixture() +def embodied_service(tmp_path: Path) -> EmbodiedService: + board = Board() + manifest = Manifest(path=tmp_path / "workspace" / "embodied" / "manifest.json", board=board) + monitor = HardwareMonitor(board=board, manifest=manifest) + return EmbodiedService(hardware_monitor=monitor, board=board, manifest=manifest) + + +def test_get_replay_datasets_returns_empty_without_history(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + + assert store.get_replay_datasets(current_dataset="demo", policy="act") == [] + + +def test_get_replay_datasets_returns_recent_success_dataset(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + _append_record( + store, + timestamp="2026-05-07T10:00:00+00:00", + dataset="old_pick", + policy="act", + outcome="success", + ) + _append_record( + store, + timestamp="2026-05-07T11:00:00+00:00", + dataset="new_place", + policy="act", + outcome="success", + ) + + assert store.get_replay_datasets(current_dataset="demo", policy="act") == ["new_place", "old_pick"] + + +def test_get_replay_datasets_excludes_current_dataset(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + _append_record( + store, + timestamp="2026-05-07T11:00:00+00:00", + dataset="demo", + policy="act", + outcome="success", + ) + _append_record( + store, + timestamp="2026-05-07T10:00:00+00:00", + dataset="history_a", + policy="act", + outcome="success", + ) + + assert store.get_replay_datasets(current_dataset="demo", policy="act") == ["history_a"] + + +def test_get_replay_datasets_respects_max_datasets(tmp_path: Path) -> None: + store = ExperienceStore(tmp_path) + _append_record(store, timestamp="2026-05-07T12:00:00+00:00", dataset="history_c", policy="act", outcome="success") + _append_record(store, timestamp="2026-05-07T11:00:00+00:00", dataset="history_b", policy="act", outcome="success") + _append_record(store, timestamp="2026-05-07T10:00:00+00:00", dataset="history_a", policy="act", outcome="success") + + assert store.get_replay_datasets(current_dataset="demo", policy="act", max_datasets=2) == [ + "history_c", + "history_b", + ] + + +def test_start_job_state_keeps_argv_unchanged_when_continual_learning_disabled( + embodied_service: EmbodiedService, +) -> None: + dataset_map = _dataset_map("demo", "history_a") + embodied_service.datasets.resolve_runtime_dataset = lambda name: dataset_map[name] # type: ignore[method-assign] + embodied_service.datasets.get_local_dataset = lambda dataset_id: dataset_map.get(dataset_id.removeprefix("local/")) # type: ignore[method-assign] + _append_record( + ExperienceStore(embodied_service.manifest._path.parent.parent), + timestamp="2026-05-07T12:00:00+00:00", + dataset="history_a", + policy="act", + outcome="success", + ) + captured: dict[str, object] = {} + + def fake_train(manifest, *, dataset, policy_type, steps, device): + argv = [ + "python3", + "-m", + "lerobot.scripts.lerobot_train", + f"--dataset.root={dataset.local_path}", + ] + captured["argv"] = argv + captured["dataset"] = dataset + return argv + + with patch( + "roboclaw.embodied.service.session.train.CommandBuilder.train", + side_effect=fake_train, + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.run_detached", + new=AsyncMock(return_value="job-1"), + ): + state = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + { + "dataset_name": "demo", + "policy_type": "act", + "steps": 1000, + "device": "cuda", + "continual_learning": False, + }, + )) + + assert captured["dataset"].local_path == Path("/tmp/demo") + assert "history_a" not in " ".join(captured["argv"]) + assert state["replay_datasets"] == "" + + +def test_start_job_state_includes_replay_dataset_when_continual_learning_enabled( + embodied_service: EmbodiedService, +) -> None: + dataset_map = _dataset_map("demo", "history_a") + embodied_service.datasets.resolve_runtime_dataset = lambda name: dataset_map[name] # type: ignore[method-assign] + embodied_service.datasets.get_local_dataset = lambda dataset_id: dataset_map.get(dataset_id.removeprefix("local/")) # type: ignore[method-assign] + _append_record( + ExperienceStore(embodied_service.manifest._path.parent.parent), + timestamp="2026-05-07T12:00:00+00:00", + dataset="history_a", + policy="act", + outcome="success", + ) + captured: dict[str, object] = {} + replay_runtime = DatasetRuntimeRef( + name="demo", + repo_id="local/demo", + local_path=Path("/tmp/replay_demo__history_a"), + ) + + def fake_train(manifest, *, dataset, policy_type, steps, device): + argv = [ + "python3", + "-m", + "lerobot.scripts.lerobot_train", + f"--dataset.root={dataset.local_path}", + ] + captured["argv"] = argv + captured["dataset"] = dataset + return argv + + with patch( + "roboclaw.embodied.service.session.train.CommandBuilder.train", + side_effect=fake_train, + ), patch( + "roboclaw.embodied.service.session.train.TrainSession._prepare_replay_runtime_dataset", + return_value=replay_runtime, + ), patch( + "roboclaw.embodied.executor.SubprocessExecutor.run_detached", + new=AsyncMock(return_value="job-1"), + ): + state = asyncio.run(embodied_service.train.start_job_state( + embodied_service.manifest, + { + "dataset_name": "demo", + "policy_type": "act", + "steps": 1000, + "device": "cuda", + "continual_learning": True, + }, + )) + + assert captured["dataset"].local_path == replay_runtime.local_path + assert "history_a" in " ".join(captured["argv"]) + assert state["replay_datasets"] == "history_a" + assert "history_a" in str(state["experience_hint"]) + + +def _append_record( + store: ExperienceStore, + *, + timestamp: str, + dataset: str, + policy: str, + outcome: str, +) -> None: + store.append(ExperienceRecord( + timestamp=timestamp, + task_type="train", + summary=f"{dataset} -> {outcome}", + outcome=outcome, + dataset=dataset, + policy=policy, + provider="local", + )) + + +def _dataset_map(*dataset_names: str) -> dict[str, SimpleNamespace]: + mapping: dict[str, SimpleNamespace] = {} + for dataset_name in dataset_names: + mapping[dataset_name] = SimpleNamespace( + id=f"local/{dataset_name}", + runtime=DatasetRuntimeRef( + name=dataset_name, + repo_id=f"local/{dataset_name}", + local_path=Path(f"/tmp/{dataset_name}"), + ), + ) + return mapping diff --git a/tests/test_train_experience.py b/tests/test_train_experience.py index b2ccd74c..e56f71e2 100644 --- a/tests/test_train_experience.py +++ b/tests/test_train_experience.py @@ -29,7 +29,7 @@ def isolated_roboclaw_home(tmp_path): @pytest.fixture() def embodied_service(tmp_path: Path) -> EmbodiedService: board = Board() - manifest = Manifest(path=tmp_path / "manifest.json", board=board) + manifest = Manifest(path=tmp_path / "workspace" / "embodied" / "manifest.json", board=board) monitor = HardwareMonitor(board=board, manifest=manifest) return EmbodiedService(hardware_monitor=monitor, board=board, manifest=manifest) @@ -72,11 +72,11 @@ def test_train_session_records_experience_and_reuses_it_as_hint(embodied_service {"dataset_name": "demo", "policy_type": "act", "steps": 1000, "device": "cuda"}, )) - assert "finished" in str(second["experience_hint"]) + assert "success" in str(second["experience_hint"]) assert "demo" in str(second["message"]) - store = ExperienceStore(embodied_service.manifest._path.parent) + store = ExperienceStore(embodied_service.manifest._path.parent.parent) records = store.read_all() outcomes = {record.outcome for record in records} assert "submitted" in outcomes - assert "finished" in outcomes + assert "success" in outcomes diff --git a/ui/src/domains/training/pages/TrainingCenterPage.tsx b/ui/src/domains/training/pages/TrainingCenterPage.tsx index d9930fb2..18af5e18 100644 --- a/ui/src/domains/training/pages/TrainingCenterPage.tsx +++ b/ui/src/domains/training/pages/TrainingCenterPage.tsx @@ -48,6 +48,7 @@ export default function TrainingCenterPage() { const [policyType, setPolicyType] = useState('act') const [trainSteps, setTrainSteps] = useState(100000) const [trainDevice, setTrainDevice] = useState('cuda') + const [continualLearning, setContinualLearning] = useState(false) const [pullPolicyRepo, setPullPolicyRepo] = useState('') useEffect(() => { @@ -109,6 +110,15 @@ export default function TrainingCenterPage() { +