diff --git a/benchmark/base.py b/benchmark/base.py index b650aaca..6333512f 100644 --- a/benchmark/base.py +++ b/benchmark/base.py @@ -218,6 +218,8 @@ def normed_mobs_to_samples(self, normed_mobs: MetaObs): sample['image'] = normed_mobs.image # (cameras, C, H, W) else: sample['image'] = normed_mobs.image + else: + sample['image'] = None # Explicitly set to None when no image # State: use specified control space (already normalized) sample['state'] = normed_mobs.state[i] if len(normed_mobs.state.shape) > 1 else normed_mobs.state diff --git a/benchmark/gymnasium/__init__.py b/benchmark/gymnasium/__init__.py new file mode 100644 index 00000000..9b8df58d --- /dev/null +++ b/benchmark/gymnasium/__init__.py @@ -0,0 +1,182 @@ +""" +Gymnasium Environment Adapter + +Provides a MetaEnv wrapper for standard Gymnasium environments, +including MuJoCo environments like HalfCheetah, Ant, Humanoid, etc. + +Usage: + # Via config: + python train_rl.py -a td3 -e gymnasium/halfcheetah + + # Programmatically: + from benchmark.gymnasium import create_env + env = create_env(config) +""" + +import numpy as np +import gymnasium as gym +from ..base import MetaEnv, MetaObs, MetaAction +import torch + + + +def create_env(config): + """Factory function for creating GymnasiumEnv.""" + return GymnasiumEnv(config) + + +class GymnasiumEnv(MetaEnv): + """ + MetaEnv wrapper for standard Gymnasium environments. + + Supports continuous control environments like: + - MuJoCo: HalfCheetah-v5, Ant-v5, Humanoid-v5, Walker2d-v5, Hopper-v5, etc. + - Classic Control: Pendulum-v1, MountainCarContinuous-v0 + - Box2D: LunarLanderContinuous-v3, BipedalWalker-v3 + + Config attributes: + task (str): Gymnasium environment ID (e.g., 'HalfCheetah-v5') + render_mode (str): Render mode ('human', 'rgb_array', None). Default: None + use_camera (bool): Whether to capture images. Default: False + max_episode_steps (int): Override max steps. Default: use env default + raw_lang (str): Language instruction. Default: task name + """ + + def __init__(self, config, *args): + self.config = config + self.task = config.task + self.render_mode = getattr(config, 'render_mode', None) + self.use_camera = getattr(config, 'use_camera', False) + self.max_episode_steps = getattr(config, 'max_episode_steps', None) + self.raw_lang = getattr(config, 'raw_lang', self.task) + + # Control space info (most MuJoCo envs use normalized actions) + self.ctrl_space = getattr(config, 'ctrl_space', 'joint') + self.ctrl_type = getattr(config, 'ctrl_type', 'abs') + + env = self._create_env() + super().__init__(env) + + # Store action space bounds for reference + self.action_low = self.env.action_space.low + self.action_high = self.env.action_space.high + self.action_dim = self.env.action_space.shape[0] + self.state_dim = self.env.observation_space.shape[0] + self.debug_step=0 + + def _create_env(self): + """Create the underlying Gymnasium environment.""" + kwargs = {} + if self.render_mode: + kwargs['render_mode'] = self.render_mode + if self.max_episode_steps: + kwargs['max_episode_steps'] = self.max_episode_steps + + env = gym.make(self.task, **kwargs) + return env + + @property + def action_space(self): + """Return the action space of the environment.""" + return self.env.action_space + + @property + def observation_space(self): + """Return the observation space of the environment.""" + return self.env.observation_space + + def meta2act(self, maction: MetaAction): + """Convert MetaAction to Gymnasium action.""" + actions = maction['action'] # (action_dim, ) + return actions + + def obs2meta(self, obs) -> MetaObs: + """Convert Gymnasium observation to MetaObs.""" + # Handle different observation types + if isinstance(obs, dict): + # For goal-conditioned environments + state = obs.get('observation', obs.get('state', None)) + if state is None: + state = np.concatenate([v.flatten() for v in obs.values()]) + else: + state = obs + + state = np.asarray(state, dtype=np.float32) + + # Capture image if requested + image = None + if self.use_camera and self.render_mode == 'rgb_array': + image = self.env.render() + if image is not None: + # Convert to (N, C, H, W) format + if len(image.shape) == 3: # (H, W, C) + image = image[np.newaxis, ...] # -> (1, H, W, C) + if image.shape[-1] == 3: # (N, H, W, C) -> (N, C, H, W) + image = image.transpose(0, 3, 1, 2) + + return MetaObs(state=state, image=image, raw_lang=self.raw_lang) + + def step(self, maction): + """Execute action and return (obs, reward, done, info).""" + action = self.meta2act(maction) + observation, reward, terminated, truncated, info = self.env.step(action) + + obs = self.obs2meta(observation) + + # Combine terminated and truncated into done + # Store both for proper handling in RL algorithms + done = terminated or truncated + info['terminated'] = terminated + info['truncated'] = truncated + info['TimeLimit.truncated'] = truncated and not terminated + self.debug_step += 1 + + return obs, reward, done, info + + def reset(self, **kwargs): + """Reset environment and return initial observation.""" + obs, info = self.env.reset(**kwargs) + self.debug_step=0 + return self.obs2meta(obs) + + def render(self): + """Render the environment.""" + return self.env.render() + + def close(self): + """Close the environment.""" + self.env.close() + + def ensure_action_reasonable(self, action): + """Ensure action is reasonable.""" + if self.action_low is not None or self.action_high is not None: + # Convert Tensor to numpy array if needed + if torch is not None and torch.is_tensor(action): + action = action.detach().cpu().numpy() + action = np.asarray(action) + + if np.any(action <= self.action_high) and np.any(action >= self.action_low): + return True + + return False + + +# Common MuJoCo environment configurations +MUJOCO_ENVS = { + 'halfcheetah': 'HalfCheetah-v5', + 'ant': 'Ant-v5', + 'humanoid': 'Humanoid-v5', + 'walker2d': 'Walker2d-v5', + 'hopper': 'Hopper-v5', + 'swimmer': 'Swimmer-v5', + 'reacher': 'Reacher-v5', + 'pusher': 'Pusher-v5', + 'inverted_pendulum': 'InvertedPendulum-v5', + 'inverted_double_pendulum': 'InvertedDoublePendulum-v5', +} + + +def get_env_id(name: str) -> str: + """Get full environment ID from short name.""" + return MUJOCO_ENVS.get(name.lower(), name) + diff --git a/configs/env/gymnasium/ant.yaml b/configs/env/gymnasium/ant.yaml new file mode 100644 index 00000000..3a48e810 --- /dev/null +++ b/configs/env/gymnasium/ant.yaml @@ -0,0 +1,26 @@ +# Ant-v5 environment configuration +# MuJoCo locomotion task: make a 3D ant robot run forward + +type: benchmark.gymnasium.GymnasiumEnv +name: ant + +# Environment settings +task: Ant-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Ant + +# Dimensions (explicit specification) +state_dim: 27 # qpos: 13, qvel: 14, excluding x,y position +action_dim: 8 # torques for 8 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the ant walk forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.5 * control_cost + healthy_reward diff --git a/configs/env/gymnasium/halfcheetah.yaml b/configs/env/gymnasium/halfcheetah.yaml new file mode 100644 index 00000000..8a600791 --- /dev/null +++ b/configs/env/gymnasium/halfcheetah.yaml @@ -0,0 +1,26 @@ +# HalfCheetah-v5 environment configuration +# MuJoCo locomotion task: make a 2D cheetah robot run as fast as possible + +type: benchmark.gymnasium.GymnasiumEnv +name: halfcheetah + +# Environment settings +task: HalfCheetah-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for HalfCheetah + +# Dimensions (explicit specification) +state_dim: 17 # qpos: 9, qvel: 8, excluding x-position +action_dim: 6 # torques for 6 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the cheetah run as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.1 * control_cost diff --git a/configs/env/gymnasium/hopper.yaml b/configs/env/gymnasium/hopper.yaml new file mode 100644 index 00000000..7cdd06ee --- /dev/null +++ b/configs/env/gymnasium/hopper.yaml @@ -0,0 +1,26 @@ +# Hopper-v5 environment configuration +# MuJoCo locomotion task: make a 2D one-legged robot hop forward + +type: benchmark.gymnasium.GymnasiumEnv +name: hopper + +# Environment settings +task: Hopper-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Hopper + +# Dimensions (explicit specification) +state_dim: 11 # qpos: 5, qvel: 6, excluding x-position +action_dim: 3 # torques for 3 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the hopper hop forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.001 * control_cost + healthy_reward diff --git a/configs/env/gymnasium/walker2d.yaml b/configs/env/gymnasium/walker2d.yaml new file mode 100644 index 00000000..e4708dd3 --- /dev/null +++ b/configs/env/gymnasium/walker2d.yaml @@ -0,0 +1,26 @@ +# Walker2d-v5 environment configuration +# MuJoCo locomotion task: make a 2D bipedal robot walk forward + +type: benchmark.gymnasium.GymnasiumEnv +name: walker2d + +# Environment settings +task: Walker2d-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Walker2d + +# Dimensions (explicit specification) +state_dim: 17 # qpos: 8, qvel: 9, excluding x-position +action_dim: 6 # torques for 6 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the walker walk forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.001 * control_cost + healthy_reward diff --git a/configs/rl/td3.yaml b/configs/rl/td3.yaml new file mode 100644 index 00000000..1d365c24 --- /dev/null +++ b/configs/rl/td3.yaml @@ -0,0 +1,32 @@ +# TD3 Algorithm Configuration +# Twin Delayed Deep Deterministic Policy Gradient + +type: td3 + +# Network architecture +hidden_dims: [256, 256] + +# Discount factor (gamma) +discount: 0.99 + +# Soft update coefficient for target networks +tau: 0.005 + +# Learning rates +actor_lr: 0.0003 +critic_lr: 0.0003 + +# Target policy smoothing +policy_noise: 0.2 # Noise added to target policy during critic update +noise_clip: 0.5 # Range to clip target policy noise + +# Delayed policy updates +policy_freq: 2 # Update policy every N critic updates + +# Exploration noise +expl_noise: 0.1 # Gaussian noise scale for exploration + +# Control settings +ctrl_space: joint # Control space: 'ee' (end-effector) or 'joint' +ctrl_type: absolute # Control type: 'delta' or 'absolute' + diff --git a/docs/rl_framework_integrated_design.md b/docs/rl_framework_integrated_design.md new file mode 100644 index 00000000..bacb1981 --- /dev/null +++ b/docs/rl_framework_integrated_design.md @@ -0,0 +1,1674 @@ +# ILStudio RL框架集成设计 + +## 设计原则 + +1. **足够抽象**:只定义核心接口,不涉及具体实现 +2. **通用性**:支持各种RL算法和训练模式 +3. **兼容性**:直接使用MetaEnv和MetaPolicy,无需适配 +4. **扩展性**:易于扩展支持并行训练和分布式训练 +5. **模块化**:奖励函数、配置系统独立可替换 + +--- + +## 目录结构 + +``` +ILStudio/ +├── rl/ # 强化学习模块 +│ ├── __init__.py +│ ├── base.py # RL 基类定义 +│ ├── algorithms/ # RL 算法实现 +│ │ ├── __init__.py +│ │ ├── ppo.py +│ │ ├── sac.py +│ │ ├── td3.py +│ │ ├── dpo.py # Direct Preference Optimization (VLA) +│ │ ├── grpo.py # Group Relative Policy Optimization (VLA) +│ │ └── reinforce.py +│ ├── buffer/ # 经验回放 +│ │ ├── __init__.py +│ │ ├── base_replay.py # BaseReplay基类 +│ │ ├── memory_replay.py # 内存存储实现 +│ │ ├── rollout_buffer.py # On-policy Rollout Buffer +│ │ └── priority_buffer.py # 优先级采样Buffer +│ ├── rewards/ # 奖励函数模块 +│ │ ├── __init__.py +│ │ ├── base_reward.py # 奖励函数基类 +│ │ ├── sparse_reward.py # 稀疏奖励 +│ │ ├── dense_reward.py # 密集奖励 +│ │ ├── learned_reward.py # 学习的奖励模型 +│ │ └── language_reward.py # 语言条件奖励(VLA) +│ ├── collectors/ # 数据收集器模块(新增) +│ │ ├── __init__.py +│ │ ├── base_collector.py # Collector基类 +│ │ ├── sim_collector.py # 仿真环境收集器 +│ │ └── real_collector.py # 真实机器人收集器 +│ ├── trainers/ # 训练器实现 +│ │ ├── __init__.py +│ │ ├── base_trainer.py # BaseTrainer基类 +│ │ ├── simple_trainer.py # 单机训练器 +│ │ ├── parallel_trainer.py # 并行环境训练器 +│ │ └── distributed_trainer.py # 分布式训练器 +│ └── utils/ # 工具函数 +│ ├── __init__.py +│ └── data_processor.py # 数据处理器(对齐ILStudio pipeline) +├── configs/ +│ ├── rl/ # RL 配置(新增) +│ │ ├── ppo.yaml +│ │ ├── sac.yaml +│ │ ├── dpo.yaml +│ │ └── grpo.yaml +│ └── ... +├── train_rl.py # RL 训练入口(新增) +└── ... +``` + +--- + +## 核心基类设计 + +### 1. Replay基类 (`BaseReplay`) + +**职责:** 存储和管理经验数据(transitions) + +**设计理念:** +- **存储原始Meta数据**:在buffer中存储原始的MetaObs和MetaAction,保持数据的原始性 +- **完整信息存储**:支持存储MetaObs和MetaAction的所有字段(state、image、raw_lang、state_ee、state_joint、depth、pc等) +- **扩展性**:支持存储额外的自定义字段(value、log_prob、advantage、trajectory_id等),方便后续扩展 +- **采样时转换**:采样时通过转换函数对齐ILStudio的data pipeline(normalization等) +- **兼容性**:既保持数据的原始性,又兼容ILStudio的normalization pipeline + +**核心接口:** + +```python +class BaseReplay: + """Replay Buffer基类""" + + def __init__( + self, + capacity: int = 1000000, + device: Union[str, torch.device] = 'cpu', + **kwargs + ): + """ + 初始化Replay Buffer + + Args: + capacity: Buffer容量(最大存储的transition数量) + device: 数据存储设备('cpu'或'cuda',默认'cpu') + - 'cpu': 数据存储在CPU内存中 + - 'cuda'或'cuda:0': 数据存储在GPU内存中 + **kwargs: 其他初始化参数 + """ + self.capacity = capacity + self.device = torch.device(device) if isinstance(device, str) else device + + def add(self, transition: Dict[str, Any]) -> None: + """ + 添加一个transition到buffer + + 存储原始Meta数据(MetaObs、MetaAction),不进行任何normalization + 支持存储MetaObs和MetaAction的所有字段,以及额外的自定义信息 + + Args: + transition: 包含以下字段的字典: + - state: MetaObs格式的当前状态(原始数据,包含所有字段) + - action: MetaAction格式的动作(原始数据,包含所有字段) + - reward: float奖励 + - next_state: MetaObs格式的下一个状态(原始数据) + - done: bool是否结束 + - info: 可选,额外信息字典 + - **其他自定义字段**: 可以存储任何额外的信息 + """ + raise NotImplementedError + + def sample(self, batch_size: int) -> Dict[str, Any]: + """ + 从buffer中采样一个batch(原始数据) + + Args: + batch_size: batch大小 + + Returns: + 包含原始Meta数据的字典(未经过normalization) + """ + raise NotImplementedError + + def sample_for_training( + self, + batch_size: int, + data_processor: Optional[Callable] = None + ) -> Dict[str, Any]: + """ + 采样并转换为ILStudio训练格式 + + Args: + batch_size: batch大小 + data_processor: 可选的数据处理函数,用于对齐ILStudio pipeline + - 如果为None,返回原始数据 + - 如果提供,应该是一个函数:batch -> processed_batch + + Returns: + 处理后的batch数据(符合ILStudio训练格式) + """ + batch = self.sample(batch_size) + if data_processor is not None: + batch = data_processor(batch) + return batch + + def __len__(self) -> int: + """返回buffer当前大小""" + raise NotImplementedError + + def clear(self) -> None: + """清空buffer""" + raise NotImplementedError + + def save(self, path: str, **kwargs) -> None: + """ + 保存buffer中的数据到文件 + + Args: + path: 保存路径(可以是文件路径或目录路径) + **kwargs: 保存选项 + - format: 保存格式(如'pkl'、'hdf5'、'npz'等,可选) + - compress: 是否压缩(可选) + """ + raise NotImplementedError + + def load(self, path: str, **kwargs) -> None: + """ + 从文件加载数据到buffer + + Args: + path: 加载路径(可以是文件路径或目录路径) + **kwargs: 加载选项 + - format: 加载格式(可选,可自动推断) + - append: 是否追加到现有buffer(默认False,清空后加载) + """ + raise NotImplementedError +``` + +--- + +### 2. 算法基类 (`BaseAlgorithm`) + +**职责:** 定义RL算法的核心逻辑 + +**设计理念:** 参考SKRL的设计,将replay放到算法中,这样: +- 每个算法可以有自己的replay配置 +- 一个Trainer可以训练多个不同的算法(每个算法有自己的replay) +- 更灵活,支持多智能体场景 + +**核心接口:** + +```python +class BaseAlgorithm: + """RL算法基类""" + + def __init__( + self, + meta_policy: MetaPolicy, + replay: Optional[Union[BaseReplay, Dict[str, BaseReplay]]] = None, + **kwargs + ): + """ + 初始化算法 + + Args: + meta_policy: ILStudio的MetaPolicy实例(必需属性) + replay: 支持两种格式: + - BaseReplay实例:单个replay buffer(所有环境共享) + - Dict[str, BaseReplay]:多个replay buffer(按环境类型分离) + - None:不使用replay buffer(on-policy算法) + **kwargs: 算法特定的参数 + """ + self.meta_policy = meta_policy # 必需属性 + self.replay = replay # 可选属性(off-policy算法需要) + + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """ + 使用一个batch的数据更新策略 + + Args: + batch: 可选,batch数据 + - 如果为None且replay存在,从replay采样 + - 如果提供,直接使用提供的batch + **kwargs: 更新参数,可以包括: + - batch_size: 从replay采样时的batch大小 + - env_types: 指定从哪些环境类型的replay采样(当replay是Dict[str, BaseReplay]时) + - 例如:['indoor', 'outdoor'] + - 如果为None,从所有环境类型的replay采样 + - env_weights: 不同环境类型的数据权重(当使用多个环境类型的replay时) + - 例如:{'indoor': 0.6, 'outdoor': 0.4} + - 如果为None,使用均匀权重 + + Returns: + 包含loss、metrics等信息的字典 + + 示例: + # 场景1:单个replay buffer + algorithm.update(batch_size=256) + + # 场景2:多个replay buffer(按环境类型分离) + algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + """ + raise NotImplementedError + + def compute_loss(self, batch: Dict[str, Any]) -> torch.Tensor: + """ + 计算损失(可选,某些算法可能需要) + + Args: + batch: batch数据 + + Returns: + 损失值 + """ + raise NotImplementedError + + def select_action(self, obs: MetaObs, **kwargs) -> MetaAction: + """ + 选择动作(可选,某些算法可能需要) + + Args: + obs: MetaObs格式的observation + **kwargs: 其他参数(如exploration等) + + Returns: + MetaAction格式的动作 + """ + # 默认使用meta_policy的select_action + return self.meta_policy.select_action(obs, **kwargs) + + def record_transition( + self, + state: MetaObs, + action: MetaAction, + reward: float, + next_state: MetaObs, + done: bool, + info: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + """ + 记录transition到replay buffer(如果存在) + + 支持存储完整的MetaObs和MetaAction信息,以及额外的自定义字段 + 如果使用多个replay buffer(按环境类型分离),会根据kwargs中的env_type选择对应的replay + + Args: + state: 当前状态(MetaObs,包含所有字段) + action: 动作(MetaAction,包含所有字段) + reward: 奖励 + next_state: 下一个状态(MetaObs,包含所有字段) + done: 是否结束 + info: 额外信息字典 + **kwargs: 其他自定义字段,可以存储任何额外信息 + - env_type: 环境类型标识(如果replay是Dict[str, BaseReplay]) + - 例如:value、log_prob、advantage、trajectory_id等 + """ + if self.replay is not None: + transition = { + 'state': state, + 'action': action, + 'reward': reward, + 'next_state': next_state, + 'done': done, + 'info': info or {}, + **kwargs + } + + env_type = kwargs.get('env_type', None) + + if isinstance(self.replay, dict): + if env_type is None: + raise ValueError("env_type must be provided when using multiple replay buffers") + if env_type not in self.replay: + raise ValueError(f"env_type '{env_type}' not found in replay buffers") + self.replay[env_type].add(transition) + else: + self.replay.add(transition) +``` + +--- + +### 3. 奖励函数基类 (`BaseReward`) + +**职责:** 定义奖励函数的接口 + +**设计理念:** +- **模块化**:奖励函数独立模块,易于替换和扩展 +- **可组合**:支持多个奖励函数组合使用 +- **语言条件**:支持VLA模型的语言条件奖励 + +**核心接口:** + +```python +class BaseReward: + """奖励函数基类""" + + def __init__(self, **kwargs): + """ + 初始化奖励函数 + + Args: + **kwargs: 奖励函数特定的参数 + """ + pass + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + 计算奖励 + + Args: + state: 当前状态(MetaObs) + action: 动作(MetaAction) + next_state: 下一个状态(MetaObs) + env_reward: 环境原始奖励 + info: 额外信息字典 + + Returns: + 计算后的奖励值 + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + 重置奖励函数状态(如果需要) + + Args: + **kwargs: 重置参数 + """ + pass +``` + +**具体实现示例:** + +```python +# rl/rewards/sparse_reward.py +class SparseReward(BaseReward): + """稀疏奖励:只在任务完成时给予奖励""" + + def __init__(self, success_reward: float = 1.0, **kwargs): + super().__init__(**kwargs) + self.success_reward = success_reward + + def compute(self, state, action, next_state, env_reward, info): + if info and info.get('success', False): + return self.success_reward + return 0.0 + +# rl/rewards/dense_reward.py +class DenseReward(BaseReward): + """密集奖励:基于状态距离的奖励""" + + def __init__(self, goal_key: str = 'goal', distance_scale: float = 1.0, **kwargs): + super().__init__(**kwargs) + self.goal_key = goal_key + self.distance_scale = distance_scale + + def compute(self, state, action, next_state, env_reward, info): + # 计算到目标的距离 + if hasattr(state, 'state') and hasattr(next_state, 'state'): + goal = info.get(self.goal_key, None) + if goal is not None: + prev_dist = np.linalg.norm(state.state - goal) + curr_dist = np.linalg.norm(next_state.state - goal) + progress = (prev_dist - curr_dist) * self.distance_scale + return env_reward + progress + return env_reward + +# rl/rewards/language_reward.py +class LanguageReward(BaseReward): + """语言条件奖励:基于语言指令的奖励(VLA)""" + + def __init__(self, reward_model=None, **kwargs): + super().__init__(**kwargs) + self.reward_model = reward_model + + def compute(self, state, action, next_state, env_reward, info): + if self.reward_model is None: + return env_reward + + # 使用奖励模型计算语言条件奖励 + if hasattr(state, 'raw_lang') and state.raw_lang is not None: + lang_reward = self.reward_model.compute( + state=state, + action=action, + next_state=next_state, + instruction=state.raw_lang + ) + return env_reward + lang_reward + return env_reward +``` + +--- + +### 4. 数据收集器基类 (`BaseCollector`) + +**职责:** 从环境中收集交互数据并存储到replay buffer + +**设计理念:** +- **职责分离**:将数据收集逻辑从trainer中分离,使trainer更专注于训练循环协调 +- **环境抽象**:支持单个环境、并行环境、多环境类型 +- **原始数据存储**:只存储环境原始奖励,不进行奖励函数计算,保证数据的原始性 +- **统计信息**:收集并返回episode统计信息 + +**核心接口:** + +```python +class BaseCollector: + """数据收集器基类""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + algorithm: BaseAlgorithm, + **kwargs + ): + """ + 初始化收集器 + + Args: + meta_envs: 支持多种格式: + - MetaEnv实例:单个环境 + - List[MetaEnv]:环境列表(同类型环境) + - Callable:环境工厂函数 + - Dict[str, Any]:多环境配置字典(支持不同类型环境) + algorithm: BaseAlgorithm实例(必需属性) + - 用于选择动作和记录transition + **kwargs: 收集器特定的参数 + + 注意:Collector只存储环境原始奖励,不进行奖励函数计算 + 奖励函数在trainer中用于训练时的奖励计算 + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + 收集 n_steps 的交互数据 + + Args: + n_steps: 收集的步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 如果提供,会在record_transition时传入env_type + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + 包含统计信息的字典,如: + - episode_rewards: episode奖励列表 + - episode_lengths: episode长度列表 + - total_steps: 总步数 + - env_type_stats: 按环境类型分组的统计信息(如果支持多环境) + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + 重置收集器状态(如重置环境) + + Args: + **kwargs: 重置参数 + """ + raise NotImplementedError +``` + +**具体实现示例:** + +```python +# rl/collectors/sim_collector.py +class SimCollector(BaseCollector): + """仿真环境数据收集器""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv]], + algorithm: BaseAlgorithm, + n_envs: int = 1, + **kwargs + ): + super().__init__(meta_envs, algorithm, **kwargs) + self.n_envs = n_envs + + # 初始化环境 + if isinstance(meta_envs, list): + self.envs = meta_envs + elif isinstance(meta_envs, MetaEnv): + self.envs = [meta_envs] + else: + raise ValueError(f"Unsupported env type: {type(meta_envs)}") + + self._last_obs = None + self._last_dones = None + + def reset(self, **kwargs) -> None: + """重置所有环境""" + self._last_obs = [] + self._last_dones = [] + for env in self.envs: + obs = env.reset() + self._last_obs.append(obs) + self._last_dones.append(False) + + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + 收集 n_steps 的交互数据 + + Args: + n_steps: 收集的步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 如果提供,会在record_transition时传入env_type + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + 统计信息字典 + """ + if self._last_obs is None: + self.reset() + + stats = { + 'episode_rewards': [], + 'episode_lengths': [], + 'total_steps': 0 + } + + for step in range(n_steps): + # 获取动作 + actions = [] + for i, obs in enumerate(self._last_obs): + if not self._last_dones[i]: + with torch.no_grad(): + action = self.algorithm.select_action(obs) + actions.append(action) + else: + # 如果环境已结束,使用dummy action + actions.append(None) + + # 环境交互 + new_obs_list = [] + rewards = [] + dones = [] + infos = [] + + for i, (env, action) in enumerate(zip(self.envs, actions)): + if action is not None: + new_obs, reward, done, info = env.step(action) + + # 记录transition(只存储环境原始奖励,不进行奖励函数计算) + # 如果提供了env_type,会在record_transition时传入,用于多环境数据分离 + transition_kwargs = {} + if env_type is not None: + transition_kwargs['env_type'] = env_type + + self.algorithm.record_transition( + state=self._last_obs[i], + action=action, + reward=reward, # 存储原始奖励 + next_state=new_obs, + done=done, + info=info, + **transition_kwargs # 传入env_type等额外信息 + ) + + new_obs_list.append(new_obs) + rewards.append(reward) # 统计使用原始奖励 + dones.append(done) + infos.append(info) + + # 统计episode信息 + if done and 'episode' in info: + stats['episode_rewards'].append(info['episode'].get('r', 0)) + stats['episode_lengths'].append(info['episode'].get('l', 0)) + + # 如果episode结束,重置环境 + if done: + new_obs = env.reset() + new_obs_list[i] = new_obs + else: + new_obs_list.append(self._last_obs[i]) + rewards.append(0) + dones.append(True) + infos.append({}) + + self._last_obs = new_obs_list + self._last_dones = dones + stats['total_steps'] += len([d for d in dones if not d]) + + return stats + +# rl/collectors/real_collector.py +class RealCollector(BaseCollector): + """真实机器人数据收集器""" + + def __init__( + self, + meta_envs: MetaEnv, # 真实机器人环境 + algorithm: BaseAlgorithm, + **kwargs + ): + super().__init__(meta_envs, algorithm, **kwargs) + # 真实机器人特定的初始化... + + def collect(self, n_steps: int) -> Dict[str, Any]: + """ + 从真实机器人收集数据 + + 注意:真实机器人收集可能需要特殊的安全检查和限制 + """ + # 实现真实机器人数据收集逻辑 + # 可能需要添加安全限制、速度限制等 + pass +``` + +--- + +### 5. 训练器基类 (`BaseTrainer`) + +**职责:** 协调环境、策略和算法,执行训练循环 + +**核心接口:** + +```python +class BaseTrainer: + """RL训练器基类""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + algorithm: Union[BaseAlgorithm, List[BaseAlgorithm]], + collector: Optional[BaseCollector] = None, + reward_fn: Optional[Union[BaseReward, Callable]] = None, + **kwargs + ): + """ + 初始化训练器 + + Args: + meta_envs: 支持多种格式: + - MetaEnv实例:单个环境 + - List[MetaEnv]:环境列表(同类型环境) + - Callable:环境工厂函数 + - Dict[str, Any]:多环境配置字典(支持不同类型环境) + algorithm: BaseAlgorithm实例或BaseAlgorithm列表(必需属性) + - 单个算法:单个智能体训练 + - 算法列表:多个算法在同一个环境中独立训练(每个算法有自己的replay buffer) + collector: 可选的数据收集器(如果为None,trainer会创建默认collector) + - 单个算法:单个collector + - 多个算法:可以是collector列表,每个算法对应一个collector + - 如果为None,trainer会为每个算法创建默认collector + reward_fn: 可选的奖励函数(如果为None,使用环境原始奖励) + - 可以是BaseReward实例或Callable函数 + - 在训练时用于计算奖励(用于算法更新) + - 注意:replay buffer中存储的是原始奖励,奖励函数只在训练时应用 + **kwargs: 训练器特定的参数 + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + self.reward_fn = reward_fn + + # 处理collector初始化 + if collector is None: + from rl.collectors import SimCollector + if isinstance(algorithm, BaseAlgorithm): + # 单个算法:创建单个collector + collector = SimCollector( + meta_envs=meta_envs, + algorithm=algorithm + ) + else: + # 多个算法:为每个算法创建collector + collector = [ + SimCollector( + meta_envs=meta_envs, + algorithm=alg + ) + for alg in algorithm + ] + + self.collector = collector + + def train(self, **kwargs) -> None: + """ + 执行训练循环 + + Args: + **kwargs: 训练参数,可以包括: + - total_steps: 总训练步数(可选) + - total_episodes: 总episode数(可选) + - max_time: 最大训练时间(可选) + - log_interval: 日志记录间隔(可选) + - save_interval: 模型保存间隔(可选) + - eval_interval: 评估间隔(可选) + """ + raise NotImplementedError + + def compute_reward( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + 计算奖励(支持自定义奖励函数) + + 在训练时使用,用于算法更新时的奖励计算 + replay buffer中存储的是原始奖励,奖励函数只在训练时应用 + + Args: + state: 当前状态 + action: 动作 + next_state: 下一个状态 + env_reward: 环境原始奖励 + info: 额外信息字典 + + Returns: + 计算后的奖励值 + """ + if self.reward_fn is not None: + if isinstance(self.reward_fn, BaseReward): + return self.reward_fn.compute(state, action, next_state, env_reward, info) + else: + return self.reward_fn(state, action, next_state, env_reward, info) + return env_reward + + def collect_rollout(self, n_steps: int, env_type: Optional[str] = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + 收集rollout数据(使用collector) + + Args: + n_steps: 收集步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + - 单个算法:rollout统计信息字典 + - 多个算法:rollout统计信息字典列表 + """ + if isinstance(self.collector, list): + # 多个算法:每个算法独立收集数据 + return [col.collect(n_steps, env_type=env_type) for col in self.collector] + else: + # 单个算法 + return self.collector.collect(n_steps, env_type=env_type) + + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + 评估策略性能 + + Args: + n_episodes: 评估的episode数量 + render: 是否渲染环境(可选) + env_type: 可选,指定评估的环境类型 + **kwargs: 其他评估参数 + + Returns: + 包含评估指标的字典 + """ + raise NotImplementedError + + def save(self, path: str) -> None: + """保存模型和训练状态""" + raise NotImplementedError + + def load(self, path: str) -> None: + """加载模型和训练状态""" + raise NotImplementedError +``` + +--- + +## 配置系统 + +### 配置结构 + +RL配置遵循ILStudio的配置系统设计,使用YAML格式,支持命令行覆盖。 + +#### `configs/rl/ppo.yaml` + +```yaml +name: ppo +type: rl.algorithms.ppo.PPOAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 # 折扣因子 + gae_lambda: 0.95 # GAE lambda + clip_range: 0.2 # PPO clip range + value_coef: 0.5 # Value loss 系数 + entropy_coef: 0.01 # 熵正则化系数 + max_grad_norm: 0.5 # 梯度裁剪 + n_steps: 2048 # 每次更新的步数 + batch_size: 64 # Mini-batch 大小 + n_epochs: 10 # 每次更新的 epoch 数 + learning_rate: 3e-4 + +# Replay Buffer配置(on-policy算法使用RolloutBuffer) +replay: + type: rl.buffer.rollout_buffer.RolloutBuffer + capacity: 2048 + device: cpu + n_envs: 8 + gae_lambda: 0.95 + gamma: 0.99 + +# 奖励函数配置(可选,在trainer中使用,用于训练时的奖励计算) +reward: + type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 1.0 + # 或者使用组合奖励 + # type: rl.rewards.composite_reward.CompositeReward + # components: + # - type: rl.rewards.sparse_reward.SparseReward + # success_reward: 1.0 + # - type: rl.rewards.dense_reward.DenseReward + # distance_scale: 0.1 + +# 数据收集器配置(可选,trainer会创建默认collector) +collector: + type: rl.collectors.sim_collector.SimCollector + n_envs: 8 + +# 策略网络配置(复用现有 policy 配置) +policy: + type: policy.act # 或 policy.diffusion_policy, policy.openvla 等 + # 继承对应 policy 的配置... + +# 价值网络配置(可选,Actor-Critic算法需要) +value_network: + type: mlp + hidden_dims: [256, 256] + activation: relu + +# 训练配置 +training: + total_timesteps: 1000000 + save_freq: 10000 + eval_freq: 5000 + log_interval: 100 +``` + +#### `configs/rl/sac.yaml` + +```yaml +name: sac +type: rl.algorithms.sac.SACAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 + tau: 0.005 # Soft update coefficient + learning_rate: 3e-4 + batch_size: 256 + target_update_interval: 1 + alpha: 0.2 # Temperature parameter (auto-tuned if None) + +# Replay Buffer配置(off-policy算法使用ReplayBuffer) +replay: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + prioritized: false # 是否使用优先级采样 + +# 奖励函数配置(可选) +reward: + type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 1.0 + +# 策略网络配置 +policy: + type: policy.act + +# Q网络配置 +q_network: + type: mlp + hidden_dims: [256, 256] + activation: relu + +# 训练配置 +training: + total_timesteps: 1000000 + save_freq: 10000 + eval_freq: 5000 + log_interval: 100 +``` + +#### `configs/rl/dpo.yaml` (VLA Fine-tuning) + +```yaml +name: dpo +type: rl.algorithms.dpo.DPOAlgorithm + +# 算法参数 +algorithm: + beta: 0.1 # KL penalty coefficient + learning_rate: 1e-5 + batch_size: 32 + reference_free: false # 是否使用无参考模型的 DPO + label_smoothing: 0.0 + +# 策略网络配置(VLA模型) +policy: + type: policy.openvla + # 继承对应 policy 的配置... + +# 参考策略配置(用于KL散度计算) +reference_policy: + type: policy.openvla + # 通常是预训练模型的frozen副本 + +# 奖励函数配置(语言条件奖励) +reward: + type: rl.rewards.language_reward.LanguageReward + reward_model: + type: learned_reward + checkpoint: ckpt/reward_model.pth + +# 训练配置 +training: + total_episodes: 1000 + save_freq: 100 + eval_freq: 50 + log_interval: 10 +``` + +### 配置加载 + +```python +# 在ConfigLoader中添加RL配置加载方法 +class ConfigLoader: + # ... 现有方法 ... + + def load_rl(self, name_or_path: str) -> Tuple[Dict[str, Any], str]: + """ + 加载RL配置 + + Args: + name_or_path: RL配置名称或路径(如'ppo'或'rl/ppo') + + Returns: + (配置字典, 配置文件路径) + """ + return self.load_yaml_config('rl', name_or_path) +``` + +--- + +## VLA 强化学习特殊处理 + +### DPO (Direct Preference Optimization) + +```python +# rl/algorithms/dpo.py +from dataclasses import dataclass +import torch +import torch.nn.functional as F +from rl.base import BaseAlgorithm +from benchmark.base import MetaPolicy + +@dataclass +class DPOConfig: + """DPO 特定配置""" + beta: float = 0.1 + reference_free: bool = False + label_smoothing: float = 0.0 + learning_rate: float = 1e-5 + batch_size: int = 32 + +class DPOAlgorithm(BaseAlgorithm): + """ + Direct Preference Optimization for VLA + + 适用于: + - 有偏好数据 (chosen vs rejected trajectories) + - Fine-tuning 预训练 VLA 模型 + """ + + def __init__( + self, + meta_policy: MetaPolicy, + ref_policy: MetaPolicy, # 参考策略 (frozen) + config: DPOConfig, + **kwargs + ): + super().__init__(meta_policy=meta_policy, **kwargs) + self.ref_policy = ref_policy + self.config = config + + # 冻结参考策略 + self.ref_policy.eval() + for p in self.ref_policy.parameters(): + p.requires_grad = False + + def compute_dpo_loss( + self, + chosen_obs: MetaObs, + chosen_actions: MetaAction, + rejected_obs: MetaObs, + rejected_actions: MetaAction + ) -> Dict[str, torch.Tensor]: + """ + 计算 DPO 损失 + + L_DPO = -log(σ(β * (log π(a_w|s) - log π_ref(a_w|s) + - log π(a_l|s) + log π_ref(a_l|s)))) + """ + # 计算当前策略的 log prob + chosen_logps = self.meta_policy.get_log_prob(chosen_obs, chosen_actions) + rejected_logps = self.meta_policy.get_log_prob(rejected_obs, rejected_actions) + + # 计算参考策略的 log prob + with torch.no_grad(): + ref_chosen_logps = self.ref_policy.get_log_prob(chosen_obs, chosen_actions) + ref_rejected_logps = self.ref_policy.get_log_prob(rejected_obs, rejected_actions) + + # DPO 损失 + chosen_rewards = self.config.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.config.beta * (rejected_logps - ref_rejected_logps) + + loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean() + + return { + 'loss': loss, + 'chosen_rewards': chosen_rewards.mean(), + 'rejected_rewards': rejected_rewards.mean(), + 'reward_margin': (chosen_rewards - rejected_rewards).mean() + } + + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """DPO更新逻辑""" + if batch is None: + raise ValueError("DPO requires batch with chosen/rejected pairs") + + # 提取chosen和rejected数据 + chosen_obs = batch['chosen_obs'] + chosen_actions = batch['chosen_actions'] + rejected_obs = batch['rejected_obs'] + rejected_actions = batch['rejected_actions'] + + # 计算损失 + loss_dict = self.compute_dpo_loss( + chosen_obs, chosen_actions, + rejected_obs, rejected_actions + ) + + # 反向传播 + loss_dict['loss'].backward() + + return loss_dict +``` + +### GRPO (Group Relative Policy Optimization) + +```python +# rl/algorithms/grpo.py +from dataclasses import dataclass +import torch +from rl.base import BaseAlgorithm + +@dataclass +class GRPOConfig: + """GRPO 配置""" + group_size: int = 4 + kl_coef: float = 0.1 + reward_scale: float = 1.0 + learning_rate: float = 1e-5 + +class GRPOAlgorithm(BaseAlgorithm): + """ + Group Relative Policy Optimization + + 适用于 VLA 的在线强化学习: + 1. 对每个任务/指令采样多个轨迹 + 2. 使用奖励对轨迹进行排序 + 3. 使用组内相对奖励进行策略优化 + """ + + def __init__( + self, + meta_policy: MetaPolicy, + ref_policy: MetaPolicy, # 参考策略 (frozen) + config: GRPOConfig, + **kwargs + ): + super().__init__(meta_policy=meta_policy, **kwargs) + self.ref_policy = ref_policy + self.config = config + + # 冻结参考策略 + self.ref_policy.eval() + for p in self.ref_policy.parameters(): + p.requires_grad = False + + def collect_group_rollouts(self, obs: MetaObs, n_samples: int) -> List[Dict]: + """ + 对同一观测采样多个动作序列 + + Args: + obs: 初始观测 + n_samples: 采样数量 + + Returns: + 轨迹列表 + """ + trajectories = [] + for _ in range(n_samples): + traj = self._rollout_episode(obs) + trajectories.append(traj) + return trajectories + + def compute_grpo_loss( + self, + trajectories: List[Dict], + rewards: List[float] + ) -> Dict[str, torch.Tensor]: + """ + 计算 GRPO 损失 + + 使用组内相对奖励作为 advantage + """ + # 归一化组内奖励 + rewards_tensor = torch.tensor(rewards) + normalized_rewards = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8) + + loss = 0 + for traj, adv in zip(trajectories, normalized_rewards): + log_probs = self.meta_policy.get_trajectory_log_prob(traj) + loss -= (log_probs * adv).mean() + + # KL 惩罚 + kl_loss = self._compute_kl_penalty(trajectories) + + total_loss = loss + self.config.kl_coef * kl_loss + + return { + 'loss': total_loss, + 'policy_loss': loss, + 'kl_loss': kl_loss, + 'mean_reward': rewards_tensor.mean() + } + + def _compute_kl_penalty(self, trajectories: List[Dict]) -> torch.Tensor: + """计算KL散度惩罚""" + kl_loss = 0 + for traj in trajectories: + current_logps = self.meta_policy.get_trajectory_log_prob(traj) + with torch.no_grad(): + ref_logps = self.ref_policy.get_trajectory_log_prob(traj) + kl = (current_logps - ref_logps).mean() + kl_loss += kl + return kl_loss / len(trajectories) +``` + +--- + +## 训练入口 `train_rl.py` + +```python +#!/usr/bin/env python3 +""" +ILStudio Reinforcement Learning Training Script + +支持: +- 传统 RL (PPO, SAC) 训练机器人控制策略 +- VLA fine-tuning (DPO, GRPO) 使用强化学习 +- 混合训练 (IL + RL) +""" + +import argparse +from loguru import logger +from configs.loader import ConfigLoader +from data_utils.utils import set_seed +from policy.policy_loader import PolicyLoader +from rl.algorithms import get_algorithm_class +from rl.buffer import get_replay_class +from rl.rewards import get_reward_class +from rl.collectors import get_collector_class +from rl.trainers import get_trainer_class + +def parse_args(): + parser = argparse.ArgumentParser(description='RL Training for ILStudio') + + # 基础配置 + parser.add_argument('-p', '--policy', type=str, required=True, + help='Policy config (e.g., act, diffusion_policy, openvla)') + parser.add_argument('-r', '--rl_config', type=str, default='ppo', + help='RL algorithm config (e.g., ppo, sac, dpo)') + parser.add_argument('-e', '--env', type=str, required=True, + help='Environment config') + parser.add_argument('-o', '--output_dir', type=str, default='ckpt/rl_output') + + # 训练模式 + parser.add_argument('--mode', type=str, default='online', + choices=['online', 'offline', 'hybrid'], + help='Training mode: online (env interaction), offline (dataset), hybrid') + + # 预训练模型 (用于 fine-tuning) + parser.add_argument('--pretrained', type=str, default=None, + help='Path to pretrained model checkpoint') + + # 其他 + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--device', type=str, default='cuda') + + args, unknown = parser.parse_known_args() + args.unknown_args = unknown + return args + + +def create_algorithm(rl_config, policy, env_config, args): + """创建RL算法实例""" + # 创建replay buffer(如果配置中有) + replay = None + if 'replay' in rl_config: + replay_class = get_replay_class(rl_config['replay']['type']) + replay = replay_class(**rl_config['replay']) + + # 创建算法实例(奖励函数不在算法中,在trainer中) + algorithm_class = get_algorithm_class(rl_config['type']) + algorithm = algorithm_class( + meta_policy=policy, + replay=replay, + **rl_config.get('algorithm', {}) + ) + + return algorithm + + +def create_reward_fn(rl_config): + """创建奖励函数实例(在trainer中使用,用于训练时的奖励计算)""" + reward_fn = None + if 'reward' in rl_config: + reward_class = get_reward_class(rl_config['reward']['type']) + reward_fn = reward_class(**rl_config['reward']) + return reward_fn + + +def create_collector(rl_config, env, algorithm): + """创建数据收集器实例(不包含reward_fn,只存储原始奖励)""" + collector = None + if 'collector' in rl_config: + collector_class = get_collector_class(rl_config['collector']['type']) + collector = collector_class( + meta_envs=env, + algorithm=algorithm, + **rl_config['collector'] + ) + return collector + + +def main(): + args = parse_args() + set_seed(args.seed) + + # 加载配置 + cfg_loader = ConfigLoader(args=args, unknown_args=args.unknown_args) + policy_config, _ = cfg_loader.load_policy(args.policy) + rl_config, _ = cfg_loader.load_rl(args.rl_config) + env_config, _ = cfg_loader.load_env(args.env) + + # 创建环境 + env = create_env_from_config(env_config) + + # 加载/创建策略 + policy_loader = PolicyLoader() + if args.pretrained: + logger.info(f"Loading pretrained model from {args.pretrained}") + policy = policy_loader.load_pretrained(args.pretrained, policy_config) + else: + logger.info("Creating policy from scratch") + policy = policy_loader.create_policy(policy_config, args) + + # 创建RL算法 + algorithm = create_algorithm(rl_config, policy, env_config, args) + + # 创建奖励函数(在trainer中使用,用于训练时的奖励计算) + reward_fn = create_reward_fn(rl_config) + + # 创建数据收集器(可选,trainer会创建默认的,不包含reward_fn) + collector = create_collector(rl_config, env, algorithm) + + # 创建训练器 + trainer_class = get_trainer_class('simple') # 或从配置中读取 + trainer = trainer_class( + meta_envs=env, + algorithm=algorithm, + collector=collector, # collector在trainer中(只存储原始奖励) + reward_fn=reward_fn # reward_fn在trainer中(用于训练时的奖励计算) + ) + + # 训练 + logger.info("="*60) + logger.info(f"🚀 Starting RL Training: {rl_config['name']}") + logger.info(f" Policy: {policy_config.get('name', args.policy)}") + logger.info(f" Environment: {env_config.get('name', args.env)}") + logger.info(f" Mode: {args.mode}") + logger.info("="*60) + + trainer.train( + total_steps=rl_config['training']['total_timesteps'], + log_interval=rl_config['training'].get('log_interval', 100), + save_interval=rl_config['training'].get('save_freq', 10000), + eval_interval=rl_config['training'].get('eval_freq', 5000), + output_dir=args.output_dir + ) + + logger.info("✓ Training completed!") + + +if __name__ == '__main__': + main() +``` + +--- + +## 使用示例 + +### 示例1:传统RL训练(PPO + ACT) + +```bash +python train_rl.py \ + -p act \ + -r ppo \ + -e libero.example \ + -o ckpt/rl_act_ppo +``` + +### 示例2:VLA Fine-tuning(DPO + OpenVLA) + +```bash +python train_rl.py \ + -p openvla \ + -r dpo \ + -e behavior1k.example \ + --pretrained ckpt/openvla_pretrained \ + -o ckpt/openvla_dpo +``` + +### 示例3:使用自定义奖励函数 + +```python +# 在配置文件中指定奖励函数 +# configs/rl/ppo.yaml +reward: + type: rl.rewards.composite_reward.CompositeReward + components: + - type: rl.rewards.sparse_reward.SparseReward + success_reward: 1.0 + weight: 0.5 + - type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 0.1 + weight: 0.5 +``` + +### 示例4:单个算法在多个不同环境的训练(场景1) + +```python +# 场景1:单个算法在多个不同环境(如室内、室外)中训练 +# 数据按环境类型分离存储到不同的replay buffer + +from rl.buffer import MemoryReplay +from rl.algorithms import SACAlgorithm +from rl.collectors import SimCollector +from rl.trainers import SimpleTrainer + +# 创建多个环境的replay buffer(按环境类型分离) +replay = { + 'indoor': MemoryReplay(capacity=1000000, device='cpu'), + 'outdoor': MemoryReplay(capacity=1000000, device='cpu') +} + +# 创建算法(使用多个replay buffer) +algorithm = SACAlgorithm( + meta_policy=policy, + replay=replay # Dict[str, BaseReplay],按环境类型分离 +) + +# 创建多个环境的collector +indoor_collector = SimCollector( + meta_envs=indoor_envs, # 室内环境 + algorithm=algorithm +) + +outdoor_collector = SimCollector( + meta_envs=outdoor_envs, # 室外环境 + algorithm=algorithm +) + +# 创建trainer +trainer = SimpleTrainer( + meta_envs={'indoor': indoor_envs, 'outdoor': outdoor_envs}, + algorithm=algorithm, + reward_fn=reward_fn +) + +# 训练时,分别从不同环境收集数据 +# 数据会自动存储到对应环境类型的replay buffer中 +for step in range(total_steps): + # 从室内环境收集数据(env_type='indoor') + indoor_stats = indoor_collector.collect(n_steps=1000, env_type='indoor') + + # 从室外环境收集数据(env_type='outdoor') + outdoor_stats = outdoor_collector.collect(n_steps=1000, env_type='outdoor') + + # 更新算法(可以从不同环境类型的replay混合采样) + loss = algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], # 指定从哪些环境采样 + env_weights={'indoor': 0.6, 'outdoor': 0.4} # 不同环境的数据权重 + ) +``` + +### 示例5:多个算法在同一个环境中独立训练(场景2) + +```python +# 场景2:多个算法在同一个环境中独立训练 +# 每个算法有自己的replay buffer和collector + +from rl.buffer import MemoryReplay +from rl.algorithms import SACAlgorithm, PPOAlgorithm +from rl.collectors import SimCollector +from rl.trainers import SimpleTrainer + +# 创建多个算法,每个算法有自己的replay buffer +algorithm1 = SACAlgorithm( + meta_policy=policy1, + replay=MemoryReplay(capacity=1000000, device='cpu') +) + +algorithm2 = PPOAlgorithm( + meta_policy=policy2, + replay=MemoryReplay(capacity=100000, device='cpu') +) + +# 创建多个collector,每个算法对应一个collector +collector1 = SimCollector( + meta_envs=env, # 同一个环境 + algorithm=algorithm1 +) + +collector2 = SimCollector( + meta_envs=env, # 同一个环境 + algorithm=algorithm2 +) + +# 创建trainer(传入算法列表和collector列表) +trainer = SimpleTrainer( + meta_envs=env, + algorithm=[algorithm1, algorithm2], # 多个算法 + collector=[collector1, collector2], # 每个算法对应一个collector + reward_fn=reward_fn +) + +# 训练时,每个算法独立收集数据和更新 +for step in range(total_steps): + # 收集rollout(返回每个算法的统计信息) + stats_list = trainer.collect_rollout(n_steps=1000) + # stats_list[0] 是 algorithm1 的统计信息 + # stats_list[1] 是 algorithm2 的统计信息 + + # 每个算法独立更新 + loss1 = algorithm1.update(batch_size=256) + loss2 = algorithm2.update(batch_size=64) +``` + +### 示例6:配置文件中支持多环境场景 + +```yaml +# configs/rl/multi_env_sac.yaml +name: multi_env_sac +type: rl.algorithms.sac.SACAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 + learning_rate: 3e-4 + batch_size: 256 + +# Replay Buffer配置(按环境类型分离) +replay: + indoor: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + outdoor: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + +# 策略网络配置 +policy: + type: policy.act + +# 训练配置 +training: + total_timesteps: 1000000 + env_types: ['indoor', 'outdoor'] + env_weights: + indoor: 0.6 + outdoor: 0.4 +``` + +--- + +## 接口关系图 + +``` +┌─────────────────────────────────────────┐ +│ BaseReplay │ +│ (存储经验数据) │ +│ - 完整MetaObs/MetaAction │ +│ - 自定义字段(value、log_prob等) │ +└─────────────────────────────────────────┘ + ▲ + │ (可选,在算法中) + │ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ BaseAlgorithm │ │ BaseCollector │ │ BaseTrainer │ +│ │ │ │ │ │ +│ + meta_policy │◄─────┤ + algorithm │◄─────┤ + collector │ +│ + replay │ │ + meta_envs │ │ + algorithm │ +│ + update() │ │ + collect() │ │ + reward_fn │ +│ + record_ │ │ (只存储原始奖励) │ │ + compute_ │ +│ transition() │ └─────────────────┘ │ reward() │ +└─────────────────┘ │ │ + train() │ + │ │ + evaluate() │ + │ └─────────────────┘ + │ │ + ┌────────────┼────────────┐ │ + │ │ │ │ 使用 + │ 使用 │ 使用 │ │ + ▼ ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌─────────────────┐ + │ MetaEnv │ │MetaPolicy│ │ BaseReward │ + │ │ │ │ │ (奖励函数) │ + └──────────┘ └──────────┘ │ - compute() │ + └─────────────────┘ +``` + +--- + +## 设计亮点 + +### 1. **完整的Meta数据存储** +- Replay buffer存储完整的MetaObs和MetaAction,包括所有字段 +- 不丢失任何信息,方便后续分析和扩展 + +### 2. **模块化奖励函数** +- 奖励函数独立模块,易于替换和扩展 +- 支持组合多个奖励函数 +- 支持VLA的语言条件奖励 +- **原始数据存储**:replay buffer存储环境原始奖励,奖励函数只在训练时应用 +- **数据可复用性**:同样的数据可以用不同的奖励函数进行训练 + +### 3. **灵活的配置系统** +- 使用YAML配置,支持命令行覆盖 +- 复用ILStudio现有的配置加载机制 +- 配置驱动,易于实验和部署 + +### 4. **兼容ILStudio pipeline** +- **原始数据存储**:replay buffer存储原始MetaObs、MetaAction和环境原始奖励 +- **奖励函数分离**:奖励函数在trainer中应用,不影响数据收集和存储 +- 采样时通过data_processor对齐ILStudio的normalization pipeline +- 既保持灵活性,又兼容现有系统 + +### 5. **VLA特殊支持** +- 专门的DPO和GRPO算法实现 +- 支持语言条件奖励 +- 支持参考策略的KL散度约束 + +### 6. **数据集持久化** +- 支持保存和加载数据集,方便数据管理和复用 +- 支持多种保存格式(pkl、hdf5、npz等) +- 支持追加加载,可以合并多个数据集 + +### 7. **灵活的数据存储策略** +- **场景1:单个算法在多个不同环境** + - 算法可以使用 `Dict[str, BaseReplay]` 按环境类型分离存储 + - 在 `record_transition` 时传入 `env_type` 参数 + - 支持从不同环境类型的replay混合采样,可配置数据权重 + - 适用于跨域学习、多任务学习等场景 + +- **场景2:多个算法在同一个环境** + - Trainer 可以接受 `List[BaseAlgorithm]`,每个算法独立训练 + - 每个算法有自己的 replay buffer 和 collector + - 支持算法对比、ensemble训练等场景 + - 数据完全隔离,互不干扰 + +--- + +## 总结 + +这个整合设计: + +1. ✅ **保持设计2的核心架构**:BaseReplay、BaseAlgorithm、BaseTrainer三个基类 +2. ✅ **添加奖励函数模块化**:独立的BaseReward基类和具体实现 +3. ✅ **添加配置系统**:YAML配置文件和配置加载机制 +4. ✅ **添加VLA支持**:DPO和GRPO算法实现 +5. ✅ **添加训练入口**:train_rl.py脚本设计 + +通过组合这些组件,可以灵活地实现各种RL训练场景,同时保持与ILStudio现有系统的兼容性。 + diff --git a/rl/__init__.py b/rl/__init__.py new file mode 100644 index 00000000..bc650612 --- /dev/null +++ b/rl/__init__.py @@ -0,0 +1,173 @@ +""" +ILStudio Reinforcement Learning Module + +This module provides a modular RL framework for robot learning, supporting: +- Traditional RL algorithms (PPO, SAC, TD3, etc.) +- VLA fine-tuning algorithms (DPO, GRPO) +- Flexible replay buffers for experience storage +- Modular reward functions +- Data collectors for various environments +- Trainers for different training scenarios +- Infrastructure for reproducibility and stability + +The framework is designed to: +1. Be highly abstract - only core interfaces, no specific implementations +2. Be universal - support various RL algorithms and training modes +3. Be compatible - directly use MetaEnv and MetaPolicy without adapters +4. Be extensible - easy to extend for parallel and distributed training +5. Be modular - reward functions and config systems are independently replaceable +6. Be reproducible - comprehensive infrastructure for reproducibility + +Directory Structure: + rl/ + ├── __init__.py # This file + ├── algorithms/ # RL algorithm implementations + │ ├── __init__.py # Algorithm registry + │ └── base.py # BaseAlgorithm class + │ └── __init__.py # Algorithm registry + ├── buffer/ # Replay buffer implementations + │ ├── __init__.py + │ └── base_replay.py # BaseReplay class + ├── rewards/ # Reward function implementations + │ ├── __init__.py + │ └── base_reward.py # BaseReward class + ├── collectors/ # Data collector implementations + │ ├── __init__.py + │ └── base_collector.py # BaseCollector class + ├── trainers/ # Trainer implementations + │ ├── __init__.py + │ └── base_trainer.py # BaseTrainer class + ├── infra/ # Infrastructure for reproducibility + │ ├── __init__.py + │ ├── seed_manager.py # Random seed management + │ ├── logger.py # Logging system + │ ├── checkpoint.py # Checkpoint management + │ ├── callback.py # Callback system + │ └── distributed.py # Distributed training support + └── utils/ # Utility functions + └── __init__.py +""" + +# Base classes +from .algorithms.base import BaseAlgorithm +from .buffer import BaseReplay +from .rewards import BaseReward +from .collectors import BaseCollector +from .trainers import BaseTrainer + +# Factory functions +from .algorithms import get_algorithm_class, register_algorithm, list_algorithms +from .rewards import get_reward_class, register_reward, list_rewards +from .collectors import get_collector_class, register_collector, list_collectors +from .trainers import get_trainer_class, register_trainer, list_trainers + +# Utility functions +from .utils import ( + compute_gae, + compute_returns, + RunningMeanStd, + explained_variance, + polyak_update, + hard_update, +) + +# Infrastructure - Seed management +from .infra import ( + SeedManager, + set_global_seed, + get_global_seed, +) + +# Infrastructure - Logging +from .infra import ( + BaseLogger, + ConsoleLogger, + TensorBoardLogger, + CompositeLogger, +) + +# Infrastructure - Checkpoint +from .infra import CheckpointManager + +# Infrastructure - Callbacks +from .infra import ( + Callback, + CallbackList, + ProgressCallback, + EvalCallback, + CheckpointCallback, + EarlyStoppingCallback, +) + +# Infrastructure - Distributed +from .infra import ( + DistributedContext, + get_world_size, + get_rank, + is_main_process, +) + +__all__ = [ + # Base classes + 'BaseAlgorithm', + 'BaseReplay', + 'BaseReward', + 'BaseCollector', + 'BaseTrainer', + + # Factory functions for algorithms + 'get_algorithm_class', + 'register_algorithm', + 'list_algorithms', + + # Factory functions for rewards + 'get_reward_class', + 'register_reward', + 'list_rewards', + + # Factory functions for collectors + 'get_collector_class', + 'register_collector', + 'list_collectors', + + # Factory functions for trainers + 'get_trainer_class', + 'register_trainer', + 'list_trainers', + + # Utility functions + 'compute_gae', + 'compute_returns', + 'RunningMeanStd', + 'explained_variance', + 'polyak_update', + 'hard_update', + + # Infrastructure - Seed management + 'SeedManager', + 'set_global_seed', + 'get_global_seed', + + # Infrastructure - Logging + 'BaseLogger', + 'ConsoleLogger', + 'TensorBoardLogger', + 'CompositeLogger', + + # Infrastructure - Checkpoint + 'CheckpointManager', + + # Infrastructure - Callbacks + 'Callback', + 'CallbackList', + 'ProgressCallback', + 'EvalCallback', + 'CheckpointCallback', + 'EarlyStoppingCallback', + + # Infrastructure - Distributed + 'DistributedContext', + 'get_world_size', + 'get_rank', + 'is_main_process', +] diff --git a/rl/algorithms/__init__.py b/rl/algorithms/__init__.py new file mode 100644 index 00000000..d93542bf --- /dev/null +++ b/rl/algorithms/__init__.py @@ -0,0 +1,127 @@ +""" +RL Algorithms Module + +This module provides various RL algorithm implementations. + +Available algorithms: +- PPO: Proximal Policy Optimization +- SAC: Soft Actor-Critic +- TD3: Twin Delayed DDPG +- DPO: Direct Preference Optimization (for VLA) +- GRPO: Group Relative Policy Optimization (for VLA) +- REINFORCE: Basic policy gradient + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating algorithms. +""" + +from typing import Type, Dict, Any, Optional, Tuple + +# Registry for algorithm classes and their config classes +_ALGORITHM_REGISTRY: Dict[str, Type] = {} +_CONFIG_REGISTRY: Dict[str, Type] = {} + + +def register_algorithm(name: str, algorithm_class: Type, config_class: Type = None) -> None: + """ + Register an algorithm class and its config class. + + Args: + name: Algorithm name (e.g., 'ppo', 'sac') + algorithm_class: Algorithm class to register + config_class: Config class for the algorithm (optional) + """ + _ALGORITHM_REGISTRY[name.lower()] = algorithm_class + if config_class is not None: + _CONFIG_REGISTRY[name.lower()] = config_class + + +def get_algorithm_class(name_or_type: str) -> Type: + """ + Get algorithm class by name or type string. + + Args: + name_or_type: Algorithm name (e.g., 'ppo') or full type path + (e.g., 'rl.algorithms.ppo.PPOAlgorithm') + + Returns: + Algorithm class + + Raises: + ValueError: If algorithm not found + """ + # First check registry + if name_or_type.lower() in _ALGORITHM_REGISTRY: + return _ALGORITHM_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import algorithm from '{name_or_type}': {e}") + + raise ValueError(f"Unknown algorithm: '{name_or_type}'. Available: {list(_ALGORITHM_REGISTRY.keys())}") + + +def get_config_class(name: str) -> Optional[Type]: + """ + Get config class for an algorithm. + + Args: + name: Algorithm name (e.g., 'td3') + + Returns: + Config class or None if not registered + """ + return _CONFIG_REGISTRY.get(name.lower()) + + +def get_algorithm_and_config(name: str) -> Tuple[Type, Optional[Type]]: + """ + Get both algorithm class and config class. + + Args: + name: Algorithm name + + Returns: + Tuple of (algorithm_class, config_class) + """ + return get_algorithm_class(name), get_config_class(name) + + +def list_algorithms() -> list: + """List all registered algorithms.""" + return list(_ALGORITHM_REGISTRY.keys()) + + +__all__ = [ + 'register_algorithm', + 'get_algorithm_class', + 'get_config_class', + 'get_algorithm_and_config', + 'list_algorithms', +] + + +# Auto-register built-in algorithms when this module is imported +def _register_builtin_algorithms(): + """Import algorithm submodules to trigger their registration.""" + try: + from . import td3 # noqa: F401 + except ImportError: + pass + # Add more algorithms here as they are implemented + # try: + # from . import sac + # except ImportError: + # pass + + +_register_builtin_algorithms() diff --git a/rl/algorithms/base.py b/rl/algorithms/base.py new file mode 100644 index 00000000..2a4bcb68 --- /dev/null +++ b/rl/algorithms/base.py @@ -0,0 +1,514 @@ +""" +Base Algorithm Class + +This module defines the base class for all RL algorithms in the framework. + +Design Philosophy (inspired by SKRL): +- Replay buffer is placed inside the algorithm, allowing: + - Each algorithm to have its own replay configuration + - A single Trainer to train multiple different algorithms (each with own replay) + - More flexibility, supporting multi-agent scenarios +""" + +import torch +import numpy as np +from typing import Dict, Any, Optional, Union, List, Callable +from abc import ABC, abstractmethod +from dataclasses import asdict, fields + +# Type hints for Meta classes (imported at runtime to avoid circular imports) +from benchmark.base import MetaObs, MetaAction, MetaPolicy +from rl.buffer.transition import RLTransition + +class BaseAlgorithm(ABC): + """ + Base class for RL algorithms. + + This class defines the core interface for all RL algorithms. + It holds a reference to the policy (MetaPolicy) and optionally a replay buffer. + + Attributes: + meta_policy: The MetaPolicy instance used for action selection + replay: Optional replay buffer(s) for experience storage + """ + + def __init__( + self, + meta_policy: MetaPolicy, + replay: Optional[Union['BaseReplay', Dict[str, 'BaseReplay']]] = None, + **kwargs + ): + """ + Initialize the algorithm. + + Args: + meta_policy: ILStudio's MetaPolicy instance (required) + replay: Supports two formats: + - BaseReplay instance: Single replay buffer (shared by all environments) + - Dict[str, BaseReplay]: Multiple replay buffers (separated by environment type) + - None: No replay buffer (for on-policy algorithms) + **kwargs: Algorithm-specific parameters + """ + self.meta_policy = meta_policy # Required attribute + self.replay = replay # Optional attribute (off-policy algorithms need this) + self._kwargs = kwargs + + @abstractmethod + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """ + Update the policy using a batch of data. + + Args: + batch: Optional, batch data + - If None and replay exists, sample from replay + - If provided, use the provided batch directly + **kwargs: Update parameters, can include: + - batch_size: Batch size when sampling from replay + - env_types: Specify which environment types to sample from + (when replay is Dict[str, BaseReplay]) + - e.g., ['indoor', 'outdoor'] + - If None, sample from all environment types + - env_weights: Data weights for different environment types + (when using multiple replay buffers) + - e.g., {'indoor': 0.6, 'outdoor': 0.4} + - If None, use uniform weights + + Returns: + Dictionary containing loss, metrics, and other information + + Examples: + # Scenario 1: Single replay buffer + algorithm.update(batch_size=256) + + # Scenario 2: Multiple replay buffers (by environment type) + algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + """ + raise NotImplementedError + + def compute_loss(self, batch: Dict[str, Any]) -> torch.Tensor: + """ + Compute loss (optional, some algorithms may need this). + + Args: + batch: Batch data + + Returns: + Loss value + """ + raise NotImplementedError("Subclass should implement compute_loss if needed") + + def select_action( + self, + obs: Union[MetaObs, List[MetaObs], np.ndarray], + **kwargs + ) -> Union[MetaAction, List[MetaAction]]: + """ + Select action(s) for given observation(s). + + Supports both single and batched inputs for vectorized environments. + The recommended way for vectorized envs is to use organized batch obs + (a single MetaObs with batched fields like (n_envs, state_dim)). + + Args: + obs: Observation(s), can be: + - MetaObs: Single observation OR organized batch obs + (with batched fields like state: (n_envs, state_dim)) + - List[MetaObs] or np.ndarray: List of individual observations + (will be organized into batch internally) + **kwargs: Other parameters (e.g., t for timestep) + + Returns: + MetaAction: Single action or batched action + (with action field shape (n_envs, action_dim) for batch) + + Note: + For vectorized environments, it's recommended to: + 1. Use organize_obs() to convert List[MetaObs] -> batched MetaObs + 2. Pass the organized obs to this method + 3. The policy handles batched inference internally + """ + # Check if input is a list of observations that needs organizing + if isinstance(obs, (list, np.ndarray)) and len(obs) > 0: + first = obs[0] if isinstance(obs, list) else obs.flat[0] + if hasattr(first, '__dataclass_fields__'): # List of MetaObs + # Organize into batched MetaObs + obs = self._organize_obs(obs) + + # Pass to meta_policy (handles both single and batch) + return self.meta_policy.select_action(obs, **kwargs) + + def _organize_obs(self, obs_list: Union[List[MetaObs], np.ndarray]) -> MetaObs: + """ + Organize list of MetaObs into a single batched MetaObs. + + Converts List[MetaObs] -> MetaObs with batched fields. + e.g., [MetaObs(state=(10,)), MetaObs(state=(10,))] -> MetaObs(state=(2, 10)) + + Args: + obs_list: List or array of MetaObs objects + + Returns: + MetaObs with batched fields + """ + if len(obs_list) == 0: + return None + + # Convert to list of dicts + obs_dicts = [] + for o in obs_list: + if hasattr(o, '__dataclass_fields__'): + obs_dicts.append(asdict(o)) + elif isinstance(o, dict): + obs_dicts.append(o) + else: + obs_dicts.append(vars(o) if hasattr(o, '__dict__') else {}) + + # Stack each field + all_keys = list(obs_dicts[0].keys()) + batched = {} + for k in all_keys: + values = [d[k] for d in obs_dicts] + if values[0] is None: + batched[k] = None + elif isinstance(values[0], np.ndarray): + batched[k] = np.stack(values) + elif isinstance(values[0], (int, float)): + batched[k] = np.array(values) + elif isinstance(values[0], str): + batched[k] = values # Keep as list for strings + else: + batched[k] = values # Keep as list for other types + + # Convert back to MetaObs + return MetaObs(**{k: v for k, v in batched.items() if k in [f.name for f in fields(MetaObs)]}) + + def record_transition( + self, + transition: 'RLTransition', + **kwargs + ) -> None: + """ + Record transition to replay buffer (if exists). + + Args: + transition: RLTransition created by collector + **kwargs: Additional fields: + - env_type: Environment type (required for multi-buffer) + """ + if self.replay is None: + raise ValueError("Replay buffer is not set") + + from rl.buffer.transition import RLTransition + if not isinstance(transition, RLTransition): + raise TypeError("record_transition expects RLTransition") + + env_type = kwargs.get('env_type', None) + if isinstance(self.replay, dict): + if env_type is None: + raise ValueError("env_type must be provided when using multiple replay buffers") + if env_type not in self.replay: + raise ValueError(f"env_type '{env_type}' not found in replay buffers") + target_replay = self.replay[env_type] + else: + target_replay = self.replay + + target_replay.add(transition) + + def _stack_dicts(self, dict_list: List[Dict]) -> Dict: + """ + Stack a list of dicts into a single dict with batched values. + + Args: + dict_list: List of dicts with same keys + + Returns: + Dict with stacked values (numpy arrays where applicable) + """ + if not dict_list: + return {} + + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values # Keep strings as list + else: + result[key] = values # Keep other types as list + + return result + + def get_policy(self) -> MetaPolicy: + """Get the underlying MetaPolicy.""" + return self.meta_policy + + def train_mode(self) -> None: + """Set policy to training mode.""" + if hasattr(self.meta_policy, 'policy') and hasattr(self.meta_policy.policy, 'train'): + self.meta_policy.policy.train() + + def eval_mode(self) -> None: + """Set policy to evaluation mode.""" + if hasattr(self.meta_policy, 'policy') and hasattr(self.meta_policy.policy, 'eval'): + self.meta_policy.policy.eval() + + def save(self, path: str, **kwargs) -> None: + """ + Save algorithm state (model, optimizer, etc.). + + Args: + path: Save path + **kwargs: Additional save options + """ + raise NotImplementedError("Subclass should implement save") + + def load(self, path: str, **kwargs) -> None: + """ + Load algorithm state. + + Args: + path: Load path + **kwargs: Additional load options + """ + raise NotImplementedError("Subclass should implement load") + + def __repr__(self) -> str: + replay_info = "None" + if self.replay is not None: + if isinstance(self.replay, dict): + replay_info = f"Dict with {len(self.replay)} buffers" + else: + replay_info = repr(self.replay) + return f"{self.__class__.__name__}(meta_policy={self.meta_policy.__class__.__name__}, replay={replay_info})" + + +if __name__ == '__main__': + """ + Test code for BaseAlgorithm class. + + Since BaseAlgorithm is abstract, we create a simple concrete implementation for testing. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction, MetaPolicy + from rl.buffer.base_replay import BaseReplay + from rl.algorithms.base import BaseAlgorithm + from dataclasses import asdict + + # Simple replay buffer for testing + class SimpleReplay(BaseReplay): + def __init__(self, capacity=1000, device='cpu', **kwargs): + super().__init__(capacity=capacity, device=device, **kwargs) + self._storage = [] + + def add(self, transition): + if self._size < self.capacity: + self._storage.append(transition) + self._size += 1 + else: + self._storage[self._position] = transition + self._position = (self._position + 1) % self.capacity + + def sample(self, batch_size): + if self._size == 0: + return {} + indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) + return { + 'states': [self._storage[i]['state'] for i in indices], + 'actions': [self._storage[i]['action'] for i in indices], + 'rewards': np.array([self._storage[i]['reward'] for i in indices]), + 'next_states': [self._storage[i]['next_state'] for i in indices], + 'dones': np.array([self._storage[i]['done'] for i in indices]), + } + + def clear(self): + self._storage = [] + self._size = 0 + self._position = 0 + + def save(self, path, **kwargs): + pass + + def load(self, path, **kwargs): + pass + + # Simple policy for testing + class DummyPolicy: + def select_action(self, obs): + return MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + + def train(self): + pass + + def eval(self): + pass + + # Simple MetaPolicy wrapper for testing + class DummyMetaPolicy(MetaPolicy): + def __init__(self): + self.policy = DummyPolicy() + self.chunk_size = 1 + self.ctrl_space = 'ee' + self.ctrl_type = 'delta' + self.action_queue = [] + self.action_normalizer = None + self.state_normalizer = None + + def select_action(self, mobs, t=0, **kwargs): + return self.policy.select_action(mobs) + + # Simple concrete algorithm for testing + class SimpleAlgorithm(BaseAlgorithm): + def __init__(self, meta_policy, replay=None, learning_rate=1e-3, **kwargs): + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + self.learning_rate = learning_rate + self.update_count = 0 + + def update(self, batch=None, **kwargs): + batch_size = kwargs.get('batch_size', 32) + + if batch is None and self.replay is not None: + if isinstance(self.replay, dict): + # Sample from multiple replay buffers + env_types = kwargs.get('env_types', list(self.replay.keys())) + env_weights = kwargs.get('env_weights', {k: 1.0/len(env_types) for k in env_types}) + + batches = {} + for env_type in env_types: + n_samples = int(batch_size * env_weights.get(env_type, 1.0/len(env_types))) + batches[env_type] = self.replay[env_type].sample(n_samples) + batch = batches + else: + batch = self.replay.sample(batch_size) + + self.update_count += 1 + + return { + 'loss': np.random.randn(), + 'update_count': self.update_count, + 'batch_size': batch_size + } + + def compute_loss(self, batch): + return torch.tensor(np.random.randn()) + + # Test the implementation + print("=" * 60) + print("Testing BaseAlgorithm (SimpleAlgorithm implementation)") + print("=" * 60) + + # Test 1: Single replay buffer + print("\n1. Testing with single replay buffer...") + meta_policy = DummyMetaPolicy() + replay = SimpleReplay(capacity=100) + algorithm = SimpleAlgorithm(meta_policy=meta_policy, replay=replay, learning_rate=1e-4) + print(f" Created algorithm: {algorithm}") + + # Add transitions + print("\n2. Recording transitions...") + for i in range(10): + state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="test instruction" + ) + action = MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + next_state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="test instruction" + ) + + transition = RLTransition( + obs=state, + action=action, + next_obs=next_state, + reward=np.random.randn(), + done=(i == 9), + info={'step': i, 'value': np.random.randn()}, + ) + algorithm.record_transition(transition) + print(f" Replay buffer size: {len(algorithm.replay)}") + + # Update algorithm + print("\n3. Testing update...") + result = algorithm.update(batch_size=5) + print(f" Update result: {result}") + + # Test 2: Multiple replay buffers (by environment type) + print("\n4. Testing with multiple replay buffers...") + multi_replay = { + 'indoor': SimpleReplay(capacity=100), + 'outdoor': SimpleReplay(capacity=100) + } + multi_algorithm = SimpleAlgorithm(meta_policy=meta_policy, replay=multi_replay) + + # Add transitions to different environment types + for i in range(5): + state = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = MetaAction(action=np.random.randn(7).astype(np.float32)) + next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) + + transition = RLTransition( + obs=state, + action=action, + next_obs=next_state, + reward=1.0, + done=False, + ) + multi_algorithm.record_transition(transition, env_type='indoor') + transition = RLTransition( + obs=state, + action=action, + next_obs=next_state, + reward=0.5, + done=False, + ) + multi_algorithm.record_transition(transition, env_type='outdoor') + + print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") + print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") + + # Update with weighted sampling + result = multi_algorithm.update( + batch_size=8, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + print(f" Update with weighted sampling: {result}") + + # Test train/eval mode + print("\n5. Testing train/eval mode...") + algorithm.train_mode() + print(" Set to train mode") + algorithm.eval_mode() + print(" Set to eval mode") + + # Test select_action + print("\n6. Testing select_action...") + obs = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = algorithm.select_action(obs) + print(f" Selected action shape: {action.action.shape}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/algorithms/td3/__init__.py b/rl/algorithms/td3/__init__.py new file mode 100644 index 00000000..2660f4aa --- /dev/null +++ b/rl/algorithms/td3/__init__.py @@ -0,0 +1,8 @@ +from .td3 import TD3Algorithm, TD3Config +from .. import register_algorithm + +# Register algorithm with its config class +register_algorithm("td3", TD3Algorithm, TD3Config) + +__all__ = ["TD3Algorithm", "TD3Config"] + diff --git a/rl/algorithms/td3/td3.py b/rl/algorithms/td3/td3.py new file mode 100644 index 00000000..9570e9fe --- /dev/null +++ b/rl/algorithms/td3/td3.py @@ -0,0 +1,303 @@ +""" +TD3 algorithm implementation based on the official reference: +https://arxiv.org/abs/1802.09477 +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import rl.utils.action_utils as action_utils +from benchmark.base import MetaAction, MetaObs, MetaPolicy +from policy.mlp.mlp import MLPPolicy, MLPPolicyConfig +from rl.utils import polyak_update +from rl.utils.action_utils import ensure_action + +from ..base import BaseAlgorithm + + +@dataclass +class TD3Config: + state_dim: int + action_dim: int + discount: float = 0.99 + tau: float = 0.005 + policy_noise: float = 0.2 + noise_clip: float = 0.5 + policy_freq: int = 2 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + device: str = "cpu" + state_key: str = "state" + next_state_key: str = "next_state" + action_key: str = "action" + + +class _IdentityStateNormalizer: + def normalize_metaobs(self, mobs: MetaObs, ctrl_space: str): + return mobs + + +class _IdentityActionNormalizer: + def denormalize_metaact(self, mact: MetaAction): + return mact + + +class Critic(nn.Module): + def __init__(self, state_dim: int, action_dim: int): + super().__init__() + # Q1 architecture + self.l1 = nn.Linear(state_dim + action_dim, 256) + self.l2 = nn.Linear(256, 256) + self.l3 = nn.Linear(256, 1) + + # Q2 architecture + self.l4 = nn.Linear(state_dim + action_dim, 256) + self.l5 = nn.Linear(256, 256) + self.l6 = nn.Linear(256, 1) + + def forward(self, state: torch.Tensor, action: torch.Tensor): + sa = torch.cat([state, action], dim=1) + + q1 = F.relu(self.l1(sa)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + + q2 = F.relu(self.l4(sa)) + q2 = F.relu(self.l5(q2)) + q2 = self.l6(q2) + return q1, q2 + + def Q1(self, state: torch.Tensor, action: torch.Tensor): + sa = torch.cat([state, action], dim=1) + q1 = F.relu(self.l1(sa)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + return q1 + + +class TD3Algorithm(BaseAlgorithm): + """ + TD3 algorithm using MLPPolicy as actor and a twin-critic network. + """ + + def __init__( + self, + replay: Optional[Any], + config: TD3Config, + actor_config: Optional[MLPPolicyConfig] = None, + meta_policy: Optional[MetaPolicy] = None, + ensure_refine_fn= action_utils.tanh_action_to_space, + ctrl_space: str = "ee", + ctrl_type: str = "delta", + gripper_continuous: bool = False, + **kwargs, + ): + if actor_config is None: + actor_config = MLPPolicyConfig( + state_dim=config.state_dim, + action_dim=config.action_dim, + chunk_size=1, + ) + else: + actor_config.chunk_size = 1 + + self.device = torch.device(config.device) + self._env = None + self.config = config + self.ensure_refine_fn = ensure_refine_fn + self.ctrl_space = ctrl_space + self.ctrl_type = ctrl_type + self.gripper_continuous = gripper_continuous + + self.actor = MLPPolicy(actor_config).to(self.device) + self.actor_target = copy.deepcopy(self.actor) + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.actor_lr) + + self.critic = Critic(config.state_dim, config.action_dim).to(self.device) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.critic_lr) + + self.total_it = 0 + + if meta_policy is None: + meta_policy = MetaPolicy( + self.actor, + chunk_size=1, + action_normalizer=_IdentityActionNormalizer(), + state_normalizer=_IdentityStateNormalizer(), + ctrl_space=ctrl_space, + ctrl_type=ctrl_type, + ) + + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + + def set_env(self, env: Any) -> None: + """Attach an environment for action post-processing.""" + self._env = env + + def _actor_forward(self, state: torch.Tensor, model: Optional[MLPPolicy] = None) -> torch.Tensor: + model = self.actor if model is None else model + output = model(state) + if isinstance(output, dict): + action = output.get("action") + else: + action = output + if action is None: + raise ValueError("Actor output does not contain 'action'") + if action.dim() == 3: + action = action[:, 0, :] + return action + + def _to_tensor(self, value, dtype=torch.float32) -> torch.Tensor: + if torch.is_tensor(value): + return value.to(device=self.device, dtype=dtype) + return torch.as_tensor(value, device=self.device, dtype=dtype) + + def _extract_batch_action(self, batch: Dict[str, Any]) -> Any: + action = batch.get(self.config.action_key, None) + if action is None: + raise KeyError(f"Missing '{self.config.action_key}' in batch") + if hasattr(action, "action"): + return action.action + if isinstance(action, dict) and "action" in action: + return action["action"] + return action + + def _extract_batch_state(self, batch: Dict[str, Any], key: str) -> Any: + state = batch.get(key, None) + if state is None: + raise KeyError(f"Missing '{key}' in batch") + return state + + def select_action( + self, + obs: Any, + noise_scale: float = 0.0, + env: Optional[Any] = None, + **kwargs, + ) -> MetaAction: + if isinstance(obs, (list, np.ndarray)) and len(obs) > 0: + first = obs[0] if isinstance(obs, list) else obs.flat[0] + if hasattr(first, "__dataclass_fields__"): + obs = self._organize_obs(obs) + + if hasattr(obs, self.config.state_key): + state = getattr(obs, self.config.state_key) + elif hasattr(obs, "state"): + state = obs.state + elif isinstance(obs, dict) and self.config.state_key in obs: + state = obs[self.config.state_key] + else: + state = obs + state_t = self._to_tensor(state) + if state_t.dim() == 1: + state_t = state_t.unsqueeze(0) + + with torch.no_grad(): + action = self._actor_forward(state_t) + if noise_scale > 0: + action = action + noise_scale * torch.randn_like(action) + action = ensure_action(env or self._env, action, refine_fn=self.ensure_refine_fn) + + action_np = action.detach().cpu().numpy() + return MetaAction( + action=action_np, + ctrl_space=self.ctrl_space, + ctrl_type=self.ctrl_type, + gripper_continuous=self.gripper_continuous, + ) + + def update( + self, + batch: Optional[Dict[str, Any]] = None, + env: Optional[Any] = None, + **kwargs, + ) -> Dict[str, Any]: + if batch is None: + if self.replay is None: + raise ValueError("Replay buffer is not set") + batch_size = kwargs.get("batch_size", 256) + if hasattr(self.replay, "sample_for_training"): + batch = self.replay.sample_for_training(batch_size) + else: + raise ValueError("Replay buffer does not support sample_for_training") + batch = self.replay.sample(batch_size) + + state = self._to_tensor(self._extract_batch_state(batch, self.config.state_key)) + next_state = self._to_tensor(self._extract_batch_state(batch, self.config.next_state_key)) + action = self._to_tensor(self._extract_batch_action(batch)) + reward = self._to_tensor(batch.get("reward"), dtype=torch.float32).unsqueeze(-1) + done = self._to_tensor(batch.get("done"), dtype=torch.float32).unsqueeze(-1) + truncated_raw = batch.get("truncated") + if truncated_raw is None: + truncated_raw = np.zeros_like(batch.get("done")) + truncated = self._to_tensor(truncated_raw, dtype=torch.float32).unsqueeze(-1) + + terminal = done * (1.0 - truncated) + not_done = 1.0 - terminal + + self.total_it += 1 + + with torch.no_grad(): + noise = ( + torch.randn_like(action) * self.config.policy_noise + ).clamp(-self.config.noise_clip, self.config.noise_clip) + next_action = self._actor_forward(next_state, model=self.actor_target) + next_action = next_action + noise + next_action = ensure_action(env or self._env, next_action, refine_fn=self.ensure_refine_fn) + target_q1, target_q2 = self.critic_target(next_state, next_action) + target_q = torch.min(target_q1, target_q2) + target_q = reward + not_done * self.config.discount * target_q + + current_q1, current_q2 = self.critic(state, action) + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + actor_loss = None + if self.total_it % self.config.policy_freq == 0: + actor_action = self._actor_forward(state) + actor_action = ensure_action(env or self._env, actor_action, refine_fn=self.ensure_refine_fn) + actor_loss = -self.critic.Q1(state, actor_action).mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.config.tau) + polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.config.tau) + + return { + "critic_loss": critic_loss.item(), + "actor_loss": None if actor_loss is None else actor_loss.item(), + "update_step": self.total_it, + } + + def save(self, path: str, **kwargs) -> None: + payload = { + "actor": self.actor.state_dict(), + "critic": self.critic.state_dict(), + "actor_optimizer": self.actor_optimizer.state_dict(), + "critic_optimizer": self.critic_optimizer.state_dict(), + "config": self.config, + } + torch.save(payload, path) + + def load(self, path: str, **kwargs) -> None: + payload = torch.load(path, map_location=self.device) + self.actor.load_state_dict(payload["actor"]) + self.critic.load_state_dict(payload["critic"]) + self.actor_optimizer.load_state_dict(payload["actor_optimizer"]) + self.critic_optimizer.load_state_dict(payload["critic_optimizer"]) + self.actor_target = copy.deepcopy(self.actor) + self.critic_target = copy.deepcopy(self.critic) + diff --git a/rl/buffer/__init__.py b/rl/buffer/__init__.py new file mode 100644 index 00000000..d23f69b9 --- /dev/null +++ b/rl/buffer/__init__.py @@ -0,0 +1,20 @@ +""" +Replay Buffer Module + +This module provides experience replay buffers for RL algorithms. + +Classes: + BaseReplay: Base class for all replay buffers + MetaReplay: Replay buffer for MetaObs and MetaAction fields +""" + +from .base_replay import BaseReplay +from .meta_replay import MetaReplay +from .transition import RLTransition + +__all__ = [ + 'BaseReplay', + 'MetaReplay', + 'RLTransition', +] + diff --git a/rl/buffer/base_replay.py b/rl/buffer/base_replay.py new file mode 100644 index 00000000..8a1beb13 --- /dev/null +++ b/rl/buffer/base_replay.py @@ -0,0 +1,553 @@ +""" +Base Replay Buffer Class + +This module defines the base class for all replay buffers in the RL framework. + +Design Philosophy: +- Store raw Meta data: Store original MetaObs and MetaAction in buffer, keeping data in its raw form +- Complete information storage: Support storing all fields of MetaObs and MetaAction +- Extensibility: Support storing additional custom fields (value, log_prob, advantage, trajectory_id, etc.) +- Conversion during sampling: Convert to ILStudio data pipeline format (normalization, etc.) during sampling +- Compatibility: Maintain data integrity while being compatible with ILStudio's normalization pipeline +- Vectorized environment support: Each buffer corresponds to one environment type with multiple parallel envs + - Storage shape: (capacity, n_envs, ...) where capacity is number of time steps + - Each add() stores data from all parallel envs at one time step + - Sample supports selecting specific env indices +""" + +import torch +import numpy as np +import pickle +from typing import Dict, Any, Optional, Union, Callable, List, Set +from abc import ABC, abstractmethod + + +class BaseReplay(ABC): + """ + Base class for Replay Buffer. + + This class defines the interface for all replay buffers in the RL framework. + Replay buffers store experience data (transitions) for training RL algorithms. + + Supports vectorized environments where one buffer stores data from multiple parallel + environments of the same type. Storage shape is (capacity, n_envs, ...). + + Attributes: + capacity: Maximum number of time steps to store (not total transitions) + device: Device to store data on ('cpu' or 'cuda') + env_type: Environment type identifier for this buffer + n_envs: Number of parallel environments (default 1 for non-vectorized) + """ + + def __init__( + self, + capacity: int = 1000000, + device: Union[str, torch.device] = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, + **kwargs + ): + """ + Initialize the Replay Buffer. + + Args: + capacity: Buffer capacity (maximum number of time steps to store) + - Total transitions stored = capacity * n_envs + device: Data storage device ('cpu' or 'cuda', default 'cpu') + - 'cpu': Store data in CPU memory + - 'cuda' or 'cuda:0': Store data in GPU memory + env_type: Environment type identifier (e.g., 'sim', 'real', 'indoor', 'outdoor') + - Used to distinguish buffers for different environment types + - If None, defaults to 'default' + n_envs: Number of parallel environments for this buffer + - Storage shape will be (capacity, n_envs, ...) + - Default 1 for non-vectorized single environment + **kwargs: Other initialization parameters + """ + self.capacity = capacity + self.device = torch.device(device) if isinstance(device, str) else device + self.env_type = env_type if env_type is not None else 'default' + self.n_envs = n_envs + self._size = 0 + self._position = 0 + + @abstractmethod + def add(self, transition: Dict[str, Any]) -> None: + """ + Add a transition (one time step from all parallel envs) to the buffer. + + Stores raw Meta data (MetaObs, MetaAction) without any normalization. + Supports storing all fields of MetaObs and MetaAction, plus additional custom information. + + For vectorized environments (n_envs > 1): + - Each field should have shape (n_envs, ...) + - e.g., state: (n_envs, state_dim), action: (n_envs, action_dim) + - reward: (n_envs,), done: (n_envs,), truncated: (n_envs,) + + For single environment (n_envs = 1): + - Fields can be either (1, ...) or (...) shape + - Will be automatically expanded to (1, ...) if needed + + Note on done vs truncated (Gymnasium API): + - done (terminated): True if episode ended naturally (goal reached, failure, etc.) + - truncated: True if episode was cut off due to time limit or other external reasons + - For bootstrap value calculation: + - If done=True: V(next_state) = 0 (true terminal state) + - If truncated=True: V(next_state) should be bootstrapped (not a true terminal) + + Args: + transition: Dictionary containing the following fields: + - state: Current state, shape (n_envs, state_dim) or dict with MetaObs fields + - action: Action, shape (n_envs, action_dim) or dict with MetaAction fields + - reward: Reward, shape (n_envs,) + - next_state: Next state, shape (n_envs, state_dim) or dict with MetaObs fields + - done: Terminated flag, shape (n_envs,) - True if episode ended naturally + - truncated: Truncated flag, shape (n_envs,) - True if episode was cut off + - info: Optional, additional information (list of dicts, one per env) + - **other custom fields**: Can store any additional information + """ + raise NotImplementedError + + @abstractmethod + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """ + Sample a batch from the buffer (raw data). + + Sampling is done in two steps: + 1. Select which parallel environments to sample from (env_indices) + 2. Sample batch_size transitions from (time_idx, env_idx) combinations + + Args: + batch_size: Number of transitions to sample + env_indices: Optional, indices of parallel environments to sample from + - If None, sample from all environments uniformly + - e.g., [0, 2] means only sample from env 0 and env 2 + keys: Optional set of keys to sample. If None, uses default keys + (typically state, action, next_state, reward, done, truncated). + Subclasses may define their own default keys and available keys. + + Returns: + Dictionary containing raw Meta data (without normalization) + - All fields have shape (batch_size, ...) with env dimension flattened + - e.g., state: (batch_size, state_dim), action: (batch_size, action_dim) + """ + raise NotImplementedError + + def sample_for_training( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + data_processor: Optional[Callable] = None + ) -> Dict[str, Any]: + """ + Sample and convert to ILStudio training format. + + Args: + batch_size: Batch size + env_indices: Optional, indices of parallel environments to sample from + keys: Optional set of keys to sample + data_processor: Optional data processing function to align with ILStudio pipeline + - If None, return raw data + - If provided, should be a function: batch -> processed_batch + + Returns: + Processed batch data (conforming to ILStudio training format) + """ + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) + if data_processor is not None: + batch = data_processor(batch) + return batch + + def __len__(self) -> int: + """Return current buffer size (number of time steps stored).""" + return self._size + + @property + def total_transitions(self) -> int: + """ + Total number of transitions stored (size * n_envs). + + Returns: + Total transition count across all parallel environments + """ + return self._size * self.n_envs + + def get_env_type(self) -> str: + """Get the environment type identifier.""" + return self.env_type + + def get_n_envs(self) -> int: + """Get the number of parallel environments.""" + return self.n_envs + + @abstractmethod + def clear(self) -> None: + """Clear the buffer.""" + raise NotImplementedError + + @abstractmethod + def save(self, path: str, **kwargs) -> None: + """ + Save buffer data to file. + + Args: + path: Save path (can be file path or directory path) + **kwargs: Save options + - format: Save format (e.g., 'pkl', 'hdf5', 'npz', optional) + - compress: Whether to compress (optional) + """ + raise NotImplementedError + + @abstractmethod + def load(self, path: str, **kwargs) -> None: + """ + Load data from file into buffer. + + Args: + path: Load path (can be file path or directory path) + **kwargs: Load options + - format: Load format (optional, can auto-detect) + - append: Whether to append to existing buffer (default False, clear before load) + """ + raise NotImplementedError + + def is_full(self) -> bool: + """Check if buffer is full (time steps, not total transitions).""" + return self._size >= self.capacity + + def get_all( + self, + env_indices: Optional[Union[List[int], np.ndarray]] = None + ) -> Dict[str, Any]: + """ + Get all data in the buffer. + + Args: + env_indices: Optional, indices of parallel environments to get data from + - If None, get data from all environments + + Returns: + Dictionary containing all stored data + """ + return self.sample(self.total_transitions, env_indices=env_indices) if self._size > 0 else {} + + def __repr__(self) -> str: + return (f"{self.__class__.__name__}(capacity={self.capacity}, size={self._size}, " + f"n_envs={self.n_envs}, total_transitions={self.total_transitions}, " + f"env_type='{self.env_type}', device={self.device})") + + +if __name__ == '__main__': + """ + Test code for BaseReplay class. + + Since BaseReplay is abstract, we create a simple concrete implementation for testing. + Tests include both single-env and vectorized (multi-env) scenarios. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction + from dataclasses import asdict + + # Simple concrete implementation for testing (supports vectorized envs) + class SimpleReplay(BaseReplay): + """Simple in-memory replay buffer for testing with vectorized env support.""" + + def __init__( + self, + capacity: int = 1000, + device: str = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, + state_dim: int = 10, + action_dim: int = 7, + **kwargs + ): + super().__init__( + capacity=capacity, + device=device, + env_type=env_type, + n_envs=n_envs, + **kwargs + ) + self.state_dim = state_dim + self.action_dim = action_dim + + # Pre-allocate storage: (capacity, n_envs, dim) + self._state = np.zeros((capacity, n_envs, state_dim), dtype=np.float32) + self._action = np.zeros((capacity, n_envs, action_dim), dtype=np.float32) + self._reward = np.zeros((capacity, n_envs), dtype=np.float32) + self._next_state = np.zeros((capacity, n_envs, state_dim), dtype=np.float32) + self._done = np.zeros((capacity, n_envs), dtype=np.bool_) + self._truncated = np.zeros((capacity, n_envs), dtype=np.bool_) + + def add(self, transition: Dict[str, Any]) -> None: + """Add transition from all parallel envs at one time step.""" + idx = self._position + + # Store data (expecting shape (n_envs, dim)) + self._state[idx] = transition['state'] + self._action[idx] = transition['action'] + self._reward[idx] = transition['reward'] + self._next_state[idx] = transition['next_state'] + self._done[idx] = transition['done'] + self._truncated[idx] = transition.get('truncated', np.zeros(self.n_envs, dtype=bool)) + + self._position = (self._position + 1) % self.capacity + self._size = min(self._size + 1, self.capacity) + + # Default sample keys + DEFAULT_SAMPLE_KEYS = {'state', 'action', 'next_state', 'reward', 'done', 'truncated'} + + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Sample batch_size transitions from buffer. + + Args: + batch_size: Number of transitions to sample + env_indices: Optional, indices of parallel environments to sample from + keys: Optional set of keys to sample. If None, uses DEFAULT_SAMPLE_KEYS + """ + if self._size == 0: + return {} + + if keys is None: + keys = self.DEFAULT_SAMPLE_KEYS + + # Determine which envs to sample from + if env_indices is None: + env_indices = np.arange(self.n_envs) + env_indices = np.asarray(env_indices) + + # Sample (time_idx, env_idx) pairs + time_indices = np.random.randint(0, self._size, size=batch_size) + env_sample_indices = np.random.choice(env_indices, size=batch_size) + + batch = {} + if 'state' in keys: + batch['state'] = self._state[time_indices, env_sample_indices] # (batch_size, state_dim) + if 'action' in keys: + batch['action'] = self._action[time_indices, env_sample_indices] # (batch_size, action_dim) + if 'reward' in keys: + batch['reward'] = self._reward[time_indices, env_sample_indices] # (batch_size,) + if 'next_state' in keys: + batch['next_state'] = self._next_state[time_indices, env_sample_indices] + if 'done' in keys: + batch['done'] = self._done[time_indices, env_sample_indices] + if 'truncated' in keys: + batch['truncated'] = self._truncated[time_indices, env_sample_indices] + + # Always include indices for debugging + batch['time_indices'] = time_indices + batch['env_indices'] = env_sample_indices + + return batch + + def clear(self) -> None: + self._state.fill(0) + self._action.fill(0) + self._reward.fill(0) + self._next_state.fill(0) + self._done.fill(False) + self._truncated.fill(False) + self._size = 0 + self._position = 0 + + def save(self, path: str, **kwargs) -> None: + data = { + 'state': self._state[:self._size], + 'action': self._action[:self._size], + 'reward': self._reward[:self._size], + 'next_state': self._next_state[:self._size], + 'done': self._done[:self._size], + 'truncated': self._truncated[:self._size], + 'size': self._size, + 'env_type': self.env_type, + 'n_envs': self.n_envs, + } + with open(path, 'wb') as f: + pickle.dump(data, f) + print(f"Saved {self._size} time steps ({self.total_transitions} transitions) to {path}") + + def load(self, path: str, **kwargs) -> None: + append = kwargs.get('append', False) + if not append: + self.clear() + with open(path, 'rb') as f: + data = pickle.load(f) + + size = data['size'] + for i in range(size): + self.add({ + 'state': data['state'][i], + 'action': data['action'][i], + 'reward': data['reward'][i], + 'next_state': data['next_state'][i], + 'done': data['done'][i], + 'truncated': data.get('truncated', np.zeros(self.n_envs, dtype=bool))[i] if 'truncated' in data else np.zeros(self.n_envs, dtype=bool), + }) + print(f"Loaded {self._size} time steps ({self.total_transitions} transitions) from {path}") + + # Test the implementation + print("=" * 60) + print("Testing BaseReplay (SimpleReplay with Vectorized Env Support)") + print("=" * 60) + + # Test 1: Single environment (n_envs=1) + print("\n" + "-" * 40) + print("Test 1: Single Environment (n_envs=1)") + print("-" * 40) + + buffer = SimpleReplay(capacity=100, device='cpu', env_type='single', n_envs=1) + print(f"\n1. Created buffer: {buffer}") + print(f" env_type: {buffer.get_env_type()}") + print(f" n_envs: {buffer.get_n_envs()}") + + # Add transitions (n_envs=1, so shapes are (1, dim)) + print("\n2. Adding transitions...") + for i in range(10): + transition = { + 'state': np.random.randn(1, 10).astype(np.float32), + 'action': np.random.randn(1, 7).astype(np.float32), + 'reward': np.array([np.random.randn()]), + 'next_state': np.random.randn(1, 10).astype(np.float32), + 'done': np.array([i == 9]), # True terminal state + 'truncated': np.array([False]), # Not truncated + } + buffer.add(transition) + + print(f" Buffer size (time steps): {len(buffer)}") + print(f" Total transitions: {buffer.total_transitions}") + + # Sample + print("\n3. Sampling from buffer...") + batch = buffer.sample(batch_size=5) + print(f" Batch keys: {batch.keys()}") + print(f" State shape: {batch['state'].shape}") + print(f" Reward shape: {batch['reward'].shape}") + print(f" Done shape: {batch['done'].shape}") + print(f" Truncated shape: {batch['truncated'].shape}") + + # Test 2: Vectorized environment (n_envs=4) + print("\n" + "-" * 40) + print("Test 2: Vectorized Environment (n_envs=4)") + print("-" * 40) + + vec_buffer = SimpleReplay( + capacity=100, + device='cpu', + env_type='sim', + n_envs=4, + state_dim=10, + action_dim=7 + ) + print(f"\n1. Created buffer: {vec_buffer}") + + # Add transitions from all 4 envs at once + print("\n2. Adding transitions from 4 parallel envs...") + for i in range(20): + transition = { + 'state': np.random.randn(4, 10).astype(np.float32), # (n_envs, state_dim) + 'action': np.random.randn(4, 7).astype(np.float32), # (n_envs, action_dim) + 'reward': np.random.randn(4).astype(np.float32), # (n_envs,) + 'next_state': np.random.randn(4, 10).astype(np.float32), + 'done': np.random.choice([True, False], size=4), # terminated + 'truncated': np.random.choice([True, False], size=4, p=[0.1, 0.9]), # truncated (10% chance) + } + vec_buffer.add(transition) + + print(f" Buffer size (time steps): {len(vec_buffer)}") + print(f" Total transitions: {vec_buffer.total_transitions}") + + # Sample from all envs + print("\n3. Sampling from all envs...") + batch = vec_buffer.sample(batch_size=16) + print(f" State shape: {batch['state'].shape}") # (16, 10) + print(f" Sampled from envs: {np.unique(batch['env_indices'])}") + + # Sample from specific envs only + print("\n4. Sampling from specific envs [0, 2] only...") + batch = vec_buffer.sample(batch_size=16, env_indices=[0, 2]) + print(f" State shape: {batch['state'].shape}") + print(f" Sampled from envs: {np.unique(batch['env_indices'])}") + assert all(e in [0, 2] for e in batch['env_indices']), "Should only sample from envs 0 and 2" + + # Test sample_for_training with env_indices + print("\n5. Testing sample_for_training with env_indices and processor...") + def simple_processor(batch): + batch['processed'] = True + return batch + + processed_batch = vec_buffer.sample_for_training( + batch_size=8, + env_indices=[1, 3], + data_processor=simple_processor + ) + print(f" Processed: {processed_batch.get('processed', False)}") + print(f" Sampled from envs: {np.unique(processed_batch['env_indices'])}") + + # Test save and load + print("\n6. Testing save and load...") + import tempfile + import os + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, 'vec_buffer.pkl') + vec_buffer.save(save_path) + + new_buffer = SimpleReplay(capacity=100, n_envs=4, state_dim=10, action_dim=7) + new_buffer.load(save_path) + print(f" Loaded buffer size: {len(new_buffer)}") + print(f" Loaded total transitions: {new_buffer.total_transitions}") + + # Test 3: Multiple buffers for different env types + print("\n" + "-" * 40) + print("Test 3: Multiple Buffers for Different Env Types") + print("-" * 40) + + buffers = { + 'indoor': SimpleReplay(capacity=50, env_type='indoor', n_envs=2), + 'outdoor': SimpleReplay(capacity=50, env_type='outdoor', n_envs=4), + 'sim': SimpleReplay(capacity=100, env_type='sim', n_envs=8), + } + + print("\nCreated buffers for different env types:") + for name, buf in buffers.items(): + print(f" {name}: env_type='{buf.env_type}', n_envs={buf.n_envs}, capacity={buf.capacity}") + + # Add some data to each + for name, buf in buffers.items(): + for _ in range(10): + buf.add({ + 'state': np.random.randn(buf.n_envs, 10).astype(np.float32), + 'action': np.random.randn(buf.n_envs, 7).astype(np.float32), + 'reward': np.random.randn(buf.n_envs).astype(np.float32), + 'next_state': np.random.randn(buf.n_envs, 10).astype(np.float32), + 'done': np.zeros(buf.n_envs, dtype=bool), + 'truncated': np.zeros(buf.n_envs, dtype=bool), + }) + + print("\nBuffer statistics:") + for name, buf in buffers.items(): + print(f" {name}: {len(buf)} time steps, {buf.total_transitions} total transitions") + + # Test clear + print("\n7. Testing clear...") + vec_buffer.clear() + print(f" Buffer size after clear: {len(vec_buffer)}") + print(f" Total transitions after clear: {vec_buffer.total_transitions}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/buffer/meta_replay.py b/rl/buffer/meta_replay.py new file mode 100644 index 00000000..7994b753 --- /dev/null +++ b/rl/buffer/meta_replay.py @@ -0,0 +1,726 @@ +""" +Meta Replay Buffer (rewritten). + +Stores raw MetaObs/MetaAction/MetaObs(next) in env-first layout: + (n_envs, capacity, ...) + +Sampling for training returns normalized and processor-aligned data. +""" + +from __future__ import annotations + +import pickle +from dataclasses import asdict +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import numpy as np +import torch + +from benchmark.base import MetaObs, MetaAction, dict2meta +from rl.utils import RunningMeanStd +from .base_replay import BaseReplay +from .transition import RLTransition + + +CTRL_SPACE_MAP = {'ee': 0, 'joint': 1} +CTRL_SPACE_INV_MAP = {v: k for k, v in CTRL_SPACE_MAP.items()} + +CTRL_TYPE_MAP = {'delta': 0, 'absolute': 1, 'relative': 2} +CTRL_TYPE_INV_MAP = {v: k for k, v in CTRL_TYPE_MAP.items()} + +GRIPPER_CONTINUOUS_MAP = {False: 0, True: 1} +GRIPPER_CONTINUOUS_INV_MAP = {v: k for k, v in GRIPPER_CONTINUOUS_MAP.items()} + + +class _MetaObsStorage: + def __init__( + self, + n_envs: int, + capacity: int, + state_dim: Optional[int], + state_ee_dim: Optional[int], + state_joint_dim: Optional[int], + state_obj_dim: Optional[int], + image_shape: Optional[tuple], + depth_shape: Optional[tuple], + pc_shape: Optional[tuple], + store_images: bool, + store_depth: bool, + store_pc: bool, + store_lang: bool, + ) -> None: + self.n_envs = n_envs + self.capacity = capacity + self.store_images = store_images + self.store_depth = store_depth + self.store_pc = store_pc + self.store_lang = store_lang + + self.state = np.zeros((n_envs, capacity, state_dim), dtype=np.float32) if state_dim else None + self.state_ee = np.zeros((n_envs, capacity, state_ee_dim), dtype=np.float32) if state_ee_dim else None + self.state_joint = np.zeros((n_envs, capacity, state_joint_dim), dtype=np.float32) if state_joint_dim else None + self.state_obj = np.zeros((n_envs, capacity, state_obj_dim), dtype=np.float32) if state_obj_dim else None + self.image = np.zeros((n_envs, capacity, *image_shape), dtype=np.uint8) if (store_images and image_shape) else None + self.depth = np.zeros((n_envs, capacity, *depth_shape), dtype=np.float32) if (store_depth and depth_shape) else None + self.pc = np.zeros((n_envs, capacity, *pc_shape), dtype=np.float32) if (store_pc and pc_shape) else None + self.raw_lang: Optional[List[List[str]]] = [['' for _ in range(capacity)] for _ in range(n_envs)] if store_lang else None + self.timestep = np.zeros((n_envs, capacity), dtype=np.int32) + + def set(self, idx: int, obs: Union[MetaObs, Dict[str, Any], List[Any]]) -> None: + obs_dict = self._coerce_obs(obs) + if obs_dict.get('state') is not None and self.state is not None: + self.state[:, idx] = self._ensure_env_batch(obs_dict['state']) + if obs_dict.get('state_ee') is not None and self.state_ee is not None: + self.state_ee[:, idx] = self._ensure_env_batch(obs_dict['state_ee']) + if obs_dict.get('state_joint') is not None and self.state_joint is not None: + self.state_joint[:, idx] = self._ensure_env_batch(obs_dict['state_joint']) + if obs_dict.get('state_obj') is not None and self.state_obj is not None: + self.state_obj[:, idx] = self._ensure_env_batch(obs_dict['state_obj']) + if obs_dict.get('image') is not None and self.image is not None: + self.image[:, idx] = self._ensure_env_batch(obs_dict['image']) + if obs_dict.get('depth') is not None and self.depth is not None: + self.depth[:, idx] = self._ensure_env_batch(obs_dict['depth']) + if obs_dict.get('pc') is not None and self.pc is not None: + self.pc[:, idx] = self._ensure_env_batch(obs_dict['pc']) + + if self.store_lang and self.raw_lang is not None and obs_dict.get('raw_lang') is not None: + raw_lang = obs_dict['raw_lang'] + if isinstance(raw_lang, str): + raw_lang = [raw_lang] * self.n_envs + for env_i, value in enumerate(raw_lang): + self.raw_lang[env_i][idx] = value + + if obs_dict.get('timestep') is not None: + timestep = np.asarray(obs_dict['timestep']) + if timestep.ndim == 0: + timestep = np.array([timestep]) + self.timestep[:, idx] = timestep + + def get(self, time_indices: np.ndarray, env_indices: np.ndarray, keys: Set[str]) -> Dict[str, Any]: + batch: Dict[str, Any] = {} + if 'state' in keys and self.state is not None: + batch['state'] = self.state[env_indices, time_indices].copy() + if 'state_ee' in keys and self.state_ee is not None: + batch['state_ee'] = self.state_ee[env_indices, time_indices].copy() + if 'state_joint' in keys and self.state_joint is not None: + batch['state_joint'] = self.state_joint[env_indices, time_indices].copy() + if 'state_obj' in keys and self.state_obj is not None: + batch['state_obj'] = self.state_obj[env_indices, time_indices].copy() + if 'image' in keys and self.image is not None: + batch['image'] = self.image[env_indices, time_indices].copy() + if 'depth' in keys and self.depth is not None: + batch['depth'] = self.depth[env_indices, time_indices].copy() + if 'pc' in keys and self.pc is not None: + batch['pc'] = self.pc[env_indices, time_indices].copy() + if 'raw_lang' in keys and self.raw_lang is not None: + batch['raw_lang'] = [self.raw_lang[e][t] for t, e in zip(time_indices, env_indices)] + if 'timestep' in keys: + batch['timestep'] = self.timestep[env_indices, time_indices].copy() + return batch + + def _coerce_obs(self, obs: Union[MetaObs, Dict[str, Any], List[Any]]) -> Dict[str, Any]: + if isinstance(obs, list): + obs_dicts = [self._coerce_obs(o) for o in obs] + return self._stack_dicts(obs_dicts) + if hasattr(obs, '__dataclass_fields__'): + return asdict(obs) + if isinstance(obs, dict): + return obs + if hasattr(obs, '__dict__'): + return vars(obs) + return {} + + def _stack_dicts(self, dict_list: List[Dict[str, Any]]) -> Dict[str, Any]: + if not dict_list: + return {} + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values + else: + result[key] = values + return result + + def _ensure_env_batch(self, data: Any) -> np.ndarray: + arr = np.asarray(data) + if self.n_envs == 1 and arr.ndim >= 1 and arr.shape[0] != 1: + arr = arr[np.newaxis, ...] + return arr + + +class _MetaActionStorage: + def __init__(self, n_envs: int, capacity: int, action_dim: Optional[int]) -> None: + self.n_envs = n_envs + self.capacity = capacity + self.action = np.zeros((n_envs, capacity, action_dim), dtype=np.float32) if action_dim else None + self.ctrl_space = np.zeros((n_envs, capacity), dtype=np.int8) + self.ctrl_type = np.zeros((n_envs, capacity), dtype=np.int8) + self.gripper_continuous = np.zeros((n_envs, capacity), dtype=np.int8) + + def set(self, idx: int, action: Union[MetaAction, Dict[str, Any], List[Any]]) -> None: + action_dict = self._coerce_action(action) + if action_dict.get('action') is not None and self.action is not None: + self.action[:, idx] = self._ensure_env_batch(action_dict['action']) + + ctrl_space = action_dict.get('ctrl_space', 'ee') + if isinstance(ctrl_space, str): + ctrl_space = [ctrl_space] * self.n_envs + self.ctrl_space[:, idx] = np.array([CTRL_SPACE_MAP.get(cs, 0) for cs in ctrl_space], dtype=np.int8) + + ctrl_type = action_dict.get('ctrl_type', 'delta') + if isinstance(ctrl_type, str): + ctrl_type = [ctrl_type] * self.n_envs + self.ctrl_type[:, idx] = np.array([CTRL_TYPE_MAP.get(ct, 0) for ct in ctrl_type], dtype=np.int8) + + gripper_continuous = action_dict.get('gripper_continuous', False) + if isinstance(gripper_continuous, bool): + gripper_continuous = [gripper_continuous] * self.n_envs + self.gripper_continuous[:, idx] = np.array([GRIPPER_CONTINUOUS_MAP.get(gc, 0) for gc in gripper_continuous], dtype=np.int8) + + def get(self, time_indices: np.ndarray, env_indices: np.ndarray, keys: Set[str]) -> Dict[str, Any]: + batch: Dict[str, Any] = {} + if 'action' in keys and self.action is not None: + batch['action'] = self.action[env_indices, time_indices].copy() + if 'ctrl_space' in keys: + batch['ctrl_space'] = self.ctrl_space[env_indices, time_indices].copy() + batch['ctrl_space_str'] = [CTRL_SPACE_INV_MAP[v] for v in batch['ctrl_space']] + if 'ctrl_type' in keys: + batch['ctrl_type'] = self.ctrl_type[env_indices, time_indices].copy() + batch['ctrl_type_str'] = [CTRL_TYPE_INV_MAP[v] for v in batch['ctrl_type']] + if 'gripper_continuous' in keys: + batch['gripper_continuous'] = self.gripper_continuous[env_indices, time_indices].copy() + batch['gripper_continuous_bool'] = [GRIPPER_CONTINUOUS_INV_MAP[v] for v in batch['gripper_continuous']] + return batch + + def _coerce_action(self, action: Union[MetaAction, Dict[str, Any], List[Any]]) -> Dict[str, Any]: + if isinstance(action, list): + action_dicts = [self._coerce_action(a) for a in action] + return self._stack_dicts(action_dicts) + if hasattr(action, '__dataclass_fields__'): + return asdict(action) + if isinstance(action, dict): + return action + if hasattr(action, '__dict__'): + return vars(action) + return {} + + def _stack_dicts(self, dict_list: List[Dict[str, Any]]) -> Dict[str, Any]: + if not dict_list: + return {} + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values + else: + result[key] = values + return result + + def _ensure_env_batch(self, data: Any) -> np.ndarray: + arr = np.asarray(data) + if self.n_envs == 1 and arr.ndim >= 1 and arr.shape[0] != 1: + arr = arr[np.newaxis, ...] + return arr + + +class MetaReplay(BaseReplay): + """ + Replay buffer for MetaObs/MetaAction with env-first storage. + """ + + DEFAULT_SAMPLE_KEYS = {'state', 'action', 'next_state', 'reward', 'done', 'truncated'} + ALL_SAMPLE_KEYS = { + 'state', 'state_ee', 'state_joint', 'state_obj', 'image', 'depth', 'pc', 'raw_lang', 'timestep', + 'next_state', 'next_state_ee', 'next_state_joint', 'next_state_obj', 'next_image', 'next_depth', + 'next_pc', 'next_raw_lang', 'next_timestep', + 'action', 'ctrl_space', 'ctrl_type', 'gripper_continuous', + 'reward', 'done', 'truncated', 'trajectory_id', + } + + def __init__( + self, + capacity: int = 100000, + device: Union[str, torch.device] = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, + state_dim: Optional[int] = None, + state_ee_dim: Optional[int] = None, + state_joint_dim: Optional[int] = None, + state_obj_dim: Optional[int] = None, + image_shape: Optional[tuple] = None, + depth_shape: Optional[tuple] = None, + pc_shape: Optional[tuple] = None, + action_dim: Optional[int] = None, + store_images: bool = True, + store_depth: bool = False, + store_pc: bool = False, + store_lang: bool = False, + data_processor: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + data_collator: Optional[Callable[[List[Dict[str, Any]]], Dict[str, Any]]] = None, + state_normalizer: Optional[RunningMeanStd] = None, + action_normalizer: Optional[RunningMeanStd] = None, + update_normalizers: bool = True, + **kwargs, + ) -> None: + super().__init__(capacity=capacity, device=device, env_type=env_type, n_envs=n_envs, **kwargs) + + self.state_dim = state_dim + self.state_ee_dim = state_ee_dim + self.state_joint_dim = state_joint_dim + self.state_obj_dim = state_obj_dim + self.image_shape = image_shape + self.depth_shape = depth_shape + self.pc_shape = pc_shape + self.action_dim = action_dim + self.store_images = store_images + self.store_depth = store_depth + self.store_pc = store_pc + self.store_lang = store_lang + + self.data_processor = data_processor + self.data_collator = data_collator + self.update_normalizers = update_normalizers + self.state_normalizer = state_normalizer or (RunningMeanStd((state_dim,)) if state_dim else None) + self.action_normalizer = action_normalizer or (RunningMeanStd((action_dim,)) if action_dim else None) + + self._obs = _MetaObsStorage( + n_envs=n_envs, + capacity=capacity, + state_dim=state_dim, + state_ee_dim=state_ee_dim, + state_joint_dim=state_joint_dim, + state_obj_dim=state_obj_dim, + image_shape=image_shape, + depth_shape=depth_shape, + pc_shape=pc_shape, + store_images=store_images, + store_depth=store_depth, + store_pc=store_pc, + store_lang=store_lang, + ) + self._next_obs = _MetaObsStorage( + n_envs=n_envs, + capacity=capacity, + state_dim=state_dim, + state_ee_dim=state_ee_dim, + state_joint_dim=state_joint_dim, + state_obj_dim=state_obj_dim, + image_shape=image_shape, + depth_shape=depth_shape, + pc_shape=pc_shape, + store_images=store_images, + store_depth=store_depth, + store_pc=store_pc, + store_lang=store_lang, + ) + self._action = _MetaActionStorage(n_envs=n_envs, capacity=capacity, action_dim=action_dim) + self._reward = np.zeros((n_envs, capacity), dtype=np.float32) + self._done = np.zeros((n_envs, capacity), dtype=np.bool_) + self._truncated = np.zeros((n_envs, capacity), dtype=np.bool_) + self._trajectory_id = np.zeros((n_envs, capacity), dtype=np.int32) + + def add(self, transition: Union[RLTransition, Dict[str, Any]]) -> None: + if not isinstance(transition, RLTransition): + transition = self._coerce_transition(transition) + + idx = self._position + + + self._obs.set(idx, transition.obs) + self._action.set(idx, transition.action) + self._next_obs.set(idx, transition.next_obs) + + reward = np.asarray(transition.reward) + done = np.asarray(transition.done) + truncated = np.asarray(transition.truncated) if transition.truncated is not None else np.zeros(self.n_envs, dtype=bool) + if self.n_envs == 1: + reward = np.atleast_1d(reward) + done = np.atleast_1d(done) + truncated = np.atleast_1d(truncated) + + self._reward[:, idx] = reward + self._done[:, idx] = done + self._truncated[:, idx] = truncated + + if isinstance(transition.info, dict) and 'trajectory_id' in transition.info: + traj_id = transition.info['trajectory_id'] + if self.n_envs == 1: + traj_id = np.atleast_1d(traj_id) + self._trajectory_id[:, idx] = traj_id + + self._position = (self._position + 1) % self.capacity + self._size = min(self._size + 1, self.capacity) + + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + if self._size == 0: + return {} + + if keys is None: + keys = self.DEFAULT_SAMPLE_KEYS + + if env_indices is None: + env_indices = np.arange(self.n_envs) + env_indices = np.asarray(env_indices) + + time_indices = np.random.randint(0, self._size, size=batch_size) + env_sample_indices = np.random.choice(env_indices, size=batch_size) + + batch = {} + obs_keys = {k for k in keys if not k.startswith('next_')} + next_obs_keys = {k[5:] for k in keys if k.startswith('next_')} + + batch.update(self._obs.get(time_indices, env_sample_indices, obs_keys)) + next_obs_batch = self._next_obs.get(time_indices, env_sample_indices, next_obs_keys) + for key, value in next_obs_batch.items(): + batch[f"next_{key}"] = value + batch.update(self._action.get(time_indices, env_sample_indices, keys)) + + if 'reward' in keys: + batch['reward'] = self._reward[env_sample_indices, time_indices].copy() + if 'done' in keys: + batch['done'] = self._done[env_sample_indices, time_indices].copy() + if 'truncated' in keys: + batch['truncated'] = self._truncated[env_sample_indices, time_indices].copy() + if 'trajectory_id' in keys: + batch['trajectory_id'] = self._trajectory_id[env_sample_indices, time_indices].copy() + + batch['time_indices'] = time_indices.copy() + batch['env_indices'] = env_sample_indices.copy() + return batch + + def sample_as_tensor( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) + if not batch: + return {} + tensor_batch: Dict[str, Any] = {} + for key, value in batch.items(): + if isinstance(value, np.ndarray): + tensor_batch[key] = torch.from_numpy(value).to(self.device) + else: + tensor_batch[key] = value + return tensor_batch + + def sample_for_training( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + data_processor: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + data_collator: Optional[Callable[[List[Dict[str, Any]]], Dict[str, Any]]] = None, + apply_normalization: bool = True, + ) -> Dict[str, Any]: + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) + if not batch: + return {} + + if apply_normalization: + batch = self._apply_normalization(batch) + + processor = data_processor or self.data_processor + collator = data_collator or self.data_collator + if processor is not None or collator is not None: + obs_samples = self._build_samples(batch, prefix='') + next_obs_samples = self._build_samples(batch, prefix='next_') + + if processor is not None: + obs_samples = [processor(sample) for sample in obs_samples] + next_obs_samples = [processor(sample) for sample in next_obs_samples] + + if collator is not None: + batch['processed_obs'] = collator(obs_samples) + batch['processed_next_obs'] = collator(next_obs_samples) + else: + batch['processed_obs'] = self._default_collate(obs_samples) + batch['processed_next_obs'] = self._default_collate(next_obs_samples) + + return batch + + def clear(self) -> None: + self._size = 0 + self._position = 0 + self.__init__( + capacity=self.capacity, + device=self.device, + env_type=self.env_type, + n_envs=self.n_envs, + state_dim=self.state_dim, + state_ee_dim=self.state_ee_dim, + state_joint_dim=self.state_joint_dim, + state_obj_dim=self.state_obj_dim, + image_shape=self.image_shape, + depth_shape=self.depth_shape, + pc_shape=self.pc_shape, + action_dim=self.action_dim, + store_images=self.store_images, + store_depth=self.store_depth, + store_pc=self.store_pc, + store_lang=self.store_lang, + data_processor=self.data_processor, + data_collator=self.data_collator, + state_normalizer=self.state_normalizer, + action_normalizer=self.action_normalizer, + update_normalizers=self.update_normalizers, + ) + + def save(self, path: str, **kwargs) -> None: + fmt = kwargs.get('format', 'pkl') + if path.endswith('.npz'): + fmt = 'npz' + + data = { + 'size': self._size, + 'position': self._position, + 'capacity': self.capacity, + 'n_envs': self.n_envs, + 'env_type': self.env_type, + 'state_dim': self.state_dim, + 'state_ee_dim': self.state_ee_dim, + 'state_joint_dim': self.state_joint_dim, + 'state_obj_dim': self.state_obj_dim, + 'image_shape': self.image_shape, + 'depth_shape': self.depth_shape, + 'pc_shape': self.pc_shape, + 'action_dim': self.action_dim, + 'store_images': self.store_images, + 'store_depth': self.store_depth, + 'store_pc': self.store_pc, + 'store_lang': self.store_lang, + 'storage_layout': 'env_first', + } + + size = self._size + data.update({ + '_state': None if self._obs.state is None else self._obs.state[:, :size], + '_state_ee': None if self._obs.state_ee is None else self._obs.state_ee[:, :size], + '_state_joint': None if self._obs.state_joint is None else self._obs.state_joint[:, :size], + '_state_obj': None if self._obs.state_obj is None else self._obs.state_obj[:, :size], + '_image': None if self._obs.image is None else self._obs.image[:, :size], + '_depth': None if self._obs.depth is None else self._obs.depth[:, :size], + '_pc': None if self._obs.pc is None else self._obs.pc[:, :size], + '_timestep': self._obs.timestep[:, :size], + '_next_state': None if self._next_obs.state is None else self._next_obs.state[:, :size], + '_next_state_ee': None if self._next_obs.state_ee is None else self._next_obs.state_ee[:, :size], + '_next_state_joint': None if self._next_obs.state_joint is None else self._next_obs.state_joint[:, :size], + '_next_state_obj': None if self._next_obs.state_obj is None else self._next_obs.state_obj[:, :size], + '_next_image': None if self._next_obs.image is None else self._next_obs.image[:, :size], + '_next_depth': None if self._next_obs.depth is None else self._next_obs.depth[:, :size], + '_next_pc': None if self._next_obs.pc is None else self._next_obs.pc[:, :size], + '_next_timestep': self._next_obs.timestep[:, :size], + '_action': None if self._action.action is None else self._action.action[:, :size], + '_ctrl_space': self._action.ctrl_space[:, :size], + '_ctrl_type': self._action.ctrl_type[:, :size], + '_gripper_continuous': self._action.gripper_continuous[:, :size], + '_reward': self._reward[:, :size], + '_done': self._done[:, :size], + '_truncated': self._truncated[:, :size], + '_trajectory_id': self._trajectory_id[:, :size], + }) + if self.store_lang and self._obs.raw_lang is not None: + data['_raw_lang'] = [row[:size] for row in self._obs.raw_lang] + if self.store_lang and self._next_obs.raw_lang is not None: + data['_next_raw_lang'] = [row[:size] for row in self._next_obs.raw_lang] + + if fmt == 'npz': + np_data = {k: v for k, v in data.items() if isinstance(v, np.ndarray)} + if '_raw_lang' in data: + np_data['_raw_lang'] = np.array(data['_raw_lang'], dtype=object) + if '_next_raw_lang' in data: + np_data['_next_raw_lang'] = np.array(data['_next_raw_lang'], dtype=object) + np_data['_metadata'] = np.array([data['size'], data['position'], data['capacity'], data['n_envs']]) + np_data['_env_type'] = np.array([data['env_type']], dtype=object) + np.savez_compressed(path, **np_data) + else: + with open(path, 'wb') as f: + pickle.dump(data, f) + + def load(self, path: str, **kwargs) -> None: + append = kwargs.get('append', False) + if not append: + self.clear() + + if path.endswith('.npz'): + data = np.load(path, allow_pickle=True) + size = int(data['_metadata'][0]) + for i in range(size): + transition = self._extract_transition_from_npz(data, i) + self.add(transition) + return + + with open(path, 'rb') as f: + data = pickle.load(f) + size = data['size'] + for i in range(size): + transition = self._extract_transition_from_data(data, i) + self.add(transition) + + def _apply_normalization(self, batch: Dict[str, Any]) -> Dict[str, Any]: + if self.state_normalizer is not None: + if 'state' in batch: + batch['state'] = self.state_normalizer.normalize(batch['state']) + if 'next_state' in batch: + batch['next_state'] = self.state_normalizer.normalize(batch['next_state']) + if self.action_normalizer is not None and 'action' in batch: + batch['action'] = self.action_normalizer.normalize(batch['action']) + return batch + + def _build_samples(self, batch: Dict[str, Any], prefix: str = '') -> List[Dict[str, Any]]: + samples: List[Dict[str, Any]] = [] + state_key = f'{prefix}state' + image_key = f'{prefix}image' + raw_lang_key = f'{prefix}raw_lang' + timestep_key = f'{prefix}timestep' + + state_arr = batch.get(state_key, None) + batch_size = state_arr.shape[0] if isinstance(state_arr, np.ndarray) and state_arr.ndim > 1 else len(batch.get(raw_lang_key, [])) + if batch_size == 0: + return samples + + for i in range(batch_size): + sample = {} + if image_key in batch: + sample['image'] = batch[image_key][i] + if state_key in batch: + sample['state'] = batch[state_key][i] if isinstance(batch[state_key], np.ndarray) else batch[state_key] + if raw_lang_key in batch: + sample['raw_lang'] = batch[raw_lang_key][i] + if timestep_key in batch: + sample['timestamp'] = batch[timestep_key][i] + samples.append(sample) + return samples + + def _default_collate(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: + if not samples: + return {} + batch: Dict[str, Any] = {} + keys = samples[0].keys() + for key in keys: + values = [s[key] for s in samples] + if values[0] is None: + batch[key] = None + elif isinstance(values[0], np.ndarray): + batch[key] = torch.from_numpy(np.stack(values)) + elif isinstance(values[0], torch.Tensor): + batch[key] = torch.stack(values) + elif isinstance(values[0], str): + batch[key] = values + elif isinstance(values[0], (int, float)): + batch[key] = torch.tensor(values) + else: + batch[key] = values + return batch + + def _maybe_update_normalizers(self, obs: MetaObs, action: MetaAction) -> None: + if not self.update_normalizers: + return + if self.state_normalizer is not None and obs is not None and getattr(obs, 'state', None) is not None: + state_arr = np.asarray(getattr(obs, 'state')) + if self.n_envs == 1 and state_arr.ndim == 1: + state_arr = state_arr[np.newaxis, :] + if state_arr.ndim >= 2: + self.state_normalizer.update(state_arr) + if self.action_normalizer is not None and action is not None and getattr(action, 'action', None) is not None: + action_arr = np.asarray(getattr(action, 'action')) + if self.n_envs == 1 and action_arr.ndim == 1: + action_arr = action_arr[np.newaxis, :] + if action_arr.ndim >= 2: + self.action_normalizer.update(action_arr) + + def _coerce_transition(self, transition: Dict[str, Any]) -> RLTransition: + obs = transition.get('state') or transition.get('obs') + action = transition.get('action') + next_obs = transition.get('next_state') or transition.get('next_obs') + reward = transition.get('reward') + done = transition.get('done') + truncated = transition.get('truncated', None) + info = transition.get('info', None) + + if not isinstance(obs, MetaObs): + obs = dict2meta(obs or {}, mtype='obs') + if not isinstance(action, MetaAction): + action = dict2meta(action or {}, mtype='act') + if not isinstance(next_obs, MetaObs): + next_obs = dict2meta(next_obs or {}, mtype='obs') + + return RLTransition(obs=obs, action=action, next_obs=next_obs, reward=reward, done=done, truncated=truncated, info=info) + + def _extract_transition_from_data(self, data: Dict[str, Any], idx: int) -> RLTransition: + obs = MetaObs( + state=data.get('_state')[:, idx] if data.get('_state') is not None else None, + state_ee=data.get('_state_ee')[:, idx] if data.get('_state_ee') is not None else None, + state_joint=data.get('_state_joint')[:, idx] if data.get('_state_joint') is not None else None, + state_obj=data.get('_state_obj')[:, idx] if data.get('_state_obj') is not None else None, + image=data.get('_image')[:, idx] if data.get('_image') is not None else None, + depth=data.get('_depth')[:, idx] if data.get('_depth') is not None else None, + pc=data.get('_pc')[:, idx] if data.get('_pc') is not None else None, + raw_lang=[row[idx] for row in data.get('_raw_lang', [])] if data.get('_raw_lang') is not None else None, + timestep=data.get('_timestep')[:, idx] if data.get('_timestep') is not None else None, + ) + next_obs = MetaObs( + state=data.get('_next_state')[:, idx] if data.get('_next_state') is not None else None, + state_ee=data.get('_next_state_ee')[:, idx] if data.get('_next_state_ee') is not None else None, + state_joint=data.get('_next_state_joint')[:, idx] if data.get('_next_state_joint') is not None else None, + state_obj=data.get('_next_state_obj')[:, idx] if data.get('_next_state_obj') is not None else None, + image=data.get('_next_image')[:, idx] if data.get('_next_image') is not None else None, + depth=data.get('_next_depth')[:, idx] if data.get('_next_depth') is not None else None, + pc=data.get('_next_pc')[:, idx] if data.get('_next_pc') is not None else None, + raw_lang=[row[idx] for row in data.get('_next_raw_lang', [])] if data.get('_next_raw_lang') is not None else None, + timestep=data.get('_next_timestep')[:, idx] if data.get('_next_timestep') is not None else None, + ) + action = MetaAction( + action=data.get('_action')[:, idx] if data.get('_action') is not None else None, + ctrl_space=[CTRL_SPACE_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_ctrl_space')[:, idx])], + ctrl_type=[CTRL_TYPE_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_ctrl_type')[:, idx])], + gripper_continuous=[GRIPPER_CONTINUOUS_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_gripper_continuous')[:, idx])], + ) + return RLTransition( + obs=obs, + action=action, + next_obs=next_obs, + reward=data.get('_reward')[:, idx], + done=data.get('_done')[:, idx], + truncated=data.get('_truncated')[:, idx] if data.get('_truncated') is not None else None, + ) + + def _extract_transition_from_npz(self, npz_data, idx: int) -> RLTransition: + data = {k: npz_data[k] for k in npz_data.files} + return self._extract_transition_from_data(data, idx) + + def get_trajectory(self, trajectory_id: int, env_idx: int = 0) -> Dict[str, Any]: + mask = self._trajectory_id[env_idx, :self._size] == trajectory_id + time_indices = np.where(mask)[0] + if len(time_indices) == 0: + return {} + env_indices = np.full(len(time_indices), env_idx, dtype=np.int64) + return self.sample(len(time_indices), env_indices=env_indices) + + def __repr__(self) -> str: + return (f"MetaReplay(capacity={self.capacity}, size={self._size}, " + f"n_envs={self.n_envs}, total_transitions={self.total_transitions}, " + f"env_type='{self.env_type}', state_dim={self.state_dim}, " + f"action_dim={self.action_dim}, store_images={self.store_images}, " + f"store_lang={self.store_lang})") + + diff --git a/rl/buffer/transition.py b/rl/buffer/transition.py new file mode 100644 index 00000000..ab708045 --- /dev/null +++ b/rl/buffer/transition.py @@ -0,0 +1,38 @@ +""" +RL transition data structures. +""" + +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np + +from benchmark.base import MetaObs, MetaAction + + +@dataclass +class RLTransition: + """ + One-step transition for RL replay (raw environment meta data). + """ + + obs: MetaObs + action: MetaAction + next_obs: MetaObs + reward: Any + done: Any + truncated: Optional[Any] = None + info: Optional[Any] = None + + def to_batch(self) -> "RLTransition": + """ + Ensure obs/action/next_obs are batched along env dimension. + """ + if hasattr(self.obs, "to_batch"): + self.obs.to_batch() + if hasattr(self.action, "to_batch"): + self.action.to_batch() + if hasattr(self.next_obs, "to_batch"): + self.next_obs.to_batch() + return self + diff --git a/rl/collectors/__init__.py b/rl/collectors/__init__.py new file mode 100644 index 00000000..b74a1ea8 --- /dev/null +++ b/rl/collectors/__init__.py @@ -0,0 +1,85 @@ +""" +Data Collectors Module + +This module provides data collectors for gathering experience from environments. + +Design Philosophy: +- Responsibility separation: Separate data collection logic from trainer +- Environment abstraction: Support single environment, parallel environments, multiple environment types +- Raw data storage: Only store raw environment rewards, no reward function computation +- Statistics: Collect and return episode statistics + +Available collectors: +- SimCollector: Collector for simulation environments +- RealCollector: Collector for real robot environments + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating collectors. +""" + +from typing import Type, Dict, Any + +from .base_collector import BaseCollector, DummyCollector + +# Registry for collector classes +_COLLECTOR_REGISTRY: Dict[str, Type] = {} + + +def register_collector(name: str, collector_class: Type) -> None: + """ + Register a collector class. + + Args: + name: Collector name (e.g., 'sim', 'real') + collector_class: Collector class to register + """ + _COLLECTOR_REGISTRY[name.lower()] = collector_class + + +def get_collector_class(name_or_type: str) -> Type: + """ + Get collector class by name or type string. + + Args: + name_or_type: Collector name (e.g., 'sim') or full type path + (e.g., 'rl.collectors.sim_collector.SimCollector') + + Returns: + Collector class + + Raises: + ValueError: If collector not found + """ + # First check registry + if name_or_type.lower() in _COLLECTOR_REGISTRY: + return _COLLECTOR_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import collector from '{name_or_type}': {e}") + + raise ValueError(f"Unknown collector: '{name_or_type}'. Available: {list(_COLLECTOR_REGISTRY.keys())}") + + +def list_collectors() -> list: + """List all registered collectors.""" + return list(_COLLECTOR_REGISTRY.keys()) + + +__all__ = [ + 'BaseCollector', + 'DummyCollector', + 'register_collector', + 'get_collector_class', + 'list_collectors', +] + diff --git a/rl/collectors/base_collector.py b/rl/collectors/base_collector.py new file mode 100644 index 00000000..829ac6f9 --- /dev/null +++ b/rl/collectors/base_collector.py @@ -0,0 +1,541 @@ +""" +Base Collector Class + +This module defines the base class for all data collectors in the RL framework. + +Design Philosophy: +- Responsibility separation: Separate data collection logic from trainer + so that trainer can focus on training loop coordination +- Environment abstraction: Support vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) +- Raw data storage: Only store raw environment rewards, no reward function computation, + ensuring data integrity +- Statistics: Collect and return episode statistics +- Exploration support: Support random exploration phase and noise-based exploration +""" + +import numpy as np +from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction +from benchmark.base import MetaObs, MetaAction +from benchmark.utils import organize_obs +from rl.buffer.transition import RLTransition +# Vector environment protocol +from rl.envs import VectorEnvProtocol, EnvsType + +if TYPE_CHECKING: + from utils.exploration import ExplorationScheduler + + +class BaseCollector(ABC): + """ + Base class for data collectors. + + This class defines the interface for all data collectors in the RL framework. + Collectors gather experience data from vectorized environments by interacting with them + using the algorithm's policy. + + Attributes: + envs: Vectorized environment(s) to collect data from + algorithm: Algorithm instance for action selection and transition recording + exploration: Optional exploration strategy for action exploration + + Note: Collector only stores raw environment rewards, no reward function computation. + Reward functions are applied in the trainer during training time. + """ + + def __init__( + self, + envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], + algorithm: 'BaseAlgorithm', + exploration: Optional['ExplorationScheduler'] = None, + **kwargs + ): + """ + Initialize the collector. + + Args: + envs: Vectorized environment(s), supports: + - VectorEnvProtocol: Single vectorized environment + (SequentialVectorEnv, SubprocVectorEnv, DummyVectorEnv, etc.) + - Dict[str, VectorEnvProtocol]: Multi-environment dict for different env types + e.g., {'sim': sim_vec_env, 'real': real_vec_env} + algorithm: BaseAlgorithm instance (required) + - Used for action selection and transition recording + exploration: Optional ExplorationScheduler for action exploration + - If None, no exploration is applied (use policy actions directly) + - Supports initial random exploration + noise-based exploration + - Example: + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + **kwargs: Collector-specific parameters + + Note: Collector only stores raw environment rewards, no reward function computation. + Reward functions are applied in trainer during training time. + """ + self.envs = envs + self.algorithm = algorithm + self.exploration = exploration + self._kwargs = kwargs + + # Total steps counter for exploration scheduling + self._total_steps = 0 + + # Normalize environment storage: always use dict internally + if isinstance(envs, dict): + self._envs_dict: Dict[str, VectorEnvProtocol] = envs + self._is_multi_env = True + else: + self._envs_dict = {'default': envs} + self._is_multi_env = False + + @abstractmethod + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + Collect n_steps of interaction data. + + Args: + n_steps: Number of steps to collect + env_type: Optional, environment type identifier (for multi-environment scenarios) + - If provided, will be passed to record_transition with env_type + - Used to support a single algorithm storing data from multiple different environments + + Returns: + Dictionary containing statistics, such as: + - episode_rewards: List of episode rewards + - episode_lengths: List of episode lengths + - total_steps: Total number of steps collected + - env_type_stats: Statistics grouped by environment type (if multi-environment is supported) + """ + raise NotImplementedError + + @abstractmethod + def reset(self, **kwargs) -> None: + """ + Reset collector state (e.g., reset environments). + + Args: + **kwargs: Reset parameters + """ + raise NotImplementedError + + def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + """ + Get the vectorized environment by type. + + Args: + env_type: Environment type identifier. If None, returns 'default' env + or the first env if 'default' doesn't exist. + + Returns: + The vectorized environment + + Raises: + KeyError: If specified env_type is not found + """ + if env_type is not None: + return self._envs_dict[env_type] + + # Return 'default' if exists, otherwise return first env + if 'default' in self._envs_dict: + return self._envs_dict['default'] + return list(self._envs_dict.values())[0] + + def get_total_env_num(self) -> int: + """ + Get total number of parallel environments across all types. + + Returns: + Total count of parallel environments + """ + return sum(len(env) for env in self._envs_dict.values()) + + def get_env_types(self) -> List[str]: + """ + Get all environment type identifiers. + + Returns: + List of environment type strings + """ + return list(self._envs_dict.keys()) + + def get_envs(self) -> Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]]: + """Get the underlying environment(s).""" + return self.envs + + # ==================== Exploration Methods ==================== + + def set_exploration(self, exploration: Optional['ExplorationScheduler']) -> None: + """ + Set or update the exploration strategy. + + Args: + exploration: ExplorationScheduler instance or None to disable exploration + """ + self.exploration = exploration + + def apply_exploration( + self, + action: Any, + obs: Any = None, + **kwargs + ) -> Any: + """ + Apply exploration to action(s). + + Supports: + - MetaAction: Single action or batched (action.action shape: (action_dim,) or (n_envs, action_dim)) + - List[MetaAction]: List of individual actions + - np.ndarray: Raw action array + - dict with 'action' key + + Args: + action: Action(s) from policy + obs: Optional observation (for uncertainty-based exploration) + **kwargs: Additional arguments for exploration strategy + + Returns: + Explored action(s) (same type/structure as input) + """ + if self.exploration is None: + return action + + # Case 1: List of MetaAction objects + if isinstance(action, list) and len(action) > 0 and hasattr(action[0], 'action'): + # Stack actions, apply exploration, then update each + action_arrays = [a.action for a in action] + stacked = np.stack(action_arrays) # (n_envs, action_dim) + explored = self.exploration( + stacked, + step=self._total_steps, + obs=obs, + **kwargs + ) + # Update each MetaAction + for i, a in enumerate(action): + a.action = explored[i] + return action + + # Case 2: Single MetaAction (possibly with batched action field) + elif hasattr(action, 'action') and action.action is not None: + action_array = action.action + explored_array = self.exploration( + action_array, + step=self._total_steps, + obs=obs, + **kwargs + ) + action.action = explored_array + return action + + # Case 3: Dict with 'action' key + elif isinstance(action, dict) and 'action' in action: + action_array = action['action'] + explored_array = self.exploration( + action_array, + step=self._total_steps, + obs=obs, + **kwargs + ) + action['action'] = explored_array + return action + + # Case 4: Raw numpy array + elif isinstance(action, np.ndarray): + return self.exploration( + action, + step=self._total_steps, + obs=obs, + **kwargs + ) + + # Case 5: Unknown type, return as-is + else: + return action + + def reset_exploration(self, env_idx: Optional[int] = None) -> None: + """ + Reset exploration state (e.g., for OU noise). + + Args: + env_idx: Optional specific environment index to reset + """ + if self.exploration is not None: + self.exploration.reset(env_idx) + + @property + def total_steps(self) -> int: + """Get total steps collected so far.""" + return self._total_steps + + @property + def is_exploring_randomly(self) -> bool: + """Check if currently in random exploration phase.""" + if self.exploration is None: + return False + return self.exploration.is_random_phase + + def _unpack_actions(self, actions: Any, n_envs: int) -> List[Any]: + """ + Unpack batched actions into a list for vec_env.step. + + Handles: + - MetaAction with batched action field: (n_envs, action_dim) -> List[MetaAction] + - List[MetaAction]: Return as-is + - np.ndarray: (n_envs, action_dim) -> List[np.ndarray] + + Args: + actions: Batched actions from policy + n_envs: Number of environments + + Returns: + List of individual actions, one per environment + """ + # Case 1: Already a list + if isinstance(actions, list): + return actions + + # Case 2: MetaAction with batched action field + if hasattr(actions, 'action') and actions.action is not None: + action_array = actions.action + if isinstance(action_array, np.ndarray) and action_array.ndim >= 2: + # Batched: (n_envs, action_dim) or (n_envs, chunk_size, action_dim) + unpacked = [] + for i in range(min(n_envs, len(action_array))): + # Create new MetaAction for each env + single_action = MetaAction( + action=action_array[i], + ctrl_space=getattr(actions, 'ctrl_space', 'ee'), + ctrl_type=getattr(actions, 'ctrl_type', 'delta'), + gripper_continuous=getattr(actions, 'gripper_continuous', False), + ) + unpacked.append(single_action) + return unpacked + else: + # Single action, replicate for all envs + return [actions] * n_envs + + # Case 3: np.ndarray (batched) + if isinstance(actions, np.ndarray): + if actions.ndim >= 2: + return [actions[i] for i in range(min(n_envs, len(actions)))] + else: + return [actions] * n_envs + + # Case 4: Unknown, replicate for all envs + return [actions] * n_envs + + def get_algorithm(self) -> 'BaseAlgorithm': + """Get the algorithm instance.""" + return self.algorithm + + @property + def env_num(self) -> int: + """ + Number of environments in the default (or first) environment. + + Returns: + Number of parallel environments + """ + if 'default' in self._envs_dict: + return len(self._envs_dict['default']) + return len(list(self._envs_dict.values())[0]) + + def __repr__(self) -> str: + env_info = f"{self.get_total_env_num()} envs" if self._is_multi_env else f"{self.env_num} envs" + return f"{self.__class__.__name__}(envs={env_info}, algorithm={self.algorithm.__class__.__name__})" + + +class DummyCollector(BaseCollector): + """ + Simple collector implementation for off-policy RL algorithms. + + This collector: + - Works with vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) + - Handles batched observations and actions + - Creates RLTransition objects for storage in replay buffer + - Tracks episode statistics + - Supports both step-by-step and batch collection + """ + + def __init__(self, envs, algorithm, ctrl_space='joint', action_dim=None, **kwargs): + super().__init__(envs, algorithm, **kwargs) + self.vec_env = self.get_env() + self._last_obs = None + self.ctrl_space = ctrl_space + self.action_dim = action_dim + + # Episode tracking + self._episode_rewards = None + self._episode_lengths = None + + def reset(self, **kwargs): + """Reset the collector and environments.""" + self.vec_env = self.get_env() + self._last_obs = self.vec_env.reset() + + # Initialize episode tracking + num_envs = len(self.vec_env) + self._episode_rewards = np.zeros(num_envs, dtype=np.float32) + self._episode_lengths = np.zeros(num_envs, dtype=np.int32) + + def collect_step( + self, + noise_scale: float = 0.0, + use_random: bool = False, + env_type: str = None + ) -> Dict[str, Any]: + """ + Collect one step of interaction data. + + This is the primary method for off-policy training where we collect + one step at a time and interleave with policy updates. + + Args: + noise_scale: Exploration noise scale (for policy actions) + use_random: If True, use random actions (for initial exploration) + env_type: Optional environment type identifier + + Returns: + Dictionary with statistics for this step + """ + if self._last_obs is None: + self.reset() + + from benchmark.base import MetaObs, MetaAction + from benchmark.utils import organize_obs + from rl.buffer.transition import RLTransition + + stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0} + + # Organize observations into batched MetaObs + obs = self._last_obs + if not isinstance(obs, MetaObs): + obs = organize_obs(obs, self.ctrl_space) + + num_envs = len(self.vec_env) + + # Select action + if use_random: + # Random exploration + action_dim = self.action_dim + if action_dim is None: + # Try to infer from environment + single_env = self.vec_env.envs[0] if hasattr(self.vec_env, 'envs') else self.vec_env + if hasattr(single_env, 'action_space'): + action_dim = single_env.action_space.shape[0] + elif hasattr(single_env, 'action_dim'): + action_dim = single_env.action_dim + else: + raise ValueError("Cannot determine action_dim for random exploration") + action_array = np.random.uniform(-1, 1, (num_envs, action_dim)).astype(np.float32) + action = MetaAction(action=action_array) + else: + # Policy action with exploration noise + action = self.algorithm.select_action(obs, noise_scale=noise_scale, env=self.vec_env) + + # Unpack action to numpy array + if hasattr(action, 'action'): + action_array = action.action + else: + action_array = action + + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Convert to list of dicts for vec_env.step + step_actions = [{'action': action_array[i]} for i in range(num_envs)] + + # Step environment + next_obs, rewards, dones, infos = self.vec_env.step(step_actions) + + # Handle infos which may be None, list, or numpy array + if infos is not None and len(infos) > 0: + truncated = np.array([ + info.get('TimeLimit.truncated', False) if isinstance(info, dict) else False + for info in infos + ]) + else: + truncated = np.zeros_like(dones) + + # Organize next observations + if not isinstance(next_obs, MetaObs): + next_obs = organize_obs(next_obs, self.ctrl_space) + + # Ensure action is MetaAction + if not isinstance(action, MetaAction): + action = MetaAction(action=action_array) + + # Create transition and record + transition = RLTransition( + obs=obs, + action=action, + next_obs=next_obs, + reward=rewards, + done=dones, + truncated=truncated, + ) + + kwargs_trans = {'env_type': env_type} if env_type else {} + self.algorithm.record_transition(transition, **kwargs_trans) + + # Update episode statistics + self._episode_rewards += rewards + self._episode_lengths += 1 + stats['total_steps'] = num_envs + self._total_steps += num_envs + + # Handle done episodes + done_indices = np.where(dones)[0] + if len(done_indices) > 0: + stats['episode_rewards'] = self._episode_rewards[done_indices].tolist() + stats['episode_lengths'] = self._episode_lengths[done_indices].tolist() + + # Reset tracking for done envs + self._episode_rewards[done_indices] = 0 + self._episode_lengths[done_indices] = 0 + + # Reset done environments and update next_obs + reset_obs = self.vec_env.reset(id=done_indices) + if reset_obs is not None: + reset_obs_organized = organize_obs(reset_obs, self.ctrl_space) if not isinstance(reset_obs, MetaObs) else reset_obs + if isinstance(next_obs, MetaObs) and next_obs.state is not None: + if hasattr(reset_obs_organized, 'state') and reset_obs_organized.state is not None: + next_obs.state[done_indices] = reset_obs_organized.state + + # Update last observation + self._last_obs = next_obs + + return stats + + def collect(self, n_steps, env_type=None): + """ + Collect n_steps of interaction data. + + Args: + n_steps: Number of steps to collect + env_type: Optional environment type identifier + + Returns: + Dictionary with statistics: + - episode_rewards: List of episode rewards + - episode_lengths: List of episode lengths + - total_steps: Total number of steps collected + - env_type: Environment type identifier + """ + stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0, 'env_type': env_type} + + for _ in range(n_steps): + step_stats = self.collect_step(env_type=env_type) + stats['episode_rewards'].extend(step_stats.get('episode_rewards', [])) + stats['episode_lengths'].extend(step_stats.get('episode_lengths', [])) + stats['total_steps'] += step_stats.get('total_steps', 0) + + return stats + diff --git a/rl/envs/__init__.py b/rl/envs/__init__.py new file mode 100644 index 00000000..6d8cacbd --- /dev/null +++ b/rl/envs/__init__.py @@ -0,0 +1,19 @@ +""" +RL Environments Module + +Provides protocols and utilities for vectorized environments. +""" + +from .protocols import VectorEnvProtocol, VectorEnv, EnvsType +from .utils import make_vector_env, get_env_info + +__all__ = [ + # Protocols + 'VectorEnvProtocol', + 'VectorEnv', + 'EnvsType', + # Utilities + 'make_vector_env', + 'get_env_info', +] + diff --git a/rl/envs/protocols.py b/rl/envs/protocols.py new file mode 100644 index 00000000..35ddedc1 --- /dev/null +++ b/rl/envs/protocols.py @@ -0,0 +1,77 @@ +""" +Vector Environment Protocols + +Defines the protocol (interface) for vectorized environments. +Supports SequentialVectorEnv, SubprocVectorEnv, and custom implementations. +""" + +from typing import Protocol, runtime_checkable, Any, Optional, Union, List, Dict +import numpy as np + + +@runtime_checkable +class VectorEnvProtocol(Protocol): + """ + Vectorized environment protocol. + + Any class implementing this protocol can be used as a vectorized environment + in the RL framework. This includes: + - benchmark.utils.SequentialVectorEnv + - tianshou.env.SubprocVectorEnv + - tianshou.env.DummyVectorEnv + - tianshou.env.ShmemVectorEnv + - Any custom implementation satisfying this interface + + Attributes: + env_num: Number of parallel environments + """ + env_num: int + + def reset( + self, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> Any: + """ + Reset environment(s). + + Args: + id: Optional environment index(es) to reset. + - None: Reset all environments + - int: Reset single environment + - List[int] or np.ndarray: Reset specific environments + + Returns: + Observations from reset environment(s) + """ + ... + + def step( + self, + action: Any, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> tuple: + """ + Execute action(s) in environment(s). + + Args: + action: Action(s) to execute + id: Optional environment index(es) to step + + Returns: + Tuple of (obs, reward, done, info) + """ + ... + + def close(self) -> None: + """Close all environments and release resources.""" + ... + + def __len__(self) -> int: + """Return number of environments.""" + ... + + +# Type aliases for convenience +VectorEnv = VectorEnvProtocol +EnvsType = Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]] + diff --git a/rl/envs/utils.py b/rl/envs/utils.py new file mode 100644 index 00000000..f3c972fa --- /dev/null +++ b/rl/envs/utils.py @@ -0,0 +1,99 @@ +""" +Environment Utilities + +Provides helper functions for creating and managing vectorized environments. +""" + +from typing import Callable, Any, Optional +from .protocols import VectorEnvProtocol + + +def make_vector_env( + env_fn: Callable[[], Any], + num_envs: int = 1, + vector_type: str = 'sequential', + **kwargs +) -> VectorEnvProtocol: + """ + Create a vectorized environment from an environment factory function. + + This is a convenience function for creating vectorized environments. + For more control, create the vectorized environment directly. + + Args: + env_fn: Factory function that returns a single environment (e.g., MetaEnv). + This function will be called num_envs times to create parallel envs. + num_envs: Number of parallel environments to create. Default is 1. + vector_type: Type of vectorization to use: + - 'sequential': SequentialVectorEnv (no multiprocessing, safe for daemon processes) + - 'subproc': SubprocVectorEnv (multiprocessing, faster for CPU-bound envs) + - 'dummy': DummyVectorEnv (tianshou's sequential implementation) + - 'shmem': ShmemVectorEnv (shared memory, fastest for large observations) + **kwargs: Additional arguments passed to the vector environment constructor + + Returns: + VectorEnvProtocol: A vectorized environment instance + + Examples: + >>> from benchmark.aloha import create_env + >>> + >>> # Create 4 parallel environments using sequential vectorization + >>> vec_env = make_vector_env( + ... env_fn=lambda: create_env(config), + ... num_envs=4, + ... vector_type='sequential' + ... ) + >>> + >>> # Create 8 parallel environments using subprocesses + >>> vec_env = make_vector_env( + ... env_fn=lambda: create_env(config), + ... num_envs=8, + ... vector_type='subproc' + ... ) + + Raises: + ValueError: If vector_type is not recognized + ImportError: If required module for vector_type is not available + """ + env_fns = [env_fn for _ in range(num_envs)] + + if vector_type == 'sequential': + from benchmark.utils import SequentialVectorEnv + return SequentialVectorEnv(env_fns) + + elif vector_type == 'subproc': + from tianshou.env import SubprocVectorEnv + return SubprocVectorEnv(env_fns, **kwargs) + + elif vector_type == 'dummy': + from tianshou.env import DummyVectorEnv + return DummyVectorEnv(env_fns) + + elif vector_type == 'shmem': + from tianshou.env import ShmemVectorEnv + return ShmemVectorEnv(env_fns, **kwargs) + + else: + raise ValueError( + f"Unknown vector_type: {vector_type}. " + f"Supported types: 'sequential', 'subproc', 'dummy', 'shmem'" + ) + + +def get_env_info(envs: VectorEnvProtocol) -> dict: + """ + Get information about a vectorized environment. + + Args: + envs: A vectorized environment + + Returns: + Dictionary containing environment information: + - env_num: Number of parallel environments + - type: Type name of the vectorized environment + """ + return { + 'env_num': envs.env_num if hasattr(envs, 'env_num') else len(envs), + 'type': type(envs).__name__, + } + diff --git a/rl/infra/__init__.py b/rl/infra/__init__.py new file mode 100644 index 00000000..155f64da --- /dev/null +++ b/rl/infra/__init__.py @@ -0,0 +1,56 @@ +""" +RL Infrastructure Module + +This module provides infrastructure components for reproducibility and stability: +- SeedManager: Unified random seed management +- Logger: Training logging system (TensorBoard, WandB, etc.) +- Checkpoint: Model and training state checkpoint management +- Callback: Training callback system for hooks and monitoring +- Distributed: Distributed training support utilities + +Design Philosophy: +- Reproducibility: Ensure experiments can be reproduced exactly +- Stability: Provide robust training infrastructure +- Modularity: Each component can be used independently +- Extensibility: Easy to add new logging backends, callbacks, etc. +""" + +from .seed_manager import SeedManager, set_global_seed, get_global_seed +from .logger import BaseLogger, ConsoleLogger, TensorBoardLogger, CompositeLogger +from .checkpoint import CheckpointManager +from .callback import ( + Callback, CallbackList, + ProgressCallback, EvalCallback, CheckpointCallback, EarlyStoppingCallback +) +from .distributed import DistributedContext, get_world_size, get_rank, is_main_process + +__all__ = [ + # Seed management + 'SeedManager', + 'set_global_seed', + 'get_global_seed', + + # Logging + 'BaseLogger', + 'ConsoleLogger', + 'TensorBoardLogger', + 'CompositeLogger', + + # Checkpoint + 'CheckpointManager', + + # Callbacks + 'Callback', + 'CallbackList', + 'ProgressCallback', + 'EvalCallback', + 'CheckpointCallback', + 'EarlyStoppingCallback', + + # Distributed + 'DistributedContext', + 'get_world_size', + 'get_rank', + 'is_main_process', +] + diff --git a/rl/infra/callback.py b/rl/infra/callback.py new file mode 100644 index 00000000..5e9ffb42 --- /dev/null +++ b/rl/infra/callback.py @@ -0,0 +1,624 @@ +""" +Callback System for RL Training + +This module provides a flexible callback system for hooking into the training loop: +- Progress tracking and logging +- Evaluation during training +- Checkpoint saving +- Early stopping +- Custom callbacks + +Design Philosophy: +- Non-intrusive: Callbacks don't modify training logic +- Composable: Multiple callbacks can be combined +- Extensible: Easy to add custom callbacks +- Event-driven: Callbacks respond to training events +""" + +import time +import numpy as np +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List, Callable, Union +from dataclasses import dataclass, field + + +@dataclass +class TrainingState: + """ + Container for training state passed to callbacks. + + This provides a standardized interface for callbacks to access training info. + """ + step: int = 0 + episode: int = 0 + total_timesteps: int = 0 + + # Episode info + episode_reward: float = 0.0 + episode_length: int = 0 + episode_rewards: List[float] = field(default_factory=list) + episode_lengths: List[int] = field(default_factory=list) + + # Training info + loss: Optional[float] = None + learning_rate: Optional[float] = None + + # Timing + fps: float = 0.0 + time_elapsed: float = 0.0 + + # Extra info + info: Dict[str, Any] = field(default_factory=dict) + + # Control flags + should_stop: bool = False + + +class Callback(ABC): + """ + Base class for training callbacks. + + Callbacks can hook into various points in the training loop: + - on_training_start: Called once at the beginning of training + - on_training_end: Called once at the end of training + - on_step_start: Called before each training step + - on_step_end: Called after each training step + - on_episode_start: Called at the start of each episode + - on_episode_end: Called at the end of each episode + - on_rollout_start: Called before data collection + - on_rollout_end: Called after data collection + - on_update_start: Called before policy update + - on_update_end: Called after policy update + + Override only the methods you need. + """ + + def __init__(self): + self.training_state: Optional[TrainingState] = None + self.trainer = None + self.logger = None + + def set_trainer(self, trainer) -> None: + """Set reference to the trainer.""" + self.trainer = trainer + + def set_logger(self, logger) -> None: + """Set reference to the logger.""" + self.logger = logger + + def on_training_start(self, state: TrainingState) -> None: + """Called at the beginning of training.""" + pass + + def on_training_end(self, state: TrainingState) -> None: + """Called at the end of training.""" + pass + + def on_step_start(self, state: TrainingState) -> None: + """Called before each training step.""" + pass + + def on_step_end(self, state: TrainingState) -> bool: + """ + Called after each training step. + + Returns: + True to continue training, False to stop + """ + return True + + def on_episode_start(self, state: TrainingState) -> None: + """Called at the start of each episode.""" + pass + + def on_episode_end(self, state: TrainingState) -> None: + """Called at the end of each episode.""" + pass + + def on_rollout_start(self, state: TrainingState) -> None: + """Called before data collection (rollout).""" + pass + + def on_rollout_end(self, state: TrainingState) -> None: + """Called after data collection (rollout).""" + pass + + def on_update_start(self, state: TrainingState) -> None: + """Called before policy update.""" + pass + + def on_update_end(self, state: TrainingState) -> None: + """Called after policy update.""" + pass + + +class CallbackList(Callback): + """ + Container for multiple callbacks. + + Forwards all events to contained callbacks in order. + """ + + def __init__(self, callbacks: Optional[List[Callback]] = None): + super().__init__() + self.callbacks = callbacks or [] + + def append(self, callback: Callback) -> None: + """Add a callback.""" + self.callbacks.append(callback) + + def set_trainer(self, trainer) -> None: + """Set trainer for all callbacks.""" + self.trainer = trainer + for callback in self.callbacks: + callback.set_trainer(trainer) + + def set_logger(self, logger) -> None: + """Set logger for all callbacks.""" + self.logger = logger + for callback in self.callbacks: + callback.set_logger(logger) + + def on_training_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_training_start(state) + + def on_training_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_training_end(state) + + def on_step_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_step_start(state) + + def on_step_end(self, state: TrainingState) -> bool: + continue_training = True + for callback in self.callbacks: + if not callback.on_step_end(state): + continue_training = False + return continue_training + + def on_episode_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_episode_start(state) + + def on_episode_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_episode_end(state) + + def on_rollout_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_rollout_start(state) + + def on_rollout_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_rollout_end(state) + + def on_update_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_update_start(state) + + def on_update_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_update_end(state) + + +class ProgressCallback(Callback): + """ + Callback for logging training progress. + + Logs metrics at specified intervals. + """ + + def __init__( + self, + log_interval: int = 100, + verbose: int = 1 + ): + """ + Initialize progress callback. + + Args: + log_interval: Steps between log outputs + verbose: Verbosity level (0=silent, 1=normal, 2=detailed) + """ + super().__init__() + self.log_interval = log_interval + self.verbose = verbose + self._start_time = None + self._last_log_step = 0 + + def on_training_start(self, state: TrainingState) -> None: + self._start_time = time.time() + if self.verbose >= 1: + print("=" * 60) + print("Training started") + print("=" * 60) + + def on_training_end(self, state: TrainingState) -> None: + if self.verbose >= 1: + elapsed = time.time() - self._start_time + print("=" * 60) + print(f"Training completed in {elapsed:.1f}s") + print(f"Total steps: {state.step}") + print(f"Total episodes: {state.episode}") + if state.episode_rewards: + print(f"Final mean reward: {np.mean(state.episode_rewards[-100:]):.2f}") + print("=" * 60) + + def on_step_end(self, state: TrainingState) -> bool: + if state.step % self.log_interval == 0 and state.step > self._last_log_step: + self._last_log_step = state.step + self._log_progress(state) + return True + + def _log_progress(self, state: TrainingState) -> None: + """Log current progress.""" + if self.verbose == 0: + return + + elapsed = time.time() - self._start_time + fps = state.step / elapsed if elapsed > 0 else 0 + + # Build log message + parts = [f"Step: {state.step:>8}"] + + if state.episode_rewards: + mean_reward = np.mean(state.episode_rewards[-100:]) + parts.append(f"Reward: {mean_reward:>8.2f}") + + if state.loss is not None: + parts.append(f"Loss: {state.loss:>8.4f}") + + parts.append(f"FPS: {fps:>6.0f}") + parts.append(f"Time: {elapsed:>6.0f}s") + + print(" | ".join(parts)) + + if self.logger: + metrics = { + 'progress/fps': fps, + 'progress/time_elapsed': elapsed + } + if state.episode_rewards: + metrics['progress/mean_reward'] = np.mean(state.episode_rewards[-100:]) + self.logger.log_scalars(metrics, step=state.step) + + +class EvalCallback(Callback): + """ + Callback for periodic evaluation during training. + + Evaluates the policy on a separate environment at specified intervals. + """ + + def __init__( + self, + eval_fn: Callable[[int], Dict[str, float]], + eval_interval: int = 10000, + n_eval_episodes: int = 10, + verbose: int = 1 + ): + """ + Initialize evaluation callback. + + Args: + eval_fn: Function that takes step and returns eval metrics + eval_interval: Steps between evaluations + n_eval_episodes: Number of episodes for evaluation + verbose: Verbosity level + """ + super().__init__() + self.eval_fn = eval_fn + self.eval_interval = eval_interval + self.n_eval_episodes = n_eval_episodes + self.verbose = verbose + + self._last_eval_step = 0 + self.eval_results: List[Dict[str, Any]] = [] + + def on_step_end(self, state: TrainingState) -> bool: + if state.step >= self._last_eval_step + self.eval_interval: + self._last_eval_step = state.step + self._evaluate(state) + return True + + def _evaluate(self, state: TrainingState) -> None: + """Run evaluation.""" + if self.verbose >= 1: + print(f"\n[Eval @ step {state.step}]", end=" ") + + # Run evaluation + eval_metrics = self.eval_fn(state.step) + + # Store results + result = { + 'step': state.step, + **eval_metrics + } + self.eval_results.append(result) + + # Log + if self.verbose >= 1: + metrics_str = " | ".join([f"{k}: {v:.2f}" for k, v in eval_metrics.items()]) + print(metrics_str) + + if self.logger: + self.logger.log_scalars( + {f"eval/{k}": v for k, v in eval_metrics.items()}, + step=state.step + ) + + +class CheckpointCallback(Callback): + """ + Callback for saving checkpoints during training. + + Saves checkpoints at specified intervals and keeps best checkpoint. + """ + + def __init__( + self, + checkpoint_manager: 'CheckpointManager', + save_interval: int = 10000, + save_on_best: bool = True, + metric_name: str = "episode_reward", + verbose: int = 1 + ): + """ + Initialize checkpoint callback. + + Args: + checkpoint_manager: CheckpointManager instance + save_interval: Steps between checkpoint saves + save_on_best: Whether to save when best metric is achieved + metric_name: Metric to track for best model + verbose: Verbosity level + """ + super().__init__() + self.checkpoint_manager = checkpoint_manager + self.save_interval = save_interval + self.save_on_best = save_on_best + self.metric_name = metric_name + self.verbose = verbose + + self._last_save_step = 0 + self._best_metric = None + + def on_step_end(self, state: TrainingState) -> bool: + # Save at interval + if state.step >= self._last_save_step + self.save_interval: + self._last_save_step = state.step + self._save_checkpoint(state, is_best=False) + + # Save on best + if self.save_on_best: + metric = self._get_metric(state) + if metric is not None: + if self._best_metric is None or metric > self._best_metric: + self._best_metric = metric + self._save_checkpoint(state, is_best=True) + + return True + + def on_training_end(self, state: TrainingState) -> None: + """Save final checkpoint.""" + self._save_checkpoint(state, is_best=False) + + def _get_metric(self, state: TrainingState) -> Optional[float]: + """Get the metric value for best model tracking.""" + if self.metric_name == "episode_reward" and state.episode_rewards: + return np.mean(state.episode_rewards[-10:]) + elif self.metric_name in state.info: + return state.info[self.metric_name] + return None + + def _save_checkpoint(self, state: TrainingState, is_best: bool) -> None: + """Save checkpoint using the trainer's save method.""" + if self.trainer is None: + return + + metric = self._get_metric(state) + + # Get model and optimizer state from trainer + save_dict = {} + if hasattr(self.trainer, 'algorithm'): + alg = self.trainer.algorithm + if hasattr(alg, 'meta_policy') and hasattr(alg.meta_policy, 'policy'): + policy = alg.meta_policy.policy + if hasattr(policy, 'state_dict'): + save_dict['model'] = policy.state_dict() + + path = self.checkpoint_manager.save( + step=state.step, + episode=state.episode, + total_timesteps=state.total_timesteps, + reward=metric, + is_best=is_best, + **save_dict + ) + + if self.verbose >= 1 and is_best: + print(f"\n[Checkpoint] Saved best model @ step {state.step} (metric: {metric:.2f})") + + +class EarlyStoppingCallback(Callback): + """ + Callback for early stopping based on a metric. + + Stops training if metric doesn't improve for a specified number of steps. + """ + + def __init__( + self, + patience: int = 50000, + min_delta: float = 0.0, + metric_name: str = "episode_reward", + mode: str = "max", + verbose: int = 1 + ): + """ + Initialize early stopping callback. + + Args: + patience: Number of steps to wait for improvement + min_delta: Minimum change to qualify as improvement + metric_name: Metric to monitor + mode: 'max' or 'min' - whether higher or lower is better + verbose: Verbosity level + """ + super().__init__() + self.patience = patience + self.min_delta = min_delta + self.metric_name = metric_name + self.mode = mode + self.verbose = verbose + + self._best_metric = None + self._steps_without_improvement = 0 + self._last_check_step = 0 + + def on_step_end(self, state: TrainingState) -> bool: + # Check every 1000 steps + if state.step < self._last_check_step + 1000: + return True + self._last_check_step = state.step + + metric = self._get_metric(state) + if metric is None: + return True + + improved = False + if self._best_metric is None: + improved = True + elif self.mode == "max" and metric > self._best_metric + self.min_delta: + improved = True + elif self.mode == "min" and metric < self._best_metric - self.min_delta: + improved = True + + if improved: + self._best_metric = metric + self._steps_without_improvement = 0 + else: + self._steps_without_improvement += 1000 + + # Check for early stopping + if self._steps_without_improvement >= self.patience: + if self.verbose >= 1: + print(f"\n[Early Stopping] No improvement for {self.patience} steps. Stopping training.") + state.should_stop = True + return False + + return True + + def _get_metric(self, state: TrainingState) -> Optional[float]: + """Get the metric value.""" + if self.metric_name == "episode_reward" and state.episode_rewards: + return np.mean(state.episode_rewards[-10:]) + elif self.metric_name in state.info: + return state.info[self.metric_name] + return None + + +if __name__ == '__main__': + """ + Test code for Callback module. + """ + print("=" * 60) + print("Testing Callback Module") + print("=" * 60) + + # Test 1: TrainingState + print("\n1. Testing TrainingState...") + state = TrainingState( + step=100, + episode=5, + episode_reward=50.0, + episode_rewards=[40.0, 45.0, 50.0] + ) + print(f" Step: {state.step}, Episode: {state.episode}") + print(f" Episode rewards: {state.episode_rewards}") + + # Test 2: ProgressCallback + print("\n2. Testing ProgressCallback...") + progress_cb = ProgressCallback(log_interval=50, verbose=1) + + # Simulate training + state = TrainingState() + progress_cb.on_training_start(state) + + for step in range(200): + state.step = step + state.episode_rewards.append(np.random.randn() * 10 + 50) + state.loss = 1.0 - step * 0.001 + progress_cb.on_step_end(state) + + progress_cb.on_training_end(state) + + # Test 3: CallbackList + print("\n3. Testing CallbackList...") + + class CountingCallback(Callback): + def __init__(self): + super().__init__() + self.step_count = 0 + self.episode_count = 0 + + def on_step_end(self, state): + self.step_count += 1 + return True + + def on_episode_end(self, state): + self.episode_count += 1 + + cb1 = CountingCallback() + cb2 = CountingCallback() + callback_list = CallbackList([cb1, cb2]) + + state = TrainingState() + for _ in range(10): + callback_list.on_step_end(state) + + callback_list.on_episode_end(state) + callback_list.on_episode_end(state) + + print(f" CB1 step count: {cb1.step_count}, episode count: {cb1.episode_count}") + print(f" CB2 step count: {cb2.step_count}, episode count: {cb2.episode_count}") + assert cb1.step_count == 10 and cb2.step_count == 10 + assert cb1.episode_count == 2 and cb2.episode_count == 2 + + # Test 4: EarlyStoppingCallback + print("\n4. Testing EarlyStoppingCallback...") + early_stop = EarlyStoppingCallback(patience=5000, verbose=0) + + state = TrainingState() + # Simulate no improvement + for step in range(10): + state.step = step * 1000 + state.episode_rewards = [50.0] * 10 # Constant reward + result = early_stop.on_step_end(state) + + print(f" Steps without improvement: {early_stop._steps_without_improvement}") + print(f" Should continue: {result}") + + # Test 5: EvalCallback + print("\n5. Testing EvalCallback...") + + def dummy_eval_fn(step): + return {'mean_reward': 50 + step * 0.001, 'success_rate': 0.8} + + eval_cb = EvalCallback( + eval_fn=dummy_eval_fn, + eval_interval=5000, + verbose=1 + ) + + state = TrainingState(step=10000) + eval_cb.on_step_end(state) + + print(f" Eval results: {eval_cb.eval_results}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/checkpoint.py b/rl/infra/checkpoint.py new file mode 100644 index 00000000..6c96b10d --- /dev/null +++ b/rl/infra/checkpoint.py @@ -0,0 +1,468 @@ +""" +Checkpoint Manager for RL Training + +This module provides checkpoint management for saving and loading training state: +- Model weights (policy, value network, etc.) +- Optimizer states +- Training progress (step, episode, etc.) +- Replay buffer (optional) +- Configuration and hyperparameters + +Design Philosophy: +- Atomic saves: Use temporary files to prevent corruption +- Version tracking: Track checkpoint versions for compatibility +- Selective loading: Load only specific components if needed +- Automatic cleanup: Keep only N most recent checkpoints +""" + +import os +import json +import shutil +import glob +import torch +import numpy as np +from typing import Dict, Any, Optional, Union, List, Callable +from datetime import datetime +from pathlib import Path +from dataclasses import dataclass, asdict + + +@dataclass +class CheckpointMetadata: + """Metadata for a checkpoint.""" + version: str = "1.0" + timestamp: str = "" + step: int = 0 + episode: int = 0 + total_timesteps: int = 0 + best_reward: Optional[float] = None + extra_info: Optional[Dict[str, Any]] = None + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + + +class CheckpointManager: + """ + Manager for saving and loading training checkpoints. + + Features: + - Save/load model weights, optimizer states, training state + - Atomic saves with temporary files + - Automatic cleanup of old checkpoints + - Best model tracking + - Version compatibility checking + + Usage: + manager = CheckpointManager( + checkpoint_dir="checkpoints/", + max_to_keep=5 + ) + + # Save checkpoint + manager.save( + step=1000, + model=policy.state_dict(), + optimizer=optimizer.state_dict(), + config=config_dict + ) + + # Load checkpoint + state = manager.load_latest() + policy.load_state_dict(state['model']) + """ + + VERSION = "1.0" + + def __init__( + self, + checkpoint_dir: str, + max_to_keep: int = 5, + keep_best: bool = True, + save_optimizer: bool = True, + save_replay_buffer: bool = False, + checkpoint_prefix: str = "ckpt", + **kwargs + ): + """ + Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + max_to_keep: Maximum number of checkpoints to keep (0 = unlimited) + keep_best: Whether to always keep the best checkpoint + save_optimizer: Whether to save optimizer state by default + save_replay_buffer: Whether to save replay buffer by default + checkpoint_prefix: Prefix for checkpoint filenames + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.max_to_keep = max_to_keep + self.keep_best = keep_best + self.save_optimizer = save_optimizer + self.save_replay_buffer = save_replay_buffer + self.checkpoint_prefix = checkpoint_prefix + + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self._best_reward: Optional[float] = None + self._best_checkpoint: Optional[str] = None + + def _get_checkpoint_path(self, step: int) -> Path: + """Get checkpoint file path for a given step.""" + return self.checkpoint_dir / f"{self.checkpoint_prefix}_{step:08d}.pt" + + def _get_best_checkpoint_path(self) -> Path: + """Get path for best checkpoint.""" + return self.checkpoint_dir / f"{self.checkpoint_prefix}_best.pt" + + def _get_metadata_path(self) -> Path: + """Get path for metadata file.""" + return self.checkpoint_dir / "checkpoint_metadata.json" + + def save( + self, + step: int, + model: Optional[Dict[str, Any]] = None, + optimizer: Optional[Dict[str, Any]] = None, + scheduler: Optional[Dict[str, Any]] = None, + replay_buffer: Optional[Any] = None, + config: Optional[Dict[str, Any]] = None, + episode: int = 0, + total_timesteps: int = 0, + reward: Optional[float] = None, + extra: Optional[Dict[str, Any]] = None, + is_best: bool = False, + **kwargs + ) -> str: + """ + Save a checkpoint. + + Args: + step: Current training step + model: Model state dict (or dict of state dicts) + optimizer: Optimizer state dict (or dict of state dicts) + scheduler: Learning rate scheduler state dict + replay_buffer: Replay buffer to save (if save_replay_buffer=True) + config: Configuration dictionary + episode: Current episode number + total_timesteps: Total environment timesteps + reward: Current reward (for best model tracking) + extra: Extra data to save + is_best: Force save as best checkpoint + **kwargs: Additional state to save + + Returns: + Path to saved checkpoint + """ + # Create checkpoint data + checkpoint = { + 'metadata': asdict(CheckpointMetadata( + version=self.VERSION, + step=step, + episode=episode, + total_timesteps=total_timesteps, + best_reward=reward, + extra_info=extra + )), + 'step': step, + 'episode': episode, + 'total_timesteps': total_timesteps, + } + + if model is not None: + checkpoint['model'] = model + + if optimizer is not None and self.save_optimizer: + checkpoint['optimizer'] = optimizer + + if scheduler is not None: + checkpoint['scheduler'] = scheduler + + if config is not None: + checkpoint['config'] = config + + if replay_buffer is not None and self.save_replay_buffer: + # Save replay buffer separately (can be large) + buffer_path = self.checkpoint_dir / f"{self.checkpoint_prefix}_{step:08d}_buffer.pt" + torch.save(replay_buffer, buffer_path) + checkpoint['replay_buffer_path'] = str(buffer_path) + + # Add any extra kwargs + checkpoint.update(kwargs) + + # Save to temporary file first (atomic save) + checkpoint_path = self._get_checkpoint_path(step) + temp_path = checkpoint_path.with_suffix('.tmp') + + torch.save(checkpoint, temp_path) + temp_path.rename(checkpoint_path) + + # Update best checkpoint + if is_best or (reward is not None and (self._best_reward is None or reward > self._best_reward)): + self._best_reward = reward + self._best_checkpoint = str(checkpoint_path) + + # Copy to best checkpoint + best_path = self._get_best_checkpoint_path() + shutil.copy(checkpoint_path, best_path) + + # Save metadata + self._save_metadata() + + # Cleanup old checkpoints + self._cleanup_old_checkpoints() + + return str(checkpoint_path) + + def load( + self, + path: Optional[str] = None, + step: Optional[int] = None, + load_best: bool = False, + load_optimizer: bool = True, + load_replay_buffer: bool = False, + map_location: Optional[Union[str, torch.device]] = None + ) -> Dict[str, Any]: + """ + Load a checkpoint. + + Args: + path: Direct path to checkpoint file + step: Load checkpoint from specific step + load_best: Load best checkpoint + load_optimizer: Whether to load optimizer state + load_replay_buffer: Whether to load replay buffer + map_location: Device to map tensors to + + Returns: + Checkpoint dictionary + """ + # Determine which checkpoint to load + if path is not None: + checkpoint_path = Path(path) + elif load_best: + checkpoint_path = self._get_best_checkpoint_path() + elif step is not None: + checkpoint_path = self._get_checkpoint_path(step) + else: + checkpoint_path = self._get_latest_checkpoint_path() + + if not checkpoint_path or not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Load checkpoint + checkpoint = torch.load(checkpoint_path, map_location=map_location) + + # Optionally remove optimizer state + if not load_optimizer and 'optimizer' in checkpoint: + del checkpoint['optimizer'] + + # Optionally load replay buffer + if load_replay_buffer and 'replay_buffer_path' in checkpoint: + buffer_path = checkpoint['replay_buffer_path'] + if os.path.exists(buffer_path): + checkpoint['replay_buffer'] = torch.load(buffer_path, map_location=map_location) + + return checkpoint + + def load_latest(self, **kwargs) -> Dict[str, Any]: + """Load the most recent checkpoint.""" + return self.load(**kwargs) + + def load_best(self, **kwargs) -> Dict[str, Any]: + """Load the best checkpoint.""" + return self.load(load_best=True, **kwargs) + + def _get_latest_checkpoint_path(self) -> Optional[Path]: + """Get path to the most recent checkpoint.""" + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + + # Filter out best checkpoint and buffer files + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + + if not checkpoints: + return None + + # Sort by step number and return latest + checkpoints.sort() + return Path(checkpoints[-1]) + + def _cleanup_old_checkpoints(self) -> None: + """Remove old checkpoints, keeping only max_to_keep most recent.""" + if self.max_to_keep <= 0: + return + + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + + # Filter out best checkpoint and buffer files + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + checkpoints.sort() + + # Keep only the most recent + to_remove = checkpoints[:-self.max_to_keep] if len(checkpoints) > self.max_to_keep else [] + + for checkpoint in to_remove: + # Don't remove best checkpoint + if self.keep_best and checkpoint == self._best_checkpoint: + continue + + os.remove(checkpoint) + + # Also remove associated buffer file + buffer_path = checkpoint.replace('.pt', '_buffer.pt') + if os.path.exists(buffer_path): + os.remove(buffer_path) + + def _save_metadata(self) -> None: + """Save checkpoint metadata to JSON file.""" + metadata = { + 'best_reward': self._best_reward, + 'best_checkpoint': self._best_checkpoint, + 'version': self.VERSION + } + + metadata_path = self._get_metadata_path() + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + def _load_metadata(self) -> None: + """Load checkpoint metadata from JSON file.""" + metadata_path = self._get_metadata_path() + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + metadata = json.load(f) + self._best_reward = metadata.get('best_reward') + self._best_checkpoint = metadata.get('best_checkpoint') + + def list_checkpoints(self) -> List[Dict[str, Any]]: + """ + List all available checkpoints. + + Returns: + List of checkpoint info dictionaries + """ + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + checkpoints.sort() + + result = [] + for ckpt_path in checkpoints: + # Extract step from filename + filename = os.path.basename(ckpt_path) + step_str = filename.replace(f"{self.checkpoint_prefix}_", "").replace(".pt", "") + try: + step = int(step_str) + except ValueError: + step = -1 + + result.append({ + 'path': ckpt_path, + 'step': step, + 'is_best': ckpt_path == self._best_checkpoint, + 'size_mb': os.path.getsize(ckpt_path) / (1024 * 1024) + }) + + return result + + def has_checkpoint(self) -> bool: + """Check if any checkpoint exists.""" + return self._get_latest_checkpoint_path() is not None + + def get_best_reward(self) -> Optional[float]: + """Get the best reward seen so far.""" + return self._best_reward + + +if __name__ == '__main__': + """ + Test code for CheckpointManager. + """ + import tempfile + + print("=" * 60) + print("Testing CheckpointManager") + print("=" * 60) + + with tempfile.TemporaryDirectory() as tmpdir: + # Test 1: Basic save/load + print("\n1. Testing basic save/load...") + manager = CheckpointManager( + checkpoint_dir=tmpdir, + max_to_keep=3 + ) + + # Create dummy model and optimizer state + model_state = {'layer1.weight': torch.randn(10, 5), 'layer1.bias': torch.randn(10)} + optimizer_state = {'state': {}, 'param_groups': [{'lr': 0.001}]} + config = {'learning_rate': 0.001, 'batch_size': 32} + + # Save checkpoint + path = manager.save( + step=100, + model=model_state, + optimizer=optimizer_state, + config=config, + episode=10, + reward=50.0 + ) + print(f" Saved checkpoint to: {path}") + + # Load checkpoint + loaded = manager.load() + print(f" Loaded step: {loaded['step']}") + print(f" Loaded episode: {loaded['episode']}") + print(f" Model keys: {list(loaded['model'].keys())}") + + # Verify model weights match + assert torch.allclose(loaded['model']['layer1.weight'], model_state['layer1.weight']) + print(" Model weights match!") + + # Test 2: Multiple checkpoints and cleanup + print("\n2. Testing multiple checkpoints and cleanup...") + for step in [200, 300, 400, 500]: + manager.save( + step=step, + model=model_state, + reward=step * 0.1 + ) + + checkpoints = manager.list_checkpoints() + print(f" Number of checkpoints: {len(checkpoints)}") + print(f" Checkpoint steps: {[c['step'] for c in checkpoints]}") + + # Should keep at most max_to_keep (3) + possibly the first checkpoint (if it's best) + # The cleanup keeps max_to_keep most recent, plus best if keep_best=True + assert len(checkpoints) <= 4, f"Expected at most 4 checkpoints, got {len(checkpoints)}" + + # Test 3: Best checkpoint + print("\n3. Testing best checkpoint...") + print(f" Best reward: {manager.get_best_reward()}") + + best_ckpt = manager.load_best() + print(f" Best checkpoint step: {best_ckpt['step']}") + + # Test 4: Load specific step + print("\n4. Testing load specific step...") + ckpt_400 = manager.load(step=400) + print(f" Loaded step 400: {ckpt_400['step']}") + assert ckpt_400['step'] == 400 + + # Test 5: has_checkpoint + print("\n5. Testing has_checkpoint...") + print(f" Has checkpoint: {manager.has_checkpoint()}") + assert manager.has_checkpoint() + + # Test 6: New manager loads metadata + print("\n6. Testing metadata persistence...") + new_manager = CheckpointManager(checkpoint_dir=tmpdir, max_to_keep=3) + new_manager._load_metadata() + print(f" Loaded best reward: {new_manager.get_best_reward()}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/distributed.py b/rl/infra/distributed.py new file mode 100644 index 00000000..2c2a343b --- /dev/null +++ b/rl/infra/distributed.py @@ -0,0 +1,487 @@ +""" +Distributed Training Support for RL + +This module provides utilities for distributed and parallel training: +- Multi-GPU training support +- Multi-process data collection +- Weight synchronization +- Gradient aggregation + +Design Philosophy: +- Transparent: Code works with or without distributed setup +- Compatible: Works with PyTorch DDP and other backends +- Flexible: Support various parallelism strategies +""" + +import os +import torch +import torch.distributed as dist +from typing import Optional, List, Any, Dict, Union, Callable +from contextlib import contextmanager + + +# Global distributed state +_DISTRIBUTED_CONTEXT: Optional['DistributedContext'] = None + + +class DistributedContext: + """ + Context manager for distributed training. + + Handles initialization and cleanup of distributed training environment. + Supports PyTorch DDP and manual multi-process setups. + + Usage: + # Single process (no distribution) + ctx = DistributedContext() + + # Multi-GPU DDP + ctx = DistributedContext( + backend='nccl', + init_method='env://' + ) + + with ctx: + # Training code + pass + """ + + def __init__( + self, + backend: str = 'nccl', + init_method: Optional[str] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, + local_rank: Optional[int] = None, + timeout_minutes: int = 30, + auto_init: bool = True, + **kwargs + ): + """ + Initialize distributed context. + + Args: + backend: Distributed backend ('nccl', 'gloo', 'mpi') + init_method: URL specifying how to initialize process group + world_size: Total number of processes (auto-detect from env if None) + rank: Global rank of this process (auto-detect from env if None) + local_rank: Local rank on this node (auto-detect from env if None) + timeout_minutes: Timeout for distributed operations + auto_init: Whether to auto-initialize if env vars are set + """ + self.backend = backend + self.init_method = init_method + self.timeout_minutes = timeout_minutes + self._initialized = False + self._kwargs = kwargs + + # Try to get from environment variables + self.world_size = world_size or self._get_env_int('WORLD_SIZE', 1) + self.rank = rank or self._get_env_int('RANK', 0) + self.local_rank = local_rank or self._get_env_int('LOCAL_RANK', 0) + + # Auto-initialize if running in distributed mode + if auto_init and self.world_size > 1: + self.init() + + @staticmethod + def _get_env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + val = os.environ.get(key) + if val is not None: + try: + return int(val) + except ValueError: + pass + return default + + def init(self) -> None: + """Initialize distributed process group.""" + if self._initialized: + return + + if self.world_size <= 1: + # Single process, no distribution needed + self._initialized = True + return + + if dist.is_initialized(): + # Already initialized elsewhere + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self._initialized = True + return + + # Initialize process group + init_method = self.init_method or 'env://' + timeout = torch.distributed.default_pg_timeout + if hasattr(torch.distributed, 'init_process_group'): + from datetime import timedelta + timeout = timedelta(minutes=self.timeout_minutes) + + dist.init_process_group( + backend=self.backend, + init_method=init_method, + world_size=self.world_size, + rank=self.rank, + timeout=timeout + ) + + # Set CUDA device if available + if torch.cuda.is_available() and self.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(self.local_rank) + + self._initialized = True + + # Set global context + global _DISTRIBUTED_CONTEXT + _DISTRIBUTED_CONTEXT = self + + def cleanup(self) -> None: + """Cleanup distributed process group.""" + if self._initialized and dist.is_initialized(): + dist.destroy_process_group() + self._initialized = False + + global _DISTRIBUTED_CONTEXT + if _DISTRIBUTED_CONTEXT is self: + _DISTRIBUTED_CONTEXT = None + + def __enter__(self): + self.init() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + return False + + @property + def is_initialized(self) -> bool: + """Check if distributed is initialized.""" + return self._initialized + + @property + def is_distributed(self) -> bool: + """Check if running in distributed mode.""" + return self.world_size > 1 + + @property + def is_main_process(self) -> bool: + """Check if this is the main process (rank 0).""" + return self.rank == 0 + + @property + def device(self) -> torch.device: + """Get the device for this process.""" + if torch.cuda.is_available() and self.local_rank < torch.cuda.device_count(): + return torch.device(f'cuda:{self.local_rank}') + return torch.device('cpu') + + def barrier(self) -> None: + """Synchronization barrier across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.barrier() + + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast tensor from src to all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.broadcast(tensor, src=src) + return tensor + + def all_reduce( + self, + tensor: torch.Tensor, + op: dist.ReduceOp = dist.ReduceOp.SUM + ) -> torch.Tensor: + """All-reduce tensor across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.all_reduce(tensor, op=op) + return tensor + + def all_gather(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors from all processes.""" + if not self.is_distributed or not dist.is_initialized(): + return [tensor] + + tensor_list = [torch.zeros_like(tensor) for _ in range(self.world_size)] + dist.all_gather(tensor_list, tensor) + return tensor_list + + def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: + """Reduce tensor with mean across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor / self.world_size + return tensor + + def sync_params(self, model: torch.nn.Module, src: int = 0) -> None: + """Synchronize model parameters from src to all processes.""" + if not self.is_distributed or not dist.is_initialized(): + return + + for param in model.parameters(): + dist.broadcast(param.data, src=src) + + def sync_gradients(self, model: torch.nn.Module) -> None: + """Synchronize gradients across all processes (average).""" + if not self.is_distributed or not dist.is_initialized(): + return + + for param in model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad = param.grad / self.world_size + + +# Convenience functions +def get_distributed_context() -> Optional[DistributedContext]: + """Get the global distributed context.""" + return _DISTRIBUTED_CONTEXT + + +def get_world_size() -> int: + """Get world size (total number of processes).""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.world_size + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def get_rank() -> int: + """Get global rank of current process.""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.rank + if dist.is_initialized(): + return dist.get_rank() + return 0 + + +def get_local_rank() -> int: + """Get local rank of current process.""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.local_rank + return int(os.environ.get('LOCAL_RANK', 0)) + + +def is_main_process() -> bool: + """Check if this is the main process (rank 0).""" + return get_rank() == 0 + + +def is_distributed() -> bool: + """Check if running in distributed mode.""" + return get_world_size() > 1 + + +def barrier() -> None: + """Synchronization barrier across all processes.""" + if _DISTRIBUTED_CONTEXT is not None: + _DISTRIBUTED_CONTEXT.barrier() + elif dist.is_initialized(): + dist.barrier() + + +@contextmanager +def main_process_first(): + """ + Context manager to ensure main process runs first. + + Useful for downloading files or creating directories. + + Usage: + with main_process_first(): + # Only main process runs this first, others wait + download_dataset() + """ + if not is_main_process(): + barrier() + + yield + + if is_main_process(): + barrier() + + +def print_once(*args, **kwargs) -> None: + """Print only on main process.""" + if is_main_process(): + print(*args, **kwargs) + + +def reduce_dict( + data: Dict[str, float], + average: bool = True +) -> Dict[str, float]: + """ + Reduce a dictionary of values across all processes. + + Args: + data: Dictionary with scalar values + average: If True, compute average; if False, compute sum + + Returns: + Reduced dictionary + """ + if not is_distributed() or not dist.is_initialized(): + return data + + world_size = get_world_size() + + # Stack values into tensor + keys = sorted(data.keys()) + values = torch.tensor([data[k] for k in keys], dtype=torch.float32) + + if torch.cuda.is_available(): + values = values.cuda() + + dist.all_reduce(values, op=dist.ReduceOp.SUM) + + if average: + values = values / world_size + + return {k: v.item() for k, v in zip(keys, values)} + + +class DistributedSampler: + """ + Simple distributed sampler for replay buffers. + + Ensures each process samples different data. + """ + + def __init__( + self, + dataset_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0 + ): + """ + Initialize distributed sampler. + + Args: + dataset_size: Total size of the dataset + num_replicas: Number of processes (auto-detect if None) + rank: Rank of current process (auto-detect if None) + shuffle: Whether to shuffle indices + seed: Random seed for shuffling + """ + self.dataset_size = dataset_size + self.num_replicas = num_replicas or get_world_size() + self.rank = rank or get_rank() + self.shuffle = shuffle + self.seed = seed + self.epoch = 0 + + def __iter__(self): + """Generate indices for this process.""" + if self.shuffle: + # Deterministic shuffling based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.dataset_size, generator=g).tolist() + else: + indices = list(range(self.dataset_size)) + + # Subsample for this replica + indices = indices[self.rank::self.num_replicas] + return iter(indices) + + def __len__(self): + """Return number of samples for this process.""" + return (self.dataset_size + self.num_replicas - 1) // self.num_replicas + + def set_epoch(self, epoch: int) -> None: + """Set epoch for deterministic shuffling.""" + self.epoch = epoch + + +if __name__ == '__main__': + """ + Test code for Distributed module. + """ + print("=" * 60) + print("Testing Distributed Module") + print("=" * 60) + + # Test 1: Basic context (single process) + print("\n1. Testing DistributedContext (single process)...") + ctx = DistributedContext(auto_init=False) + ctx.init() + + print(f" World size: {ctx.world_size}") + print(f" Rank: {ctx.rank}") + print(f" Local rank: {ctx.local_rank}") + print(f" Is distributed: {ctx.is_distributed}") + print(f" Is main process: {ctx.is_main_process}") + print(f" Device: {ctx.device}") + + ctx.cleanup() + + # Test 2: Convenience functions + print("\n2. Testing convenience functions...") + print(f" get_world_size(): {get_world_size()}") + print(f" get_rank(): {get_rank()}") + print(f" get_local_rank(): {get_local_rank()}") + print(f" is_main_process(): {is_main_process()}") + print(f" is_distributed(): {is_distributed()}") + + # Test 3: print_once + print("\n3. Testing print_once...") + print_once(" This should print (main process)") + + # Test 4: reduce_dict (single process - no-op) + print("\n4. Testing reduce_dict (single process)...") + data = {'loss': 0.5, 'reward': 100.0} + reduced = reduce_dict(data) + print(f" Input: {data}") + print(f" Output: {reduced}") + + # Test 5: DistributedSampler + print("\n5. Testing DistributedSampler...") + sampler = DistributedSampler( + dataset_size=100, + num_replicas=4, + rank=0, + shuffle=True, + seed=42 + ) + + indices = list(sampler) + print(f" Dataset size: 100, Replicas: 4") + print(f" Samples for rank 0: {len(indices)}") + print(f" First 10 indices: {indices[:10]}") + + # Test different rank + sampler_rank1 = DistributedSampler( + dataset_size=100, + num_replicas=4, + rank=1, + shuffle=True, + seed=42 + ) + indices_rank1 = list(sampler_rank1) + print(f" First 10 indices (rank 1): {indices_rank1[:10]}") + + # Verify no overlap + overlap = set(indices) & set(indices_rank1) + print(f" Overlap between rank 0 and 1: {len(overlap)} (should be 0)") + assert len(overlap) == 0, "Samplers should not overlap!" + + # Test 6: Context manager + print("\n6. Testing context manager...") + with DistributedContext(auto_init=False) as ctx: + print(f" Inside context: rank={ctx.rank}") + print(" Context exited successfully") + + # Test 7: main_process_first context + print("\n7. Testing main_process_first...") + with main_process_first(): + print(" Main process first block executed") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/logger.py b/rl/infra/logger.py new file mode 100644 index 00000000..f9243480 --- /dev/null +++ b/rl/infra/logger.py @@ -0,0 +1,593 @@ +""" +Logger Module for RL Training + +This module provides a flexible logging system for RL training, supporting: +- Console logging with progress bars +- TensorBoard logging +- WandB logging (optional) +- Composite logging (multiple backends) + +Design Philosophy: +- Unified interface: All loggers implement the same BaseLogger interface +- Lazy initialization: Heavy backends (TensorBoard, WandB) are initialized lazily +- Thread-safe: Support concurrent logging from multiple processes +- Metric aggregation: Support for rolling window statistics +""" + +import os +import sys +import time +import json +import numpy as np +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Union, List +from collections import defaultdict, deque +from datetime import datetime +from pathlib import Path + + +class MetricTracker: + """ + Track metrics with rolling window statistics. + + Supports computing mean, std, min, max over a sliding window. + """ + + def __init__(self, window_size: int = 100): + """ + Initialize metric tracker. + + Args: + window_size: Size of the rolling window for statistics + """ + self.window_size = window_size + self._values: Dict[str, deque] = defaultdict(lambda: deque(maxlen=window_size)) + self._total_counts: Dict[str, int] = defaultdict(int) + self._total_sums: Dict[str, float] = defaultdict(float) + + def add(self, key: str, value: float) -> None: + """Add a value to a metric.""" + self._values[key].append(value) + self._total_counts[key] += 1 + self._total_sums[key] += value + + def get_mean(self, key: str) -> Optional[float]: + """Get rolling mean of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.mean(self._values[key]) + + def get_std(self, key: str) -> Optional[float]: + """Get rolling std of a metric.""" + if key not in self._values or len(self._values[key]) < 2: + return None + return np.std(self._values[key]) + + def get_min(self, key: str) -> Optional[float]: + """Get rolling min of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.min(self._values[key]) + + def get_max(self, key: str) -> Optional[float]: + """Get rolling max of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.max(self._values[key]) + + def get_total_mean(self, key: str) -> Optional[float]: + """Get total mean (not windowed) of a metric.""" + if self._total_counts[key] == 0: + return None + return self._total_sums[key] / self._total_counts[key] + + def get_latest(self, key: str) -> Optional[float]: + """Get latest value of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return self._values[key][-1] + + def get_count(self, key: str) -> int: + """Get total count of a metric.""" + return self._total_counts[key] + + def get_stats(self, key: str) -> Dict[str, Optional[float]]: + """Get all statistics for a metric.""" + return { + 'mean': self.get_mean(key), + 'std': self.get_std(key), + 'min': self.get_min(key), + 'max': self.get_max(key), + 'latest': self.get_latest(key), + 'total_mean': self.get_total_mean(key), + 'count': self.get_count(key) + } + + def keys(self) -> List[str]: + """Get all tracked metric keys.""" + return list(self._values.keys()) + + def clear(self) -> None: + """Clear all metrics.""" + self._values.clear() + self._total_counts.clear() + self._total_sums.clear() + + +class BaseLogger(ABC): + """ + Base class for all loggers. + + Provides a unified interface for logging training metrics. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + name: str = "rl_training", + **kwargs + ): + """ + Initialize logger. + + Args: + log_dir: Directory to save logs + name: Name of the experiment/run + **kwargs: Additional logger-specific arguments + """ + self.log_dir = log_dir + self.name = name + self._step = 0 + self._start_time = time.time() + self.metrics = MetricTracker() + + if log_dir: + os.makedirs(log_dir, exist_ok=True) + + @abstractmethod + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """ + Log a scalar value. + + Args: + key: Metric name + value: Metric value + step: Step number (uses internal counter if None) + """ + raise NotImplementedError + + @abstractmethod + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """ + Log multiple scalar values. + + Args: + data: Dictionary of metric name -> value + step: Step number (uses internal counter if None) + """ + raise NotImplementedError + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log a histogram of values (optional, not all loggers support this).""" + pass + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log an image (optional, not all loggers support this).""" + pass + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log a video (optional, not all loggers support this).""" + pass + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text (optional, not all loggers support this).""" + pass + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + pass + + def set_step(self, step: int) -> None: + """Set the current step.""" + self._step = step + + def get_step(self) -> int: + """Get the current step.""" + return self._step + + def increment_step(self, n: int = 1) -> int: + """Increment step and return new value.""" + self._step += n + return self._step + + def get_elapsed_time(self) -> float: + """Get elapsed time since logger creation.""" + return time.time() - self._start_time + + def close(self) -> None: + """Close the logger and release resources.""" + pass + + def flush(self) -> None: + """Flush any buffered data.""" + pass + + +class ConsoleLogger(BaseLogger): + """ + Simple console logger with optional progress bar. + + Outputs training metrics to console/terminal. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + name: str = "rl_training", + log_interval: int = 100, + verbose: int = 1, + **kwargs + ): + """ + Initialize console logger. + + Args: + log_dir: Directory to save logs (also saves to file if provided) + name: Name of the experiment + log_interval: Steps between log outputs + verbose: Verbosity level (0=silent, 1=normal, 2=detailed) + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.log_interval = log_interval + self.verbose = verbose + self._file = None + + if log_dir: + log_file = os.path.join(log_dir, f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + self._file = open(log_file, 'w') + + def _format_value(self, value: float) -> str: + """Format a value for display.""" + if abs(value) < 0.001 or abs(value) > 10000: + return f"{value:.3e}" + return f"{value:.4f}" + + def _write(self, message: str) -> None: + """Write message to console and optionally to file.""" + if self.verbose > 0: + print(message) + if self._file: + self._file.write(message + "\n") + self._file.flush() + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar value.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + + if self.verbose >= 2 or (step % self.log_interval == 0): + elapsed = self.get_elapsed_time() + self._write(f"[{step:>8}] {key}: {self._format_value(value)} (elapsed: {elapsed:.1f}s)") + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log multiple scalar values.""" + step = step if step is not None else self._step + + for key, value in data.items(): + self.metrics.add(key, value) + + if step % self.log_interval == 0: + elapsed = self.get_elapsed_time() + metrics_str = " | ".join([f"{k}: {self._format_value(v)}" for k, v in data.items()]) + self._write(f"[{step:>8}] {metrics_str} (elapsed: {elapsed:.1f}s)") + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + self._write("\n" + "=" * 60) + self._write("Hyperparameters:") + self._write("=" * 60) + for key, value in params.items(): + self._write(f" {key}: {value}") + self._write("=" * 60 + "\n") + + # Also save to JSON file + if self.log_dir: + params_file = os.path.join(self.log_dir, "hyperparams.json") + with open(params_file, 'w') as f: + json.dump(params, f, indent=2, default=str) + + def close(self) -> None: + """Close the logger.""" + if self._file: + self._file.close() + self._file = None + + +class TensorBoardLogger(BaseLogger): + """ + TensorBoard logger for rich visualization. + + Supports scalars, histograms, images, and more. + """ + + def __init__( + self, + log_dir: str, + name: str = "rl_training", + flush_secs: int = 30, + **kwargs + ): + """ + Initialize TensorBoard logger. + + Args: + log_dir: Directory to save TensorBoard logs + name: Name of the experiment + flush_secs: Flush interval in seconds + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.flush_secs = flush_secs + self._writer = None + + def _get_writer(self): + """Lazy initialization of SummaryWriter.""" + if self._writer is None: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + "TensorBoard not installed. Install with: pip install tensorboard" + ) + + log_path = os.path.join(self.log_dir, self.name) + self._writer = SummaryWriter(log_dir=log_path, flush_secs=self.flush_secs) + return self._writer + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar value.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + self._get_writer().add_scalar(key, value, step) + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log multiple scalar values.""" + step = step if step is not None else self._step + for key, value in data.items(): + self.metrics.add(key, value) + self._get_writer().add_scalar(key, value, step) + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log a histogram.""" + step = step if step is not None else self._step + self._get_writer().add_histogram(key, values, step) + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log an image (expects HWC or CHW format).""" + step = step if step is not None else self._step + # Convert HWC to CHW if needed + if image.ndim == 3 and image.shape[-1] in [1, 3, 4]: + image = np.transpose(image, (2, 0, 1)) + self._get_writer().add_image(key, image, step) + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log a video (expects THWC or NTCHW format).""" + step = step if step is not None else self._step + # Add batch dimension if needed (THWC -> NTCHW) + if video.ndim == 4: + video = video[np.newaxis, ...] # Add N dimension + video = np.transpose(video, (0, 1, 4, 2, 3)) # NTHWC -> NTCHW + self._get_writer().add_video(key, video, step, fps=fps) + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text.""" + step = step if step is not None else self._step + self._get_writer().add_text(key, text, step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + # Filter out non-serializable values + filtered_params = {} + for k, v in params.items(): + if isinstance(v, (int, float, str, bool, type(None))): + filtered_params[k] = v + else: + filtered_params[k] = str(v) + + self._get_writer().add_hparams(filtered_params, {}) + + def flush(self) -> None: + """Flush buffered data.""" + if self._writer: + self._writer.flush() + + def close(self) -> None: + """Close the logger.""" + if self._writer: + self._writer.close() + self._writer = None + + +class CompositeLogger(BaseLogger): + """ + Composite logger that forwards logs to multiple backends. + + Usage: + logger = CompositeLogger([ + ConsoleLogger(verbose=1), + TensorBoardLogger(log_dir="logs/tb") + ]) + """ + + def __init__( + self, + loggers: List[BaseLogger], + log_dir: Optional[str] = None, + name: str = "rl_training", + **kwargs + ): + """ + Initialize composite logger. + + Args: + loggers: List of logger instances to forward to + log_dir: Optional directory (passed to base class) + name: Name of the experiment + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.loggers = loggers + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar to all backends.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + for logger in self.loggers: + logger.log_scalar(key, value, step) + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log scalars to all backends.""" + step = step if step is not None else self._step + for key, value in data.items(): + self.metrics.add(key, value) + for logger in self.loggers: + logger.log_scalars(data, step) + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log histogram to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_histogram(key, values, step) + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log image to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_image(key, image, step) + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log video to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_video(key, video, step, fps) + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_text(key, text, step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters to all backends.""" + for logger in self.loggers: + logger.log_hyperparams(params) + + def set_step(self, step: int) -> None: + """Set step for all loggers.""" + self._step = step + for logger in self.loggers: + logger.set_step(step) + + def flush(self) -> None: + """Flush all loggers.""" + for logger in self.loggers: + logger.flush() + + def close(self) -> None: + """Close all loggers.""" + for logger in self.loggers: + logger.close() + + +if __name__ == '__main__': + """ + Test code for Logger module. + """ + import tempfile + import shutil + + print("=" * 60) + print("Testing Logger Module") + print("=" * 60) + + # Test 1: MetricTracker + print("\n1. Testing MetricTracker...") + tracker = MetricTracker(window_size=5) + + for i in range(10): + tracker.add("loss", 1.0 - i * 0.1) + tracker.add("reward", i * 10) + + print(f" Loss stats: {tracker.get_stats('loss')}") + print(f" Reward stats: {tracker.get_stats('reward')}") + print(f" Tracked keys: {tracker.keys()}") + + # Test 2: ConsoleLogger + print("\n2. Testing ConsoleLogger...") + with tempfile.TemporaryDirectory() as tmpdir: + console_logger = ConsoleLogger( + log_dir=tmpdir, + name="test_console", + log_interval=2, + verbose=1 + ) + + console_logger.log_hyperparams({ + 'learning_rate': 0.001, + 'batch_size': 32, + 'algorithm': 'PPO' + }) + + for step in range(5): + console_logger.log_scalars({ + 'loss': 1.0 - step * 0.1, + 'reward': step * 10 + }, step=step) + console_logger.increment_step() + + print(f" Elapsed time: {console_logger.get_elapsed_time():.2f}s") + console_logger.close() + + # Test 3: TensorBoardLogger (if tensorboard is installed) + print("\n3. Testing TensorBoardLogger...") + try: + with tempfile.TemporaryDirectory() as tmpdir: + tb_logger = TensorBoardLogger( + log_dir=tmpdir, + name="test_tb" + ) + + for step in range(10): + tb_logger.log_scalars({ + 'loss': 1.0 - step * 0.05, + 'reward': step * 5 + }, step=step) + + # Test histogram + tb_logger.log_histogram("weights", np.random.randn(1000), step=0) + + tb_logger.flush() + tb_logger.close() + print(" TensorBoard logger test passed!") + except ImportError: + print(" TensorBoard not installed, skipping...") + + # Test 4: CompositeLogger + print("\n4. Testing CompositeLogger...") + with tempfile.TemporaryDirectory() as tmpdir: + composite_logger = CompositeLogger( + loggers=[ + ConsoleLogger(log_interval=5, verbose=1), + ], + log_dir=tmpdir, + name="test_composite" + ) + + for step in range(10): + composite_logger.log_scalar("test_metric", step * 0.5, step=step) + + composite_logger.close() + print(" Composite logger test passed!") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/seed_manager.py b/rl/infra/seed_manager.py new file mode 100644 index 00000000..f3654b61 --- /dev/null +++ b/rl/infra/seed_manager.py @@ -0,0 +1,357 @@ +""" +Seed Manager for Reproducibility + +This module provides unified random seed management to ensure reproducibility +across different libraries (NumPy, PyTorch, Python random, etc.). + +Features: +- Unified seed setting for all random number generators +- Support for deterministic operations in PyTorch +- Worker seed management for parallel data loading +- Global seed state tracking +""" + +import os +import random +import numpy as np +import torch +from typing import Optional, Callable +from dataclasses import dataclass + + +# Global seed state +_GLOBAL_SEED: Optional[int] = None + + +@dataclass +class SeedState: + """Container for random state from different libraries.""" + python_state: tuple + numpy_state: dict + torch_state: torch.Tensor + torch_cuda_state: Optional[list] = None + + +class SeedManager: + """ + Unified random seed manager for reproducibility. + + This class manages random seeds for: + - Python's random module + - NumPy's random number generator + - PyTorch's CPU and GPU random number generators + - CUDA deterministic operations + + Usage: + # Simple usage + SeedManager.set_seed(42) + + # With deterministic mode (slower but fully reproducible) + SeedManager.set_seed(42, deterministic=True) + + # Save and restore state + state = SeedManager.get_state() + # ... do something ... + SeedManager.set_state(state) + """ + + @staticmethod + def set_seed( + seed: int, + deterministic: bool = False, + benchmark: bool = True, + warn_only: bool = False + ) -> None: + """ + Set random seed for all libraries. + + Args: + seed: Random seed value + deterministic: If True, enable deterministic algorithms in PyTorch + (may reduce performance but ensures reproducibility) + benchmark: If True, enable cuDNN benchmark mode for faster training + (disable when deterministic=True for full reproducibility) + warn_only: If True, only warn instead of error when deterministic + operations are not available + """ + global _GLOBAL_SEED + _GLOBAL_SEED = seed + + # Python random + random.seed(seed) + + # NumPy + np.random.seed(seed) + + # PyTorch + torch.manual_seed(seed) + + # PyTorch CUDA + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # For multi-GPU + + # Deterministic mode + if deterministic: + # Disable benchmark for determinism + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # Use deterministic algorithms + if hasattr(torch, 'use_deterministic_algorithms'): + torch.use_deterministic_algorithms(True, warn_only=warn_only) + elif hasattr(torch, 'set_deterministic'): + torch.set_deterministic(True) + else: + # Enable benchmark for performance + torch.backends.cudnn.benchmark = benchmark + torch.backends.cudnn.deterministic = False + + # Environment variable for hash seed + os.environ['PYTHONHASHSEED'] = str(seed) + + @staticmethod + def get_seed() -> Optional[int]: + """Get the current global seed.""" + return _GLOBAL_SEED + + @staticmethod + def get_state() -> SeedState: + """ + Get current random state from all libraries. + + Returns: + SeedState containing states from all RNGs + """ + cuda_state = None + if torch.cuda.is_available(): + cuda_state = [torch.cuda.get_rng_state(i) for i in range(torch.cuda.device_count())] + + return SeedState( + python_state=random.getstate(), + numpy_state=np.random.get_state(), + torch_state=torch.get_rng_state(), + torch_cuda_state=cuda_state + ) + + @staticmethod + def set_state(state: SeedState) -> None: + """ + Restore random state for all libraries. + + Args: + state: SeedState to restore + """ + random.setstate(state.python_state) + np.random.set_state(state.numpy_state) + torch.set_rng_state(state.torch_state) + + if state.torch_cuda_state is not None and torch.cuda.is_available(): + for i, cuda_state in enumerate(state.torch_cuda_state): + if i < torch.cuda.device_count(): + torch.cuda.set_rng_state(cuda_state, i) + + @staticmethod + def worker_init_fn(worker_id: int) -> None: + """ + Worker initialization function for PyTorch DataLoader. + + Use this as the `worker_init_fn` argument in DataLoader to ensure + each worker has a different but reproducible random seed. + + Args: + worker_id: Worker ID (0, 1, 2, ...) + + Usage: + DataLoader(..., worker_init_fn=SeedManager.worker_init_fn) + """ + global _GLOBAL_SEED + if _GLOBAL_SEED is not None: + worker_seed = _GLOBAL_SEED + worker_id + else: + worker_seed = torch.initial_seed() % 2**32 + + random.seed(worker_seed) + np.random.seed(worker_seed) + + @staticmethod + def fork_rng(devices: Optional[list] = None, enabled: bool = True) -> 'RNGForkContext': + """ + Fork the RNG state for a context (useful for dropout in eval mode). + + Args: + devices: List of CUDA devices to fork RNG state for + enabled: Whether to actually fork (useful for conditional forking) + + Returns: + Context manager that restores RNG state on exit + + Usage: + with SeedManager.fork_rng(): + # Operations here won't affect global RNG state + pass + """ + return RNGForkContext(devices=devices, enabled=enabled) + + +class RNGForkContext: + """Context manager for forking RNG state.""" + + def __init__(self, devices: Optional[list] = None, enabled: bool = True): + self.devices = devices + self.enabled = enabled + self._state: Optional[SeedState] = None + + def __enter__(self): + if self.enabled: + self._state = SeedManager.get_state() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enabled and self._state is not None: + SeedManager.set_state(self._state) + return False + + +# Convenience functions +def set_global_seed( + seed: int, + deterministic: bool = False, + benchmark: bool = True +) -> None: + """ + Set global random seed for reproducibility. + + Convenience function that calls SeedManager.set_seed(). + + Args: + seed: Random seed value + deterministic: Enable deterministic mode (slower but fully reproducible) + benchmark: Enable cuDNN benchmark (disable when deterministic=True) + """ + SeedManager.set_seed(seed, deterministic=deterministic, benchmark=benchmark) + + +def get_global_seed() -> Optional[int]: + """Get the current global seed.""" + return SeedManager.get_seed() + + +if __name__ == '__main__': + """ + Test code for SeedManager. + """ + print("=" * 60) + print("Testing SeedManager") + print("=" * 60) + + # Test 1: Basic seed setting + print("\n1. Testing basic seed setting...") + SeedManager.set_seed(42) + + # Generate some random numbers + py_random1 = [random.random() for _ in range(5)] + np_random1 = np.random.rand(5).tolist() + torch_random1 = torch.rand(5).tolist() + + # Reset and generate again + SeedManager.set_seed(42) + py_random2 = [random.random() for _ in range(5)] + np_random2 = np.random.rand(5).tolist() + torch_random2 = torch.rand(5).tolist() + + print(f" Python random match: {py_random1 == py_random2}") + print(f" NumPy random match: {np_random1 == np_random2}") + print(f" PyTorch random match: {torch_random1 == torch_random2}") + + assert py_random1 == py_random2, "Python random not reproducible" + assert np_random1 == np_random2, "NumPy random not reproducible" + assert torch_random1 == torch_random2, "PyTorch random not reproducible" + + # Test 2: Get/Set state + print("\n2. Testing state save/restore...") + SeedManager.set_seed(123) + + # Generate some numbers + _ = random.random() + _ = np.random.rand() + _ = torch.rand(1) + + # Save state + state = SeedManager.get_state() + + # Generate more numbers + val1 = random.random() + val2 = np.random.rand() + val3 = torch.rand(1).item() + + # Restore state + SeedManager.set_state(state) + + # Generate again - should match + val1_restored = random.random() + val2_restored = np.random.rand() + val3_restored = torch.rand(1).item() + + print(f" Python state restored: {val1 == val1_restored}") + print(f" NumPy state restored: {val2 == val2_restored}") + print(f" PyTorch state restored: {val3 == val3_restored}") + + assert val1 == val1_restored + assert val2 == val2_restored + assert val3 == val3_restored + + # Test 3: Fork RNG context + print("\n3. Testing RNG fork context...") + SeedManager.set_seed(456) + + before = random.random() + + # Save current position + SeedManager.set_seed(456) + _ = random.random() # Advance to same position + + with SeedManager.fork_rng(): + # This should not affect the outer state + for _ in range(10): + random.random() + + after = random.random() + + # Reset and check + SeedManager.set_seed(456) + _ = random.random() + expected_after = random.random() + + print(f" Fork context preserved state: {after == expected_after}") + + # Test 4: Global seed tracking + print("\n4. Testing global seed tracking...") + SeedManager.set_seed(789) + print(f" Global seed: {SeedManager.get_seed()}") + assert SeedManager.get_seed() == 789 + + # Test convenience functions + set_global_seed(999) + print(f" After set_global_seed: {get_global_seed()}") + assert get_global_seed() == 999 + + # Test 5: Deterministic mode + print("\n5. Testing deterministic mode...") + SeedManager.set_seed(42, deterministic=True) + print(" Deterministic mode enabled (slower but fully reproducible)") + + # Test 6: Worker init function + print("\n6. Testing worker init function...") + SeedManager.set_seed(42) + + # Simulate worker initialization + for worker_id in range(3): + SeedManager.worker_init_fn(worker_id) + worker_val = random.random() + print(f" Worker {worker_id} random value: {worker_val:.6f}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/rewards/__init__.py b/rl/rewards/__init__.py new file mode 100644 index 00000000..7aca62ae --- /dev/null +++ b/rl/rewards/__init__.py @@ -0,0 +1,86 @@ +""" +Reward Functions Module + +This module provides modular reward functions for RL training. + +Design Philosophy: +- Modular: Reward functions are independent modules, easy to replace and extend +- Composable: Support combining multiple reward functions +- Language-conditioned: Support VLA's language-conditioned rewards + +Available reward functions: +- SparseReward: Only give reward when task is completed +- DenseReward: Distance-based reward +- LearnedReward: Learned reward model +- LanguageReward: Language-conditioned reward (for VLA) +- CompositeReward: Combine multiple reward functions + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating reward functions. +""" + +from typing import Type, Dict, Any + +from .base_reward import BaseReward + +# Registry for reward classes +_REWARD_REGISTRY: Dict[str, Type] = {} + + +def register_reward(name: str, reward_class: Type) -> None: + """ + Register a reward function class. + + Args: + name: Reward function name (e.g., 'sparse', 'dense') + reward_class: Reward function class to register + """ + _REWARD_REGISTRY[name.lower()] = reward_class + + +def get_reward_class(name_or_type: str) -> Type: + """ + Get reward function class by name or type string. + + Args: + name_or_type: Reward name (e.g., 'sparse') or full type path + (e.g., 'rl.rewards.sparse_reward.SparseReward') + + Returns: + Reward function class + + Raises: + ValueError: If reward function not found + """ + # First check registry + if name_or_type.lower() in _REWARD_REGISTRY: + return _REWARD_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import reward function from '{name_or_type}': {e}") + + raise ValueError(f"Unknown reward function: '{name_or_type}'. Available: {list(_REWARD_REGISTRY.keys())}") + + +def list_rewards() -> list: + """List all registered reward functions.""" + return list(_REWARD_REGISTRY.keys()) + + +__all__ = [ + 'BaseReward', + 'register_reward', + 'get_reward_class', + 'list_rewards', +] + diff --git a/rl/rewards/base_reward.py b/rl/rewards/base_reward.py new file mode 100644 index 00000000..d204dd48 --- /dev/null +++ b/rl/rewards/base_reward.py @@ -0,0 +1,346 @@ +""" +Base Reward Function Class + +This module defines the base class for all reward functions in the RL framework. + +Design Philosophy: +- Modular: Reward functions are independent modules, easy to replace and extend +- Composable: Support combining multiple reward functions +- Language-conditioned: Support VLA's language-conditioned rewards +""" + +import numpy as np +from typing import Dict, Any, Optional +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction + + +class BaseReward(ABC): + """ + Base class for reward functions. + + This class defines the interface for all reward functions in the RL framework. + Reward functions compute custom rewards based on state, action, and other information. + + Note: The reward function is used in the Trainer during training time. + The replay buffer stores raw environment rewards, and the reward function + is applied during the update step for computing the training reward. + """ + + def __init__(self, **kwargs): + """ + Initialize the reward function. + + Args: + **kwargs: Reward function specific parameters + """ + self._kwargs = kwargs + + @abstractmethod + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Compute the reward. + + Args: + state: Current state (MetaObs) + action: Action (MetaAction) + next_state: Next state (MetaObs) + env_reward: Environment's raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + Reset reward function state (if needed). + + Some reward functions may have internal state (e.g., running statistics, + episode counters) that need to be reset at the beginning of an episode. + + Args: + **kwargs: Reset parameters + """ + pass + + def __call__( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Callable interface for computing reward. + + This allows the reward function to be used as a callable: + reward = reward_fn(state, action, next_state, env_reward, info) + + Args: + state: Current state (MetaObs) + action: Action (MetaAction) + next_state: Next state (MetaObs) + env_reward: Environment's raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + return self.compute(state, action, next_state, env_reward, info) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._kwargs})" + + +class IdentityReward(BaseReward): + """ + Identity reward function - returns the environment reward unchanged. + + This is the default reward function when no custom reward is specified. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Return the environment reward unchanged.""" + return env_reward + + +class ScaledReward(BaseReward): + """ + Scaled reward function - scales the environment reward by a factor. + """ + + def __init__(self, scale: float = 1.0, offset: float = 0.0, **kwargs): + """ + Initialize scaled reward. + + Args: + scale: Scaling factor for the reward + offset: Offset to add to the scaled reward + **kwargs: Additional parameters + """ + super().__init__(**kwargs) + self.scale = scale + self.offset = offset + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Return scaled and offset reward.""" + return env_reward * self.scale + self.offset + + +class CompositeReward(BaseReward): + """ + Composite reward function - combines multiple reward functions. + + This allows combining different reward signals with weights. + """ + + def __init__( + self, + reward_fns: list, + weights: Optional[list] = None, + **kwargs + ): + """ + Initialize composite reward. + + Args: + reward_fns: List of reward function instances + weights: Optional list of weights for each reward function. + If None, uses uniform weights. + **kwargs: Additional parameters + """ + super().__init__(**kwargs) + self.reward_fns = reward_fns + if weights is None: + weights = [1.0] * len(reward_fns) + assert len(weights) == len(reward_fns), "Number of weights must match number of reward functions" + self.weights = weights + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Compute weighted sum of all reward functions.""" + total_reward = 0.0 + for reward_fn, weight in zip(self.reward_fns, self.weights): + total_reward += weight * reward_fn.compute( + state, action, next_state, env_reward, info + ) + return total_reward + + def reset(self, **kwargs) -> None: + """Reset all component reward functions.""" + for reward_fn in self.reward_fns: + reward_fn.reset(**kwargs) + + +if __name__ == '__main__': + """ + Test code for BaseReward class and its implementations. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction + + # Test IdentityReward + print("=" * 60) + print("Testing BaseReward and implementations") + print("=" * 60) + + # Create sample states and actions + state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + action = MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + next_state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + + # Test 1: IdentityReward + print("\n1. Testing IdentityReward...") + identity_reward = IdentityReward() + env_reward = 1.5 + reward = identity_reward.compute(state, action, next_state, env_reward, {}) + print(f" IdentityReward: {identity_reward}") + print(f" Env reward: {env_reward}, Computed reward: {reward}") + assert reward == env_reward, "IdentityReward should return env_reward unchanged" + + # Test callable interface + reward_callable = identity_reward(state, action, next_state, env_reward, {}) + print(f" Callable interface result: {reward_callable}") + assert reward_callable == env_reward, "Callable interface should work the same" + + # Test 2: ScaledReward + print("\n2. Testing ScaledReward...") + scaled_reward = ScaledReward(scale=2.0, offset=0.5) + env_reward = 1.0 + reward = scaled_reward.compute(state, action, next_state, env_reward, {}) + expected = 1.0 * 2.0 + 0.5 # 2.5 + print(f" ScaledReward: {scaled_reward}") + print(f" Env reward: {env_reward}, Scale: 2.0, Offset: 0.5") + print(f" Computed reward: {reward}, Expected: {expected}") + assert abs(reward - expected) < 1e-6, "ScaledReward should scale and offset correctly" + + # Test 3: CompositeReward + print("\n3. Testing CompositeReward...") + reward_fn1 = IdentityReward() + reward_fn2 = ScaledReward(scale=0.5, offset=0.0) + composite_reward = CompositeReward( + reward_fns=[reward_fn1, reward_fn2], + weights=[0.6, 0.4] + ) + env_reward = 2.0 + reward = composite_reward.compute(state, action, next_state, env_reward, {}) + # Expected: 0.6 * 2.0 + 0.4 * (2.0 * 0.5) = 1.2 + 0.4 = 1.6 + expected = 0.6 * 2.0 + 0.4 * (2.0 * 0.5) + print(f" CompositeReward: {composite_reward}") + print(f" Env reward: {env_reward}") + print(f" Component 1 (IdentityReward, weight=0.6): {reward_fn1.compute(state, action, next_state, env_reward, {})}") + print(f" Component 2 (ScaledReward*0.5, weight=0.4): {reward_fn2.compute(state, action, next_state, env_reward, {})}") + print(f" Computed reward: {reward}, Expected: {expected}") + assert abs(reward - expected) < 1e-6, "CompositeReward should compute weighted sum correctly" + + # Test 4: Custom reward function (abstract class implementation) + print("\n4. Testing custom reward function...") + + class SuccessBonus(BaseReward): + """Give bonus reward when task is successful.""" + + def __init__(self, bonus: float = 10.0, **kwargs): + super().__init__(**kwargs) + self.bonus = bonus + + def compute(self, state, action, next_state, env_reward, info): + if info and info.get('success', False): + return env_reward + self.bonus + return env_reward + + success_bonus = SuccessBonus(bonus=5.0) + + # Without success + info_no_success = {'success': False} + reward_no_success = success_bonus.compute(state, action, next_state, 1.0, info_no_success) + print(f" SuccessBonus (no success): reward = {reward_no_success}") + assert reward_no_success == 1.0 + + # With success + info_success = {'success': True} + reward_success = success_bonus.compute(state, action, next_state, 1.0, info_success) + print(f" SuccessBonus (success): reward = {reward_success}") + assert reward_success == 6.0 # 1.0 + 5.0 + + # Test 5: Reset functionality + print("\n5. Testing reset functionality...") + + class StatefulReward(BaseReward): + """Reward function with internal state.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.step_count = 0 + + def compute(self, state, action, next_state, env_reward, info): + self.step_count += 1 + return env_reward + self.step_count * 0.01 + + def reset(self, **kwargs): + self.step_count = 0 + + stateful_reward = StatefulReward() + + # Compute a few rewards + for i in range(5): + r = stateful_reward.compute(state, action, next_state, 1.0, {}) + print(f" After 5 steps, step_count = {stateful_reward.step_count}") + + # Reset + stateful_reward.reset() + print(f" After reset, step_count = {stateful_reward.step_count}") + assert stateful_reward.step_count == 0, "Reset should clear step_count" + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/trainers/__init__.py b/rl/trainers/__init__.py new file mode 100644 index 00000000..00c94d85 --- /dev/null +++ b/rl/trainers/__init__.py @@ -0,0 +1,90 @@ +""" +Trainers Module + +This module provides trainers for coordinating RL training. + +Design Philosophy: +- Coordinate environment, policy, and algorithm for training loop +- Support single algorithm and multiple algorithms training +- Support custom reward functions (applied during training, not data collection) +- Support evaluation during training + +Available trainers: +- SimpleTrainer: Single machine trainer +- ParallelTrainer: Parallel environment trainer +- DistributedTrainer: Distributed training + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating trainers. +""" + +from typing import Type, Dict, Any + +from .base_trainer import BaseTrainer +from .offpolicy_trainer import OffPolicyTrainer, OffPolicyTrainerConfig + +# Registry for trainer classes +_TRAINER_REGISTRY: Dict[str, Type] = { + 'offpolicy': OffPolicyTrainer, +} + + +def register_trainer(name: str, trainer_class: Type) -> None: + """ + Register a trainer class. + + Args: + name: Trainer name (e.g., 'simple', 'parallel') + trainer_class: Trainer class to register + """ + _TRAINER_REGISTRY[name.lower()] = trainer_class + + +def get_trainer_class(name_or_type: str) -> Type: + """ + Get trainer class by name or type string. + + Args: + name_or_type: Trainer name (e.g., 'simple') or full type path + (e.g., 'rl.trainers.simple_trainer.SimpleTrainer') + + Returns: + Trainer class + + Raises: + ValueError: If trainer not found + """ + # First check registry + if name_or_type.lower() in _TRAINER_REGISTRY: + return _TRAINER_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import trainer from '{name_or_type}': {e}") + + raise ValueError(f"Unknown trainer: '{name_or_type}'. Available: {list(_TRAINER_REGISTRY.keys())}") + + +def list_trainers() -> list: + """List all registered trainers.""" + return list(_TRAINER_REGISTRY.keys()) + + +__all__ = [ + 'BaseTrainer', + 'OffPolicyTrainer', + 'OffPolicyTrainerConfig', + 'register_trainer', + 'get_trainer_class', + 'list_trainers', +] + diff --git a/rl/trainers/base_trainer.py b/rl/trainers/base_trainer.py new file mode 100644 index 00000000..17076a98 --- /dev/null +++ b/rl/trainers/base_trainer.py @@ -0,0 +1,858 @@ +""" +Base Trainer Class + +This module defines the base class for all RL trainers in the framework. + +Design Philosophy: +- Coordinate algorithm and collectors for executing training loop +- Support a single algorithm collecting data from multiple different environment types +- Each environment type has its own Collector and Replay Buffer +- Collectors manage environments internally (Trainer doesn't need envs directly) +- Support custom reward functions (applied during training, not data collection) +- Support evaluation during training + +Architecture: + Trainer + ├── algorithm: BaseAlgorithm (single algorithm) + │ └── replay: Dict[env_type, BaseReplay] (multiple replays, one per env type) + ├── collectors: Dict[env_type, BaseCollector] (multiple collectors, one per env type) + │ ├── collector['env_type_A']: manages VectorEnv A and stores to replay['env_type_A'] + │ ├── collector['env_type_B']: manages VectorEnv B and stores to replay['env_type_B'] + │ └── ... + └── reward_fn: Optional[BaseReward] +""" + +import numpy as np +import torch +from typing import Dict, Any, Optional, Union, List, Callable, TYPE_CHECKING +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction + +# Vector environment protocol +from rl.envs import VectorEnvProtocol, EnvsType + +if TYPE_CHECKING: + from rl.rewards.base_reward import BaseReward +else: + # Import at runtime to avoid potential circular imports + from rl.rewards.base_reward import BaseReward + + +class BaseTrainer(ABC): + """ + Base class for RL trainers. + + This class defines the interface for all trainers in the RL framework. + Trainers coordinate the training loop by managing: + - Algorithm (with policy and multiple replay buffers for different env types) + - Collectors (for data collection from different env types) + - Reward functions (applied during training) + + Key Design: + - Trainer does NOT directly manage environments + - Each Collector manages its own VectorEnv internally + - Algorithm has Dict[env_type, Replay] for storing data from different env types + - Number of collectors must match the number of env types in algorithm's replay dict + + Attributes: + algorithm: Algorithm for training + collectors: Dict of collectors, keyed by env_type + reward_fn: Optional custom reward function + """ + + def __init__( + self, + algorithm: 'BaseAlgorithm', + collectors: Union['BaseCollector', Dict[str, 'BaseCollector']], + reward_fn: Optional[Union['BaseReward', Callable]] = None, + **kwargs + ): + """ + Initialize the trainer. + + Args: + algorithm: BaseAlgorithm instance (required) + - Should have replay as Dict[env_type, BaseReplay] for multi-env scenarios + - Each replay buffer stores data from its corresponding env type + collectors: Data collectors, supports: + - BaseCollector: Single collector (for single env type, uses 'default') + - Dict[str, BaseCollector]: Multi-collector dict for different env types + e.g., {'sim': sim_collector, 'real': real_collector} + - Each collector manages its own VectorEnv + - Keys should match algorithm's replay dict keys + reward_fn: Optional reward function (if None, use raw environment reward) + - Can be BaseReward instance or Callable function + - Applied during training for algorithm updates + - Note: Replay buffer stores raw rewards, reward function only applied during training + **kwargs: Trainer-specific parameters + """ + self.algorithm = algorithm + self.reward_fn = reward_fn + self._kwargs = kwargs + + # Normalize collector storage: always use dict internally + if isinstance(collectors, dict): + self._collectors_dict: Dict[str, 'BaseCollector'] = collectors + else: + self._collectors_dict = {'default': collectors} + + # Store reference for easier access + self.collectors = self._collectors_dict + + def get_collector(self, env_type: Optional[str] = None) -> 'BaseCollector': + """ + Get the collector by env_type. + + Args: + env_type: Environment type identifier. If None, returns 'default' collector + or the first collector if 'default' doesn't exist. + + Returns: + The collector for the specified env type + + Raises: + KeyError: If specified env_type is not found + """ + if env_type is not None: + return self._collectors_dict[env_type] + + # Return 'default' if exists, otherwise return first collector + if 'default' in self._collectors_dict: + return self._collectors_dict['default'] + return list(self._collectors_dict.values())[0] + + def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + """ + Get the vectorized environment by type (from collector). + + Args: + env_type: Environment type identifier. + + Returns: + The vectorized environment managed by the collector + """ + collector = self.get_collector(env_type) + return collector.get_env() + + def get_total_env_num(self) -> int: + """ + Get total number of parallel environments across all types. + + Returns: + Total count of parallel environments across all collectors + """ + return sum(col.get_total_env_num() for col in self._collectors_dict.values()) + + def get_env_types(self) -> List[str]: + """ + Get all environment type identifiers. + + Returns: + List of environment type strings (collector keys) + """ + return list(self._collectors_dict.keys()) + + @property + def env_num(self) -> int: + """ + Number of environments in the default (or first) collector. + + Returns: + Number of parallel environments + """ + collector = self.get_collector() + return collector.env_num + + @abstractmethod + def train(self, **kwargs) -> None: + """ + Execute training loop. + + Args: + **kwargs: Training parameters, can include: + - total_steps: Total training steps (optional) + - total_episodes: Total episodes (optional) + - max_time: Maximum training time (optional) + - log_interval: Logging interval (optional) + - save_interval: Model save interval (optional) + - eval_interval: Evaluation interval (optional) + """ + raise NotImplementedError + + def compute_reward( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Compute reward (support custom reward function). + + Used during training for algorithm update reward computation. + Replay buffer stores raw rewards, reward function only applied during training. + + Args: + state: Current state + action: Action + next_state: Next state + env_reward: Environment raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + if self.reward_fn is not None: + if isinstance(self.reward_fn, BaseReward): + return self.reward_fn.compute(state, action, next_state, env_reward, info) + else: + # Assume it's a callable + return self.reward_fn(state, action, next_state, env_reward, info) + return env_reward + + def collect_rollout( + self, + n_steps: int, + env_type: Optional[str] = None + ) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]: + """ + Collect rollout data (using collectors). + + Args: + n_steps: Number of steps to collect + env_type: Optional environment type identifier + - If specified, only collect from that env type's collector + - If None, collect from ALL collectors + + Returns: + - If env_type specified: Single stats dict from that collector + - If env_type is None: Dict[env_type, stats] from all collectors + """ + if env_type is not None: + # Collect from specific env type + collector = self.get_collector(env_type) + return collector.collect(n_steps, env_type=env_type) + else: + # Collect from all collectors + all_stats = {} + for env_type_key, collector in self._collectors_dict.items(): + all_stats[env_type_key] = collector.collect(n_steps, env_type=env_type_key) + return all_stats + + def collect_rollout_all(self, n_steps: int) -> Dict[str, Dict[str, Any]]: + """ + Collect rollout data from ALL collectors. + + Args: + n_steps: Number of steps to collect from each collector + + Returns: + Dict[env_type, stats] from all collectors + """ + return self.collect_rollout(n_steps, env_type=None) + + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Evaluate policy performance. + + Default implementation uses collector's environment and algorithm's select_action. + Subclasses can override for custom evaluation logic. + + Args: + n_episodes: Number of episodes to evaluate + render: Whether to render environment (optional) + env_type: Optional, specify evaluation environment type + **kwargs: Other evaluation parameters (vec_env, max_timesteps, ctrl_space, etc.) + + Returns: + Dictionary containing evaluation metrics + """ + from benchmark.utils import organize_obs + + # Get evaluation environment + vec_env = kwargs.get('vec_env', None) + if vec_env is None: + # Try to get from collector + collector = self.get_collector(env_type) + if collector is not None: + vec_env = collector.get_env() + else: + return {'error': 'No evaluation environment available'} + + max_timesteps = kwargs.get('max_timesteps', 1000) + ctrl_space = kwargs.get('ctrl_space', 'joint') + + num_envs = len(vec_env) + all_returns = [] + all_lengths = [] + episodes_completed = 0 + + # Set algorithm to eval mode + self.algorithm.eval_mode() + + while episodes_completed < n_episodes: + obs = vec_env.reset() + obs = organize_obs(obs, ctrl_space) + + episode_returns = np.zeros(num_envs, dtype=np.float32) + episode_lengths = np.zeros(num_envs, dtype=np.int32) + done_mask = np.zeros(num_envs, dtype=bool) + + for t in range(max_timesteps): + with torch.no_grad(): + action = self.algorithm.select_action(obs, noise_scale=0.0, env=vec_env) + + # Unpack action + if hasattr(action, 'action'): + action_array = action.action + else: + action_array = action + + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Convert to list of dicts for vec_env.step + step_actions = [{'action': action_array[i]} for i in range(len(action_array))] + + # Step environment + next_obs, rewards, dones, infos = vec_env.step(step_actions) + next_obs = organize_obs(next_obs, ctrl_space) + + # Accumulate rewards for active episodes + episode_returns += rewards * (~done_mask) + episode_lengths += (~done_mask).astype(np.int32) + + # Check for newly done episodes + newly_done = dones & (~done_mask) + if newly_done.any(): + for idx in np.where(newly_done)[0]: + all_returns.append(episode_returns[idx]) + all_lengths.append(episode_lengths[idx]) + episodes_completed += 1 + + done_mask = done_mask | dones + + # Reset done environments + if dones.any(): + done_indices = np.where(dones)[0] + reset_obs = vec_env.reset(id=done_indices) + if reset_obs is not None: + reset_obs_organized = organize_obs(reset_obs, ctrl_space) + if isinstance(next_obs, MetaObs) and next_obs.state is not None: + if hasattr(reset_obs_organized, 'state') and reset_obs_organized.state is not None: + next_obs.state[done_indices] = reset_obs_organized.state + episode_returns[done_indices] = 0 + episode_lengths[done_indices] = 0 + done_mask[done_indices] = False + + obs = next_obs + + if episodes_completed >= n_episodes: + break + + # Handle truncated episodes + for idx in range(num_envs): + if not done_mask[idx] and episode_lengths[idx] > 0: + all_returns.append(episode_returns[idx]) + all_lengths.append(episode_lengths[idx]) + episodes_completed += 1 + if episodes_completed >= n_episodes: + break + + # Set back to train mode + self.algorithm.train_mode() + + # Compute statistics + all_returns = np.array(all_returns[:n_episodes]) + all_lengths = np.array(all_lengths[:n_episodes]) + + return { + 'returns': all_returns.tolist(), + 'mean_return': float(np.mean(all_returns)), + 'std_return': float(np.std(all_returns)), + 'min_return': float(np.min(all_returns)), + 'max_return': float(np.max(all_returns)), + 'episode_lengths': all_lengths.tolist(), + 'mean_length': float(np.mean(all_lengths)), + 'num_episodes': len(all_returns), + } + + def save(self, path: str) -> None: + """ + Save model and training state. + + Default implementation saves the algorithm's model. + Subclasses can override to save additional state. + + Args: + path: Path to save the model + """ + self.algorithm.save(path) + + def load(self, path: str) -> None: + """ + Load model and training state. + + Default implementation loads the algorithm's model. + Subclasses can override to load additional state. + + Args: + path: Path to load the model from + """ + self.algorithm.load(path) + + def save_checkpoint(self, path: str, step: int) -> None: + """ + Save training checkpoint with step information. + + Args: + path: Directory or full path to save checkpoint + step: Current training step number + """ + import os + if os.path.isdir(path): + ckpt_path = os.path.join(path, f'checkpoint_{step}.pt') + else: + ckpt_path = path + self.save(ckpt_path) + + def load_checkpoint(self, path: str) -> int: + """ + Load checkpoint and return the step number. + + Extracts step number from checkpoint filename (e.g., checkpoint_10000.pt -> 10000). + + Args: + path: Path to checkpoint file + + Returns: + Step number extracted from checkpoint filename, or 0 if not found + """ + import os + self.load(path) + # Try to extract step from checkpoint name + try: + step = int(os.path.basename(path).split('_')[-1].split('.')[0]) + except (ValueError, IndexError): + step = 0 + return step + + def get_algorithm(self) -> 'BaseAlgorithm': + """Get the algorithm.""" + return self.algorithm + + def get_collectors(self) -> Dict[str, 'BaseCollector']: + """Get all collectors as a dict.""" + return self._collectors_dict + + def get_replay(self, env_type: Optional[str] = None) -> 'BaseReplay': + """ + Get the replay buffer for a specific env type. + + Args: + env_type: Environment type. If None, returns default or first replay. + + Returns: + The replay buffer + """ + replay = self.algorithm.replay + if isinstance(replay, dict): + if env_type is not None: + return replay[env_type] + elif 'default' in replay: + return replay['default'] + else: + return list(replay.values())[0] + return replay + + def set_reward_fn(self, reward_fn: Union['BaseReward', Callable]) -> None: + """ + Set or update the reward function. + + Args: + reward_fn: New reward function to use + """ + self.reward_fn = reward_fn + + def reset_collectors(self) -> None: + """Reset all collectors.""" + for collector in self._collectors_dict.values(): + collector.reset() + + def __repr__(self) -> str: + algo_info = self.algorithm.__class__.__name__ + env_types = list(self._collectors_dict.keys()) + collector_info = f"Dict[{len(self._collectors_dict)}]: {env_types}" + reward_info = "None" if self.reward_fn is None else self.reward_fn.__class__.__name__ + total_envs = self.get_total_env_num() + return (f"{self.__class__.__name__}(algorithm={algo_info}, " + f"collectors={collector_info}, total_envs={total_envs}, reward_fn={reward_info})") + + +if __name__ == '__main__': + """ + Test code for BaseTrainer class. + + Since BaseTrainer is abstract, we create a simple concrete implementation for testing. + Tests the new architecture where Trainer uses Collectors (which manage envs internally). + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction, MetaEnv, MetaPolicy + from rl.buffer.base_replay import BaseReplay + from rl.algorithms.base import BaseAlgorithm + from rl.rewards.base_reward import BaseReward, IdentityReward, ScaledReward + from rl.collectors.base_collector import BaseCollector + from rl.envs import VectorEnvProtocol + from dataclasses import asdict + + # Use MetaReplay for efficient env-first storage with vectorized sampling + from rl.buffer.meta_replay import MetaReplay + + # Simple dummy environment for testing + class DummyEnv: + def __init__(self, state_dim=10, action_dim=7, max_steps=100): + self.state_dim = state_dim + self.action_dim = action_dim + self._step_count = 0 + self._max_steps = max_steps + + def reset(self): + self._step_count = 0 + return {'state': np.random.randn(self.state_dim).astype(np.float32)} + + def step(self, action): + self._step_count += 1 + obs = {'state': np.random.randn(self.state_dim).astype(np.float32)} + reward = np.random.randn() + done = self._step_count >= self._max_steps + info = {'step': self._step_count} + return obs, reward, done, info + + def close(self): + pass + + # Simple MetaEnv wrapper for testing + class DummyMetaEnv(MetaEnv): + def __init__(self, state_dim=10, action_dim=7, max_steps=100): + self.env = DummyEnv(state_dim=state_dim, action_dim=action_dim, max_steps=max_steps) + self.prev_obs = None + + def obs2meta(self, raw_obs): + return MetaObs(state=raw_obs['state'], raw_lang="test") + + def meta2act(self, action, *args): + if hasattr(action, 'action'): + return action.action + return action + + def reset(self): + init_obs = self.env.reset() + self.prev_obs = self.obs2meta(init_obs) + return self.prev_obs + + def step(self, action): + act = self.meta2act(action) + obs, reward, done, info = self.env.step(act) + self.prev_obs = self.obs2meta(obs) + return self.prev_obs, reward, done, info + + # Use SequentialVectorEnv from benchmark/utils.py for testing + from benchmark.utils import SequentialVectorEnv + + # Simple policy for testing + class DummyPolicy: + def __init__(self, action_dim=7): + self.action_dim = action_dim + + def select_action(self, obs, n_envs=1): + # Return batched actions (n_envs, action_dim) + return MetaAction(action=np.random.randn(n_envs, self.action_dim).astype(np.float32)) + + def train(self): + pass + + def eval(self): + pass + + # Simple MetaPolicy wrapper for testing + class DummyMetaPolicy(MetaPolicy): + def __init__(self, action_dim=7): + self.policy = DummyPolicy(action_dim=action_dim) + self.chunk_size = 1 + self.ctrl_space = 'ee' + self.ctrl_type = 'delta' + self.action_queue = [] + self.action_normalizer = None + self.state_normalizer = None + + def select_action(self, mobs, t=0, **kwargs): + # Infer batch size from obs + n_envs = 1 + if hasattr(mobs, 'state') and mobs.state is not None: + n_envs = mobs.state.shape[0] if mobs.state.ndim > 1 else 1 + + # Get MetaAction from policy + mact = self.policy.select_action(mobs, n_envs=n_envs) + + # Simulate MetaPolicy.inference output format: + # It returns a numpy array of dicts (one dict per env) + # We must convert MetaAction to this format + action_array = mact.action + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Create list of dicts, then convert to object array + action_dicts = [{'action': action_array[i]} for i in range(len(action_array))] + return np.array(action_dicts, dtype=object) + + # Simple algorithm for testing (supports Dict[env_type, replay]) + class DummyAlgorithm(BaseAlgorithm): + def __init__(self, meta_policy, replay=None, **kwargs): + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + self._timestep = 0 + self._update_count = 0 + + def update(self, batch=None, **kwargs): + self._update_count += 1 + batch_size = kwargs.get('batch_size', 32) + env_type = kwargs.get('env_type', None) + + if batch is None and self.replay is not None: + if isinstance(self.replay, dict): + if env_type: + batch = self.replay[env_type].sample(batch_size) + else: + # Sample from all replays + batch = {} + for k, v in self.replay.items(): + batch[k] = v.sample(batch_size) + else: + batch = self.replay.sample(batch_size) + return {'loss': np.random.randn(), 'update_count': self._update_count} + + def select_action(self, obs, **kwargs): + return self.meta_policy.select_action(obs, t=self._timestep) + + # Import DummyCollector from base_collector + from rl.collectors.base_collector import DummyCollector + + # Simple trainer implementation for testing (new architecture) + class SimpleTrainer(BaseTrainer): + """Simple trainer for testing - uses collectors (which manage envs internally).""" + + def __init__( + self, + algorithm, + collectors, + reward_fn=None, + **kwargs + ): + super().__init__(algorithm, collectors, reward_fn, **kwargs) + self._total_steps = 0 + self._training_logs = [] + + def train(self, **kwargs): + total_steps = kwargs.get('total_steps', 1000) + log_interval = kwargs.get('log_interval', 100) + update_interval = kwargs.get('update_interval', 50) + batch_size = kwargs.get('batch_size', 32) + + print(f"Starting training for {total_steps} steps...") + + while self._total_steps < total_steps: + # Collect data from all env types + all_stats = self.collect_rollout_all(n_steps=update_interval) + + # Sum up total steps from all collectors + total_collected = sum(stats['total_steps'] for stats in all_stats.values()) + self._total_steps += total_collected + + # Update algorithm (can sample from any/all replays) + result = self.algorithm.update(batch_size=batch_size) + + # Log + if self._total_steps % log_interval == 0: + print(f" Step {self._total_steps}: loss={result.get('loss', 0.0):.4f}") + self._training_logs.append({ + 'step': self._total_steps, + 'loss': result.get('loss', 0.0) + }) + + print(f"Training completed. Total steps: {self._total_steps}") + + def evaluate(self, n_episodes=10, render=False, env_type=None, **kwargs): + print(f"Evaluating for {n_episodes} episodes on env_type={env_type}...") + + # Get env from collector + vec_env = self.get_env(env_type) + alg = self.algorithm + + episode_rewards = [] + episode_lengths = [] + + for ep in range(n_episodes): + obs = vec_env.reset(id=0) + episode_reward = 0.0 + episode_length = 0 + done = False + + while not done: + action = alg.select_action(obs) + obs, reward, done, info = vec_env.step(action, id=0) + episode_reward += reward + episode_length += 1 + + episode_rewards.append(episode_reward) + episode_lengths.append(episode_length) + + results = { + 'mean_reward': np.mean(episode_rewards), + 'std_reward': np.std(episode_rewards), + 'mean_length': np.mean(episode_lengths), + 'episode_rewards': episode_rewards, + 'episode_lengths': episode_lengths + } + print(f"Evaluation: mean_reward={results['mean_reward']:.2f} ± {results['std_reward']:.2f}") + return results + + def save(self, path): + print(f"Saving to {path} (mock)") + + def load(self, path): + print(f"Loading from {path} (mock)") + + # ========================================================================== + # Test the implementation + # ========================================================================== + print("=" * 60) + print("Testing BaseTrainer with New Architecture") + print("(Trainer uses Collectors which manage Envs internally)") + print("=" * 60) + + # Test 1: Single environment type + print("\n" + "-" * 40) + print("Test 1: Single Environment Type") + print("-" * 40) + + # Create env, replay, algorithm, collector + env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] + vec_env = SequentialVectorEnv(env_fns) + print(f"\n1. Created SequentialVectorEnv with {vec_env.env_num} parallel envs") + + meta_policy = DummyMetaPolicy(action_dim=7) + replay = MetaReplay(capacity=10000, env_type='default', n_envs=4, state_dim=10, action_dim=7) + algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) + + # Create collector (it manages the env) + collector = DummyCollector(envs=vec_env, algorithm=algorithm) + + # Create trainer (NO envs parameter - collector manages it!) + trainer = SimpleTrainer( + algorithm=algorithm, + collectors=collector, # Single collector -> becomes {'default': collector} + reward_fn=None + ) + print(f"\n2. Created trainer: {trainer}") + print(f" env_types: {trainer.get_env_types()}") + print(f" total_env_num: {trainer.get_total_env_num()}") + + # Test compute_reward + print("\n3. Testing compute_reward...") + state = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = MetaAction(action=np.random.randn(7).astype(np.float32)) + next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) + env_reward = 1.5 + computed_reward = trainer.compute_reward(state, action, next_state, env_reward, {}) + print(f" Env reward: {env_reward}, Computed reward: {computed_reward}") + + # Test training + print("\n4. Testing training...") + trainer.train(total_steps=200, log_interval=100, update_interval=20, batch_size=16) + print(f" Replay buffer size: {len(algorithm.replay)}") + + # Test 2: Multiple environment types (the main use case!) + print("\n" + "-" * 40) + print("Test 2: Multiple Environment Types") + print("(One algorithm, multiple env types, each with own collector & replay)") + print("-" * 40) + + # Create envs for different types + sim_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] + real_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) for _ in range(2)] + + sim_vec_env = SequentialVectorEnv(sim_env_fns) + real_vec_env = SequentialVectorEnv(real_env_fns) + print(f"\n1. Created sim env with {len(sim_vec_env)} parallel envs") + print(f" Created real env with {len(real_vec_env)} parallel envs") + + # Create replays for each env type + replays = { + 'sim': MetaReplay(capacity=5000, env_type='sim', n_envs=4, state_dim=10, action_dim=7), + 'real': MetaReplay(capacity=2000, env_type='real', n_envs=2, state_dim=10, action_dim=7), + } + + # Create algorithm with multi-replay + meta_policy2 = DummyMetaPolicy(action_dim=7) + algorithm2 = DummyAlgorithm(meta_policy=meta_policy2, replay=replays) + + # Create collectors for each env type + collectors = { + 'sim': DummyCollector(envs=sim_vec_env, algorithm=algorithm2), + 'real': DummyCollector(envs=real_vec_env, algorithm=algorithm2), + } + print(f"\n2. Created collectors for env types: {list(collectors.keys())}") + + # Create trainer with multiple collectors + multi_trainer = SimpleTrainer( + algorithm=algorithm2, + collectors=collectors, + reward_fn=IdentityReward() + ) + print(f"\n3. Created multi-env trainer: {multi_trainer}") + print(f" env_types: {multi_trainer.get_env_types()}") + print(f" total_env_num: {multi_trainer.get_total_env_num()}") + print(f" sim env_num: {len(multi_trainer.get_env('sim'))}") + print(f" real env_num: {len(multi_trainer.get_env('real'))}") + + # Test collect from specific env type + print("\n4. Testing collect from specific env type...") + sim_stats = multi_trainer.collect_rollout(n_steps=50, env_type='sim') + print(f" Sim collected: {sim_stats['total_steps']} steps") + print(f" Sim replay size: {len(replays['sim'])}") + print(f" Real replay size: {len(replays['real'])}") + + # Test collect from all env types + print("\n5. Testing collect from ALL env types...") + all_stats = multi_trainer.collect_rollout_all(n_steps=30) + print(f" Collected from: {list(all_stats.keys())}") + for env_type, stats in all_stats.items(): + print(f" {env_type}: {stats['total_steps']} steps") + print(f" Sim replay size: {len(replays['sim'])}") + print(f" Real replay size: {len(replays['real'])}") + + # Test training with multi-env + print("\n6. Testing training with multi-env...") + multi_trainer.train(total_steps=200, log_interval=100, update_interval=20, batch_size=16) + print(f" Final sim replay size: {len(replays['sim'])}") + print(f" Final real replay size: {len(replays['real'])}") + + # Test evaluation on specific env type + print("\n7. Testing evaluation on 'sim' env...") + eval_results = multi_trainer.evaluate(n_episodes=3, env_type='sim') + print(f" Mean reward: {eval_results['mean_reward']:.2f}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/trainers/offpolicy_trainer.py b/rl/trainers/offpolicy_trainer.py new file mode 100644 index 00000000..cadd7545 --- /dev/null +++ b/rl/trainers/offpolicy_trainer.py @@ -0,0 +1,260 @@ +""" +Off-Policy Trainer + +Trainer for off-policy RL algorithms (TD3, SAC, DDPG, etc.) + +Design: +- Separates data collection (Collector) from policy updates (Trainer) +- Supports evaluation during training +- Handles checkpointing and logging +""" + +import os +import json +import numpy as np +from typing import Dict, Any, Optional, Callable, Type +from dataclasses import dataclass +from loguru import logger +from tqdm import tqdm + +from .base_trainer import BaseTrainer +from rl.collectors import BaseCollector +from rl.algorithms.base import BaseAlgorithm +from benchmark.utils import SequentialVectorEnv + + +@dataclass +class OffPolicyTrainerConfig: + """Configuration for OffPolicyTrainer.""" + total_steps: int = 1000000 + start_steps: int = 25000 # Random exploration steps + update_after: int = 1000 # Start updating after this many steps + update_every: int = 50 # Update every N steps + batch_size: int = 256 + + # Logging and evaluation + log_freq: int = 1000 + eval_freq: int = 10000 + eval_episodes: int = 10 + save_freq: int = 50000 + + # Output + output_dir: str = 'ckpt/rl_training' + + # Environment settings + max_timesteps: int = 1000 + ctrl_space: str = 'joint' + + # Exploration + expl_noise: float = 0.1 + + +class OffPolicyTrainer(BaseTrainer): + """ + Trainer for off-policy RL algorithms. + + Coordinates: + - Collector: Gathers experience from environment + - Algorithm: Updates policy using collected data + - Evaluation: Periodically evaluates policy performance + """ + + def __init__( + self, + algorithm: BaseAlgorithm, + collector: BaseCollector, + config: Optional[OffPolicyTrainerConfig] = None, + eval_env_fn: Optional[Callable] = None, + eval_vec_env_cls: Optional[Type] = None, + **kwargs + ): + """ + Initialize OffPolicyTrainer. + + Args: + algorithm: RL algorithm (TD3, SAC, etc.) + collector: Data collector for environment interaction + config: Trainer configuration + eval_env_fn: Factory function for creating evaluation environments + eval_vec_env_cls: Vector environment class for evaluation (default: SequentialVectorEnv) + **kwargs: Additional arguments + """ + super().__init__(algorithm=algorithm, collectors=collector, **kwargs) + + self.config = config or OffPolicyTrainerConfig() + self.collector = collector + self.eval_env_fn = eval_env_fn + self.eval_vec_env_cls = eval_vec_env_cls or SequentialVectorEnv + + # Training state + self._current_step = 0 + self._episode_count = 0 + self._recent_rewards = [] + + def train(self, resume_step: int = 0) -> Dict[str, Any]: + """ + Run the training loop. + + Args: + resume_step: Step to resume from (for checkpoint resumption) + + Returns: + Training statistics + """ + os.makedirs(self.config.output_dir, exist_ok=True) + + # Save config + config_path = os.path.join(self.config.output_dir, 'config.json') + with open(config_path, 'w') as f: + config_dict = {k: v for k, v in vars(self.config).items()} + json.dump(config_dict, f, indent=2, default=str) + + self._current_step = resume_step + + logger.info("=" * 60) + logger.info(f"Starting training for {self.config.total_steps} steps") + logger.info("=" * 60) + + # Reset collector + self.collector.reset() + + pbar = tqdm( + range(resume_step, self.config.total_steps), + desc="Training", + initial=resume_step, + total=self.config.total_steps + ) + + for step in pbar: + self._current_step = step + + # Collect one step of data + # Use random exploration for initial steps + use_random = step < self.config.start_steps + collect_stats = self.collector.collect_step( + noise_scale=self.config.expl_noise if not use_random else None, + use_random=use_random, + ) + + # Update episode tracking + if collect_stats.get('episode_rewards'): + self._recent_rewards.extend(collect_stats['episode_rewards']) + self._episode_count += len(collect_stats['episode_rewards']) + # Keep only recent 100 rewards + self._recent_rewards = self._recent_rewards[-100:] + + # Policy update + if step >= self.config.update_after and step % self.config.update_every == 0: + for _ in range(self.config.update_every): + self.algorithm.update( + batch_size=self.config.batch_size, + env=self.collector.get_env() + ) + + # Logging + if step > 0 and step % self.config.log_freq == 0: + avg_reward = np.mean(self._recent_rewards) if self._recent_rewards else 0.0 + pbar.set_postfix({ + 'episodes': self._episode_count, + 'avg_reward': f'{avg_reward:.2f}', + 'buffer': len(self.algorithm.replay) if self.algorithm.replay else 0, + }) + + # Evaluation + if step > 0 and step % self.config.eval_freq == 0: + self._run_evaluation(step) + + # Save checkpoint + if step > 0 and step % self.config.save_freq == 0: + self._save_checkpoint(step) + + # Final save + final_path = os.path.join(self.config.output_dir, 'final_model.pt') + self.algorithm.save(final_path) + logger.info(f"Training complete. Final model saved to {final_path}") + + return { + 'total_steps': self.config.total_steps, + 'episode_count': self._episode_count, + 'final_avg_reward': np.mean(self._recent_rewards) if self._recent_rewards else 0.0, + } + + def _run_evaluation(self, step: int) -> Dict[str, Any]: + """Run evaluation and log results.""" + logger.info(f"Step {step}: Running evaluation...") + + if self.eval_env_fn is None: + logger.warning("No eval_env_fn provided, skipping evaluation") + return {} + + # Run evaluation (evaluate method handles env creation/closing) + eval_result = self.evaluate(n_episodes=self.config.eval_episodes) + + if not eval_result or 'error' in eval_result: + return eval_result + + # Log results + logger.info( + f"Step {step}: Eval return = {eval_result['mean_return']:.2f} ± {eval_result['std_return']:.2f}, " + f"length = {eval_result['mean_length']:.1f}, " + f"episodes = {eval_result['num_episodes']}" + ) + + # Save eval results + eval_path = os.path.join(self.config.output_dir, 'eval_results.json') + eval_data = {'step': step, **eval_result} + with open(eval_path, 'a') as f: + f.write(json.dumps(eval_data, default=lambda x: x.tolist() if hasattr(x, 'tolist') else x) + '\n') + + return eval_result + + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Evaluate policy performance. + + Overrides base class to support creating eval env from eval_env_fn. + + Args: + n_episodes: Number of episodes to evaluate + render: Whether to render environment (not used currently) + env_type: Optional environment type (not used currently) + **kwargs: Additional arguments (e.g., vec_env, max_timesteps) + + Returns: + Dictionary containing evaluation metrics + """ + vec_env = kwargs.get('vec_env', None) + should_close = False + + # Create evaluation environment if not provided + if vec_env is None and self.eval_env_fn is not None: + eval_env_fns = [self.eval_env_fn for _ in range(n_episodes)] + vec_env = self.eval_vec_env_cls(eval_env_fns) + should_close = True + + # Set default max_timesteps and ctrl_space from config + if 'max_timesteps' not in kwargs: + kwargs['max_timesteps'] = self.config.max_timesteps + if 'ctrl_space' not in kwargs: + kwargs['ctrl_space'] = self.config.ctrl_space + + # Pass vec_env to base class evaluate + kwargs['vec_env'] = vec_env + result = super().evaluate(n_episodes=n_episodes, render=render, env_type=env_type, **kwargs) + + if should_close and vec_env is not None: + vec_env.close() + + return result + + def _save_checkpoint(self, step: int) -> None: + """Save training checkpoint.""" + self.save_checkpoint(self.config.output_dir, step) + logger.info(f"Saved checkpoint to {self.config.output_dir}/checkpoint_{step}.pt") + diff --git a/rl/utils/__init__.py b/rl/utils/__init__.py new file mode 100644 index 00000000..5c773d5a --- /dev/null +++ b/rl/utils/__init__.py @@ -0,0 +1,327 @@ +""" +RL Utilities Module + +This module provides utility functions for the RL framework. + +Available utilities: +- DataProcessor: Data processor for aligning with ILStudio pipeline +- Running statistics (mean, variance) for normalization +- Advantage computation (GAE) +- Discount reward computation +- Action post-processing helpers (ensure/clip actions) +""" + +import numpy as np +import torch +from typing import Dict, Any, Optional, List, Union + + + +def compute_gae( + rewards: np.ndarray, + values: np.ndarray, + dones: np.ndarray, + next_value: float, + gamma: float = 0.99, + gae_lambda: float = 0.95 +) -> np.ndarray: + """ + Compute Generalized Advantage Estimation (GAE). + + Args: + rewards: Array of rewards [T] + values: Array of value estimates [T] + dones: Array of done flags [T] + next_value: Value estimate for the next state + gamma: Discount factor + gae_lambda: GAE lambda parameter + + Returns: + Array of advantages [T] + """ + T = len(rewards) + advantages = np.zeros(T, dtype=np.float32) + last_gae = 0.0 + + for t in reversed(range(T)): + if t == T - 1: + next_non_terminal = 1.0 - dones[t] + next_val = next_value + else: + next_non_terminal = 1.0 - dones[t] + next_val = values[t + 1] + + delta = rewards[t] + gamma * next_val * next_non_terminal - values[t] + advantages[t] = last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae + + return advantages + + +def compute_returns( + rewards: np.ndarray, + dones: np.ndarray, + next_value: float = 0.0, + gamma: float = 0.99 +) -> np.ndarray: + """ + Compute discounted returns. + + Args: + rewards: Array of rewards [T] + dones: Array of done flags [T] + next_value: Value estimate for the next state + gamma: Discount factor + + Returns: + Array of discounted returns [T] + """ + T = len(rewards) + returns = np.zeros(T, dtype=np.float32) + running_return = next_value + + for t in reversed(range(T)): + running_return = rewards[t] + gamma * running_return * (1.0 - dones[t]) + returns[t] = running_return + + return returns + + +class RunningMeanStd: + """ + Running mean and standard deviation tracker. + + Useful for observation normalization in RL. + """ + + def __init__(self, shape: tuple = (), epsilon: float = 1e-8): + """ + Initialize running statistics. + + Args: + shape: Shape of the data to track + epsilon: Small value for numerical stability + """ + self.mean = np.zeros(shape, dtype=np.float64) + self.var = np.ones(shape, dtype=np.float64) + self.count = epsilon + self.epsilon = epsilon + + def update(self, x: np.ndarray) -> None: + """ + Update running statistics with new data. + + Args: + x: New data batch [batch_size, *shape] + """ + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + + self._update_from_moments(batch_mean, batch_var, batch_count) + + def _update_from_moments( + self, + batch_mean: np.ndarray, + batch_var: np.ndarray, + batch_count: int + ) -> None: + """Update from batch moments.""" + delta = batch_mean - self.mean + total_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / total_count + m_a = self.var * self.count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * self.count * batch_count / total_count + new_var = M2 / total_count + + self.mean = new_mean + self.var = new_var + self.count = total_count + + def normalize(self, x: np.ndarray) -> np.ndarray: + """ + Normalize data using running statistics. + + Args: + x: Data to normalize + + Returns: + Normalized data + """ + return (x - self.mean) / np.sqrt(self.var + self.epsilon) + + def denormalize(self, x: np.ndarray) -> np.ndarray: + """ + Denormalize data using running statistics. + + Args: + x: Normalized data + + Returns: + Denormalized data + """ + return x * np.sqrt(self.var + self.epsilon) + self.mean + + +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: + """ + Compute explained variance. + + Args: + y_pred: Predicted values + y_true: True values + + Returns: + Explained variance (1.0 is perfect prediction) + """ + var_y = np.var(y_true) + if var_y == 0: + return np.nan + return 1.0 - np.var(y_true - y_pred) / var_y + + +def polyak_update( + source_params: List[torch.nn.Parameter], + target_params: List[torch.nn.Parameter], + tau: float = 0.005 +) -> None: + """ + Perform Polyak (soft) update of target network parameters. + + target = tau * source + (1 - tau) * target + + Args: + source_params: Source network parameters + target_params: Target network parameters + tau: Interpolation factor (0.0 = no update, 1.0 = full copy) + """ + with torch.no_grad(): + for source_param, target_param in zip(source_params, target_params): + target_param.data.mul_(1.0 - tau) + target_param.data.add_(tau * source_param.data) + + +def hard_update( + source_params: List[torch.nn.Parameter], + target_params: List[torch.nn.Parameter] +) -> None: + """ + Perform hard update of target network parameters (full copy). + + Args: + source_params: Source network parameters + target_params: Target network parameters + """ + polyak_update(source_params, target_params, tau=1.0) + + +__all__ = [ + 'compute_gae', + 'compute_returns', + 'RunningMeanStd', + 'explained_variance', + 'polyak_update', + 'hard_update', +] + + +if __name__ == '__main__': + """ + Test code for RL utilities. + """ + print("=" * 60) + print("Testing RL Utilities") + print("=" * 60) + + # Test compute_gae + print("\n1. Testing compute_gae...") + rewards = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + values = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + dones = np.array([0, 0, 0, 0, 1], dtype=np.float32) + next_value = 0.0 + + advantages = compute_gae(rewards, values, dones, next_value, gamma=0.99, gae_lambda=0.95) + print(f" Rewards: {rewards}") + print(f" Values: {values}") + print(f" Advantages: {advantages}") + print(f" Advantages shape: {advantages.shape}") + + # Test compute_returns + print("\n2. Testing compute_returns...") + returns = compute_returns(rewards, dones, next_value=0.0, gamma=0.99) + print(f" Returns: {returns}") + print(f" Returns shape: {returns.shape}") + + # Test RunningMeanStd + print("\n3. Testing RunningMeanStd...") + rms = RunningMeanStd(shape=(5,)) + + # Update with random data + for i in range(10): + data = np.random.randn(32, 5) + rms.update(data) + + print(f" Mean: {rms.mean}") + print(f" Var: {rms.var}") + print(f" Count: {rms.count}") + + # Test normalization + test_data = np.random.randn(10, 5) + normalized = rms.normalize(test_data) + denormalized = rms.denormalize(normalized) + print(f" Normalization error: {np.max(np.abs(test_data - denormalized)):.2e}") + + # Test explained_variance + print("\n4. Testing explained_variance...") + y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + y_pred = np.array([1.1, 1.9, 3.1, 4.1, 4.9]) + ev = explained_variance(y_pred, y_true) + print(f" y_true: {y_true}") + print(f" y_pred: {y_pred}") + print(f" Explained variance: {ev:.4f}") + + # Test polyak_update + print("\n5. Testing polyak_update...") + + class SimpleNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(10, 5) + + source_net = SimpleNet() + target_net = SimpleNet() + + # Initialize target differently + with torch.no_grad(): + target_net.fc.weight.fill_(0.0) + target_net.fc.bias.fill_(0.0) + + print(f" Source weight mean: {source_net.fc.weight.mean().item():.4f}") + print(f" Target weight mean before update: {target_net.fc.weight.mean().item():.4f}") + + polyak_update( + list(source_net.parameters()), + list(target_net.parameters()), + tau=0.5 + ) + print(f" Target weight mean after polyak update (tau=0.5): {target_net.fc.weight.mean().item():.4f}") + + # Test hard_update + print("\n6. Testing hard_update...") + with torch.no_grad(): + target_net.fc.weight.fill_(0.0) + + hard_update( + list(source_net.parameters()), + list(target_net.parameters()) + ) + + diff = (source_net.fc.weight - target_net.fc.weight).abs().max().item() + print(f" Weight difference after hard update: {diff:.2e}") + assert diff < 1e-6, "Hard update should copy exactly" + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/utils/action_utils.py b/rl/utils/action_utils.py new file mode 100644 index 00000000..c9fe21af --- /dev/null +++ b/rl/utils/action_utils.py @@ -0,0 +1,109 @@ +""" +Action utility helpers for RL algorithms. + +These helpers keep action post-processing outside MetaEnv, so algorithms can +apply different refinement strategies when needed. +""" + +from __future__ import annotations + +from typing import Callable, Optional, Tuple + +import numpy as np +import torch +from loguru import logger + +def _get_action_space(env) -> Optional[object]: + """Best-effort extraction of an action_space from env or wrappers.""" + if env is None: + return None + if hasattr(env, "action_space"): + return env.action_space + if hasattr(env, "env") and hasattr(env.env, "action_space"): + return env.env.action_space + if hasattr(env, "envs") and env.envs: + first_env = env.envs[0] + if hasattr(first_env, "action_space"): + return first_env.action_space + if hasattr(first_env, "env") and hasattr(first_env.env, "action_space"): + return first_env.env.action_space + return None + + +def _get_action_bounds(env) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + action_space = _get_action_space(env) + if action_space is None: + return None, None + low = getattr(action_space, "low", None) + high = getattr(action_space, "high", None) + if low is None or high is None: + return None, None + return np.asarray(low), np.asarray(high) + + +def clip_action_to_space(env, action): + """Clip action to env action space bounds if available.""" + low, high = _get_action_bounds(env) + if low is None or high is None: + return action + if torch.is_tensor(action): + low_t = torch.as_tensor(low, device=action.device, dtype=action.dtype) + high_t = torch.as_tensor(high, device=action.device, dtype=action.dtype) + return torch.clamp(action, low_t, high_t) + return np.clip(action, low, high) + +def tanh_action_to_space(env, action): + """ + Apply tanh to action and scale to env action space bounds. + + Maps tanh output [-1, 1] to [action_space.low, action_space.high]. + If no action space bounds available, just applies tanh. + """ + # Apply tanh to bound to [-1, 1] + if torch.is_tensor(action): + tanh_action = torch.tanh(action) + else: + tanh_action = np.tanh(action) + + # Scale to action space bounds + low, high = _get_action_bounds(env) + if low is None or high is None: + return tanh_action + + # Map [-1, 1] -> [low, high]: scaled = low + (tanh + 1) * (high - low) / 2 + if torch.is_tensor(tanh_action): + low_t = torch.as_tensor(low, device=tanh_action.device, dtype=tanh_action.dtype) + high_t = torch.as_tensor(high, device=tanh_action.device, dtype=tanh_action.dtype) + return low_t + (tanh_action + 1.0) * (high_t - low_t) / 2.0 + else: + return low + (tanh_action + 1.0) * (high - low) / 2.0 + +def ensure_action( + env, + action, + refine_fn: Optional[Callable[[object, object], object]] = None, +): + """ + Ensure actions are valid for the environment. + + - Always applies the refine function (e.g., tanh_action_to_space) to map + network outputs to valid action space. + - Clips to env action space bounds if available. + + Args: + env: Environment (or wrapper) providing action_space. + action: Tensor or numpy array action. + refine_fn: Optional callable (env, action) -> action for custom refinement. + """ + try: + reasonable = env.envs[0].ensure_action_reasonable(action) + except ValueError: + logger.warning(f"Environment {env.__class__.__name__} does not support ensure_action_reasonable, using tanh_action_to_space") + # Always apply refine_fn to map network outputs to action space + if refine_fn is not None: + action = refine_fn(env, action) + + # Final clip to ensure bounds (safety measure) + action = clip_action_to_space(env, action) + + return action \ No newline at end of file diff --git a/train_rl.py b/train_rl.py new file mode 100644 index 00000000..bce66dce --- /dev/null +++ b/train_rl.py @@ -0,0 +1,500 @@ +""" +Reinforcement Learning Training Script + +This script provides a unified interface for training RL algorithms. +It supports: +- Multiple registered algorithms (TD3, SAC, PPO, etc.) +- Vectorized environments for parallel data collection +- Flexible configuration via YAML files and CLI overrides +- Checkpoint saving and resuming + +Usage: + python train_rl.py -a td3 -e aloha --num_envs 4 --total_steps 1000000 + python train_rl.py -a configs/rl/td3.yaml -e configs/env/custom.yaml -o ckpt/td3_experiment +""" + +import configs # Must be first to suppress TensorFlow logs +import os +import argparse +import json +import importlib +from loguru import logger +import numpy as np +import torch +from tqdm import tqdm + +from data_utils.utils import set_seed +from benchmark.utils import SequentialVectorEnv, organize_obs + + +def parse_args(): + """Parse command line arguments for RL training.""" + parser = argparse.ArgumentParser(description='Train an RL algorithm') + + # Algorithm arguments + parser.add_argument('-a', '--algorithm', type=str, default='td3', + help='Algorithm name (registered) or path to YAML config file') + + # Environment arguments + parser.add_argument('-e', '--env', type=str, default='aloha', + help='Env config (name under configs/env or absolute path to yaml)') + parser.add_argument('--num_envs', type=int, default=1, + help='Number of parallel environments') + parser.add_argument('--use_subproc', action='store_true', + help='Use SubprocVectorEnv instead of SequentialVectorEnv') + + # Training arguments + parser.add_argument('--total_steps', type=int, default=1000000, + help='Total environment steps for training') + parser.add_argument('--start_steps', type=int, default=25000, + help='Number of random exploration steps before policy is used') + parser.add_argument('--update_after', type=int, default=1000, + help='Number of steps before starting policy updates') + parser.add_argument('--update_every', type=int, default=50, + help='Update policy every N environment steps') + parser.add_argument('--batch_size', type=int, default=256, + help='Batch size for policy updates') + parser.add_argument('--replay_size', type=int, default=1000000, + help='Replay buffer capacity') + parser.add_argument('--expl_noise', type=float, default=0.1, + help='Exploration noise scale (can be overridden by algo config)') + + # Output and logging + parser.add_argument('-o', '--output_dir', type=str, default='ckpt/rl_training', + help='Output directory for checkpoints and logs') + parser.add_argument('--save_freq', type=int, default=50000, + help='Save checkpoint every N steps') + parser.add_argument('--eval_freq', type=int, default=10000, + help='Evaluate policy every N steps') + parser.add_argument('--eval_episodes', type=int, default=10, + help='Number of episodes for evaluation') + parser.add_argument('--log_freq', type=int, default=1000, + help='Log training stats every N steps') + + # Misc + parser.add_argument('-s', '--seed', type=int, default=0, + help='Random seed') + parser.add_argument('--device', type=str, default='cuda', + help='Device to use (cuda or cpu)') + parser.add_argument('--resume', type=str, default=None, + help='Path to checkpoint to resume from') + + args, unknown = parser.parse_known_args() + args._unknown = unknown + return args + + +def load_env_config(args): + """Load environment configuration from YAML.""" + from configs.loader import ConfigLoader + cfg_loader = ConfigLoader(args=args, unknown_args=getattr(args, '_unknown', [])) + env_cfg, env_cfg_path = cfg_loader.load_env(args.env) + return env_cfg, env_cfg_path + + +def load_algo_config(args): + """ + Load algorithm configuration from YAML. + + Priority: + 1. If path ends with .yaml/.yml, load directly + 2. If algorithm name (e.g., 'td3'), try to load from configs/rl/{name}.yaml + 3. If no config file found, use defaults + """ + import yaml + + algo_name_or_path = args.algorithm + + # Check if it's a path to a YAML config file + if algo_name_or_path.endswith('.yaml') or algo_name_or_path.endswith('.yml'): + # Load from specified YAML config file + if os.path.exists(algo_name_or_path): + with open(algo_name_or_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + algo_cfg_path = algo_name_or_path + else: + # Try configs/rl directory + config_path = os.path.join('configs/rl', algo_name_or_path) + if os.path.exists(config_path): + with open(config_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + algo_cfg_path = config_path + else: + raise FileNotFoundError(f"Algorithm config file not found: {algo_name_or_path}") + + logger.info(f"Loaded algorithm config from: {algo_cfg_path}") + return algo_cfg, algo_cfg_path + else: + # Try to load from configs/rl/{algo_name}.yaml + config_path = os.path.join('configs/rl', f'{algo_name_or_path}.yaml') + if os.path.exists(config_path): + with open(config_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + logger.info(f"Loaded algorithm config from: {config_path}") + return algo_cfg, config_path + else: + # No config file, use defaults + logger.info(f"No config file found for '{algo_name_or_path}', using default parameters") + return {'type': algo_name_or_path}, None + + +def create_env_fn(env_cfg, env_module): + """Create a factory function for environment creation.""" + def _create(): + return env_module.create_env(env_cfg) + return _create + + +def create_vector_env(env_cfg, num_envs, use_subproc=False): + """Create vectorized environment.""" + # Parse env type + env_type = env_cfg.type + if '.' in env_type: + module_path, class_name = env_type.rsplit('.', 1) + env_module = importlib.import_module(module_path) + env_name = module_path.split('.')[-1] if '.' in module_path else module_path + else: + env_module = importlib.import_module(f"benchmark.{env_type}") + env_name = env_type + + if not hasattr(env_module, 'create_env'): + raise AttributeError(f"env module {env_type} has no 'create_env'") + + env_fns = [create_env_fn(env_cfg, env_module) for _ in range(num_envs)] + + if use_subproc and num_envs > 1: + from tianshou.env import SubprocVectorEnv + vec_env = SubprocVectorEnv(env_fns) + else: + vec_env = SequentialVectorEnv(env_fns) + + return vec_env, env_name, env_module + + +def get_env_dims(env_cfg, vec_env=None): + """ + Get state and action dimensions from environment config. + + Priority: + 1. Read from env_cfg (state_dim, action_dim) - preferred + 2. Infer from vec_env if not specified in config - fallback + + Args: + env_cfg: Environment configuration (namespace or dict) + vec_env: Optional vectorized environment for fallback inference + + Returns: + (state_dim, action_dim): Tuple of dimensions + """ + state_dim = None + action_dim = None + + # 1. Try to get from env_cfg first (preferred) + if hasattr(env_cfg, 'state_dim'): + state_dim = env_cfg.state_dim + logger.info(f"Read state_dim={state_dim} from env config") + if hasattr(env_cfg, 'action_dim'): + action_dim = env_cfg.action_dim + logger.info(f"Read action_dim={action_dim} from env config") + + # 2. Fallback: infer from environment if not specified + if (state_dim is None or action_dim is None) and vec_env is not None: + logger.warning("state_dim or action_dim not specified in env config, inferring from environment...") + + if state_dim is None: + # Reset to get observation + obs = vec_env.reset() + + # Handle different observation formats + if hasattr(obs, 'state'): + state_dim = obs.state.shape[-1] if hasattr(obs.state, 'shape') else len(obs.state) + elif isinstance(obs, np.ndarray): + if obs.ndim == 1: + state_dim = obs.shape[0] + else: + # For vectorized env, obs[0] is single env obs + first_obs = obs[0] if obs.dtype == np.object_ else obs[0] + if hasattr(first_obs, 'state'): + state_dim = first_obs.state.shape[-1] + else: + state_dim = first_obs.shape[-1] if hasattr(first_obs, 'shape') else len(first_obs) + elif isinstance(obs, dict): + state_dim = obs.get('state', obs.get('observation')).shape[-1] + else: + # Try to get first env's observation + first_obs = obs[0] if hasattr(obs, '__getitem__') else obs + if hasattr(first_obs, 'state'): + state_dim = first_obs.state.shape[-1] + else: + state_dim = len(first_obs) if hasattr(first_obs, '__len__') else 1 + logger.info(f"Inferred state_dim={state_dim} from environment") + + if action_dim is None: + # Get action dimension from action_space if available + single_env = vec_env.envs[0] if hasattr(vec_env, 'envs') else vec_env + if hasattr(single_env, 'action_space'): + action_dim = single_env.action_space.shape[0] + elif hasattr(single_env, 'env') and hasattr(single_env.env, 'action_space'): + action_dim = single_env.env.action_space.shape[0] + else: + # Default fallback + logger.warning("Could not determine action_dim from env, defaulting to state_dim") + action_dim = state_dim + logger.info(f"Inferred action_dim={action_dim} from environment") + + if state_dim is None or action_dim is None: + raise ValueError("Could not determine state_dim and action_dim. Please specify them in env config.") + + return state_dim, action_dim + + +def create_meta_policy(args, algo_cfg, state_dim, action_dim): + """ + Create MetaPolicy for the algorithm. + + Similar to how load_policy() works in policy/utils.py. + """ + from benchmark.base import MetaPolicy + from data_utils.normalize import Identity + + # Get control space and type from config or defaults + ctrl_space = 'ee' + ctrl_type = 'delta' + + if algo_cfg is not None: + ctrl_space = algo_cfg.get('ctrl_space', ctrl_space) + ctrl_type = algo_cfg.get('ctrl_type', ctrl_type) + + # For RL, we typically use identity normalizers (actions are already in [-1, 1]) + action_normalizer = Identity() + state_normalizer = Identity() + + # Store in args for later use + args.ctrl_space = ctrl_space + args.ctrl_type = ctrl_type + + return { + 'action_normalizer': action_normalizer, + 'state_normalizer': state_normalizer, + 'ctrl_space': ctrl_space, + 'ctrl_type': ctrl_type, + } + + +def create_replay_buffer(args, state_dim, action_dim, num_envs): + """Create replay buffer.""" + from rl.buffer import MetaReplay + + replay = MetaReplay( + capacity=args.replay_size, + state_dim=state_dim, + action_dim=action_dim, + n_envs=num_envs, + device='cpu', # Store on CPU, move to GPU during training + ) + return replay + + +def create_algorithm(args, algo_cfg, state_dim, action_dim, replay, meta_policy_params): + """ + Create RL algorithm from config dynamically. + + Args: + args: Command line arguments + algo_cfg: Algorithm config dict (from YAML) + state_dim: State dimension + action_dim: Action dimension + replay: Replay buffer + meta_policy_params: Dict with normalizers and ctrl settings for MetaPolicy + """ + import dataclasses + from rl.algorithms import get_algorithm_class, get_config_class, list_algorithms + + logger.info(f"Available algorithms: {list_algorithms()}") + + # Determine algorithm type from config + algo_args = algo_cfg if algo_cfg is not None else {} + algo_type = algo_args.get('type', algo_args.get('algorithm', args.algorithm)) + + # Get algorithm class and config class dynamically + AlgorithmClass = get_algorithm_class(algo_type) + ConfigClass = get_config_class(algo_type) + logger.info(f"Using algorithm: {AlgorithmClass.__name__}") + + # Build config if ConfigClass is available + config = None + if ConfigClass is not None: + # Required parameters + config_params = { + 'state_dim': state_dim, + 'action_dim': action_dim, + 'device': args.device, + } + + # Get config field names and add matching parameters from algo_cfg + if dataclasses.is_dataclass(ConfigClass): + config_fields = {f.name for f in dataclasses.fields(ConfigClass) if f.name not in config_params} + for key in config_fields: + if key in algo_args: + config_params[key] = algo_args[key] + + config = ConfigClass(**config_params) + + # Log config (dynamically get attributes) + log_attrs = ['discount', 'tau', 'actor_lr', 'critic_lr', 'lr'] + log_parts = [f"{attr}={getattr(config, attr)}" for attr in log_attrs if hasattr(config, attr)] + if log_parts: + logger.info(f"{AlgorithmClass.__name__} config: {', '.join(log_parts)}") + + # Get exploration noise from algo_cfg (overrides command line) + if 'expl_noise' in algo_args: + args.expl_noise = algo_args['expl_noise'] + logger.info(f"Using expl_noise={args.expl_noise} from algo config") + + # Create algorithm + if config is not None: + algorithm = AlgorithmClass( + replay=replay, + config=config, + meta_policy=None, + ctrl_space=meta_policy_params['ctrl_space'], + ctrl_type=meta_policy_params['ctrl_type'], + ) + else: + # Fallback for algorithms without config class + excluded_keys = {'type', 'algorithm', 'ctrl_space', 'ctrl_type', 'expl_noise'} + filtered_args = {k: v for k, v in algo_args.items() if k not in excluded_keys} + + algorithm = AlgorithmClass( + replay=replay, + state_dim=state_dim, + action_dim=action_dim, + device=args.device, + **filtered_args + ) + + return algorithm + + +def extract_state(obs, state_key='state'): + """Extract state array from observation.""" + if hasattr(obs, state_key): + return getattr(obs, state_key) + elif isinstance(obs, dict): + return obs.get(state_key, obs.get('observation')) + elif isinstance(obs, np.ndarray): + if obs.dtype == np.object_: + # Array of MetaObs + return np.stack([getattr(o, state_key) if hasattr(o, state_key) else o for o in obs]) + return obs + return obs + + +def train(args): + """Main training loop using OffPolicyTrainer and DummyCollector.""" + # Set seed + set_seed(args.seed) + logger.info(f"Set random seed to {args.seed}") + + # Load environment config + env_cfg, env_cfg_path = load_env_config(args) + logger.info(f"Loaded env config from: {env_cfg_path}") + + # Load algorithm config (similar to how eval_sim.py loads policy) + algo_cfg, algo_cfg_path = load_algo_config(args) + + # Create vectorized environment + vec_env, env_name, env_module = create_vector_env( + env_cfg, args.num_envs, args.use_subproc + ) + logger.info(f"Created {args.num_envs} parallel environments: {env_name}") + + # Sync derived values from env config (like eval_sim.py does) + if hasattr(env_cfg, 'max_timesteps'): + args.max_timesteps = env_cfg.max_timesteps + else: + args.max_timesteps = 1000 # Default + + if hasattr(env_cfg, 'task'): + args.task = env_cfg.task + + # Get environment dimensions from config (preferred) or infer from env (fallback) + state_dim, action_dim = get_env_dims(env_cfg, vec_env) + logger.info(f"State dim: {state_dim}, Action dim: {action_dim}") + + # Create MetaPolicy parameters (normalizers, ctrl settings) + meta_policy_params = create_meta_policy(args, algo_cfg, state_dim, action_dim) + logger.info(f"Control space: {args.ctrl_space}, Control type: {args.ctrl_type}") + + # Create replay buffer + replay = create_replay_buffer(args, state_dim, action_dim, args.num_envs) + logger.info(f"Created replay buffer with capacity {args.replay_size}") + + # Create algorithm (with MetaPolicy) + algorithm = create_algorithm(args, algo_cfg, state_dim, action_dim, replay, meta_policy_params) + algorithm.set_env(vec_env) # Bind environment for action processing + logger.info(f"Created algorithm: {algorithm}") + + # Create Collector for environment interaction + from rl.collectors import DummyCollector + collector = DummyCollector( + envs=vec_env, + algorithm=algorithm, + ctrl_space=args.ctrl_space, + action_dim=action_dim, + ) + + # Create Trainer configuration + from rl.trainers.offpolicy_trainer import OffPolicyTrainer, OffPolicyTrainerConfig + trainer_config = OffPolicyTrainerConfig( + total_steps=args.total_steps, + start_steps=args.start_steps, + update_after=args.update_after, + update_every=args.update_every, + batch_size=args.batch_size, + log_freq=args.log_freq, + eval_freq=args.eval_freq, + eval_episodes=args.eval_episodes, + save_freq=args.save_freq, + output_dir=args.output_dir, + max_timesteps=args.max_timesteps, + ctrl_space=args.ctrl_space, + expl_noise=args.expl_noise, + ) + + # Create evaluation environment factory + def eval_env_fn(): + return create_env_fn(env_cfg, env_module)() + + # Select vector environment class based on args + if args.use_subproc: + from benchmark.utils import SubprocVectorEnv + eval_vec_env_cls = SubprocVectorEnv + else: + eval_vec_env_cls = SequentialVectorEnv + + # Create Trainer + trainer = OffPolicyTrainer( + algorithm=algorithm, + collector=collector, + config=trainer_config, + eval_env_fn=eval_env_fn, + eval_vec_env_cls=eval_vec_env_cls, + ) + + # Resume from checkpoint if specified + start_step = 0 + if args.resume: + start_step = trainer.load_checkpoint(args.resume) + + # Run training + trainer.train(resume_step=start_step) + + # Cleanup + vec_env.close() + + +if __name__ == '__main__': + args = parse_args() + train(args) + diff --git a/utils/__init__.py b/utils/__init__.py index 4adaa8f7..c93083f3 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -6,5 +6,26 @@ # Configure logging on package import from .logger import logger -__all__ = ['logger'] +# Exploration strategies +from .exploration import ( + BaseExploration, + NoExploration, + GaussianNoise, + OUNoise, + CustomNoise, + RandomExploration, + ExplorationScheduler, +) + +__all__ = [ + 'logger', + # Exploration + 'BaseExploration', + 'NoExploration', + 'GaussianNoise', + 'OUNoise', + 'CustomNoise', + 'RandomExploration', + 'ExplorationScheduler', +] diff --git a/utils/exploration.py b/utils/exploration.py new file mode 100644 index 00000000..11a7d612 --- /dev/null +++ b/utils/exploration.py @@ -0,0 +1,415 @@ +""" +Exploration Strategies for RL + +This module defines exploration strategies for action selection during training. +Supports: +- Random exploration (for initial exploration phase) +- Gaussian noise +- Ornstein-Uhlenbeck noise +- Custom noise functions (e.g., uncertainty-based) + +Usage: + # Gaussian noise with initial random exploration + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + + # Apply exploration to action + explored_action = exploration(action, step=current_step) +""" + +import numpy as np +from abc import ABC, abstractmethod +from typing import Optional, Callable, Any + + +class BaseExploration(ABC): + """ + Base class for exploration strategies. + + Exploration strategies modify actions to encourage exploration during training. + """ + + @abstractmethod + def __call__( + self, + action: np.ndarray, + step: int = 0, + **kwargs + ) -> np.ndarray: + """ + Apply exploration to action. + + Args: + action: Original action from policy, shape (action_dim,) or (n_envs, action_dim) + step: Current training step (for annealing) + **kwargs: Additional arguments (e.g., obs for uncertainty-based exploration) + + Returns: + Explored action with same shape as input + """ + raise NotImplementedError + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset exploration state (e.g., for OU noise).""" + pass + + +class NoExploration(BaseExploration): + """No exploration - return action as-is.""" + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + return action + + +class GaussianNoise(BaseExploration): + """ + Gaussian noise exploration. + + Adds zero-mean Gaussian noise to actions, with optional annealing. + Used in TD3, SAC, etc. + + Example: + noise = GaussianNoise(sigma=0.1, sigma_min=0.01, decay_steps=100000) + noisy_action = noise(action, step=current_step) + """ + + def __init__( + self, + sigma: float = 0.1, + sigma_min: float = 0.01, + sigma_decay: float = 1.0, # No decay by default + decay_steps: int = 0, # 0 means no decay + ): + """ + Args: + sigma: Initial noise standard deviation + sigma_min: Minimum noise standard deviation (for annealing) + sigma_decay: Decay factor per step (exponential decay, < 1.0 to enable) + decay_steps: Total steps for linear decay (> 0 to enable, takes priority over sigma_decay) + """ + self.sigma_init = sigma + self.sigma = sigma + self.sigma_min = sigma_min + self.sigma_decay = sigma_decay + self.decay_steps = decay_steps + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Anneal sigma + if self.decay_steps > 0: + # Linear decay + decay_ratio = max(0.0, 1.0 - step / self.decay_steps) + self.sigma = self.sigma_min + (self.sigma_init - self.sigma_min) * decay_ratio + elif self.sigma_decay < 1.0: + # Exponential decay + self.sigma = max(self.sigma_min, self.sigma_init * (self.sigma_decay ** step)) + + # Add Gaussian noise + noise = np.random.normal(0, self.sigma, size=action.shape) + noisy_action = action + noise + + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + pass # Gaussian noise is stateless + + def __repr__(self) -> str: + return (f"GaussianNoise(sigma={self.sigma:.4f}, sigma_init={self.sigma_init}, " + f"sigma_min={self.sigma_min}, decay_steps={self.decay_steps})") + + +class OUNoise(BaseExploration): + """ + Ornstein-Uhlenbeck noise exploration. + + Temporally correlated noise, often used in continuous control tasks. + Used in DDPG, etc. + + The OU process: dx = theta * (mu - x) * dt + sigma * dW + + Example: + noise = OUNoise(action_dim=7, n_envs=4, sigma=0.2, theta=0.15) + noisy_action = noise(action, step=current_step) + """ + + def __init__( + self, + action_dim: int, + n_envs: int = 1, + mu: float = 0.0, + theta: float = 0.15, + sigma: float = 0.2, + sigma_min: float = 0.01, + sigma_decay: float = 1.0, + decay_steps: int = 0, + ): + """ + Args: + action_dim: Dimension of action space + n_envs: Number of parallel environments + mu: Mean of the noise (typically 0) + theta: Rate of mean reversion (how fast noise returns to mu) + sigma: Volatility of the noise + sigma_min: Minimum sigma for annealing + sigma_decay: Exponential decay factor + decay_steps: Steps for linear decay + """ + self.action_dim = action_dim + self.n_envs = n_envs + self.mu = mu + self.theta = theta + self.sigma_init = sigma + self.sigma = sigma + self.sigma_min = sigma_min + self.sigma_decay = sigma_decay + self.decay_steps = decay_steps + + # Initialize state for each environment + self._state = np.ones((n_envs, action_dim)) * mu + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Anneal sigma + if self.decay_steps > 0: + decay_ratio = max(0.0, 1.0 - step / self.decay_steps) + self.sigma = self.sigma_min + (self.sigma_init - self.sigma_min) * decay_ratio + elif self.sigma_decay < 1.0: + self.sigma = max(self.sigma_min, self.sigma_init * (self.sigma_decay ** step)) + + # Handle both single and batched actions + is_batched = action.ndim == 2 + if not is_batched: + action = action[np.newaxis, :] # (1, action_dim) + + batch_size = action.shape[0] + + # Update OU state: dx = theta * (mu - x) + sigma * noise + dx = (self.theta * (self.mu - self._state[:batch_size]) + + self.sigma * np.random.randn(batch_size, self.action_dim)) + self._state[:batch_size] += dx + + # Add noise to action + noisy_action = action + self._state[:batch_size] + + if not is_batched: + noisy_action = noisy_action[0] + + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset OU state.""" + if env_idx is None: + self._state = np.ones((self.n_envs, self.action_dim)) * self.mu + else: + self._state[env_idx] = self.mu + + def __repr__(self) -> str: + return (f"OUNoise(action_dim={self.action_dim}, n_envs={self.n_envs}, " + f"theta={self.theta}, sigma={self.sigma:.4f})") + + +class CustomNoise(BaseExploration): + """ + Custom noise exploration using a user-provided function. + + Allows implementing advanced exploration strategies like: + - Uncertainty-based exploration (using model epistemic uncertainty) + - Curiosity-driven exploration + - Parameter noise + - Any other custom exploration logic + + Example: + def uncertainty_noise_fn(action, step, obs=None, model=None, **kwargs): + if model is not None and obs is not None: + uncertainty = model.get_uncertainty(obs) + noise_scale = uncertainty * 0.5 + else: + noise_scale = 0.1 + noise = np.random.normal(0, noise_scale, size=action.shape) + return action + noise + + noise = CustomNoise(noise_fn=uncertainty_noise_fn) + noisy_action = noise(action, step=step, obs=obs, model=model) + """ + + def __init__( + self, + noise_fn: Callable[[np.ndarray, int, Any], np.ndarray], + reset_fn: Optional[Callable[[Optional[int]], None]] = None, + ): + """ + Args: + noise_fn: Custom noise function with signature: + noise_fn(action, step, **kwargs) -> noisy_action + - action: Original action from policy + - step: Current training step + - **kwargs: Additional info (obs, model uncertainty, etc.) + reset_fn: Optional reset function for stateful noise + """ + self.noise_fn = noise_fn + self.reset_fn = reset_fn + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + noisy_action = self.noise_fn(action, step, **kwargs) + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + if self.reset_fn is not None: + self.reset_fn(env_idx) + + def __repr__(self) -> str: + return f"CustomNoise(noise_fn={self.noise_fn.__name__})" + + +class RandomExploration(BaseExploration): + """ + Random action exploration. + + Samples random actions from action space, completely ignoring policy output. + Used for initial exploration phase. + + Example: + random_explore = RandomExploration( + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + random_action = random_explore(policy_action) # policy_action is ignored + """ + + def __init__( + self, + action_low: np.ndarray, + action_high: np.ndarray, + ): + """ + Args: + action_low: Lower bound of action space + action_high: Upper bound of action space + """ + self.action_low = np.asarray(action_low) + self.action_high = np.asarray(action_high) + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Completely ignore input action, sample random action + random_action = np.random.uniform( + self.action_low, + self.action_high, + size=action.shape + ) + return random_action.astype(action.dtype) + + def __repr__(self) -> str: + return f"RandomExploration(action_low={self.action_low}, action_high={self.action_high})" + + +class ExplorationScheduler: + """ + Scheduler for switching between exploration strategies. + + Supports: + - Initial random exploration phase (before a certain step) + - Transition to noise-based exploration + - Annealing exploration over time + + Example: + # Create scheduler with 10000 random steps, then Gaussian noise + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1, decay_steps=100000), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + + # In training loop: + for step in range(total_steps): + action = policy(obs) + explored_action = exploration(action, step=step) + # or use internal counter: + explored_action = exploration(action) + """ + + def __init__( + self, + exploration_strategy: BaseExploration, + random_steps: int = 0, + action_low: Optional[np.ndarray] = None, + action_high: Optional[np.ndarray] = None, + ): + """ + Args: + exploration_strategy: Main exploration strategy (e.g., GaussianNoise, OUNoise) + random_steps: Number of initial steps to use pure random exploration + - Set to 0 to disable random exploration phase + action_low: Lower bound for random exploration (required if random_steps > 0) + action_high: Upper bound for random exploration (required if random_steps > 0) + """ + self.exploration_strategy = exploration_strategy + self.random_steps = random_steps + self.action_low = action_low + self.action_high = action_high + + # Create random exploration for initial phase + if random_steps > 0: + if action_low is None or action_high is None: + raise ValueError("action_low and action_high required when random_steps > 0") + self.random_exploration = RandomExploration(action_low, action_high) + else: + self.random_exploration = None + + self._current_step = 0 + + def __call__( + self, + action: np.ndarray, + step: Optional[int] = None, + **kwargs + ) -> np.ndarray: + """ + Apply exploration based on current step. + + Args: + action: Original action from policy + step: Optional step override (if None, uses internal counter) + **kwargs: Additional arguments for exploration strategy + + Returns: + Explored action + """ + if step is None: + step = self._current_step + self._current_step += 1 + + # Initial random exploration phase + if step < self.random_steps and self.random_exploration is not None: + return self.random_exploration(action, step, **kwargs) + + # Main exploration strategy (pass adjusted step for proper annealing) + return self.exploration_strategy(action, step - self.random_steps, **kwargs) + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset exploration state.""" + self.exploration_strategy.reset(env_idx) + if self.random_exploration is not None: + self.random_exploration.reset(env_idx) + + def reset_step_counter(self) -> None: + """Reset internal step counter.""" + self._current_step = 0 + + @property + def is_random_phase(self) -> bool: + """Check if still in random exploration phase.""" + return self._current_step < self.random_steps + + @property + def current_step(self) -> int: + """Get current step count.""" + return self._current_step + + def __repr__(self) -> str: + return (f"ExplorationScheduler(strategy={self.exploration_strategy}, " + f"random_steps={self.random_steps}, current_step={self._current_step})") + +