Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
pip install -e ".[dev]" pytest pytest-asyncio pexpect || pip install -e . pytest pytest-asyncio pexpect

- name: Run tests
run: python -m pytest tests/ -v
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ dependencies = [
]

[project.optional-dependencies]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"pexpect>=4.9.0,<5.0.0",
"ruff>=0.1.0",
]
wecom = [
"wecom-aibot-sdk-python>=0.1.5",
]
Expand Down
76 changes: 74 additions & 2 deletions roboclaw/embodied/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
VerificationRequest,
Verifier,
)
from roboclaw.embodied.workflow import WorkflowPlan, WorkflowPlanner, WorkflowSpec


class EmbodiedService:
Expand Down Expand Up @@ -226,6 +227,77 @@ def _verify_inference_preflight(
if not result.ok:
raise ActionError(result.format_violations())

def plan_workflow(self, spec: WorkflowSpec | dict[str, Any]) -> WorkflowPlan:
"""Compile a workflow spec into concrete stage plans and validations."""
return WorkflowPlanner(self.manifest, self.datasets).plan(spec)

async def start_workflow_phase(
self,
spec: WorkflowSpec | dict[str, Any],
phase: str,
) -> dict[str, Any]:
"""Start a workflow phase using the unified workflow spec interface."""
workflow = spec if isinstance(spec, WorkflowSpec) else WorkflowSpec.model_validate(spec)
plan = self.plan_workflow(workflow)
stage = next((item for item in plan.stages if item.stage == phase), None)
if stage is None:
raise RuntimeError(f"Unknown workflow phase '{phase}'.")
if not stage.enabled:
raise RuntimeError(f"Workflow phase '{phase}' is disabled.")
if stage.issues:
raise RuntimeError(" · ".join(issue.message for issue in stage.issues))
if not stage.ready:
if stage.blocked_by:
blocked = ", ".join(stage.blocked_by)
raise RuntimeError(f"Workflow phase '{phase}' is waiting on: {blocked}.")
raise RuntimeError(f"Workflow phase '{phase}' is not ready.")

if phase == "record":
dataset_name = await self.start_recording(
task=workflow.record.task,
num_episodes=workflow.record.num_episodes,
fps=workflow.record.fps,
episode_time_s=workflow.record.episode_time_s,
reset_time_s=workflow.record.reset_time_s,
dataset_name=stage.dataset_name,
use_cameras=workflow.hardware.use_cameras,
arms=workflow.hardware.arms,
)
return {"status": "recording", "dataset_name": dataset_name}

if phase == "train":
result = await self.train.train(
manifest=self.manifest,
kwargs={
"dataset_name": stage.dataset_name,
"policy_type": workflow.train.policy_type,
"steps": workflow.train.steps,
"device": workflow.train.device,
},
tty_handoff=None,
)
job_id = result.rsplit("Job ID:", 1)[-1].strip() if "Job ID:" in result else ""
return {"message": result, "job_id": job_id}

if phase == "infer":
await self.start_inference(
checkpoint_path=stage.checkpoint_path,
source_dataset=stage.source_dataset,
dataset_name=stage.dataset_name,
task=workflow.infer.task,
num_episodes=workflow.infer.num_episodes,
episode_time_s=workflow.infer.episode_time_s,
arms=workflow.hardware.arms,
use_cameras=workflow.hardware.use_cameras,
)
return {
"status": "inferring",
"dataset_name": stage.dataset_name,
"checkpoint_path": stage.checkpoint_path,
}

raise RuntimeError(f"Workflow phase '{phase}' is not supported.")

# -- Operations (Web entry points) --

async def start_teleop(self, *, fps: int = 30, arms: str = "") -> None:
Expand Down Expand Up @@ -293,7 +365,7 @@ async def start_inference(
) -> None:
self._require_capability("infer" if use_cameras else "infer_without_cameras")
output_dataset = self.datasets.prepare_recording_dataset(dataset_name, prefix="eval")
source = self.datasets.resolve_runtime_dataset(source_dataset) if source_dataset else None
source = self.datasets.resolve_runtime_dataset(source_dataset) if source_dataset and not checkpoint_path else None
argv = CommandBuilder.infer(
self.manifest,
dataset=output_dataset.runtime,
Expand Down Expand Up @@ -347,7 +419,7 @@ async def run_inference(
) -> str:
self._require_capability("infer" if use_cameras else "infer_without_cameras")
output_dataset = self.datasets.prepare_recording_dataset(dataset_name, prefix="eval")
source = self.datasets.resolve_runtime_dataset(source_dataset) if source_dataset else None
source = self.datasets.resolve_runtime_dataset(source_dataset) if source_dataset and not checkpoint_path else None
argv = CommandBuilder.infer(
self.manifest,
dataset=output_dataset.runtime,
Expand Down
Loading
Loading