diff --git a/.github/workflows/agent.yml b/.github/workflows/agent.yml index 5f3a9aa9..fc78aefb 100644 --- a/.github/workflows/agent.yml +++ b/.github/workflows/agent.yml @@ -1,6 +1,16 @@ name: gs-agent on: push: + branches: + - main + paths: + - "*" + - ".github/**" + - "src/agent/**" + - "src/schemas/**" + pull_request: + branches: + - main paths: - "*" - ".github/**" diff --git a/.github/workflows/env.yml b/.github/workflows/env.yml index 36134c6f..6278f2cc 100644 --- a/.github/workflows/env.yml +++ b/.github/workflows/env.yml @@ -1,6 +1,16 @@ name: gs-env on: push: + branches: + - main + paths: + - "*" + - ".github/**" + - "src/env/**" + - "src/schemas/**" + pull_request: + branches: + - main paths: - "*" - ".github/**" diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index e7796198..9bac616d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,7 +1,14 @@ name: pre-commit on: merge_group: + branches: + - main push: + branches: + - main + pull_request: + branches: + - main jobs: pre-commit: runs-on: ubuntu-24.04 diff --git a/.github/workflows/schemas.yml b/.github/workflows/schemas.yml index 20957daf..8db9bc48 100644 --- a/.github/workflows/schemas.yml +++ b/.github/workflows/schemas.yml @@ -1,6 +1,15 @@ name: gs-schemas on: push: + branches: + - main + paths: + - "*" + - ".github/**" + - "src/schemas/**" + pull_request: + branches: + - main paths: - "*" - ".github/**" diff --git a/src/agent/gs_agent/wrappers/teleop_wrapper.py b/src/agent/gs_agent/wrappers/teleop_wrapper.py index 4540f0bf..b5ff1e63 100644 --- a/src/agent/gs_agent/wrappers/teleop_wrapper.py +++ b/src/agent/gs_agent/wrappers/teleop_wrapper.py @@ -1,10 +1,12 @@ import threading +import time from typing import Any import numpy as np import torch from numpy.typing import NDArray from pynput import keyboard +from scipy.spatial.transform import Rotation as R from gs_agent.bases.env_wrapper import BaseEnvWrapper @@ -58,11 +60,11 @@ def stop(self) -> None: self.listener.stop() self.listener.join() - def on_press(self, key: keyboard.Key) -> None: + def on_press(self, key: keyboard.Key | keyboard.KeyCode | None) -> None: with self.lock: self.pressed_keys.add(key) - def on_release(self, key: keyboard.Key) -> None: + def on_release(self, key: keyboard.Key | keyboard.KeyCode | None) -> None: with self.lock: self.pressed_keys.discard(key) @@ -123,13 +125,16 @@ def __init__( def set_environment(self, env: Any) -> None: """Set the environment after creation.""" - self.env = env + # Cannot reassign self.env as it's declared as Final in BaseEnvWrapper + # Instead, we'll work with the env passed in __init__ + if not hasattr(self, "_env_initialized"): + self._env_initialized = True - self.target_position, self.target_orientation = self.env.get_ee_pose() - self.target_position = self.target_position - self.target_orientation = self.target_orientation - print("self.target_position", self.target_position.shape) - print("self.target_orientation", self.target_orientation.shape) + self.target_position, self.target_orientation = self.env.get_ee_pose() + self.target_position = self.target_position + self.target_orientation = self.target_orientation + print("self.target_position", self.target_position.shape) + print("self.target_orientation", self.target_orientation.shape) def start(self) -> None: """Start keyboard listener.""" @@ -173,7 +178,7 @@ def stop(self) -> None: """Stop keyboard listener.""" self.running = False if self.recording: - self.stop_recording() + self._stop_recording() if self.listener: self.listener.stop() @@ -290,8 +295,6 @@ def _sync_pose_from_env(self) -> None: obs = self._convert_observation_to_dict() if obs is None: return - from scipy.spatial.transform import Rotation as R - self.current_position = obs["end_effector_pos"].copy() self.current_orientation = R.from_quat(obs["end_effector_quat"]).as_euler("xyz") @@ -489,6 +492,39 @@ def critic_obs_dim(self) -> int: def num_envs(self) -> int: return 1 + def _stop_recording(self) -> None: + """Stop recording trajectory data.""" + if self.recording: + self.recording = False + print(f"Recording stopped. Captured {len(self.trajectory_data)} steps.") + # Could save trajectory data here if needed + self.trajectory_data.clear() + self.recording_start_time = None + + def _record_trajectory_step(self, command: KeyboardCommand, obs: dict[str, Any]) -> None: + """Record a step of trajectory data.""" + if not self.recording: + return + + # Create trajectory step with timestamp + current_time = time.time() + if self.recording_start_time is None: + self.recording_start_time = current_time + + step_data: TrajectoryStep = { + "timestamp": current_time - self.recording_start_time, + "command": { + "position": command.position.copy(), + "orientation": command.orientation.copy(), + "gripper_close": command.gripper_close, + "reset_scene": command.reset_scene, + "quit_teleop": command.quit_teleop, + }, + "observation": obs.copy(), + } + + self.trajectory_data.append(step_data) + def close(self) -> None: """Close the wrapper.""" self.stop() diff --git a/src/agent/pyproject.toml b/src/agent/pyproject.toml index e90f224b..16091bfd 100644 --- a/src/agent/pyproject.toml +++ b/src/agent/pyproject.toml @@ -11,6 +11,8 @@ dependencies = [ "gymnasium[classic_control]>=1.2.0", "matplotlib>=3.10.0", "pandas>=2.3.2", + "pynput>=1.7.0", + "scipy>=1.10.0", "tensordict>=0.9", "torch>=2.2", "tqdm>=4.67.1", diff --git a/uv.lock b/uv.lock index b6236cc5..0b7053f5 100644 --- a/uv.lock +++ b/uv.lock @@ -733,6 +733,9 @@ dependencies = [ { name = "gymnasium", extra = ["classic-control"] }, { name = "matplotlib" }, { name = "pandas" }, + { name = "pynput" }, + { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "tensordict" }, { name = "torch", version = "2.8.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" }, { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, @@ -758,6 +761,8 @@ requires-dist = [ { name = "gymnasium", extras = ["classic-control"], specifier = ">=1.2.0" }, { name = "matplotlib", specifier = ">=3.10.0" }, { name = "pandas", specifier = ">=2.3.2" }, + { name = "pynput", specifier = ">=1.7.0" }, + { name = "scipy", specifier = ">=1.10.0" }, { name = "tensordict", specifier = ">=0.9" }, { name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.2", index = "https://download.pytorch.org/whl/cpu" }, { name = "torch", marker = "platform_machine != 'x86_64' and sys_platform == 'linux'", specifier = ">=2.2", index = "https://download.pytorch.org/whl/cpu" },