Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/agent.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
name: gs-agent
on:
push:
branches:
- main
paths:
- "*"
- ".github/**"
- "src/agent/**"
- "src/schemas/**"
pull_request:
branches:
- main
paths:
- "*"
- ".github/**"
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/env.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
name: gs-env
on:
push:
branches:
- main
paths:
- "*"
- ".github/**"
- "src/env/**"
- "src/schemas/**"
pull_request:
branches:
- main
paths:
- "*"
- ".github/**"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
name: pre-commit
on:
merge_group:
branches:
- main
push:
branches:
- main
pull_request:
branches:
- main
jobs:
pre-commit:
runs-on: ubuntu-24.04
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/schemas.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
name: gs-schemas
on:
push:
branches:
- main
paths:
- "*"
- ".github/**"
- "src/schemas/**"
pull_request:
branches:
- main
paths:
- "*"
- ".github/**"
Expand Down
58 changes: 47 additions & 11 deletions src/agent/gs_agent/wrappers/teleop_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import threading
import time
from typing import Any

import numpy as np
import torch
from numpy.typing import NDArray
from pynput import keyboard
from scipy.spatial.transform import Rotation as R

from gs_agent.bases.env_wrapper import BaseEnvWrapper

Expand Down Expand Up @@ -58,11 +60,11 @@ def stop(self) -> None:
self.listener.stop()
self.listener.join()

def on_press(self, key: keyboard.Key) -> None:
def on_press(self, key: keyboard.Key | keyboard.KeyCode | None) -> None:
with self.lock:
self.pressed_keys.add(key)

def on_release(self, key: keyboard.Key) -> None:
def on_release(self, key: keyboard.Key | keyboard.KeyCode | None) -> None:
with self.lock:
self.pressed_keys.discard(key)

Expand Down Expand Up @@ -123,13 +125,16 @@ def __init__(

def set_environment(self, env: Any) -> None:
"""Set the environment after creation."""
self.env = env
# Cannot reassign self.env as it's declared as Final in BaseEnvWrapper
# Instead, we'll work with the env passed in __init__
if not hasattr(self, "_env_initialized"):
self._env_initialized = True

self.target_position, self.target_orientation = self.env.get_ee_pose()
self.target_position = self.target_position
self.target_orientation = self.target_orientation
print("self.target_position", self.target_position.shape)
print("self.target_orientation", self.target_orientation.shape)
self.target_position, self.target_orientation = self.env.get_ee_pose()
self.target_position = self.target_position
self.target_orientation = self.target_orientation
print("self.target_position", self.target_position.shape)
print("self.target_orientation", self.target_orientation.shape)

def start(self) -> None:
"""Start keyboard listener."""
Expand Down Expand Up @@ -173,7 +178,7 @@ def stop(self) -> None:
"""Stop keyboard listener."""
self.running = False
if self.recording:
self.stop_recording()
self._stop_recording()
if self.listener:
self.listener.stop()

Expand Down Expand Up @@ -290,8 +295,6 @@ def _sync_pose_from_env(self) -> None:
obs = self._convert_observation_to_dict()
if obs is None:
return
from scipy.spatial.transform import Rotation as R

self.current_position = obs["end_effector_pos"].copy()
self.current_orientation = R.from_quat(obs["end_effector_quat"]).as_euler("xyz")

Expand Down Expand Up @@ -489,6 +492,39 @@ def critic_obs_dim(self) -> int:
def num_envs(self) -> int:
return 1

def _stop_recording(self) -> None:
"""Stop recording trajectory data."""
if self.recording:
self.recording = False
print(f"Recording stopped. Captured {len(self.trajectory_data)} steps.")
# Could save trajectory data here if needed
self.trajectory_data.clear()
self.recording_start_time = None

def _record_trajectory_step(self, command: KeyboardCommand, obs: dict[str, Any]) -> None:
"""Record a step of trajectory data."""
if not self.recording:
return

# Create trajectory step with timestamp
current_time = time.time()
if self.recording_start_time is None:
self.recording_start_time = current_time

step_data: TrajectoryStep = {
"timestamp": current_time - self.recording_start_time,
"command": {
"position": command.position.copy(),
"orientation": command.orientation.copy(),
"gripper_close": command.gripper_close,
"reset_scene": command.reset_scene,
"quit_teleop": command.quit_teleop,
},
"observation": obs.copy(),
}

self.trajectory_data.append(step_data)

def close(self) -> None:
"""Close the wrapper."""
self.stop()
Expand Down
2 changes: 2 additions & 0 deletions src/agent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ dependencies = [
"gymnasium[classic_control]>=1.2.0",
"matplotlib>=3.10.0",
"pandas>=2.3.2",
"pynput>=1.7.0",
"scipy>=1.10.0",
"tensordict>=0.9",
"torch>=2.2",
"tqdm>=4.67.1",
Expand Down
5 changes: 5 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading