From cbe01d51f194ae64335c5cf485275d45755c2e16 Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Mon, 26 Jan 2026 17:04:44 +0800 Subject: [PATCH 1/6] [feat] add rl framework --- docs/rl_framework_integrated_design.md | 1674 ++++++++++++++++++++++++ rl/__init__.py | 172 +++ rl/algorithms/__init__.py | 79 ++ rl/base.py | 431 ++++++ rl/buffer/__init__.py | 15 + rl/buffer/base_replay.py | 304 +++++ rl/collectors/__init__.py | 84 ++ rl/collectors/base_collector.py | 439 +++++++ rl/infra/__init__.py | 56 + rl/infra/callback.py | 624 +++++++++ rl/infra/checkpoint.py | 468 +++++++ rl/infra/distributed.py | 487 +++++++ rl/infra/logger.py | 593 +++++++++ rl/infra/seed_manager.py | 357 +++++ rl/rewards/__init__.py | 86 ++ rl/rewards/base_reward.py | 346 +++++ rl/trainers/__init__.py | 85 ++ rl/trainers/base_trainer.py | 587 +++++++++ rl/utils/__init__.py | 325 +++++ 19 files changed, 7212 insertions(+) create mode 100644 docs/rl_framework_integrated_design.md create mode 100644 rl/__init__.py create mode 100644 rl/algorithms/__init__.py create mode 100644 rl/base.py create mode 100644 rl/buffer/__init__.py create mode 100644 rl/buffer/base_replay.py create mode 100644 rl/collectors/__init__.py create mode 100644 rl/collectors/base_collector.py create mode 100644 rl/infra/__init__.py create mode 100644 rl/infra/callback.py create mode 100644 rl/infra/checkpoint.py create mode 100644 rl/infra/distributed.py create mode 100644 rl/infra/logger.py create mode 100644 rl/infra/seed_manager.py create mode 100644 rl/rewards/__init__.py create mode 100644 rl/rewards/base_reward.py create mode 100644 rl/trainers/__init__.py create mode 100644 rl/trainers/base_trainer.py create mode 100644 rl/utils/__init__.py diff --git a/docs/rl_framework_integrated_design.md b/docs/rl_framework_integrated_design.md new file mode 100644 index 00000000..bacb1981 --- /dev/null +++ b/docs/rl_framework_integrated_design.md @@ -0,0 +1,1674 @@ +# ILStudio RL框架集成设计 + +## 设计原则 + +1. **足够抽象**:只定义核心接口,不涉及具体实现 +2. **通用性**:支持各种RL算法和训练模式 +3. **兼容性**:直接使用MetaEnv和MetaPolicy,无需适配 +4. **扩展性**:易于扩展支持并行训练和分布式训练 +5. **模块化**:奖励函数、配置系统独立可替换 + +--- + +## 目录结构 + +``` +ILStudio/ +├── rl/ # 强化学习模块 +│ ├── __init__.py +│ ├── base.py # RL 基类定义 +│ ├── algorithms/ # RL 算法实现 +│ │ ├── __init__.py +│ │ ├── ppo.py +│ │ ├── sac.py +│ │ ├── td3.py +│ │ ├── dpo.py # Direct Preference Optimization (VLA) +│ │ ├── grpo.py # Group Relative Policy Optimization (VLA) +│ │ └── reinforce.py +│ ├── buffer/ # 经验回放 +│ │ ├── __init__.py +│ │ ├── base_replay.py # BaseReplay基类 +│ │ ├── memory_replay.py # 内存存储实现 +│ │ ├── rollout_buffer.py # On-policy Rollout Buffer +│ │ └── priority_buffer.py # 优先级采样Buffer +│ ├── rewards/ # 奖励函数模块 +│ │ ├── __init__.py +│ │ ├── base_reward.py # 奖励函数基类 +│ │ ├── sparse_reward.py # 稀疏奖励 +│ │ ├── dense_reward.py # 密集奖励 +│ │ ├── learned_reward.py # 学习的奖励模型 +│ │ └── language_reward.py # 语言条件奖励(VLA) +│ ├── collectors/ # 数据收集器模块(新增) +│ │ ├── __init__.py +│ │ ├── base_collector.py # Collector基类 +│ │ ├── sim_collector.py # 仿真环境收集器 +│ │ └── real_collector.py # 真实机器人收集器 +│ ├── trainers/ # 训练器实现 +│ │ ├── __init__.py +│ │ ├── base_trainer.py # BaseTrainer基类 +│ │ ├── simple_trainer.py # 单机训练器 +│ │ ├── parallel_trainer.py # 并行环境训练器 +│ │ └── distributed_trainer.py # 分布式训练器 +│ └── utils/ # 工具函数 +│ ├── __init__.py +│ └── data_processor.py # 数据处理器(对齐ILStudio pipeline) +├── configs/ +│ ├── rl/ # RL 配置(新增) +│ │ ├── ppo.yaml +│ │ ├── sac.yaml +│ │ ├── dpo.yaml +│ │ └── grpo.yaml +│ └── ... +├── train_rl.py # RL 训练入口(新增) +└── ... +``` + +--- + +## 核心基类设计 + +### 1. Replay基类 (`BaseReplay`) + +**职责:** 存储和管理经验数据(transitions) + +**设计理念:** +- **存储原始Meta数据**:在buffer中存储原始的MetaObs和MetaAction,保持数据的原始性 +- **完整信息存储**:支持存储MetaObs和MetaAction的所有字段(state、image、raw_lang、state_ee、state_joint、depth、pc等) +- **扩展性**:支持存储额外的自定义字段(value、log_prob、advantage、trajectory_id等),方便后续扩展 +- **采样时转换**:采样时通过转换函数对齐ILStudio的data pipeline(normalization等) +- **兼容性**:既保持数据的原始性,又兼容ILStudio的normalization pipeline + +**核心接口:** + +```python +class BaseReplay: + """Replay Buffer基类""" + + def __init__( + self, + capacity: int = 1000000, + device: Union[str, torch.device] = 'cpu', + **kwargs + ): + """ + 初始化Replay Buffer + + Args: + capacity: Buffer容量(最大存储的transition数量) + device: 数据存储设备('cpu'或'cuda',默认'cpu') + - 'cpu': 数据存储在CPU内存中 + - 'cuda'或'cuda:0': 数据存储在GPU内存中 + **kwargs: 其他初始化参数 + """ + self.capacity = capacity + self.device = torch.device(device) if isinstance(device, str) else device + + def add(self, transition: Dict[str, Any]) -> None: + """ + 添加一个transition到buffer + + 存储原始Meta数据(MetaObs、MetaAction),不进行任何normalization + 支持存储MetaObs和MetaAction的所有字段,以及额外的自定义信息 + + Args: + transition: 包含以下字段的字典: + - state: MetaObs格式的当前状态(原始数据,包含所有字段) + - action: MetaAction格式的动作(原始数据,包含所有字段) + - reward: float奖励 + - next_state: MetaObs格式的下一个状态(原始数据) + - done: bool是否结束 + - info: 可选,额外信息字典 + - **其他自定义字段**: 可以存储任何额外的信息 + """ + raise NotImplementedError + + def sample(self, batch_size: int) -> Dict[str, Any]: + """ + 从buffer中采样一个batch(原始数据) + + Args: + batch_size: batch大小 + + Returns: + 包含原始Meta数据的字典(未经过normalization) + """ + raise NotImplementedError + + def sample_for_training( + self, + batch_size: int, + data_processor: Optional[Callable] = None + ) -> Dict[str, Any]: + """ + 采样并转换为ILStudio训练格式 + + Args: + batch_size: batch大小 + data_processor: 可选的数据处理函数,用于对齐ILStudio pipeline + - 如果为None,返回原始数据 + - 如果提供,应该是一个函数:batch -> processed_batch + + Returns: + 处理后的batch数据(符合ILStudio训练格式) + """ + batch = self.sample(batch_size) + if data_processor is not None: + batch = data_processor(batch) + return batch + + def __len__(self) -> int: + """返回buffer当前大小""" + raise NotImplementedError + + def clear(self) -> None: + """清空buffer""" + raise NotImplementedError + + def save(self, path: str, **kwargs) -> None: + """ + 保存buffer中的数据到文件 + + Args: + path: 保存路径(可以是文件路径或目录路径) + **kwargs: 保存选项 + - format: 保存格式(如'pkl'、'hdf5'、'npz'等,可选) + - compress: 是否压缩(可选) + """ + raise NotImplementedError + + def load(self, path: str, **kwargs) -> None: + """ + 从文件加载数据到buffer + + Args: + path: 加载路径(可以是文件路径或目录路径) + **kwargs: 加载选项 + - format: 加载格式(可选,可自动推断) + - append: 是否追加到现有buffer(默认False,清空后加载) + """ + raise NotImplementedError +``` + +--- + +### 2. 算法基类 (`BaseAlgorithm`) + +**职责:** 定义RL算法的核心逻辑 + +**设计理念:** 参考SKRL的设计,将replay放到算法中,这样: +- 每个算法可以有自己的replay配置 +- 一个Trainer可以训练多个不同的算法(每个算法有自己的replay) +- 更灵活,支持多智能体场景 + +**核心接口:** + +```python +class BaseAlgorithm: + """RL算法基类""" + + def __init__( + self, + meta_policy: MetaPolicy, + replay: Optional[Union[BaseReplay, Dict[str, BaseReplay]]] = None, + **kwargs + ): + """ + 初始化算法 + + Args: + meta_policy: ILStudio的MetaPolicy实例(必需属性) + replay: 支持两种格式: + - BaseReplay实例:单个replay buffer(所有环境共享) + - Dict[str, BaseReplay]:多个replay buffer(按环境类型分离) + - None:不使用replay buffer(on-policy算法) + **kwargs: 算法特定的参数 + """ + self.meta_policy = meta_policy # 必需属性 + self.replay = replay # 可选属性(off-policy算法需要) + + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """ + 使用一个batch的数据更新策略 + + Args: + batch: 可选,batch数据 + - 如果为None且replay存在,从replay采样 + - 如果提供,直接使用提供的batch + **kwargs: 更新参数,可以包括: + - batch_size: 从replay采样时的batch大小 + - env_types: 指定从哪些环境类型的replay采样(当replay是Dict[str, BaseReplay]时) + - 例如:['indoor', 'outdoor'] + - 如果为None,从所有环境类型的replay采样 + - env_weights: 不同环境类型的数据权重(当使用多个环境类型的replay时) + - 例如:{'indoor': 0.6, 'outdoor': 0.4} + - 如果为None,使用均匀权重 + + Returns: + 包含loss、metrics等信息的字典 + + 示例: + # 场景1:单个replay buffer + algorithm.update(batch_size=256) + + # 场景2:多个replay buffer(按环境类型分离) + algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + """ + raise NotImplementedError + + def compute_loss(self, batch: Dict[str, Any]) -> torch.Tensor: + """ + 计算损失(可选,某些算法可能需要) + + Args: + batch: batch数据 + + Returns: + 损失值 + """ + raise NotImplementedError + + def select_action(self, obs: MetaObs, **kwargs) -> MetaAction: + """ + 选择动作(可选,某些算法可能需要) + + Args: + obs: MetaObs格式的observation + **kwargs: 其他参数(如exploration等) + + Returns: + MetaAction格式的动作 + """ + # 默认使用meta_policy的select_action + return self.meta_policy.select_action(obs, **kwargs) + + def record_transition( + self, + state: MetaObs, + action: MetaAction, + reward: float, + next_state: MetaObs, + done: bool, + info: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + """ + 记录transition到replay buffer(如果存在) + + 支持存储完整的MetaObs和MetaAction信息,以及额外的自定义字段 + 如果使用多个replay buffer(按环境类型分离),会根据kwargs中的env_type选择对应的replay + + Args: + state: 当前状态(MetaObs,包含所有字段) + action: 动作(MetaAction,包含所有字段) + reward: 奖励 + next_state: 下一个状态(MetaObs,包含所有字段) + done: 是否结束 + info: 额外信息字典 + **kwargs: 其他自定义字段,可以存储任何额外信息 + - env_type: 环境类型标识(如果replay是Dict[str, BaseReplay]) + - 例如:value、log_prob、advantage、trajectory_id等 + """ + if self.replay is not None: + transition = { + 'state': state, + 'action': action, + 'reward': reward, + 'next_state': next_state, + 'done': done, + 'info': info or {}, + **kwargs + } + + env_type = kwargs.get('env_type', None) + + if isinstance(self.replay, dict): + if env_type is None: + raise ValueError("env_type must be provided when using multiple replay buffers") + if env_type not in self.replay: + raise ValueError(f"env_type '{env_type}' not found in replay buffers") + self.replay[env_type].add(transition) + else: + self.replay.add(transition) +``` + +--- + +### 3. 奖励函数基类 (`BaseReward`) + +**职责:** 定义奖励函数的接口 + +**设计理念:** +- **模块化**:奖励函数独立模块,易于替换和扩展 +- **可组合**:支持多个奖励函数组合使用 +- **语言条件**:支持VLA模型的语言条件奖励 + +**核心接口:** + +```python +class BaseReward: + """奖励函数基类""" + + def __init__(self, **kwargs): + """ + 初始化奖励函数 + + Args: + **kwargs: 奖励函数特定的参数 + """ + pass + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + 计算奖励 + + Args: + state: 当前状态(MetaObs) + action: 动作(MetaAction) + next_state: 下一个状态(MetaObs) + env_reward: 环境原始奖励 + info: 额外信息字典 + + Returns: + 计算后的奖励值 + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + 重置奖励函数状态(如果需要) + + Args: + **kwargs: 重置参数 + """ + pass +``` + +**具体实现示例:** + +```python +# rl/rewards/sparse_reward.py +class SparseReward(BaseReward): + """稀疏奖励:只在任务完成时给予奖励""" + + def __init__(self, success_reward: float = 1.0, **kwargs): + super().__init__(**kwargs) + self.success_reward = success_reward + + def compute(self, state, action, next_state, env_reward, info): + if info and info.get('success', False): + return self.success_reward + return 0.0 + +# rl/rewards/dense_reward.py +class DenseReward(BaseReward): + """密集奖励:基于状态距离的奖励""" + + def __init__(self, goal_key: str = 'goal', distance_scale: float = 1.0, **kwargs): + super().__init__(**kwargs) + self.goal_key = goal_key + self.distance_scale = distance_scale + + def compute(self, state, action, next_state, env_reward, info): + # 计算到目标的距离 + if hasattr(state, 'state') and hasattr(next_state, 'state'): + goal = info.get(self.goal_key, None) + if goal is not None: + prev_dist = np.linalg.norm(state.state - goal) + curr_dist = np.linalg.norm(next_state.state - goal) + progress = (prev_dist - curr_dist) * self.distance_scale + return env_reward + progress + return env_reward + +# rl/rewards/language_reward.py +class LanguageReward(BaseReward): + """语言条件奖励:基于语言指令的奖励(VLA)""" + + def __init__(self, reward_model=None, **kwargs): + super().__init__(**kwargs) + self.reward_model = reward_model + + def compute(self, state, action, next_state, env_reward, info): + if self.reward_model is None: + return env_reward + + # 使用奖励模型计算语言条件奖励 + if hasattr(state, 'raw_lang') and state.raw_lang is not None: + lang_reward = self.reward_model.compute( + state=state, + action=action, + next_state=next_state, + instruction=state.raw_lang + ) + return env_reward + lang_reward + return env_reward +``` + +--- + +### 4. 数据收集器基类 (`BaseCollector`) + +**职责:** 从环境中收集交互数据并存储到replay buffer + +**设计理念:** +- **职责分离**:将数据收集逻辑从trainer中分离,使trainer更专注于训练循环协调 +- **环境抽象**:支持单个环境、并行环境、多环境类型 +- **原始数据存储**:只存储环境原始奖励,不进行奖励函数计算,保证数据的原始性 +- **统计信息**:收集并返回episode统计信息 + +**核心接口:** + +```python +class BaseCollector: + """数据收集器基类""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + algorithm: BaseAlgorithm, + **kwargs + ): + """ + 初始化收集器 + + Args: + meta_envs: 支持多种格式: + - MetaEnv实例:单个环境 + - List[MetaEnv]:环境列表(同类型环境) + - Callable:环境工厂函数 + - Dict[str, Any]:多环境配置字典(支持不同类型环境) + algorithm: BaseAlgorithm实例(必需属性) + - 用于选择动作和记录transition + **kwargs: 收集器特定的参数 + + 注意:Collector只存储环境原始奖励,不进行奖励函数计算 + 奖励函数在trainer中用于训练时的奖励计算 + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + 收集 n_steps 的交互数据 + + Args: + n_steps: 收集的步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 如果提供,会在record_transition时传入env_type + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + 包含统计信息的字典,如: + - episode_rewards: episode奖励列表 + - episode_lengths: episode长度列表 + - total_steps: 总步数 + - env_type_stats: 按环境类型分组的统计信息(如果支持多环境) + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + 重置收集器状态(如重置环境) + + Args: + **kwargs: 重置参数 + """ + raise NotImplementedError +``` + +**具体实现示例:** + +```python +# rl/collectors/sim_collector.py +class SimCollector(BaseCollector): + """仿真环境数据收集器""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv]], + algorithm: BaseAlgorithm, + n_envs: int = 1, + **kwargs + ): + super().__init__(meta_envs, algorithm, **kwargs) + self.n_envs = n_envs + + # 初始化环境 + if isinstance(meta_envs, list): + self.envs = meta_envs + elif isinstance(meta_envs, MetaEnv): + self.envs = [meta_envs] + else: + raise ValueError(f"Unsupported env type: {type(meta_envs)}") + + self._last_obs = None + self._last_dones = None + + def reset(self, **kwargs) -> None: + """重置所有环境""" + self._last_obs = [] + self._last_dones = [] + for env in self.envs: + obs = env.reset() + self._last_obs.append(obs) + self._last_dones.append(False) + + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + 收集 n_steps 的交互数据 + + Args: + n_steps: 收集的步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 如果提供,会在record_transition时传入env_type + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + 统计信息字典 + """ + if self._last_obs is None: + self.reset() + + stats = { + 'episode_rewards': [], + 'episode_lengths': [], + 'total_steps': 0 + } + + for step in range(n_steps): + # 获取动作 + actions = [] + for i, obs in enumerate(self._last_obs): + if not self._last_dones[i]: + with torch.no_grad(): + action = self.algorithm.select_action(obs) + actions.append(action) + else: + # 如果环境已结束,使用dummy action + actions.append(None) + + # 环境交互 + new_obs_list = [] + rewards = [] + dones = [] + infos = [] + + for i, (env, action) in enumerate(zip(self.envs, actions)): + if action is not None: + new_obs, reward, done, info = env.step(action) + + # 记录transition(只存储环境原始奖励,不进行奖励函数计算) + # 如果提供了env_type,会在record_transition时传入,用于多环境数据分离 + transition_kwargs = {} + if env_type is not None: + transition_kwargs['env_type'] = env_type + + self.algorithm.record_transition( + state=self._last_obs[i], + action=action, + reward=reward, # 存储原始奖励 + next_state=new_obs, + done=done, + info=info, + **transition_kwargs # 传入env_type等额外信息 + ) + + new_obs_list.append(new_obs) + rewards.append(reward) # 统计使用原始奖励 + dones.append(done) + infos.append(info) + + # 统计episode信息 + if done and 'episode' in info: + stats['episode_rewards'].append(info['episode'].get('r', 0)) + stats['episode_lengths'].append(info['episode'].get('l', 0)) + + # 如果episode结束,重置环境 + if done: + new_obs = env.reset() + new_obs_list[i] = new_obs + else: + new_obs_list.append(self._last_obs[i]) + rewards.append(0) + dones.append(True) + infos.append({}) + + self._last_obs = new_obs_list + self._last_dones = dones + stats['total_steps'] += len([d for d in dones if not d]) + + return stats + +# rl/collectors/real_collector.py +class RealCollector(BaseCollector): + """真实机器人数据收集器""" + + def __init__( + self, + meta_envs: MetaEnv, # 真实机器人环境 + algorithm: BaseAlgorithm, + **kwargs + ): + super().__init__(meta_envs, algorithm, **kwargs) + # 真实机器人特定的初始化... + + def collect(self, n_steps: int) -> Dict[str, Any]: + """ + 从真实机器人收集数据 + + 注意:真实机器人收集可能需要特殊的安全检查和限制 + """ + # 实现真实机器人数据收集逻辑 + # 可能需要添加安全限制、速度限制等 + pass +``` + +--- + +### 5. 训练器基类 (`BaseTrainer`) + +**职责:** 协调环境、策略和算法,执行训练循环 + +**核心接口:** + +```python +class BaseTrainer: + """RL训练器基类""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + algorithm: Union[BaseAlgorithm, List[BaseAlgorithm]], + collector: Optional[BaseCollector] = None, + reward_fn: Optional[Union[BaseReward, Callable]] = None, + **kwargs + ): + """ + 初始化训练器 + + Args: + meta_envs: 支持多种格式: + - MetaEnv实例:单个环境 + - List[MetaEnv]:环境列表(同类型环境) + - Callable:环境工厂函数 + - Dict[str, Any]:多环境配置字典(支持不同类型环境) + algorithm: BaseAlgorithm实例或BaseAlgorithm列表(必需属性) + - 单个算法:单个智能体训练 + - 算法列表:多个算法在同一个环境中独立训练(每个算法有自己的replay buffer) + collector: 可选的数据收集器(如果为None,trainer会创建默认collector) + - 单个算法:单个collector + - 多个算法:可以是collector列表,每个算法对应一个collector + - 如果为None,trainer会为每个算法创建默认collector + reward_fn: 可选的奖励函数(如果为None,使用环境原始奖励) + - 可以是BaseReward实例或Callable函数 + - 在训练时用于计算奖励(用于算法更新) + - 注意:replay buffer中存储的是原始奖励,奖励函数只在训练时应用 + **kwargs: 训练器特定的参数 + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + self.reward_fn = reward_fn + + # 处理collector初始化 + if collector is None: + from rl.collectors import SimCollector + if isinstance(algorithm, BaseAlgorithm): + # 单个算法:创建单个collector + collector = SimCollector( + meta_envs=meta_envs, + algorithm=algorithm + ) + else: + # 多个算法:为每个算法创建collector + collector = [ + SimCollector( + meta_envs=meta_envs, + algorithm=alg + ) + for alg in algorithm + ] + + self.collector = collector + + def train(self, **kwargs) -> None: + """ + 执行训练循环 + + Args: + **kwargs: 训练参数,可以包括: + - total_steps: 总训练步数(可选) + - total_episodes: 总episode数(可选) + - max_time: 最大训练时间(可选) + - log_interval: 日志记录间隔(可选) + - save_interval: 模型保存间隔(可选) + - eval_interval: 评估间隔(可选) + """ + raise NotImplementedError + + def compute_reward( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + 计算奖励(支持自定义奖励函数) + + 在训练时使用,用于算法更新时的奖励计算 + replay buffer中存储的是原始奖励,奖励函数只在训练时应用 + + Args: + state: 当前状态 + action: 动作 + next_state: 下一个状态 + env_reward: 环境原始奖励 + info: 额外信息字典 + + Returns: + 计算后的奖励值 + """ + if self.reward_fn is not None: + if isinstance(self.reward_fn, BaseReward): + return self.reward_fn.compute(state, action, next_state, env_reward, info) + else: + return self.reward_fn(state, action, next_state, env_reward, info) + return env_reward + + def collect_rollout(self, n_steps: int, env_type: Optional[str] = None) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + 收集rollout数据(使用collector) + + Args: + n_steps: 收集步数 + env_type: 可选,环境类型标识(用于多环境场景) + - 用于支持单个算法在多个不同环境的数据存储 + + Returns: + - 单个算法:rollout统计信息字典 + - 多个算法:rollout统计信息字典列表 + """ + if isinstance(self.collector, list): + # 多个算法:每个算法独立收集数据 + return [col.collect(n_steps, env_type=env_type) for col in self.collector] + else: + # 单个算法 + return self.collector.collect(n_steps, env_type=env_type) + + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + 评估策略性能 + + Args: + n_episodes: 评估的episode数量 + render: 是否渲染环境(可选) + env_type: 可选,指定评估的环境类型 + **kwargs: 其他评估参数 + + Returns: + 包含评估指标的字典 + """ + raise NotImplementedError + + def save(self, path: str) -> None: + """保存模型和训练状态""" + raise NotImplementedError + + def load(self, path: str) -> None: + """加载模型和训练状态""" + raise NotImplementedError +``` + +--- + +## 配置系统 + +### 配置结构 + +RL配置遵循ILStudio的配置系统设计,使用YAML格式,支持命令行覆盖。 + +#### `configs/rl/ppo.yaml` + +```yaml +name: ppo +type: rl.algorithms.ppo.PPOAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 # 折扣因子 + gae_lambda: 0.95 # GAE lambda + clip_range: 0.2 # PPO clip range + value_coef: 0.5 # Value loss 系数 + entropy_coef: 0.01 # 熵正则化系数 + max_grad_norm: 0.5 # 梯度裁剪 + n_steps: 2048 # 每次更新的步数 + batch_size: 64 # Mini-batch 大小 + n_epochs: 10 # 每次更新的 epoch 数 + learning_rate: 3e-4 + +# Replay Buffer配置(on-policy算法使用RolloutBuffer) +replay: + type: rl.buffer.rollout_buffer.RolloutBuffer + capacity: 2048 + device: cpu + n_envs: 8 + gae_lambda: 0.95 + gamma: 0.99 + +# 奖励函数配置(可选,在trainer中使用,用于训练时的奖励计算) +reward: + type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 1.0 + # 或者使用组合奖励 + # type: rl.rewards.composite_reward.CompositeReward + # components: + # - type: rl.rewards.sparse_reward.SparseReward + # success_reward: 1.0 + # - type: rl.rewards.dense_reward.DenseReward + # distance_scale: 0.1 + +# 数据收集器配置(可选,trainer会创建默认collector) +collector: + type: rl.collectors.sim_collector.SimCollector + n_envs: 8 + +# 策略网络配置(复用现有 policy 配置) +policy: + type: policy.act # 或 policy.diffusion_policy, policy.openvla 等 + # 继承对应 policy 的配置... + +# 价值网络配置(可选,Actor-Critic算法需要) +value_network: + type: mlp + hidden_dims: [256, 256] + activation: relu + +# 训练配置 +training: + total_timesteps: 1000000 + save_freq: 10000 + eval_freq: 5000 + log_interval: 100 +``` + +#### `configs/rl/sac.yaml` + +```yaml +name: sac +type: rl.algorithms.sac.SACAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 + tau: 0.005 # Soft update coefficient + learning_rate: 3e-4 + batch_size: 256 + target_update_interval: 1 + alpha: 0.2 # Temperature parameter (auto-tuned if None) + +# Replay Buffer配置(off-policy算法使用ReplayBuffer) +replay: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + prioritized: false # 是否使用优先级采样 + +# 奖励函数配置(可选) +reward: + type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 1.0 + +# 策略网络配置 +policy: + type: policy.act + +# Q网络配置 +q_network: + type: mlp + hidden_dims: [256, 256] + activation: relu + +# 训练配置 +training: + total_timesteps: 1000000 + save_freq: 10000 + eval_freq: 5000 + log_interval: 100 +``` + +#### `configs/rl/dpo.yaml` (VLA Fine-tuning) + +```yaml +name: dpo +type: rl.algorithms.dpo.DPOAlgorithm + +# 算法参数 +algorithm: + beta: 0.1 # KL penalty coefficient + learning_rate: 1e-5 + batch_size: 32 + reference_free: false # 是否使用无参考模型的 DPO + label_smoothing: 0.0 + +# 策略网络配置(VLA模型) +policy: + type: policy.openvla + # 继承对应 policy 的配置... + +# 参考策略配置(用于KL散度计算) +reference_policy: + type: policy.openvla + # 通常是预训练模型的frozen副本 + +# 奖励函数配置(语言条件奖励) +reward: + type: rl.rewards.language_reward.LanguageReward + reward_model: + type: learned_reward + checkpoint: ckpt/reward_model.pth + +# 训练配置 +training: + total_episodes: 1000 + save_freq: 100 + eval_freq: 50 + log_interval: 10 +``` + +### 配置加载 + +```python +# 在ConfigLoader中添加RL配置加载方法 +class ConfigLoader: + # ... 现有方法 ... + + def load_rl(self, name_or_path: str) -> Tuple[Dict[str, Any], str]: + """ + 加载RL配置 + + Args: + name_or_path: RL配置名称或路径(如'ppo'或'rl/ppo') + + Returns: + (配置字典, 配置文件路径) + """ + return self.load_yaml_config('rl', name_or_path) +``` + +--- + +## VLA 强化学习特殊处理 + +### DPO (Direct Preference Optimization) + +```python +# rl/algorithms/dpo.py +from dataclasses import dataclass +import torch +import torch.nn.functional as F +from rl.base import BaseAlgorithm +from benchmark.base import MetaPolicy + +@dataclass +class DPOConfig: + """DPO 特定配置""" + beta: float = 0.1 + reference_free: bool = False + label_smoothing: float = 0.0 + learning_rate: float = 1e-5 + batch_size: int = 32 + +class DPOAlgorithm(BaseAlgorithm): + """ + Direct Preference Optimization for VLA + + 适用于: + - 有偏好数据 (chosen vs rejected trajectories) + - Fine-tuning 预训练 VLA 模型 + """ + + def __init__( + self, + meta_policy: MetaPolicy, + ref_policy: MetaPolicy, # 参考策略 (frozen) + config: DPOConfig, + **kwargs + ): + super().__init__(meta_policy=meta_policy, **kwargs) + self.ref_policy = ref_policy + self.config = config + + # 冻结参考策略 + self.ref_policy.eval() + for p in self.ref_policy.parameters(): + p.requires_grad = False + + def compute_dpo_loss( + self, + chosen_obs: MetaObs, + chosen_actions: MetaAction, + rejected_obs: MetaObs, + rejected_actions: MetaAction + ) -> Dict[str, torch.Tensor]: + """ + 计算 DPO 损失 + + L_DPO = -log(σ(β * (log π(a_w|s) - log π_ref(a_w|s) + - log π(a_l|s) + log π_ref(a_l|s)))) + """ + # 计算当前策略的 log prob + chosen_logps = self.meta_policy.get_log_prob(chosen_obs, chosen_actions) + rejected_logps = self.meta_policy.get_log_prob(rejected_obs, rejected_actions) + + # 计算参考策略的 log prob + with torch.no_grad(): + ref_chosen_logps = self.ref_policy.get_log_prob(chosen_obs, chosen_actions) + ref_rejected_logps = self.ref_policy.get_log_prob(rejected_obs, rejected_actions) + + # DPO 损失 + chosen_rewards = self.config.beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = self.config.beta * (rejected_logps - ref_rejected_logps) + + loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean() + + return { + 'loss': loss, + 'chosen_rewards': chosen_rewards.mean(), + 'rejected_rewards': rejected_rewards.mean(), + 'reward_margin': (chosen_rewards - rejected_rewards).mean() + } + + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """DPO更新逻辑""" + if batch is None: + raise ValueError("DPO requires batch with chosen/rejected pairs") + + # 提取chosen和rejected数据 + chosen_obs = batch['chosen_obs'] + chosen_actions = batch['chosen_actions'] + rejected_obs = batch['rejected_obs'] + rejected_actions = batch['rejected_actions'] + + # 计算损失 + loss_dict = self.compute_dpo_loss( + chosen_obs, chosen_actions, + rejected_obs, rejected_actions + ) + + # 反向传播 + loss_dict['loss'].backward() + + return loss_dict +``` + +### GRPO (Group Relative Policy Optimization) + +```python +# rl/algorithms/grpo.py +from dataclasses import dataclass +import torch +from rl.base import BaseAlgorithm + +@dataclass +class GRPOConfig: + """GRPO 配置""" + group_size: int = 4 + kl_coef: float = 0.1 + reward_scale: float = 1.0 + learning_rate: float = 1e-5 + +class GRPOAlgorithm(BaseAlgorithm): + """ + Group Relative Policy Optimization + + 适用于 VLA 的在线强化学习: + 1. 对每个任务/指令采样多个轨迹 + 2. 使用奖励对轨迹进行排序 + 3. 使用组内相对奖励进行策略优化 + """ + + def __init__( + self, + meta_policy: MetaPolicy, + ref_policy: MetaPolicy, # 参考策略 (frozen) + config: GRPOConfig, + **kwargs + ): + super().__init__(meta_policy=meta_policy, **kwargs) + self.ref_policy = ref_policy + self.config = config + + # 冻结参考策略 + self.ref_policy.eval() + for p in self.ref_policy.parameters(): + p.requires_grad = False + + def collect_group_rollouts(self, obs: MetaObs, n_samples: int) -> List[Dict]: + """ + 对同一观测采样多个动作序列 + + Args: + obs: 初始观测 + n_samples: 采样数量 + + Returns: + 轨迹列表 + """ + trajectories = [] + for _ in range(n_samples): + traj = self._rollout_episode(obs) + trajectories.append(traj) + return trajectories + + def compute_grpo_loss( + self, + trajectories: List[Dict], + rewards: List[float] + ) -> Dict[str, torch.Tensor]: + """ + 计算 GRPO 损失 + + 使用组内相对奖励作为 advantage + """ + # 归一化组内奖励 + rewards_tensor = torch.tensor(rewards) + normalized_rewards = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8) + + loss = 0 + for traj, adv in zip(trajectories, normalized_rewards): + log_probs = self.meta_policy.get_trajectory_log_prob(traj) + loss -= (log_probs * adv).mean() + + # KL 惩罚 + kl_loss = self._compute_kl_penalty(trajectories) + + total_loss = loss + self.config.kl_coef * kl_loss + + return { + 'loss': total_loss, + 'policy_loss': loss, + 'kl_loss': kl_loss, + 'mean_reward': rewards_tensor.mean() + } + + def _compute_kl_penalty(self, trajectories: List[Dict]) -> torch.Tensor: + """计算KL散度惩罚""" + kl_loss = 0 + for traj in trajectories: + current_logps = self.meta_policy.get_trajectory_log_prob(traj) + with torch.no_grad(): + ref_logps = self.ref_policy.get_trajectory_log_prob(traj) + kl = (current_logps - ref_logps).mean() + kl_loss += kl + return kl_loss / len(trajectories) +``` + +--- + +## 训练入口 `train_rl.py` + +```python +#!/usr/bin/env python3 +""" +ILStudio Reinforcement Learning Training Script + +支持: +- 传统 RL (PPO, SAC) 训练机器人控制策略 +- VLA fine-tuning (DPO, GRPO) 使用强化学习 +- 混合训练 (IL + RL) +""" + +import argparse +from loguru import logger +from configs.loader import ConfigLoader +from data_utils.utils import set_seed +from policy.policy_loader import PolicyLoader +from rl.algorithms import get_algorithm_class +from rl.buffer import get_replay_class +from rl.rewards import get_reward_class +from rl.collectors import get_collector_class +from rl.trainers import get_trainer_class + +def parse_args(): + parser = argparse.ArgumentParser(description='RL Training for ILStudio') + + # 基础配置 + parser.add_argument('-p', '--policy', type=str, required=True, + help='Policy config (e.g., act, diffusion_policy, openvla)') + parser.add_argument('-r', '--rl_config', type=str, default='ppo', + help='RL algorithm config (e.g., ppo, sac, dpo)') + parser.add_argument('-e', '--env', type=str, required=True, + help='Environment config') + parser.add_argument('-o', '--output_dir', type=str, default='ckpt/rl_output') + + # 训练模式 + parser.add_argument('--mode', type=str, default='online', + choices=['online', 'offline', 'hybrid'], + help='Training mode: online (env interaction), offline (dataset), hybrid') + + # 预训练模型 (用于 fine-tuning) + parser.add_argument('--pretrained', type=str, default=None, + help='Path to pretrained model checkpoint') + + # 其他 + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--device', type=str, default='cuda') + + args, unknown = parser.parse_known_args() + args.unknown_args = unknown + return args + + +def create_algorithm(rl_config, policy, env_config, args): + """创建RL算法实例""" + # 创建replay buffer(如果配置中有) + replay = None + if 'replay' in rl_config: + replay_class = get_replay_class(rl_config['replay']['type']) + replay = replay_class(**rl_config['replay']) + + # 创建算法实例(奖励函数不在算法中,在trainer中) + algorithm_class = get_algorithm_class(rl_config['type']) + algorithm = algorithm_class( + meta_policy=policy, + replay=replay, + **rl_config.get('algorithm', {}) + ) + + return algorithm + + +def create_reward_fn(rl_config): + """创建奖励函数实例(在trainer中使用,用于训练时的奖励计算)""" + reward_fn = None + if 'reward' in rl_config: + reward_class = get_reward_class(rl_config['reward']['type']) + reward_fn = reward_class(**rl_config['reward']) + return reward_fn + + +def create_collector(rl_config, env, algorithm): + """创建数据收集器实例(不包含reward_fn,只存储原始奖励)""" + collector = None + if 'collector' in rl_config: + collector_class = get_collector_class(rl_config['collector']['type']) + collector = collector_class( + meta_envs=env, + algorithm=algorithm, + **rl_config['collector'] + ) + return collector + + +def main(): + args = parse_args() + set_seed(args.seed) + + # 加载配置 + cfg_loader = ConfigLoader(args=args, unknown_args=args.unknown_args) + policy_config, _ = cfg_loader.load_policy(args.policy) + rl_config, _ = cfg_loader.load_rl(args.rl_config) + env_config, _ = cfg_loader.load_env(args.env) + + # 创建环境 + env = create_env_from_config(env_config) + + # 加载/创建策略 + policy_loader = PolicyLoader() + if args.pretrained: + logger.info(f"Loading pretrained model from {args.pretrained}") + policy = policy_loader.load_pretrained(args.pretrained, policy_config) + else: + logger.info("Creating policy from scratch") + policy = policy_loader.create_policy(policy_config, args) + + # 创建RL算法 + algorithm = create_algorithm(rl_config, policy, env_config, args) + + # 创建奖励函数(在trainer中使用,用于训练时的奖励计算) + reward_fn = create_reward_fn(rl_config) + + # 创建数据收集器(可选,trainer会创建默认的,不包含reward_fn) + collector = create_collector(rl_config, env, algorithm) + + # 创建训练器 + trainer_class = get_trainer_class('simple') # 或从配置中读取 + trainer = trainer_class( + meta_envs=env, + algorithm=algorithm, + collector=collector, # collector在trainer中(只存储原始奖励) + reward_fn=reward_fn # reward_fn在trainer中(用于训练时的奖励计算) + ) + + # 训练 + logger.info("="*60) + logger.info(f"🚀 Starting RL Training: {rl_config['name']}") + logger.info(f" Policy: {policy_config.get('name', args.policy)}") + logger.info(f" Environment: {env_config.get('name', args.env)}") + logger.info(f" Mode: {args.mode}") + logger.info("="*60) + + trainer.train( + total_steps=rl_config['training']['total_timesteps'], + log_interval=rl_config['training'].get('log_interval', 100), + save_interval=rl_config['training'].get('save_freq', 10000), + eval_interval=rl_config['training'].get('eval_freq', 5000), + output_dir=args.output_dir + ) + + logger.info("✓ Training completed!") + + +if __name__ == '__main__': + main() +``` + +--- + +## 使用示例 + +### 示例1:传统RL训练(PPO + ACT) + +```bash +python train_rl.py \ + -p act \ + -r ppo \ + -e libero.example \ + -o ckpt/rl_act_ppo +``` + +### 示例2:VLA Fine-tuning(DPO + OpenVLA) + +```bash +python train_rl.py \ + -p openvla \ + -r dpo \ + -e behavior1k.example \ + --pretrained ckpt/openvla_pretrained \ + -o ckpt/openvla_dpo +``` + +### 示例3:使用自定义奖励函数 + +```python +# 在配置文件中指定奖励函数 +# configs/rl/ppo.yaml +reward: + type: rl.rewards.composite_reward.CompositeReward + components: + - type: rl.rewards.sparse_reward.SparseReward + success_reward: 1.0 + weight: 0.5 + - type: rl.rewards.dense_reward.DenseReward + goal_key: goal + distance_scale: 0.1 + weight: 0.5 +``` + +### 示例4:单个算法在多个不同环境的训练(场景1) + +```python +# 场景1:单个算法在多个不同环境(如室内、室外)中训练 +# 数据按环境类型分离存储到不同的replay buffer + +from rl.buffer import MemoryReplay +from rl.algorithms import SACAlgorithm +from rl.collectors import SimCollector +from rl.trainers import SimpleTrainer + +# 创建多个环境的replay buffer(按环境类型分离) +replay = { + 'indoor': MemoryReplay(capacity=1000000, device='cpu'), + 'outdoor': MemoryReplay(capacity=1000000, device='cpu') +} + +# 创建算法(使用多个replay buffer) +algorithm = SACAlgorithm( + meta_policy=policy, + replay=replay # Dict[str, BaseReplay],按环境类型分离 +) + +# 创建多个环境的collector +indoor_collector = SimCollector( + meta_envs=indoor_envs, # 室内环境 + algorithm=algorithm +) + +outdoor_collector = SimCollector( + meta_envs=outdoor_envs, # 室外环境 + algorithm=algorithm +) + +# 创建trainer +trainer = SimpleTrainer( + meta_envs={'indoor': indoor_envs, 'outdoor': outdoor_envs}, + algorithm=algorithm, + reward_fn=reward_fn +) + +# 训练时,分别从不同环境收集数据 +# 数据会自动存储到对应环境类型的replay buffer中 +for step in range(total_steps): + # 从室内环境收集数据(env_type='indoor') + indoor_stats = indoor_collector.collect(n_steps=1000, env_type='indoor') + + # 从室外环境收集数据(env_type='outdoor') + outdoor_stats = outdoor_collector.collect(n_steps=1000, env_type='outdoor') + + # 更新算法(可以从不同环境类型的replay混合采样) + loss = algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], # 指定从哪些环境采样 + env_weights={'indoor': 0.6, 'outdoor': 0.4} # 不同环境的数据权重 + ) +``` + +### 示例5:多个算法在同一个环境中独立训练(场景2) + +```python +# 场景2:多个算法在同一个环境中独立训练 +# 每个算法有自己的replay buffer和collector + +from rl.buffer import MemoryReplay +from rl.algorithms import SACAlgorithm, PPOAlgorithm +from rl.collectors import SimCollector +from rl.trainers import SimpleTrainer + +# 创建多个算法,每个算法有自己的replay buffer +algorithm1 = SACAlgorithm( + meta_policy=policy1, + replay=MemoryReplay(capacity=1000000, device='cpu') +) + +algorithm2 = PPOAlgorithm( + meta_policy=policy2, + replay=MemoryReplay(capacity=100000, device='cpu') +) + +# 创建多个collector,每个算法对应一个collector +collector1 = SimCollector( + meta_envs=env, # 同一个环境 + algorithm=algorithm1 +) + +collector2 = SimCollector( + meta_envs=env, # 同一个环境 + algorithm=algorithm2 +) + +# 创建trainer(传入算法列表和collector列表) +trainer = SimpleTrainer( + meta_envs=env, + algorithm=[algorithm1, algorithm2], # 多个算法 + collector=[collector1, collector2], # 每个算法对应一个collector + reward_fn=reward_fn +) + +# 训练时,每个算法独立收集数据和更新 +for step in range(total_steps): + # 收集rollout(返回每个算法的统计信息) + stats_list = trainer.collect_rollout(n_steps=1000) + # stats_list[0] 是 algorithm1 的统计信息 + # stats_list[1] 是 algorithm2 的统计信息 + + # 每个算法独立更新 + loss1 = algorithm1.update(batch_size=256) + loss2 = algorithm2.update(batch_size=64) +``` + +### 示例6:配置文件中支持多环境场景 + +```yaml +# configs/rl/multi_env_sac.yaml +name: multi_env_sac +type: rl.algorithms.sac.SACAlgorithm + +# 算法参数 +algorithm: + gamma: 0.99 + learning_rate: 3e-4 + batch_size: 256 + +# Replay Buffer配置(按环境类型分离) +replay: + indoor: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + outdoor: + type: rl.buffer.memory_replay.MemoryReplay + capacity: 1000000 + device: cpu + +# 策略网络配置 +policy: + type: policy.act + +# 训练配置 +training: + total_timesteps: 1000000 + env_types: ['indoor', 'outdoor'] + env_weights: + indoor: 0.6 + outdoor: 0.4 +``` + +--- + +## 接口关系图 + +``` +┌─────────────────────────────────────────┐ +│ BaseReplay │ +│ (存储经验数据) │ +│ - 完整MetaObs/MetaAction │ +│ - 自定义字段(value、log_prob等) │ +└─────────────────────────────────────────┘ + ▲ + │ (可选,在算法中) + │ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ BaseAlgorithm │ │ BaseCollector │ │ BaseTrainer │ +│ │ │ │ │ │ +│ + meta_policy │◄─────┤ + algorithm │◄─────┤ + collector │ +│ + replay │ │ + meta_envs │ │ + algorithm │ +│ + update() │ │ + collect() │ │ + reward_fn │ +│ + record_ │ │ (只存储原始奖励) │ │ + compute_ │ +│ transition() │ └─────────────────┘ │ reward() │ +└─────────────────┘ │ │ + train() │ + │ │ + evaluate() │ + │ └─────────────────┘ + │ │ + ┌────────────┼────────────┐ │ + │ │ │ │ 使用 + │ 使用 │ 使用 │ │ + ▼ ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌─────────────────┐ + │ MetaEnv │ │MetaPolicy│ │ BaseReward │ + │ │ │ │ │ (奖励函数) │ + └──────────┘ └──────────┘ │ - compute() │ + └─────────────────┘ +``` + +--- + +## 设计亮点 + +### 1. **完整的Meta数据存储** +- Replay buffer存储完整的MetaObs和MetaAction,包括所有字段 +- 不丢失任何信息,方便后续分析和扩展 + +### 2. **模块化奖励函数** +- 奖励函数独立模块,易于替换和扩展 +- 支持组合多个奖励函数 +- 支持VLA的语言条件奖励 +- **原始数据存储**:replay buffer存储环境原始奖励,奖励函数只在训练时应用 +- **数据可复用性**:同样的数据可以用不同的奖励函数进行训练 + +### 3. **灵活的配置系统** +- 使用YAML配置,支持命令行覆盖 +- 复用ILStudio现有的配置加载机制 +- 配置驱动,易于实验和部署 + +### 4. **兼容ILStudio pipeline** +- **原始数据存储**:replay buffer存储原始MetaObs、MetaAction和环境原始奖励 +- **奖励函数分离**:奖励函数在trainer中应用,不影响数据收集和存储 +- 采样时通过data_processor对齐ILStudio的normalization pipeline +- 既保持灵活性,又兼容现有系统 + +### 5. **VLA特殊支持** +- 专门的DPO和GRPO算法实现 +- 支持语言条件奖励 +- 支持参考策略的KL散度约束 + +### 6. **数据集持久化** +- 支持保存和加载数据集,方便数据管理和复用 +- 支持多种保存格式(pkl、hdf5、npz等) +- 支持追加加载,可以合并多个数据集 + +### 7. **灵活的数据存储策略** +- **场景1:单个算法在多个不同环境** + - 算法可以使用 `Dict[str, BaseReplay]` 按环境类型分离存储 + - 在 `record_transition` 时传入 `env_type` 参数 + - 支持从不同环境类型的replay混合采样,可配置数据权重 + - 适用于跨域学习、多任务学习等场景 + +- **场景2:多个算法在同一个环境** + - Trainer 可以接受 `List[BaseAlgorithm]`,每个算法独立训练 + - 每个算法有自己的 replay buffer 和 collector + - 支持算法对比、ensemble训练等场景 + - 数据完全隔离,互不干扰 + +--- + +## 总结 + +这个整合设计: + +1. ✅ **保持设计2的核心架构**:BaseReplay、BaseAlgorithm、BaseTrainer三个基类 +2. ✅ **添加奖励函数模块化**:独立的BaseReward基类和具体实现 +3. ✅ **添加配置系统**:YAML配置文件和配置加载机制 +4. ✅ **添加VLA支持**:DPO和GRPO算法实现 +5. ✅ **添加训练入口**:train_rl.py脚本设计 + +通过组合这些组件,可以灵活地实现各种RL训练场景,同时保持与ILStudio现有系统的兼容性。 + diff --git a/rl/__init__.py b/rl/__init__.py new file mode 100644 index 00000000..7ffb8010 --- /dev/null +++ b/rl/__init__.py @@ -0,0 +1,172 @@ +""" +ILStudio Reinforcement Learning Module + +This module provides a modular RL framework for robot learning, supporting: +- Traditional RL algorithms (PPO, SAC, TD3, etc.) +- VLA fine-tuning algorithms (DPO, GRPO) +- Flexible replay buffers for experience storage +- Modular reward functions +- Data collectors for various environments +- Trainers for different training scenarios +- Infrastructure for reproducibility and stability + +The framework is designed to: +1. Be highly abstract - only core interfaces, no specific implementations +2. Be universal - support various RL algorithms and training modes +3. Be compatible - directly use MetaEnv and MetaPolicy without adapters +4. Be extensible - easy to extend for parallel and distributed training +5. Be modular - reward functions and config systems are independently replaceable +6. Be reproducible - comprehensive infrastructure for reproducibility + +Directory Structure: + rl/ + ├── __init__.py # This file + ├── base.py # BaseAlgorithm class + ├── algorithms/ # RL algorithm implementations + │ └── __init__.py # Algorithm registry + ├── buffer/ # Replay buffer implementations + │ ├── __init__.py + │ └── base_replay.py # BaseReplay class + ├── rewards/ # Reward function implementations + │ ├── __init__.py + │ └── base_reward.py # BaseReward class + ├── collectors/ # Data collector implementations + │ ├── __init__.py + │ └── base_collector.py # BaseCollector class + ├── trainers/ # Trainer implementations + │ ├── __init__.py + │ └── base_trainer.py # BaseTrainer class + ├── infra/ # Infrastructure for reproducibility + │ ├── __init__.py + │ ├── seed_manager.py # Random seed management + │ ├── logger.py # Logging system + │ ├── checkpoint.py # Checkpoint management + │ ├── callback.py # Callback system + │ └── distributed.py # Distributed training support + └── utils/ # Utility functions + └── __init__.py +""" + +# Base classes +from .base import BaseAlgorithm +from .buffer import BaseReplay +from .rewards import BaseReward +from .collectors import BaseCollector +from .trainers import BaseTrainer + +# Factory functions +from .algorithms import get_algorithm_class, register_algorithm, list_algorithms +from .rewards import get_reward_class, register_reward, list_rewards +from .collectors import get_collector_class, register_collector, list_collectors +from .trainers import get_trainer_class, register_trainer, list_trainers + +# Utility functions +from .utils import ( + compute_gae, + compute_returns, + RunningMeanStd, + explained_variance, + polyak_update, + hard_update, +) + +# Infrastructure - Seed management +from .infra import ( + SeedManager, + set_global_seed, + get_global_seed, +) + +# Infrastructure - Logging +from .infra import ( + BaseLogger, + ConsoleLogger, + TensorBoardLogger, + CompositeLogger, +) + +# Infrastructure - Checkpoint +from .infra import CheckpointManager + +# Infrastructure - Callbacks +from .infra import ( + Callback, + CallbackList, + ProgressCallback, + EvalCallback, + CheckpointCallback, + EarlyStoppingCallback, +) + +# Infrastructure - Distributed +from .infra import ( + DistributedContext, + get_world_size, + get_rank, + is_main_process, +) + +__all__ = [ + # Base classes + 'BaseAlgorithm', + 'BaseReplay', + 'BaseReward', + 'BaseCollector', + 'BaseTrainer', + + # Factory functions for algorithms + 'get_algorithm_class', + 'register_algorithm', + 'list_algorithms', + + # Factory functions for rewards + 'get_reward_class', + 'register_reward', + 'list_rewards', + + # Factory functions for collectors + 'get_collector_class', + 'register_collector', + 'list_collectors', + + # Factory functions for trainers + 'get_trainer_class', + 'register_trainer', + 'list_trainers', + + # Utility functions + 'compute_gae', + 'compute_returns', + 'RunningMeanStd', + 'explained_variance', + 'polyak_update', + 'hard_update', + + # Infrastructure - Seed management + 'SeedManager', + 'set_global_seed', + 'get_global_seed', + + # Infrastructure - Logging + 'BaseLogger', + 'ConsoleLogger', + 'TensorBoardLogger', + 'CompositeLogger', + + # Infrastructure - Checkpoint + 'CheckpointManager', + + # Infrastructure - Callbacks + 'Callback', + 'CallbackList', + 'ProgressCallback', + 'EvalCallback', + 'CheckpointCallback', + 'EarlyStoppingCallback', + + # Infrastructure - Distributed + 'DistributedContext', + 'get_world_size', + 'get_rank', + 'is_main_process', +] diff --git a/rl/algorithms/__init__.py b/rl/algorithms/__init__.py new file mode 100644 index 00000000..0c7d3b3e --- /dev/null +++ b/rl/algorithms/__init__.py @@ -0,0 +1,79 @@ +""" +RL Algorithms Module + +This module provides various RL algorithm implementations. + +Available algorithms: +- PPO: Proximal Policy Optimization +- SAC: Soft Actor-Critic +- TD3: Twin Delayed DDPG +- DPO: Direct Preference Optimization (for VLA) +- GRPO: Group Relative Policy Optimization (for VLA) +- REINFORCE: Basic policy gradient + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating algorithms. +""" + +from typing import Type, Dict, Any + +# Registry for algorithm classes +_ALGORITHM_REGISTRY: Dict[str, Type] = {} + + +def register_algorithm(name: str, algorithm_class: Type) -> None: + """ + Register an algorithm class. + + Args: + name: Algorithm name (e.g., 'ppo', 'sac') + algorithm_class: Algorithm class to register + """ + _ALGORITHM_REGISTRY[name.lower()] = algorithm_class + + +def get_algorithm_class(name_or_type: str) -> Type: + """ + Get algorithm class by name or type string. + + Args: + name_or_type: Algorithm name (e.g., 'ppo') or full type path + (e.g., 'rl.algorithms.ppo.PPOAlgorithm') + + Returns: + Algorithm class + + Raises: + ValueError: If algorithm not found + """ + # First check registry + if name_or_type.lower() in _ALGORITHM_REGISTRY: + return _ALGORITHM_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import algorithm from '{name_or_type}': {e}") + + raise ValueError(f"Unknown algorithm: '{name_or_type}'. Available: {list(_ALGORITHM_REGISTRY.keys())}") + + +def list_algorithms() -> list: + """List all registered algorithms.""" + return list(_ALGORITHM_REGISTRY.keys()) + + +__all__ = [ + 'register_algorithm', + 'get_algorithm_class', + 'list_algorithms', +] + diff --git a/rl/base.py b/rl/base.py new file mode 100644 index 00000000..8c9997dd --- /dev/null +++ b/rl/base.py @@ -0,0 +1,431 @@ +""" +Base Algorithm Class + +This module defines the base class for all RL algorithms in the framework. + +Design Philosophy (inspired by SKRL): +- Replay buffer is placed inside the algorithm, allowing: + - Each algorithm to have its own replay configuration + - A single Trainer to train multiple different algorithms (each with own replay) + - More flexibility, supporting multi-agent scenarios +""" + +import torch +import numpy as np +from typing import Dict, Any, Optional, Union, List, Callable +from abc import ABC, abstractmethod + +# Type hints for Meta classes (imported at runtime to avoid circular imports) +from benchmark.base import MetaObs, MetaAction, MetaPolicy + + +class BaseAlgorithm(ABC): + """ + Base class for RL algorithms. + + This class defines the core interface for all RL algorithms. + It holds a reference to the policy (MetaPolicy) and optionally a replay buffer. + + Attributes: + meta_policy: The MetaPolicy instance used for action selection + replay: Optional replay buffer(s) for experience storage + """ + + def __init__( + self, + meta_policy: MetaPolicy, + replay: Optional[Union['BaseReplay', Dict[str, 'BaseReplay']]] = None, + **kwargs + ): + """ + Initialize the algorithm. + + Args: + meta_policy: ILStudio's MetaPolicy instance (required) + replay: Supports two formats: + - BaseReplay instance: Single replay buffer (shared by all environments) + - Dict[str, BaseReplay]: Multiple replay buffers (separated by environment type) + - None: No replay buffer (for on-policy algorithms) + **kwargs: Algorithm-specific parameters + """ + self.meta_policy = meta_policy # Required attribute + self.replay = replay # Optional attribute (off-policy algorithms need this) + self._kwargs = kwargs + + @abstractmethod + def update(self, batch: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[str, Any]: + """ + Update the policy using a batch of data. + + Args: + batch: Optional, batch data + - If None and replay exists, sample from replay + - If provided, use the provided batch directly + **kwargs: Update parameters, can include: + - batch_size: Batch size when sampling from replay + - env_types: Specify which environment types to sample from + (when replay is Dict[str, BaseReplay]) + - e.g., ['indoor', 'outdoor'] + - If None, sample from all environment types + - env_weights: Data weights for different environment types + (when using multiple replay buffers) + - e.g., {'indoor': 0.6, 'outdoor': 0.4} + - If None, use uniform weights + + Returns: + Dictionary containing loss, metrics, and other information + + Examples: + # Scenario 1: Single replay buffer + algorithm.update(batch_size=256) + + # Scenario 2: Multiple replay buffers (by environment type) + algorithm.update( + batch_size=256, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + """ + raise NotImplementedError + + def compute_loss(self, batch: Dict[str, Any]) -> torch.Tensor: + """ + Compute loss (optional, some algorithms may need this). + + Args: + batch: Batch data + + Returns: + Loss value + """ + raise NotImplementedError("Subclass should implement compute_loss if needed") + + def select_action(self, obs: MetaObs, **kwargs) -> MetaAction: + """ + Select action (optional, some algorithms may need this). + + Args: + obs: MetaObs format observation + **kwargs: Other parameters (e.g., exploration settings) + + Returns: + MetaAction format action + """ + # Default implementation uses meta_policy's select_action + return self.meta_policy.select_action(obs, **kwargs) + + def record_transition( + self, + state: MetaObs, + action: MetaAction, + reward: float, + next_state: MetaObs, + done: bool, + info: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + """ + Record transition to replay buffer (if exists). + + Supports storing complete MetaObs and MetaAction information, plus additional custom fields. + If using multiple replay buffers (by environment type), selects corresponding replay + based on env_type in kwargs. + + Args: + state: Current state (MetaObs, including all fields) + action: Action (MetaAction, including all fields) + reward: Reward + next_state: Next state (MetaObs, including all fields) + done: Whether episode ended + info: Additional information dictionary + **kwargs: Other custom fields, can store any additional information + - env_type: Environment type identifier (if replay is Dict[str, BaseReplay]) + - e.g., value, log_prob, advantage, trajectory_id, etc. + """ + if self.replay is not None: + from dataclasses import asdict + + # Convert MetaObs and MetaAction to dict if they are dataclass instances + state_dict = asdict(state) if hasattr(state, '__dataclass_fields__') else state + action_dict = asdict(action) if hasattr(action, '__dataclass_fields__') else action + next_state_dict = asdict(next_state) if hasattr(next_state, '__dataclass_fields__') else next_state + + transition = { + 'state': state_dict, + 'action': action_dict, + 'reward': reward, + 'next_state': next_state_dict, + 'done': done, + 'info': info or {}, + **kwargs + } + + env_type = kwargs.get('env_type', None) + + if isinstance(self.replay, dict): + if env_type is None: + raise ValueError("env_type must be provided when using multiple replay buffers") + if env_type not in self.replay: + raise ValueError(f"env_type '{env_type}' not found in replay buffers") + self.replay[env_type].add(transition) + else: + self.replay.add(transition) + + def get_policy(self) -> MetaPolicy: + """Get the underlying MetaPolicy.""" + return self.meta_policy + + def train_mode(self) -> None: + """Set policy to training mode.""" + if hasattr(self.meta_policy, 'policy') and hasattr(self.meta_policy.policy, 'train'): + self.meta_policy.policy.train() + + def eval_mode(self) -> None: + """Set policy to evaluation mode.""" + if hasattr(self.meta_policy, 'policy') and hasattr(self.meta_policy.policy, 'eval'): + self.meta_policy.policy.eval() + + def save(self, path: str, **kwargs) -> None: + """ + Save algorithm state (model, optimizer, etc.). + + Args: + path: Save path + **kwargs: Additional save options + """ + raise NotImplementedError("Subclass should implement save") + + def load(self, path: str, **kwargs) -> None: + """ + Load algorithm state. + + Args: + path: Load path + **kwargs: Additional load options + """ + raise NotImplementedError("Subclass should implement load") + + def __repr__(self) -> str: + replay_info = "None" + if self.replay is not None: + if isinstance(self.replay, dict): + replay_info = f"Dict with {len(self.replay)} buffers" + else: + replay_info = repr(self.replay) + return f"{self.__class__.__name__}(meta_policy={self.meta_policy.__class__.__name__}, replay={replay_info})" + + +if __name__ == '__main__': + """ + Test code for BaseAlgorithm class. + + Since BaseAlgorithm is abstract, we create a simple concrete implementation for testing. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction, MetaPolicy + from rl.buffer.base_replay import BaseReplay + from dataclasses import asdict + + # Simple replay buffer for testing + class SimpleReplay(BaseReplay): + def __init__(self, capacity=1000, device='cpu', **kwargs): + super().__init__(capacity=capacity, device=device, **kwargs) + self._storage = [] + + def add(self, transition): + if self._size < self.capacity: + self._storage.append(transition) + self._size += 1 + else: + self._storage[self._position] = transition + self._position = (self._position + 1) % self.capacity + + def sample(self, batch_size): + if self._size == 0: + return {} + indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) + return { + 'states': [self._storage[i]['state'] for i in indices], + 'actions': [self._storage[i]['action'] for i in indices], + 'rewards': np.array([self._storage[i]['reward'] for i in indices]), + 'next_states': [self._storage[i]['next_state'] for i in indices], + 'dones': np.array([self._storage[i]['done'] for i in indices]), + } + + def clear(self): + self._storage = [] + self._size = 0 + self._position = 0 + + def save(self, path, **kwargs): + pass + + def load(self, path, **kwargs): + pass + + # Simple policy for testing + class DummyPolicy: + def select_action(self, obs): + return MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + + def train(self): + pass + + def eval(self): + pass + + # Simple MetaPolicy wrapper for testing + class DummyMetaPolicy(MetaPolicy): + def __init__(self): + self.policy = DummyPolicy() + self.chunk_size = 1 + self.ctrl_space = 'ee' + self.ctrl_type = 'delta' + self.action_queue = [] + self.action_normalizer = None + self.state_normalizer = None + + def select_action(self, mobs, t=0, **kwargs): + return self.policy.select_action(mobs) + + # Simple concrete algorithm for testing + class SimpleAlgorithm(BaseAlgorithm): + def __init__(self, meta_policy, replay=None, learning_rate=1e-3, **kwargs): + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + self.learning_rate = learning_rate + self.update_count = 0 + + def update(self, batch=None, **kwargs): + batch_size = kwargs.get('batch_size', 32) + + if batch is None and self.replay is not None: + if isinstance(self.replay, dict): + # Sample from multiple replay buffers + env_types = kwargs.get('env_types', list(self.replay.keys())) + env_weights = kwargs.get('env_weights', {k: 1.0/len(env_types) for k in env_types}) + + batches = {} + for env_type in env_types: + n_samples = int(batch_size * env_weights.get(env_type, 1.0/len(env_types))) + batches[env_type] = self.replay[env_type].sample(n_samples) + batch = batches + else: + batch = self.replay.sample(batch_size) + + self.update_count += 1 + + return { + 'loss': np.random.randn(), + 'update_count': self.update_count, + 'batch_size': batch_size + } + + def compute_loss(self, batch): + return torch.tensor(np.random.randn()) + + # Test the implementation + print("=" * 60) + print("Testing BaseAlgorithm (SimpleAlgorithm implementation)") + print("=" * 60) + + # Test 1: Single replay buffer + print("\n1. Testing with single replay buffer...") + meta_policy = DummyMetaPolicy() + replay = SimpleReplay(capacity=100) + algorithm = SimpleAlgorithm(meta_policy=meta_policy, replay=replay, learning_rate=1e-4) + print(f" Created algorithm: {algorithm}") + + # Add transitions + print("\n2. Recording transitions...") + for i in range(10): + state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="test instruction" + ) + action = MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + next_state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="test instruction" + ) + + algorithm.record_transition( + state=state, + action=action, + reward=np.random.randn(), + next_state=next_state, + done=(i == 9), + info={'step': i}, + value=np.random.randn(), # Custom field + ) + print(f" Replay buffer size: {len(algorithm.replay)}") + + # Update algorithm + print("\n3. Testing update...") + result = algorithm.update(batch_size=5) + print(f" Update result: {result}") + + # Test 2: Multiple replay buffers (by environment type) + print("\n4. Testing with multiple replay buffers...") + multi_replay = { + 'indoor': SimpleReplay(capacity=100), + 'outdoor': SimpleReplay(capacity=100) + } + multi_algorithm = SimpleAlgorithm(meta_policy=meta_policy, replay=multi_replay) + + # Add transitions to different environment types + for i in range(5): + state = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = MetaAction(action=np.random.randn(7).astype(np.float32)) + next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) + + multi_algorithm.record_transition( + state=state, action=action, reward=1.0, + next_state=next_state, done=False, + env_type='indoor' + ) + multi_algorithm.record_transition( + state=state, action=action, reward=0.5, + next_state=next_state, done=False, + env_type='outdoor' + ) + + print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") + print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") + + # Update with weighted sampling + result = multi_algorithm.update( + batch_size=8, + env_types=['indoor', 'outdoor'], + env_weights={'indoor': 0.6, 'outdoor': 0.4} + ) + print(f" Update with weighted sampling: {result}") + + # Test train/eval mode + print("\n5. Testing train/eval mode...") + algorithm.train_mode() + print(" Set to train mode") + algorithm.eval_mode() + print(" Set to eval mode") + + # Test select_action + print("\n6. Testing select_action...") + obs = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = algorithm.select_action(obs) + print(f" Selected action shape: {action.action.shape}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/buffer/__init__.py b/rl/buffer/__init__.py new file mode 100644 index 00000000..8a2bf6e1 --- /dev/null +++ b/rl/buffer/__init__.py @@ -0,0 +1,15 @@ +""" +Replay Buffer Module + +This module provides experience replay buffers for RL algorithms. + +Classes: + BaseReplay: Base class for all replay buffers +""" + +from .base_replay import BaseReplay + +__all__ = [ + 'BaseReplay', +] + diff --git a/rl/buffer/base_replay.py b/rl/buffer/base_replay.py new file mode 100644 index 00000000..40ddcf88 --- /dev/null +++ b/rl/buffer/base_replay.py @@ -0,0 +1,304 @@ +""" +Base Replay Buffer Class + +This module defines the base class for all replay buffers in the RL framework. + +Design Philosophy: +- Store raw Meta data: Store original MetaObs and MetaAction in buffer, keeping data in its raw form +- Complete information storage: Support storing all fields of MetaObs and MetaAction +- Extensibility: Support storing additional custom fields (value, log_prob, advantage, trajectory_id, etc.) +- Conversion during sampling: Convert to ILStudio data pipeline format (normalization, etc.) during sampling +- Compatibility: Maintain data integrity while being compatible with ILStudio's normalization pipeline +""" + +import torch +import numpy as np +from typing import Dict, Any, Optional, Union, Callable +from abc import ABC, abstractmethod + + +class BaseReplay(ABC): + """ + Base class for Replay Buffer. + + This class defines the interface for all replay buffers in the RL framework. + Replay buffers store experience data (transitions) for training RL algorithms. + + Attributes: + capacity: Maximum number of transitions to store + device: Device to store data on ('cpu' or 'cuda') + """ + + def __init__( + self, + capacity: int = 1000000, + device: Union[str, torch.device] = 'cpu', + **kwargs + ): + """ + Initialize the Replay Buffer. + + Args: + capacity: Buffer capacity (maximum number of transitions to store) + device: Data storage device ('cpu' or 'cuda', default 'cpu') + - 'cpu': Store data in CPU memory + - 'cuda' or 'cuda:0': Store data in GPU memory + **kwargs: Other initialization parameters + """ + self.capacity = capacity + self.device = torch.device(device) if isinstance(device, str) else device + self._size = 0 + self._position = 0 + + @abstractmethod + def add(self, transition: Dict[str, Any]) -> None: + """ + Add a transition to the buffer. + + Stores raw Meta data (MetaObs, MetaAction) without any normalization. + Supports storing all fields of MetaObs and MetaAction, plus additional custom information. + + Args: + transition: Dictionary containing the following fields: + - state: MetaObs format current state (raw data, including all fields) + - action: MetaAction format action (raw data, including all fields) + - reward: float reward + - next_state: MetaObs format next state (raw data) + - done: bool whether episode ended + - info: Optional, additional information dictionary + - **other custom fields**: Can store any additional information + """ + raise NotImplementedError + + @abstractmethod + def sample(self, batch_size: int) -> Dict[str, Any]: + """ + Sample a batch from the buffer (raw data). + + Args: + batch_size: Batch size + + Returns: + Dictionary containing raw Meta data (without normalization) + """ + raise NotImplementedError + + def sample_for_training( + self, + batch_size: int, + data_processor: Optional[Callable] = None + ) -> Dict[str, Any]: + """ + Sample and convert to ILStudio training format. + + Args: + batch_size: Batch size + data_processor: Optional data processing function to align with ILStudio pipeline + - If None, return raw data + - If provided, should be a function: batch -> processed_batch + + Returns: + Processed batch data (conforming to ILStudio training format) + """ + batch = self.sample(batch_size) + if data_processor is not None: + batch = data_processor(batch) + return batch + + def __len__(self) -> int: + """Return current buffer size.""" + return self._size + + @abstractmethod + def clear(self) -> None: + """Clear the buffer.""" + raise NotImplementedError + + @abstractmethod + def save(self, path: str, **kwargs) -> None: + """ + Save buffer data to file. + + Args: + path: Save path (can be file path or directory path) + **kwargs: Save options + - format: Save format (e.g., 'pkl', 'hdf5', 'npz', optional) + - compress: Whether to compress (optional) + """ + raise NotImplementedError + + @abstractmethod + def load(self, path: str, **kwargs) -> None: + """ + Load data from file into buffer. + + Args: + path: Load path (can be file path or directory path) + **kwargs: Load options + - format: Load format (optional, can auto-detect) + - append: Whether to append to existing buffer (default False, clear before load) + """ + raise NotImplementedError + + def is_full(self) -> bool: + """Check if buffer is full.""" + return self._size >= self.capacity + + def get_all(self) -> Dict[str, Any]: + """ + Get all data in the buffer. + + Returns: + Dictionary containing all stored data + """ + return self.sample(self._size) if self._size > 0 else {} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(capacity={self.capacity}, size={self._size}, device={self.device})" + + +if __name__ == '__main__': + """ + Test code for BaseReplay class. + + Since BaseReplay is abstract, we create a simple concrete implementation for testing. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction + from dataclasses import asdict + + # Simple concrete implementation for testing + class SimpleReplay(BaseReplay): + """Simple in-memory replay buffer for testing.""" + + def __init__(self, capacity: int = 1000, device: str = 'cpu', **kwargs): + super().__init__(capacity=capacity, device=device, **kwargs) + self._storage = [] + + def add(self, transition: Dict[str, Any]) -> None: + if self._size < self.capacity: + self._storage.append(transition) + self._size += 1 + else: + self._storage[self._position] = transition + self._position = (self._position + 1) % self.capacity + + def sample(self, batch_size: int) -> Dict[str, Any]: + if self._size == 0: + return {} + indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) + batch = { + 'states': [self._storage[i]['state'] for i in indices], + 'actions': [self._storage[i]['action'] for i in indices], + 'rewards': np.array([self._storage[i]['reward'] for i in indices]), + 'next_states': [self._storage[i]['next_state'] for i in indices], + 'dones': np.array([self._storage[i]['done'] for i in indices]), + } + return batch + + def clear(self) -> None: + self._storage = [] + self._size = 0 + self._position = 0 + + def save(self, path: str, **kwargs) -> None: + import pickle + with open(path, 'wb') as f: + pickle.dump(self._storage[:self._size], f) + print(f"Saved {self._size} transitions to {path}") + + def load(self, path: str, **kwargs) -> None: + import pickle + append = kwargs.get('append', False) + if not append: + self.clear() + with open(path, 'rb') as f: + data = pickle.load(f) + for transition in data: + self.add(transition) + print(f"Loaded {len(data)} transitions from {path}") + + # Test the implementation + print("=" * 60) + print("Testing BaseReplay (SimpleReplay implementation)") + print("=" * 60) + + # Create buffer + buffer = SimpleReplay(capacity=100, device='cpu') + print(f"\n1. Created buffer: {buffer}") + + # Create sample transitions + print("\n2. Adding transitions...") + for i in range(10): + state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + action = MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + next_state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + + transition = { + 'state': asdict(state), + 'action': asdict(action), + 'reward': np.random.randn(), + 'next_state': asdict(next_state), + 'done': i == 9, + 'info': {'step': i}, + 'value': np.random.randn(), # Custom field + 'log_prob': np.random.randn(), # Custom field + } + buffer.add(transition) + + print(f" Buffer size after adding: {len(buffer)}") + print(f" Buffer is full: {buffer.is_full()}") + + # Sample from buffer + print("\n3. Sampling from buffer...") + batch = buffer.sample(batch_size=5) + print(f" Batch keys: {batch.keys()}") + print(f" Batch rewards shape: {batch['rewards'].shape}") + print(f" Number of states in batch: {len(batch['states'])}") + + # Test sample_for_training with processor + print("\n4. Testing sample_for_training with data processor...") + def simple_processor(batch): + """Simple processor that adds a 'processed' flag.""" + batch['processed'] = True + return batch + + processed_batch = buffer.sample_for_training(batch_size=5, data_processor=simple_processor) + print(f" Processed batch has 'processed' key: {'processed' in processed_batch}") + + # Test save and load + print("\n5. Testing save and load...") + import tempfile + import os + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, 'buffer.pkl') + buffer.save(save_path) + + # Create new buffer and load + new_buffer = SimpleReplay(capacity=100) + new_buffer.load(save_path) + print(f" Loaded buffer size: {len(new_buffer)}") + + # Test clear + print("\n6. Testing clear...") + buffer.clear() + print(f" Buffer size after clear: {len(buffer)}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/collectors/__init__.py b/rl/collectors/__init__.py new file mode 100644 index 00000000..b7446018 --- /dev/null +++ b/rl/collectors/__init__.py @@ -0,0 +1,84 @@ +""" +Data Collectors Module + +This module provides data collectors for gathering experience from environments. + +Design Philosophy: +- Responsibility separation: Separate data collection logic from trainer +- Environment abstraction: Support single environment, parallel environments, multiple environment types +- Raw data storage: Only store raw environment rewards, no reward function computation +- Statistics: Collect and return episode statistics + +Available collectors: +- SimCollector: Collector for simulation environments +- RealCollector: Collector for real robot environments + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating collectors. +""" + +from typing import Type, Dict, Any + +from .base_collector import BaseCollector + +# Registry for collector classes +_COLLECTOR_REGISTRY: Dict[str, Type] = {} + + +def register_collector(name: str, collector_class: Type) -> None: + """ + Register a collector class. + + Args: + name: Collector name (e.g., 'sim', 'real') + collector_class: Collector class to register + """ + _COLLECTOR_REGISTRY[name.lower()] = collector_class + + +def get_collector_class(name_or_type: str) -> Type: + """ + Get collector class by name or type string. + + Args: + name_or_type: Collector name (e.g., 'sim') or full type path + (e.g., 'rl.collectors.sim_collector.SimCollector') + + Returns: + Collector class + + Raises: + ValueError: If collector not found + """ + # First check registry + if name_or_type.lower() in _COLLECTOR_REGISTRY: + return _COLLECTOR_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import collector from '{name_or_type}': {e}") + + raise ValueError(f"Unknown collector: '{name_or_type}'. Available: {list(_COLLECTOR_REGISTRY.keys())}") + + +def list_collectors() -> list: + """List all registered collectors.""" + return list(_COLLECTOR_REGISTRY.keys()) + + +__all__ = [ + 'BaseCollector', + 'register_collector', + 'get_collector_class', + 'list_collectors', +] + diff --git a/rl/collectors/base_collector.py b/rl/collectors/base_collector.py new file mode 100644 index 00000000..f1f05e32 --- /dev/null +++ b/rl/collectors/base_collector.py @@ -0,0 +1,439 @@ +""" +Base Collector Class + +This module defines the base class for all data collectors in the RL framework. + +Design Philosophy: +- Responsibility separation: Separate data collection logic from trainer + so that trainer can focus on training loop coordination +- Environment abstraction: Support single environment, parallel environments, multiple environment types +- Raw data storage: Only store raw environment rewards, no reward function computation, + ensuring data integrity +- Statistics: Collect and return episode statistics +""" + +import numpy as np +from typing import Dict, Any, Optional, Union, List, Callable +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction, MetaEnv + + +class BaseCollector(ABC): + """ + Base class for data collectors. + + This class defines the interface for all data collectors in the RL framework. + Collectors gather experience data from environments by interacting with them + using the algorithm's policy. + + Attributes: + meta_envs: Environment(s) to collect data from + algorithm: Algorithm instance for action selection and transition recording + + Note: Collector only stores raw environment rewards, no reward function computation. + Reward functions are applied in the trainer during training time. + """ + + def __init__( + self, + meta_envs: Union['MetaEnv', List['MetaEnv'], Callable, Dict[str, Any]], + algorithm: 'BaseAlgorithm', + **kwargs + ): + """ + Initialize the collector. + + Args: + meta_envs: Supports multiple formats: + - MetaEnv instance: Single environment + - List[MetaEnv]: Environment list (same type environments) + - Callable: Environment factory function + - Dict[str, Any]: Multi-environment config dict (supports different env types) + algorithm: BaseAlgorithm instance (required) + - Used for action selection and transition recording + **kwargs: Collector-specific parameters + + Note: Collector only stores raw environment rewards, no reward function computation. + Reward functions are applied in trainer during training time. + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + self._kwargs = kwargs + + @abstractmethod + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + Collect n_steps of interaction data. + + Args: + n_steps: Number of steps to collect + env_type: Optional, environment type identifier (for multi-environment scenarios) + - If provided, will be passed to record_transition with env_type + - Used to support a single algorithm storing data from multiple different environments + + Returns: + Dictionary containing statistics, such as: + - episode_rewards: List of episode rewards + - episode_lengths: List of episode lengths + - total_steps: Total number of steps collected + - env_type_stats: Statistics grouped by environment type (if multi-environment is supported) + """ + raise NotImplementedError + + @abstractmethod + def reset(self, **kwargs) -> None: + """ + Reset collector state (e.g., reset environments). + + Args: + **kwargs: Reset parameters + """ + raise NotImplementedError + + def get_envs(self) -> Union['MetaEnv', List['MetaEnv'], Dict[str, Any]]: + """Get the underlying environment(s).""" + return self.meta_envs + + def get_algorithm(self) -> 'BaseAlgorithm': + """Get the algorithm instance.""" + return self.algorithm + + def __repr__(self) -> str: + env_info = type(self.meta_envs).__name__ if not isinstance(self.meta_envs, (list, dict)) else f"List[{len(self.meta_envs)}]" if isinstance(self.meta_envs, list) else f"Dict[{len(self.meta_envs)}]" + return f"{self.__class__.__name__}(envs={env_info}, algorithm={self.algorithm.__class__.__name__})" + + +if __name__ == '__main__': + """ + Test code for BaseCollector class. + + Since BaseCollector is abstract, we create a simple concrete implementation for testing. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + import torch + from benchmark.base import MetaObs, MetaAction, MetaEnv, MetaPolicy + from rl.buffer.base_replay import BaseReplay + from rl.base import BaseAlgorithm + from dataclasses import asdict + + # Simple replay buffer for testing + class SimpleReplay(BaseReplay): + def __init__(self, capacity=1000, device='cpu', **kwargs): + super().__init__(capacity=capacity, device=device, **kwargs) + self._storage = [] + + def add(self, transition): + if self._size < self.capacity: + self._storage.append(transition) + self._size += 1 + else: + self._storage[self._position] = transition + self._position = (self._position + 1) % self.capacity + + def sample(self, batch_size): + if self._size == 0: + return {} + indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) + return { + 'states': [self._storage[i]['state'] for i in indices], + 'actions': [self._storage[i]['action'] for i in indices], + 'rewards': np.array([self._storage[i]['reward'] for i in indices]), + } + + def clear(self): + self._storage = [] + self._size = 0 + self._position = 0 + + def save(self, path, **kwargs): + pass + + def load(self, path, **kwargs): + pass + + # Simple dummy environment for testing + class DummyEnv: + def __init__(self, state_dim=10, action_dim=7): + self.state_dim = state_dim + self.action_dim = action_dim + self._step_count = 0 + self._max_steps = 100 + + def reset(self): + self._step_count = 0 + return {'state': np.random.randn(self.state_dim).astype(np.float32)} + + def step(self, action): + self._step_count += 1 + obs = {'state': np.random.randn(self.state_dim).astype(np.float32)} + reward = np.random.randn() + done = self._step_count >= self._max_steps + info = {'step': self._step_count} + if done: + info['episode'] = {'r': np.random.randn() * 10, 'l': self._step_count} + return obs, reward, done, info + + def close(self): + pass + + # Simple MetaEnv wrapper for testing + class DummyMetaEnv(MetaEnv): + def __init__(self, state_dim=10, action_dim=7, max_steps=100): + self.env = DummyEnv(state_dim=state_dim, action_dim=action_dim) + self.env._max_steps = max_steps + self.prev_obs = None + + def obs2meta(self, raw_obs): + return MetaObs( + state=raw_obs['state'], + raw_lang="test instruction" + ) + + def meta2act(self, action, *args): + if hasattr(action, 'action'): + return action.action + return action + + def reset(self): + init_obs = self.env.reset() + self.prev_obs = self.obs2meta(init_obs) + return self.prev_obs + + def step(self, action): + act = self.meta2act(action) + obs, reward, done, info = self.env.step(act) + self.prev_obs = self.obs2meta(obs) + return self.prev_obs, reward, done, info + + # Simple policy for testing + class DummyPolicy: + def __init__(self, action_dim=7): + self.action_dim = action_dim + + def select_action(self, obs): + return MetaAction( + action=np.random.randn(self.action_dim).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + + def train(self): + pass + + def eval(self): + pass + + # Simple MetaPolicy wrapper for testing + class DummyMetaPolicy(MetaPolicy): + def __init__(self, action_dim=7): + self.policy = DummyPolicy(action_dim=action_dim) + self.chunk_size = 1 + self.ctrl_space = 'ee' + self.ctrl_type = 'delta' + self.action_queue = [] + self.action_normalizer = None + self.state_normalizer = None + + def select_action(self, mobs, t=0, **kwargs): + return self.policy.select_action(mobs) + + # Simple algorithm for testing + class DummyAlgorithm(BaseAlgorithm): + def __init__(self, meta_policy, replay=None, **kwargs): + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + self._timestep = 0 + + def update(self, batch=None, **kwargs): + return {'loss': 0.0} + + def select_action(self, obs, **kwargs): + return self.meta_policy.select_action(obs, t=self._timestep) + + # Simple collector implementation for testing + class SimpleCollector(BaseCollector): + """Simple collector for testing.""" + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv]], + algorithm: BaseAlgorithm, + **kwargs + ): + super().__init__(meta_envs, algorithm, **kwargs) + + # Initialize environments + if isinstance(meta_envs, list): + self.envs = meta_envs + elif isinstance(meta_envs, MetaEnv): + self.envs = [meta_envs] + else: + raise ValueError(f"Unsupported env type: {type(meta_envs)}") + + self._last_obs = None + self._last_dones = None + self._episode_rewards = [] + self._episode_lengths = [] + self._current_episode_reward = [0.0] * len(self.envs) + self._current_episode_length = [0] * len(self.envs) + + def reset(self, **kwargs) -> None: + """Reset all environments.""" + self._last_obs = [] + self._last_dones = [] + for env in self.envs: + obs = env.reset() + self._last_obs.append(obs) + self._last_dones.append(False) + self._current_episode_reward = [0.0] * len(self.envs) + self._current_episode_length = [0] * len(self.envs) + + def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: + """ + Collect n_steps of interaction data. + + Args: + n_steps: Number of steps to collect + env_type: Optional environment type identifier + + Returns: + Statistics dictionary + """ + if self._last_obs is None: + self.reset() + + stats = { + 'episode_rewards': [], + 'episode_lengths': [], + 'total_steps': 0 + } + + for step in range(n_steps): + for i, (env, obs) in enumerate(zip(self.envs, self._last_obs)): + if self._last_dones[i]: + continue + + # Get action + with torch.no_grad(): + action = self.algorithm.select_action(obs) + + # Environment interaction + new_obs, reward, done, info = env.step(action) + + # Record transition (only store raw reward) + transition_kwargs = {} + if env_type is not None: + transition_kwargs['env_type'] = env_type + + self.algorithm.record_transition( + state=obs, + action=action, + reward=reward, # Store raw reward + next_state=new_obs, + done=done, + info=info, + **transition_kwargs + ) + + # Update episode statistics + self._current_episode_reward[i] += reward + self._current_episode_length[i] += 1 + + # If episode ended + if done: + stats['episode_rewards'].append(self._current_episode_reward[i]) + stats['episode_lengths'].append(self._current_episode_length[i]) + + # Reset environment and episode stats + new_obs = env.reset() + self._current_episode_reward[i] = 0.0 + self._current_episode_length[i] = 0 + + self._last_obs[i] = new_obs + self._last_dones[i] = done if not done else False # Reset done flag after reset + stats['total_steps'] += 1 + + return stats + + # Test the implementation + print("=" * 60) + print("Testing BaseCollector (SimpleCollector implementation)") + print("=" * 60) + + # Create environment, policy, algorithm + print("\n1. Creating components...") + env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + meta_policy = DummyMetaPolicy(action_dim=7) + replay = SimpleReplay(capacity=1000) + algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) + + print(f" Environment: {env}") + print(f" MetaPolicy: {meta_policy}") + print(f" Algorithm: {algorithm}") + + # Create collector + print("\n2. Creating collector...") + collector = SimpleCollector( + meta_envs=env, + algorithm=algorithm + ) + print(f" Collector: {collector}") + + # Test reset + print("\n3. Testing reset...") + collector.reset() + print(" Reset successful") + + # Test collect + print("\n4. Testing collect...") + stats = collector.collect(n_steps=50) + print(f" Collected {stats['total_steps']} steps") + print(f" Episode rewards: {stats['episode_rewards']}") + print(f" Episode lengths: {stats['episode_lengths']}") + print(f" Replay buffer size: {len(algorithm.replay)}") + + # Test with env_type + print("\n5. Testing collect with env_type...") + # Create algorithm with multi-replay + multi_replay = { + 'indoor': SimpleReplay(capacity=1000), + 'outdoor': SimpleReplay(capacity=1000) + } + multi_algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=multi_replay) + + indoor_env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + indoor_collector = SimpleCollector( + meta_envs=indoor_env, + algorithm=multi_algorithm + ) + + # Collect to indoor replay + indoor_collector.reset() + indoor_stats = indoor_collector.collect(n_steps=30, env_type='indoor') + print(f" Indoor collected {indoor_stats['total_steps']} steps") + print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") + print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") + + # Test with multiple environments + print("\n6. Testing with multiple environments...") + envs = [ + DummyMetaEnv(state_dim=10, action_dim=7, max_steps=15), + DummyMetaEnv(state_dim=10, action_dim=7, max_steps=15), + ] + multi_env_collector = SimpleCollector( + meta_envs=envs, + algorithm=DummyAlgorithm(meta_policy=meta_policy, replay=SimpleReplay(capacity=1000)) + ) + multi_env_collector.reset() + multi_env_stats = multi_env_collector.collect(n_steps=50) + print(f" Collected from {len(envs)} environments") + print(f" Total steps: {multi_env_stats['total_steps']}") + print(f" Episodes completed: {len(multi_env_stats['episode_rewards'])}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/__init__.py b/rl/infra/__init__.py new file mode 100644 index 00000000..155f64da --- /dev/null +++ b/rl/infra/__init__.py @@ -0,0 +1,56 @@ +""" +RL Infrastructure Module + +This module provides infrastructure components for reproducibility and stability: +- SeedManager: Unified random seed management +- Logger: Training logging system (TensorBoard, WandB, etc.) +- Checkpoint: Model and training state checkpoint management +- Callback: Training callback system for hooks and monitoring +- Distributed: Distributed training support utilities + +Design Philosophy: +- Reproducibility: Ensure experiments can be reproduced exactly +- Stability: Provide robust training infrastructure +- Modularity: Each component can be used independently +- Extensibility: Easy to add new logging backends, callbacks, etc. +""" + +from .seed_manager import SeedManager, set_global_seed, get_global_seed +from .logger import BaseLogger, ConsoleLogger, TensorBoardLogger, CompositeLogger +from .checkpoint import CheckpointManager +from .callback import ( + Callback, CallbackList, + ProgressCallback, EvalCallback, CheckpointCallback, EarlyStoppingCallback +) +from .distributed import DistributedContext, get_world_size, get_rank, is_main_process + +__all__ = [ + # Seed management + 'SeedManager', + 'set_global_seed', + 'get_global_seed', + + # Logging + 'BaseLogger', + 'ConsoleLogger', + 'TensorBoardLogger', + 'CompositeLogger', + + # Checkpoint + 'CheckpointManager', + + # Callbacks + 'Callback', + 'CallbackList', + 'ProgressCallback', + 'EvalCallback', + 'CheckpointCallback', + 'EarlyStoppingCallback', + + # Distributed + 'DistributedContext', + 'get_world_size', + 'get_rank', + 'is_main_process', +] + diff --git a/rl/infra/callback.py b/rl/infra/callback.py new file mode 100644 index 00000000..5e9ffb42 --- /dev/null +++ b/rl/infra/callback.py @@ -0,0 +1,624 @@ +""" +Callback System for RL Training + +This module provides a flexible callback system for hooking into the training loop: +- Progress tracking and logging +- Evaluation during training +- Checkpoint saving +- Early stopping +- Custom callbacks + +Design Philosophy: +- Non-intrusive: Callbacks don't modify training logic +- Composable: Multiple callbacks can be combined +- Extensible: Easy to add custom callbacks +- Event-driven: Callbacks respond to training events +""" + +import time +import numpy as np +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List, Callable, Union +from dataclasses import dataclass, field + + +@dataclass +class TrainingState: + """ + Container for training state passed to callbacks. + + This provides a standardized interface for callbacks to access training info. + """ + step: int = 0 + episode: int = 0 + total_timesteps: int = 0 + + # Episode info + episode_reward: float = 0.0 + episode_length: int = 0 + episode_rewards: List[float] = field(default_factory=list) + episode_lengths: List[int] = field(default_factory=list) + + # Training info + loss: Optional[float] = None + learning_rate: Optional[float] = None + + # Timing + fps: float = 0.0 + time_elapsed: float = 0.0 + + # Extra info + info: Dict[str, Any] = field(default_factory=dict) + + # Control flags + should_stop: bool = False + + +class Callback(ABC): + """ + Base class for training callbacks. + + Callbacks can hook into various points in the training loop: + - on_training_start: Called once at the beginning of training + - on_training_end: Called once at the end of training + - on_step_start: Called before each training step + - on_step_end: Called after each training step + - on_episode_start: Called at the start of each episode + - on_episode_end: Called at the end of each episode + - on_rollout_start: Called before data collection + - on_rollout_end: Called after data collection + - on_update_start: Called before policy update + - on_update_end: Called after policy update + + Override only the methods you need. + """ + + def __init__(self): + self.training_state: Optional[TrainingState] = None + self.trainer = None + self.logger = None + + def set_trainer(self, trainer) -> None: + """Set reference to the trainer.""" + self.trainer = trainer + + def set_logger(self, logger) -> None: + """Set reference to the logger.""" + self.logger = logger + + def on_training_start(self, state: TrainingState) -> None: + """Called at the beginning of training.""" + pass + + def on_training_end(self, state: TrainingState) -> None: + """Called at the end of training.""" + pass + + def on_step_start(self, state: TrainingState) -> None: + """Called before each training step.""" + pass + + def on_step_end(self, state: TrainingState) -> bool: + """ + Called after each training step. + + Returns: + True to continue training, False to stop + """ + return True + + def on_episode_start(self, state: TrainingState) -> None: + """Called at the start of each episode.""" + pass + + def on_episode_end(self, state: TrainingState) -> None: + """Called at the end of each episode.""" + pass + + def on_rollout_start(self, state: TrainingState) -> None: + """Called before data collection (rollout).""" + pass + + def on_rollout_end(self, state: TrainingState) -> None: + """Called after data collection (rollout).""" + pass + + def on_update_start(self, state: TrainingState) -> None: + """Called before policy update.""" + pass + + def on_update_end(self, state: TrainingState) -> None: + """Called after policy update.""" + pass + + +class CallbackList(Callback): + """ + Container for multiple callbacks. + + Forwards all events to contained callbacks in order. + """ + + def __init__(self, callbacks: Optional[List[Callback]] = None): + super().__init__() + self.callbacks = callbacks or [] + + def append(self, callback: Callback) -> None: + """Add a callback.""" + self.callbacks.append(callback) + + def set_trainer(self, trainer) -> None: + """Set trainer for all callbacks.""" + self.trainer = trainer + for callback in self.callbacks: + callback.set_trainer(trainer) + + def set_logger(self, logger) -> None: + """Set logger for all callbacks.""" + self.logger = logger + for callback in self.callbacks: + callback.set_logger(logger) + + def on_training_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_training_start(state) + + def on_training_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_training_end(state) + + def on_step_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_step_start(state) + + def on_step_end(self, state: TrainingState) -> bool: + continue_training = True + for callback in self.callbacks: + if not callback.on_step_end(state): + continue_training = False + return continue_training + + def on_episode_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_episode_start(state) + + def on_episode_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_episode_end(state) + + def on_rollout_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_rollout_start(state) + + def on_rollout_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_rollout_end(state) + + def on_update_start(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_update_start(state) + + def on_update_end(self, state: TrainingState) -> None: + for callback in self.callbacks: + callback.on_update_end(state) + + +class ProgressCallback(Callback): + """ + Callback for logging training progress. + + Logs metrics at specified intervals. + """ + + def __init__( + self, + log_interval: int = 100, + verbose: int = 1 + ): + """ + Initialize progress callback. + + Args: + log_interval: Steps between log outputs + verbose: Verbosity level (0=silent, 1=normal, 2=detailed) + """ + super().__init__() + self.log_interval = log_interval + self.verbose = verbose + self._start_time = None + self._last_log_step = 0 + + def on_training_start(self, state: TrainingState) -> None: + self._start_time = time.time() + if self.verbose >= 1: + print("=" * 60) + print("Training started") + print("=" * 60) + + def on_training_end(self, state: TrainingState) -> None: + if self.verbose >= 1: + elapsed = time.time() - self._start_time + print("=" * 60) + print(f"Training completed in {elapsed:.1f}s") + print(f"Total steps: {state.step}") + print(f"Total episodes: {state.episode}") + if state.episode_rewards: + print(f"Final mean reward: {np.mean(state.episode_rewards[-100:]):.2f}") + print("=" * 60) + + def on_step_end(self, state: TrainingState) -> bool: + if state.step % self.log_interval == 0 and state.step > self._last_log_step: + self._last_log_step = state.step + self._log_progress(state) + return True + + def _log_progress(self, state: TrainingState) -> None: + """Log current progress.""" + if self.verbose == 0: + return + + elapsed = time.time() - self._start_time + fps = state.step / elapsed if elapsed > 0 else 0 + + # Build log message + parts = [f"Step: {state.step:>8}"] + + if state.episode_rewards: + mean_reward = np.mean(state.episode_rewards[-100:]) + parts.append(f"Reward: {mean_reward:>8.2f}") + + if state.loss is not None: + parts.append(f"Loss: {state.loss:>8.4f}") + + parts.append(f"FPS: {fps:>6.0f}") + parts.append(f"Time: {elapsed:>6.0f}s") + + print(" | ".join(parts)) + + if self.logger: + metrics = { + 'progress/fps': fps, + 'progress/time_elapsed': elapsed + } + if state.episode_rewards: + metrics['progress/mean_reward'] = np.mean(state.episode_rewards[-100:]) + self.logger.log_scalars(metrics, step=state.step) + + +class EvalCallback(Callback): + """ + Callback for periodic evaluation during training. + + Evaluates the policy on a separate environment at specified intervals. + """ + + def __init__( + self, + eval_fn: Callable[[int], Dict[str, float]], + eval_interval: int = 10000, + n_eval_episodes: int = 10, + verbose: int = 1 + ): + """ + Initialize evaluation callback. + + Args: + eval_fn: Function that takes step and returns eval metrics + eval_interval: Steps between evaluations + n_eval_episodes: Number of episodes for evaluation + verbose: Verbosity level + """ + super().__init__() + self.eval_fn = eval_fn + self.eval_interval = eval_interval + self.n_eval_episodes = n_eval_episodes + self.verbose = verbose + + self._last_eval_step = 0 + self.eval_results: List[Dict[str, Any]] = [] + + def on_step_end(self, state: TrainingState) -> bool: + if state.step >= self._last_eval_step + self.eval_interval: + self._last_eval_step = state.step + self._evaluate(state) + return True + + def _evaluate(self, state: TrainingState) -> None: + """Run evaluation.""" + if self.verbose >= 1: + print(f"\n[Eval @ step {state.step}]", end=" ") + + # Run evaluation + eval_metrics = self.eval_fn(state.step) + + # Store results + result = { + 'step': state.step, + **eval_metrics + } + self.eval_results.append(result) + + # Log + if self.verbose >= 1: + metrics_str = " | ".join([f"{k}: {v:.2f}" for k, v in eval_metrics.items()]) + print(metrics_str) + + if self.logger: + self.logger.log_scalars( + {f"eval/{k}": v for k, v in eval_metrics.items()}, + step=state.step + ) + + +class CheckpointCallback(Callback): + """ + Callback for saving checkpoints during training. + + Saves checkpoints at specified intervals and keeps best checkpoint. + """ + + def __init__( + self, + checkpoint_manager: 'CheckpointManager', + save_interval: int = 10000, + save_on_best: bool = True, + metric_name: str = "episode_reward", + verbose: int = 1 + ): + """ + Initialize checkpoint callback. + + Args: + checkpoint_manager: CheckpointManager instance + save_interval: Steps between checkpoint saves + save_on_best: Whether to save when best metric is achieved + metric_name: Metric to track for best model + verbose: Verbosity level + """ + super().__init__() + self.checkpoint_manager = checkpoint_manager + self.save_interval = save_interval + self.save_on_best = save_on_best + self.metric_name = metric_name + self.verbose = verbose + + self._last_save_step = 0 + self._best_metric = None + + def on_step_end(self, state: TrainingState) -> bool: + # Save at interval + if state.step >= self._last_save_step + self.save_interval: + self._last_save_step = state.step + self._save_checkpoint(state, is_best=False) + + # Save on best + if self.save_on_best: + metric = self._get_metric(state) + if metric is not None: + if self._best_metric is None or metric > self._best_metric: + self._best_metric = metric + self._save_checkpoint(state, is_best=True) + + return True + + def on_training_end(self, state: TrainingState) -> None: + """Save final checkpoint.""" + self._save_checkpoint(state, is_best=False) + + def _get_metric(self, state: TrainingState) -> Optional[float]: + """Get the metric value for best model tracking.""" + if self.metric_name == "episode_reward" and state.episode_rewards: + return np.mean(state.episode_rewards[-10:]) + elif self.metric_name in state.info: + return state.info[self.metric_name] + return None + + def _save_checkpoint(self, state: TrainingState, is_best: bool) -> None: + """Save checkpoint using the trainer's save method.""" + if self.trainer is None: + return + + metric = self._get_metric(state) + + # Get model and optimizer state from trainer + save_dict = {} + if hasattr(self.trainer, 'algorithm'): + alg = self.trainer.algorithm + if hasattr(alg, 'meta_policy') and hasattr(alg.meta_policy, 'policy'): + policy = alg.meta_policy.policy + if hasattr(policy, 'state_dict'): + save_dict['model'] = policy.state_dict() + + path = self.checkpoint_manager.save( + step=state.step, + episode=state.episode, + total_timesteps=state.total_timesteps, + reward=metric, + is_best=is_best, + **save_dict + ) + + if self.verbose >= 1 and is_best: + print(f"\n[Checkpoint] Saved best model @ step {state.step} (metric: {metric:.2f})") + + +class EarlyStoppingCallback(Callback): + """ + Callback for early stopping based on a metric. + + Stops training if metric doesn't improve for a specified number of steps. + """ + + def __init__( + self, + patience: int = 50000, + min_delta: float = 0.0, + metric_name: str = "episode_reward", + mode: str = "max", + verbose: int = 1 + ): + """ + Initialize early stopping callback. + + Args: + patience: Number of steps to wait for improvement + min_delta: Minimum change to qualify as improvement + metric_name: Metric to monitor + mode: 'max' or 'min' - whether higher or lower is better + verbose: Verbosity level + """ + super().__init__() + self.patience = patience + self.min_delta = min_delta + self.metric_name = metric_name + self.mode = mode + self.verbose = verbose + + self._best_metric = None + self._steps_without_improvement = 0 + self._last_check_step = 0 + + def on_step_end(self, state: TrainingState) -> bool: + # Check every 1000 steps + if state.step < self._last_check_step + 1000: + return True + self._last_check_step = state.step + + metric = self._get_metric(state) + if metric is None: + return True + + improved = False + if self._best_metric is None: + improved = True + elif self.mode == "max" and metric > self._best_metric + self.min_delta: + improved = True + elif self.mode == "min" and metric < self._best_metric - self.min_delta: + improved = True + + if improved: + self._best_metric = metric + self._steps_without_improvement = 0 + else: + self._steps_without_improvement += 1000 + + # Check for early stopping + if self._steps_without_improvement >= self.patience: + if self.verbose >= 1: + print(f"\n[Early Stopping] No improvement for {self.patience} steps. Stopping training.") + state.should_stop = True + return False + + return True + + def _get_metric(self, state: TrainingState) -> Optional[float]: + """Get the metric value.""" + if self.metric_name == "episode_reward" and state.episode_rewards: + return np.mean(state.episode_rewards[-10:]) + elif self.metric_name in state.info: + return state.info[self.metric_name] + return None + + +if __name__ == '__main__': + """ + Test code for Callback module. + """ + print("=" * 60) + print("Testing Callback Module") + print("=" * 60) + + # Test 1: TrainingState + print("\n1. Testing TrainingState...") + state = TrainingState( + step=100, + episode=5, + episode_reward=50.0, + episode_rewards=[40.0, 45.0, 50.0] + ) + print(f" Step: {state.step}, Episode: {state.episode}") + print(f" Episode rewards: {state.episode_rewards}") + + # Test 2: ProgressCallback + print("\n2. Testing ProgressCallback...") + progress_cb = ProgressCallback(log_interval=50, verbose=1) + + # Simulate training + state = TrainingState() + progress_cb.on_training_start(state) + + for step in range(200): + state.step = step + state.episode_rewards.append(np.random.randn() * 10 + 50) + state.loss = 1.0 - step * 0.001 + progress_cb.on_step_end(state) + + progress_cb.on_training_end(state) + + # Test 3: CallbackList + print("\n3. Testing CallbackList...") + + class CountingCallback(Callback): + def __init__(self): + super().__init__() + self.step_count = 0 + self.episode_count = 0 + + def on_step_end(self, state): + self.step_count += 1 + return True + + def on_episode_end(self, state): + self.episode_count += 1 + + cb1 = CountingCallback() + cb2 = CountingCallback() + callback_list = CallbackList([cb1, cb2]) + + state = TrainingState() + for _ in range(10): + callback_list.on_step_end(state) + + callback_list.on_episode_end(state) + callback_list.on_episode_end(state) + + print(f" CB1 step count: {cb1.step_count}, episode count: {cb1.episode_count}") + print(f" CB2 step count: {cb2.step_count}, episode count: {cb2.episode_count}") + assert cb1.step_count == 10 and cb2.step_count == 10 + assert cb1.episode_count == 2 and cb2.episode_count == 2 + + # Test 4: EarlyStoppingCallback + print("\n4. Testing EarlyStoppingCallback...") + early_stop = EarlyStoppingCallback(patience=5000, verbose=0) + + state = TrainingState() + # Simulate no improvement + for step in range(10): + state.step = step * 1000 + state.episode_rewards = [50.0] * 10 # Constant reward + result = early_stop.on_step_end(state) + + print(f" Steps without improvement: {early_stop._steps_without_improvement}") + print(f" Should continue: {result}") + + # Test 5: EvalCallback + print("\n5. Testing EvalCallback...") + + def dummy_eval_fn(step): + return {'mean_reward': 50 + step * 0.001, 'success_rate': 0.8} + + eval_cb = EvalCallback( + eval_fn=dummy_eval_fn, + eval_interval=5000, + verbose=1 + ) + + state = TrainingState(step=10000) + eval_cb.on_step_end(state) + + print(f" Eval results: {eval_cb.eval_results}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/checkpoint.py b/rl/infra/checkpoint.py new file mode 100644 index 00000000..6c96b10d --- /dev/null +++ b/rl/infra/checkpoint.py @@ -0,0 +1,468 @@ +""" +Checkpoint Manager for RL Training + +This module provides checkpoint management for saving and loading training state: +- Model weights (policy, value network, etc.) +- Optimizer states +- Training progress (step, episode, etc.) +- Replay buffer (optional) +- Configuration and hyperparameters + +Design Philosophy: +- Atomic saves: Use temporary files to prevent corruption +- Version tracking: Track checkpoint versions for compatibility +- Selective loading: Load only specific components if needed +- Automatic cleanup: Keep only N most recent checkpoints +""" + +import os +import json +import shutil +import glob +import torch +import numpy as np +from typing import Dict, Any, Optional, Union, List, Callable +from datetime import datetime +from pathlib import Path +from dataclasses import dataclass, asdict + + +@dataclass +class CheckpointMetadata: + """Metadata for a checkpoint.""" + version: str = "1.0" + timestamp: str = "" + step: int = 0 + episode: int = 0 + total_timesteps: int = 0 + best_reward: Optional[float] = None + extra_info: Optional[Dict[str, Any]] = None + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + + +class CheckpointManager: + """ + Manager for saving and loading training checkpoints. + + Features: + - Save/load model weights, optimizer states, training state + - Atomic saves with temporary files + - Automatic cleanup of old checkpoints + - Best model tracking + - Version compatibility checking + + Usage: + manager = CheckpointManager( + checkpoint_dir="checkpoints/", + max_to_keep=5 + ) + + # Save checkpoint + manager.save( + step=1000, + model=policy.state_dict(), + optimizer=optimizer.state_dict(), + config=config_dict + ) + + # Load checkpoint + state = manager.load_latest() + policy.load_state_dict(state['model']) + """ + + VERSION = "1.0" + + def __init__( + self, + checkpoint_dir: str, + max_to_keep: int = 5, + keep_best: bool = True, + save_optimizer: bool = True, + save_replay_buffer: bool = False, + checkpoint_prefix: str = "ckpt", + **kwargs + ): + """ + Initialize checkpoint manager. + + Args: + checkpoint_dir: Directory to save checkpoints + max_to_keep: Maximum number of checkpoints to keep (0 = unlimited) + keep_best: Whether to always keep the best checkpoint + save_optimizer: Whether to save optimizer state by default + save_replay_buffer: Whether to save replay buffer by default + checkpoint_prefix: Prefix for checkpoint filenames + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.max_to_keep = max_to_keep + self.keep_best = keep_best + self.save_optimizer = save_optimizer + self.save_replay_buffer = save_replay_buffer + self.checkpoint_prefix = checkpoint_prefix + + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self._best_reward: Optional[float] = None + self._best_checkpoint: Optional[str] = None + + def _get_checkpoint_path(self, step: int) -> Path: + """Get checkpoint file path for a given step.""" + return self.checkpoint_dir / f"{self.checkpoint_prefix}_{step:08d}.pt" + + def _get_best_checkpoint_path(self) -> Path: + """Get path for best checkpoint.""" + return self.checkpoint_dir / f"{self.checkpoint_prefix}_best.pt" + + def _get_metadata_path(self) -> Path: + """Get path for metadata file.""" + return self.checkpoint_dir / "checkpoint_metadata.json" + + def save( + self, + step: int, + model: Optional[Dict[str, Any]] = None, + optimizer: Optional[Dict[str, Any]] = None, + scheduler: Optional[Dict[str, Any]] = None, + replay_buffer: Optional[Any] = None, + config: Optional[Dict[str, Any]] = None, + episode: int = 0, + total_timesteps: int = 0, + reward: Optional[float] = None, + extra: Optional[Dict[str, Any]] = None, + is_best: bool = False, + **kwargs + ) -> str: + """ + Save a checkpoint. + + Args: + step: Current training step + model: Model state dict (or dict of state dicts) + optimizer: Optimizer state dict (or dict of state dicts) + scheduler: Learning rate scheduler state dict + replay_buffer: Replay buffer to save (if save_replay_buffer=True) + config: Configuration dictionary + episode: Current episode number + total_timesteps: Total environment timesteps + reward: Current reward (for best model tracking) + extra: Extra data to save + is_best: Force save as best checkpoint + **kwargs: Additional state to save + + Returns: + Path to saved checkpoint + """ + # Create checkpoint data + checkpoint = { + 'metadata': asdict(CheckpointMetadata( + version=self.VERSION, + step=step, + episode=episode, + total_timesteps=total_timesteps, + best_reward=reward, + extra_info=extra + )), + 'step': step, + 'episode': episode, + 'total_timesteps': total_timesteps, + } + + if model is not None: + checkpoint['model'] = model + + if optimizer is not None and self.save_optimizer: + checkpoint['optimizer'] = optimizer + + if scheduler is not None: + checkpoint['scheduler'] = scheduler + + if config is not None: + checkpoint['config'] = config + + if replay_buffer is not None and self.save_replay_buffer: + # Save replay buffer separately (can be large) + buffer_path = self.checkpoint_dir / f"{self.checkpoint_prefix}_{step:08d}_buffer.pt" + torch.save(replay_buffer, buffer_path) + checkpoint['replay_buffer_path'] = str(buffer_path) + + # Add any extra kwargs + checkpoint.update(kwargs) + + # Save to temporary file first (atomic save) + checkpoint_path = self._get_checkpoint_path(step) + temp_path = checkpoint_path.with_suffix('.tmp') + + torch.save(checkpoint, temp_path) + temp_path.rename(checkpoint_path) + + # Update best checkpoint + if is_best or (reward is not None and (self._best_reward is None or reward > self._best_reward)): + self._best_reward = reward + self._best_checkpoint = str(checkpoint_path) + + # Copy to best checkpoint + best_path = self._get_best_checkpoint_path() + shutil.copy(checkpoint_path, best_path) + + # Save metadata + self._save_metadata() + + # Cleanup old checkpoints + self._cleanup_old_checkpoints() + + return str(checkpoint_path) + + def load( + self, + path: Optional[str] = None, + step: Optional[int] = None, + load_best: bool = False, + load_optimizer: bool = True, + load_replay_buffer: bool = False, + map_location: Optional[Union[str, torch.device]] = None + ) -> Dict[str, Any]: + """ + Load a checkpoint. + + Args: + path: Direct path to checkpoint file + step: Load checkpoint from specific step + load_best: Load best checkpoint + load_optimizer: Whether to load optimizer state + load_replay_buffer: Whether to load replay buffer + map_location: Device to map tensors to + + Returns: + Checkpoint dictionary + """ + # Determine which checkpoint to load + if path is not None: + checkpoint_path = Path(path) + elif load_best: + checkpoint_path = self._get_best_checkpoint_path() + elif step is not None: + checkpoint_path = self._get_checkpoint_path(step) + else: + checkpoint_path = self._get_latest_checkpoint_path() + + if not checkpoint_path or not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Load checkpoint + checkpoint = torch.load(checkpoint_path, map_location=map_location) + + # Optionally remove optimizer state + if not load_optimizer and 'optimizer' in checkpoint: + del checkpoint['optimizer'] + + # Optionally load replay buffer + if load_replay_buffer and 'replay_buffer_path' in checkpoint: + buffer_path = checkpoint['replay_buffer_path'] + if os.path.exists(buffer_path): + checkpoint['replay_buffer'] = torch.load(buffer_path, map_location=map_location) + + return checkpoint + + def load_latest(self, **kwargs) -> Dict[str, Any]: + """Load the most recent checkpoint.""" + return self.load(**kwargs) + + def load_best(self, **kwargs) -> Dict[str, Any]: + """Load the best checkpoint.""" + return self.load(load_best=True, **kwargs) + + def _get_latest_checkpoint_path(self) -> Optional[Path]: + """Get path to the most recent checkpoint.""" + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + + # Filter out best checkpoint and buffer files + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + + if not checkpoints: + return None + + # Sort by step number and return latest + checkpoints.sort() + return Path(checkpoints[-1]) + + def _cleanup_old_checkpoints(self) -> None: + """Remove old checkpoints, keeping only max_to_keep most recent.""" + if self.max_to_keep <= 0: + return + + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + + # Filter out best checkpoint and buffer files + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + checkpoints.sort() + + # Keep only the most recent + to_remove = checkpoints[:-self.max_to_keep] if len(checkpoints) > self.max_to_keep else [] + + for checkpoint in to_remove: + # Don't remove best checkpoint + if self.keep_best and checkpoint == self._best_checkpoint: + continue + + os.remove(checkpoint) + + # Also remove associated buffer file + buffer_path = checkpoint.replace('.pt', '_buffer.pt') + if os.path.exists(buffer_path): + os.remove(buffer_path) + + def _save_metadata(self) -> None: + """Save checkpoint metadata to JSON file.""" + metadata = { + 'best_reward': self._best_reward, + 'best_checkpoint': self._best_checkpoint, + 'version': self.VERSION + } + + metadata_path = self._get_metadata_path() + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + def _load_metadata(self) -> None: + """Load checkpoint metadata from JSON file.""" + metadata_path = self._get_metadata_path() + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + metadata = json.load(f) + self._best_reward = metadata.get('best_reward') + self._best_checkpoint = metadata.get('best_checkpoint') + + def list_checkpoints(self) -> List[Dict[str, Any]]: + """ + List all available checkpoints. + + Returns: + List of checkpoint info dictionaries + """ + pattern = str(self.checkpoint_dir / f"{self.checkpoint_prefix}_*.pt") + checkpoints = glob.glob(pattern) + checkpoints = [c for c in checkpoints if not c.endswith('_best.pt') and '_buffer.pt' not in c] + checkpoints.sort() + + result = [] + for ckpt_path in checkpoints: + # Extract step from filename + filename = os.path.basename(ckpt_path) + step_str = filename.replace(f"{self.checkpoint_prefix}_", "").replace(".pt", "") + try: + step = int(step_str) + except ValueError: + step = -1 + + result.append({ + 'path': ckpt_path, + 'step': step, + 'is_best': ckpt_path == self._best_checkpoint, + 'size_mb': os.path.getsize(ckpt_path) / (1024 * 1024) + }) + + return result + + def has_checkpoint(self) -> bool: + """Check if any checkpoint exists.""" + return self._get_latest_checkpoint_path() is not None + + def get_best_reward(self) -> Optional[float]: + """Get the best reward seen so far.""" + return self._best_reward + + +if __name__ == '__main__': + """ + Test code for CheckpointManager. + """ + import tempfile + + print("=" * 60) + print("Testing CheckpointManager") + print("=" * 60) + + with tempfile.TemporaryDirectory() as tmpdir: + # Test 1: Basic save/load + print("\n1. Testing basic save/load...") + manager = CheckpointManager( + checkpoint_dir=tmpdir, + max_to_keep=3 + ) + + # Create dummy model and optimizer state + model_state = {'layer1.weight': torch.randn(10, 5), 'layer1.bias': torch.randn(10)} + optimizer_state = {'state': {}, 'param_groups': [{'lr': 0.001}]} + config = {'learning_rate': 0.001, 'batch_size': 32} + + # Save checkpoint + path = manager.save( + step=100, + model=model_state, + optimizer=optimizer_state, + config=config, + episode=10, + reward=50.0 + ) + print(f" Saved checkpoint to: {path}") + + # Load checkpoint + loaded = manager.load() + print(f" Loaded step: {loaded['step']}") + print(f" Loaded episode: {loaded['episode']}") + print(f" Model keys: {list(loaded['model'].keys())}") + + # Verify model weights match + assert torch.allclose(loaded['model']['layer1.weight'], model_state['layer1.weight']) + print(" Model weights match!") + + # Test 2: Multiple checkpoints and cleanup + print("\n2. Testing multiple checkpoints and cleanup...") + for step in [200, 300, 400, 500]: + manager.save( + step=step, + model=model_state, + reward=step * 0.1 + ) + + checkpoints = manager.list_checkpoints() + print(f" Number of checkpoints: {len(checkpoints)}") + print(f" Checkpoint steps: {[c['step'] for c in checkpoints]}") + + # Should keep at most max_to_keep (3) + possibly the first checkpoint (if it's best) + # The cleanup keeps max_to_keep most recent, plus best if keep_best=True + assert len(checkpoints) <= 4, f"Expected at most 4 checkpoints, got {len(checkpoints)}" + + # Test 3: Best checkpoint + print("\n3. Testing best checkpoint...") + print(f" Best reward: {manager.get_best_reward()}") + + best_ckpt = manager.load_best() + print(f" Best checkpoint step: {best_ckpt['step']}") + + # Test 4: Load specific step + print("\n4. Testing load specific step...") + ckpt_400 = manager.load(step=400) + print(f" Loaded step 400: {ckpt_400['step']}") + assert ckpt_400['step'] == 400 + + # Test 5: has_checkpoint + print("\n5. Testing has_checkpoint...") + print(f" Has checkpoint: {manager.has_checkpoint()}") + assert manager.has_checkpoint() + + # Test 6: New manager loads metadata + print("\n6. Testing metadata persistence...") + new_manager = CheckpointManager(checkpoint_dir=tmpdir, max_to_keep=3) + new_manager._load_metadata() + print(f" Loaded best reward: {new_manager.get_best_reward()}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/distributed.py b/rl/infra/distributed.py new file mode 100644 index 00000000..2c2a343b --- /dev/null +++ b/rl/infra/distributed.py @@ -0,0 +1,487 @@ +""" +Distributed Training Support for RL + +This module provides utilities for distributed and parallel training: +- Multi-GPU training support +- Multi-process data collection +- Weight synchronization +- Gradient aggregation + +Design Philosophy: +- Transparent: Code works with or without distributed setup +- Compatible: Works with PyTorch DDP and other backends +- Flexible: Support various parallelism strategies +""" + +import os +import torch +import torch.distributed as dist +from typing import Optional, List, Any, Dict, Union, Callable +from contextlib import contextmanager + + +# Global distributed state +_DISTRIBUTED_CONTEXT: Optional['DistributedContext'] = None + + +class DistributedContext: + """ + Context manager for distributed training. + + Handles initialization and cleanup of distributed training environment. + Supports PyTorch DDP and manual multi-process setups. + + Usage: + # Single process (no distribution) + ctx = DistributedContext() + + # Multi-GPU DDP + ctx = DistributedContext( + backend='nccl', + init_method='env://' + ) + + with ctx: + # Training code + pass + """ + + def __init__( + self, + backend: str = 'nccl', + init_method: Optional[str] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, + local_rank: Optional[int] = None, + timeout_minutes: int = 30, + auto_init: bool = True, + **kwargs + ): + """ + Initialize distributed context. + + Args: + backend: Distributed backend ('nccl', 'gloo', 'mpi') + init_method: URL specifying how to initialize process group + world_size: Total number of processes (auto-detect from env if None) + rank: Global rank of this process (auto-detect from env if None) + local_rank: Local rank on this node (auto-detect from env if None) + timeout_minutes: Timeout for distributed operations + auto_init: Whether to auto-initialize if env vars are set + """ + self.backend = backend + self.init_method = init_method + self.timeout_minutes = timeout_minutes + self._initialized = False + self._kwargs = kwargs + + # Try to get from environment variables + self.world_size = world_size or self._get_env_int('WORLD_SIZE', 1) + self.rank = rank or self._get_env_int('RANK', 0) + self.local_rank = local_rank or self._get_env_int('LOCAL_RANK', 0) + + # Auto-initialize if running in distributed mode + if auto_init and self.world_size > 1: + self.init() + + @staticmethod + def _get_env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + val = os.environ.get(key) + if val is not None: + try: + return int(val) + except ValueError: + pass + return default + + def init(self) -> None: + """Initialize distributed process group.""" + if self._initialized: + return + + if self.world_size <= 1: + # Single process, no distribution needed + self._initialized = True + return + + if dist.is_initialized(): + # Already initialized elsewhere + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self._initialized = True + return + + # Initialize process group + init_method = self.init_method or 'env://' + timeout = torch.distributed.default_pg_timeout + if hasattr(torch.distributed, 'init_process_group'): + from datetime import timedelta + timeout = timedelta(minutes=self.timeout_minutes) + + dist.init_process_group( + backend=self.backend, + init_method=init_method, + world_size=self.world_size, + rank=self.rank, + timeout=timeout + ) + + # Set CUDA device if available + if torch.cuda.is_available() and self.local_rank < torch.cuda.device_count(): + torch.cuda.set_device(self.local_rank) + + self._initialized = True + + # Set global context + global _DISTRIBUTED_CONTEXT + _DISTRIBUTED_CONTEXT = self + + def cleanup(self) -> None: + """Cleanup distributed process group.""" + if self._initialized and dist.is_initialized(): + dist.destroy_process_group() + self._initialized = False + + global _DISTRIBUTED_CONTEXT + if _DISTRIBUTED_CONTEXT is self: + _DISTRIBUTED_CONTEXT = None + + def __enter__(self): + self.init() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + return False + + @property + def is_initialized(self) -> bool: + """Check if distributed is initialized.""" + return self._initialized + + @property + def is_distributed(self) -> bool: + """Check if running in distributed mode.""" + return self.world_size > 1 + + @property + def is_main_process(self) -> bool: + """Check if this is the main process (rank 0).""" + return self.rank == 0 + + @property + def device(self) -> torch.device: + """Get the device for this process.""" + if torch.cuda.is_available() and self.local_rank < torch.cuda.device_count(): + return torch.device(f'cuda:{self.local_rank}') + return torch.device('cpu') + + def barrier(self) -> None: + """Synchronization barrier across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.barrier() + + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast tensor from src to all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.broadcast(tensor, src=src) + return tensor + + def all_reduce( + self, + tensor: torch.Tensor, + op: dist.ReduceOp = dist.ReduceOp.SUM + ) -> torch.Tensor: + """All-reduce tensor across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.all_reduce(tensor, op=op) + return tensor + + def all_gather(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors from all processes.""" + if not self.is_distributed or not dist.is_initialized(): + return [tensor] + + tensor_list = [torch.zeros_like(tensor) for _ in range(self.world_size)] + dist.all_gather(tensor_list, tensor) + return tensor_list + + def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: + """Reduce tensor with mean across all processes.""" + if self.is_distributed and dist.is_initialized(): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor / self.world_size + return tensor + + def sync_params(self, model: torch.nn.Module, src: int = 0) -> None: + """Synchronize model parameters from src to all processes.""" + if not self.is_distributed or not dist.is_initialized(): + return + + for param in model.parameters(): + dist.broadcast(param.data, src=src) + + def sync_gradients(self, model: torch.nn.Module) -> None: + """Synchronize gradients across all processes (average).""" + if not self.is_distributed or not dist.is_initialized(): + return + + for param in model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad = param.grad / self.world_size + + +# Convenience functions +def get_distributed_context() -> Optional[DistributedContext]: + """Get the global distributed context.""" + return _DISTRIBUTED_CONTEXT + + +def get_world_size() -> int: + """Get world size (total number of processes).""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.world_size + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def get_rank() -> int: + """Get global rank of current process.""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.rank + if dist.is_initialized(): + return dist.get_rank() + return 0 + + +def get_local_rank() -> int: + """Get local rank of current process.""" + if _DISTRIBUTED_CONTEXT is not None: + return _DISTRIBUTED_CONTEXT.local_rank + return int(os.environ.get('LOCAL_RANK', 0)) + + +def is_main_process() -> bool: + """Check if this is the main process (rank 0).""" + return get_rank() == 0 + + +def is_distributed() -> bool: + """Check if running in distributed mode.""" + return get_world_size() > 1 + + +def barrier() -> None: + """Synchronization barrier across all processes.""" + if _DISTRIBUTED_CONTEXT is not None: + _DISTRIBUTED_CONTEXT.barrier() + elif dist.is_initialized(): + dist.barrier() + + +@contextmanager +def main_process_first(): + """ + Context manager to ensure main process runs first. + + Useful for downloading files or creating directories. + + Usage: + with main_process_first(): + # Only main process runs this first, others wait + download_dataset() + """ + if not is_main_process(): + barrier() + + yield + + if is_main_process(): + barrier() + + +def print_once(*args, **kwargs) -> None: + """Print only on main process.""" + if is_main_process(): + print(*args, **kwargs) + + +def reduce_dict( + data: Dict[str, float], + average: bool = True +) -> Dict[str, float]: + """ + Reduce a dictionary of values across all processes. + + Args: + data: Dictionary with scalar values + average: If True, compute average; if False, compute sum + + Returns: + Reduced dictionary + """ + if not is_distributed() or not dist.is_initialized(): + return data + + world_size = get_world_size() + + # Stack values into tensor + keys = sorted(data.keys()) + values = torch.tensor([data[k] for k in keys], dtype=torch.float32) + + if torch.cuda.is_available(): + values = values.cuda() + + dist.all_reduce(values, op=dist.ReduceOp.SUM) + + if average: + values = values / world_size + + return {k: v.item() for k, v in zip(keys, values)} + + +class DistributedSampler: + """ + Simple distributed sampler for replay buffers. + + Ensures each process samples different data. + """ + + def __init__( + self, + dataset_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0 + ): + """ + Initialize distributed sampler. + + Args: + dataset_size: Total size of the dataset + num_replicas: Number of processes (auto-detect if None) + rank: Rank of current process (auto-detect if None) + shuffle: Whether to shuffle indices + seed: Random seed for shuffling + """ + self.dataset_size = dataset_size + self.num_replicas = num_replicas or get_world_size() + self.rank = rank or get_rank() + self.shuffle = shuffle + self.seed = seed + self.epoch = 0 + + def __iter__(self): + """Generate indices for this process.""" + if self.shuffle: + # Deterministic shuffling based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(self.dataset_size, generator=g).tolist() + else: + indices = list(range(self.dataset_size)) + + # Subsample for this replica + indices = indices[self.rank::self.num_replicas] + return iter(indices) + + def __len__(self): + """Return number of samples for this process.""" + return (self.dataset_size + self.num_replicas - 1) // self.num_replicas + + def set_epoch(self, epoch: int) -> None: + """Set epoch for deterministic shuffling.""" + self.epoch = epoch + + +if __name__ == '__main__': + """ + Test code for Distributed module. + """ + print("=" * 60) + print("Testing Distributed Module") + print("=" * 60) + + # Test 1: Basic context (single process) + print("\n1. Testing DistributedContext (single process)...") + ctx = DistributedContext(auto_init=False) + ctx.init() + + print(f" World size: {ctx.world_size}") + print(f" Rank: {ctx.rank}") + print(f" Local rank: {ctx.local_rank}") + print(f" Is distributed: {ctx.is_distributed}") + print(f" Is main process: {ctx.is_main_process}") + print(f" Device: {ctx.device}") + + ctx.cleanup() + + # Test 2: Convenience functions + print("\n2. Testing convenience functions...") + print(f" get_world_size(): {get_world_size()}") + print(f" get_rank(): {get_rank()}") + print(f" get_local_rank(): {get_local_rank()}") + print(f" is_main_process(): {is_main_process()}") + print(f" is_distributed(): {is_distributed()}") + + # Test 3: print_once + print("\n3. Testing print_once...") + print_once(" This should print (main process)") + + # Test 4: reduce_dict (single process - no-op) + print("\n4. Testing reduce_dict (single process)...") + data = {'loss': 0.5, 'reward': 100.0} + reduced = reduce_dict(data) + print(f" Input: {data}") + print(f" Output: {reduced}") + + # Test 5: DistributedSampler + print("\n5. Testing DistributedSampler...") + sampler = DistributedSampler( + dataset_size=100, + num_replicas=4, + rank=0, + shuffle=True, + seed=42 + ) + + indices = list(sampler) + print(f" Dataset size: 100, Replicas: 4") + print(f" Samples for rank 0: {len(indices)}") + print(f" First 10 indices: {indices[:10]}") + + # Test different rank + sampler_rank1 = DistributedSampler( + dataset_size=100, + num_replicas=4, + rank=1, + shuffle=True, + seed=42 + ) + indices_rank1 = list(sampler_rank1) + print(f" First 10 indices (rank 1): {indices_rank1[:10]}") + + # Verify no overlap + overlap = set(indices) & set(indices_rank1) + print(f" Overlap between rank 0 and 1: {len(overlap)} (should be 0)") + assert len(overlap) == 0, "Samplers should not overlap!" + + # Test 6: Context manager + print("\n6. Testing context manager...") + with DistributedContext(auto_init=False) as ctx: + print(f" Inside context: rank={ctx.rank}") + print(" Context exited successfully") + + # Test 7: main_process_first context + print("\n7. Testing main_process_first...") + with main_process_first(): + print(" Main process first block executed") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/logger.py b/rl/infra/logger.py new file mode 100644 index 00000000..f9243480 --- /dev/null +++ b/rl/infra/logger.py @@ -0,0 +1,593 @@ +""" +Logger Module for RL Training + +This module provides a flexible logging system for RL training, supporting: +- Console logging with progress bars +- TensorBoard logging +- WandB logging (optional) +- Composite logging (multiple backends) + +Design Philosophy: +- Unified interface: All loggers implement the same BaseLogger interface +- Lazy initialization: Heavy backends (TensorBoard, WandB) are initialized lazily +- Thread-safe: Support concurrent logging from multiple processes +- Metric aggregation: Support for rolling window statistics +""" + +import os +import sys +import time +import json +import numpy as np +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, Union, List +from collections import defaultdict, deque +from datetime import datetime +from pathlib import Path + + +class MetricTracker: + """ + Track metrics with rolling window statistics. + + Supports computing mean, std, min, max over a sliding window. + """ + + def __init__(self, window_size: int = 100): + """ + Initialize metric tracker. + + Args: + window_size: Size of the rolling window for statistics + """ + self.window_size = window_size + self._values: Dict[str, deque] = defaultdict(lambda: deque(maxlen=window_size)) + self._total_counts: Dict[str, int] = defaultdict(int) + self._total_sums: Dict[str, float] = defaultdict(float) + + def add(self, key: str, value: float) -> None: + """Add a value to a metric.""" + self._values[key].append(value) + self._total_counts[key] += 1 + self._total_sums[key] += value + + def get_mean(self, key: str) -> Optional[float]: + """Get rolling mean of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.mean(self._values[key]) + + def get_std(self, key: str) -> Optional[float]: + """Get rolling std of a metric.""" + if key not in self._values or len(self._values[key]) < 2: + return None + return np.std(self._values[key]) + + def get_min(self, key: str) -> Optional[float]: + """Get rolling min of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.min(self._values[key]) + + def get_max(self, key: str) -> Optional[float]: + """Get rolling max of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return np.max(self._values[key]) + + def get_total_mean(self, key: str) -> Optional[float]: + """Get total mean (not windowed) of a metric.""" + if self._total_counts[key] == 0: + return None + return self._total_sums[key] / self._total_counts[key] + + def get_latest(self, key: str) -> Optional[float]: + """Get latest value of a metric.""" + if key not in self._values or len(self._values[key]) == 0: + return None + return self._values[key][-1] + + def get_count(self, key: str) -> int: + """Get total count of a metric.""" + return self._total_counts[key] + + def get_stats(self, key: str) -> Dict[str, Optional[float]]: + """Get all statistics for a metric.""" + return { + 'mean': self.get_mean(key), + 'std': self.get_std(key), + 'min': self.get_min(key), + 'max': self.get_max(key), + 'latest': self.get_latest(key), + 'total_mean': self.get_total_mean(key), + 'count': self.get_count(key) + } + + def keys(self) -> List[str]: + """Get all tracked metric keys.""" + return list(self._values.keys()) + + def clear(self) -> None: + """Clear all metrics.""" + self._values.clear() + self._total_counts.clear() + self._total_sums.clear() + + +class BaseLogger(ABC): + """ + Base class for all loggers. + + Provides a unified interface for logging training metrics. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + name: str = "rl_training", + **kwargs + ): + """ + Initialize logger. + + Args: + log_dir: Directory to save logs + name: Name of the experiment/run + **kwargs: Additional logger-specific arguments + """ + self.log_dir = log_dir + self.name = name + self._step = 0 + self._start_time = time.time() + self.metrics = MetricTracker() + + if log_dir: + os.makedirs(log_dir, exist_ok=True) + + @abstractmethod + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """ + Log a scalar value. + + Args: + key: Metric name + value: Metric value + step: Step number (uses internal counter if None) + """ + raise NotImplementedError + + @abstractmethod + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """ + Log multiple scalar values. + + Args: + data: Dictionary of metric name -> value + step: Step number (uses internal counter if None) + """ + raise NotImplementedError + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log a histogram of values (optional, not all loggers support this).""" + pass + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log an image (optional, not all loggers support this).""" + pass + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log a video (optional, not all loggers support this).""" + pass + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text (optional, not all loggers support this).""" + pass + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + pass + + def set_step(self, step: int) -> None: + """Set the current step.""" + self._step = step + + def get_step(self) -> int: + """Get the current step.""" + return self._step + + def increment_step(self, n: int = 1) -> int: + """Increment step and return new value.""" + self._step += n + return self._step + + def get_elapsed_time(self) -> float: + """Get elapsed time since logger creation.""" + return time.time() - self._start_time + + def close(self) -> None: + """Close the logger and release resources.""" + pass + + def flush(self) -> None: + """Flush any buffered data.""" + pass + + +class ConsoleLogger(BaseLogger): + """ + Simple console logger with optional progress bar. + + Outputs training metrics to console/terminal. + """ + + def __init__( + self, + log_dir: Optional[str] = None, + name: str = "rl_training", + log_interval: int = 100, + verbose: int = 1, + **kwargs + ): + """ + Initialize console logger. + + Args: + log_dir: Directory to save logs (also saves to file if provided) + name: Name of the experiment + log_interval: Steps between log outputs + verbose: Verbosity level (0=silent, 1=normal, 2=detailed) + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.log_interval = log_interval + self.verbose = verbose + self._file = None + + if log_dir: + log_file = os.path.join(log_dir, f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + self._file = open(log_file, 'w') + + def _format_value(self, value: float) -> str: + """Format a value for display.""" + if abs(value) < 0.001 or abs(value) > 10000: + return f"{value:.3e}" + return f"{value:.4f}" + + def _write(self, message: str) -> None: + """Write message to console and optionally to file.""" + if self.verbose > 0: + print(message) + if self._file: + self._file.write(message + "\n") + self._file.flush() + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar value.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + + if self.verbose >= 2 or (step % self.log_interval == 0): + elapsed = self.get_elapsed_time() + self._write(f"[{step:>8}] {key}: {self._format_value(value)} (elapsed: {elapsed:.1f}s)") + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log multiple scalar values.""" + step = step if step is not None else self._step + + for key, value in data.items(): + self.metrics.add(key, value) + + if step % self.log_interval == 0: + elapsed = self.get_elapsed_time() + metrics_str = " | ".join([f"{k}: {self._format_value(v)}" for k, v in data.items()]) + self._write(f"[{step:>8}] {metrics_str} (elapsed: {elapsed:.1f}s)") + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + self._write("\n" + "=" * 60) + self._write("Hyperparameters:") + self._write("=" * 60) + for key, value in params.items(): + self._write(f" {key}: {value}") + self._write("=" * 60 + "\n") + + # Also save to JSON file + if self.log_dir: + params_file = os.path.join(self.log_dir, "hyperparams.json") + with open(params_file, 'w') as f: + json.dump(params, f, indent=2, default=str) + + def close(self) -> None: + """Close the logger.""" + if self._file: + self._file.close() + self._file = None + + +class TensorBoardLogger(BaseLogger): + """ + TensorBoard logger for rich visualization. + + Supports scalars, histograms, images, and more. + """ + + def __init__( + self, + log_dir: str, + name: str = "rl_training", + flush_secs: int = 30, + **kwargs + ): + """ + Initialize TensorBoard logger. + + Args: + log_dir: Directory to save TensorBoard logs + name: Name of the experiment + flush_secs: Flush interval in seconds + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.flush_secs = flush_secs + self._writer = None + + def _get_writer(self): + """Lazy initialization of SummaryWriter.""" + if self._writer is None: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + "TensorBoard not installed. Install with: pip install tensorboard" + ) + + log_path = os.path.join(self.log_dir, self.name) + self._writer = SummaryWriter(log_dir=log_path, flush_secs=self.flush_secs) + return self._writer + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar value.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + self._get_writer().add_scalar(key, value, step) + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log multiple scalar values.""" + step = step if step is not None else self._step + for key, value in data.items(): + self.metrics.add(key, value) + self._get_writer().add_scalar(key, value, step) + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log a histogram.""" + step = step if step is not None else self._step + self._get_writer().add_histogram(key, values, step) + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log an image (expects HWC or CHW format).""" + step = step if step is not None else self._step + # Convert HWC to CHW if needed + if image.ndim == 3 and image.shape[-1] in [1, 3, 4]: + image = np.transpose(image, (2, 0, 1)) + self._get_writer().add_image(key, image, step) + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log a video (expects THWC or NTCHW format).""" + step = step if step is not None else self._step + # Add batch dimension if needed (THWC -> NTCHW) + if video.ndim == 4: + video = video[np.newaxis, ...] # Add N dimension + video = np.transpose(video, (0, 1, 4, 2, 3)) # NTHWC -> NTCHW + self._get_writer().add_video(key, video, step, fps=fps) + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text.""" + step = step if step is not None else self._step + self._get_writer().add_text(key, text, step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters.""" + # Filter out non-serializable values + filtered_params = {} + for k, v in params.items(): + if isinstance(v, (int, float, str, bool, type(None))): + filtered_params[k] = v + else: + filtered_params[k] = str(v) + + self._get_writer().add_hparams(filtered_params, {}) + + def flush(self) -> None: + """Flush buffered data.""" + if self._writer: + self._writer.flush() + + def close(self) -> None: + """Close the logger.""" + if self._writer: + self._writer.close() + self._writer = None + + +class CompositeLogger(BaseLogger): + """ + Composite logger that forwards logs to multiple backends. + + Usage: + logger = CompositeLogger([ + ConsoleLogger(verbose=1), + TensorBoardLogger(log_dir="logs/tb") + ]) + """ + + def __init__( + self, + loggers: List[BaseLogger], + log_dir: Optional[str] = None, + name: str = "rl_training", + **kwargs + ): + """ + Initialize composite logger. + + Args: + loggers: List of logger instances to forward to + log_dir: Optional directory (passed to base class) + name: Name of the experiment + """ + super().__init__(log_dir=log_dir, name=name, **kwargs) + self.loggers = loggers + + def log_scalar(self, key: str, value: float, step: Optional[int] = None) -> None: + """Log a scalar to all backends.""" + step = step if step is not None else self._step + self.metrics.add(key, value) + for logger in self.loggers: + logger.log_scalar(key, value, step) + + def log_scalars(self, data: Dict[str, float], step: Optional[int] = None) -> None: + """Log scalars to all backends.""" + step = step if step is not None else self._step + for key, value in data.items(): + self.metrics.add(key, value) + for logger in self.loggers: + logger.log_scalars(data, step) + + def log_histogram(self, key: str, values: np.ndarray, step: Optional[int] = None) -> None: + """Log histogram to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_histogram(key, values, step) + + def log_image(self, key: str, image: np.ndarray, step: Optional[int] = None) -> None: + """Log image to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_image(key, image, step) + + def log_video(self, key: str, video: np.ndarray, step: Optional[int] = None, fps: int = 30) -> None: + """Log video to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_video(key, video, step, fps) + + def log_text(self, key: str, text: str, step: Optional[int] = None) -> None: + """Log text to all backends that support it.""" + step = step if step is not None else self._step + for logger in self.loggers: + logger.log_text(key, text, step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters to all backends.""" + for logger in self.loggers: + logger.log_hyperparams(params) + + def set_step(self, step: int) -> None: + """Set step for all loggers.""" + self._step = step + for logger in self.loggers: + logger.set_step(step) + + def flush(self) -> None: + """Flush all loggers.""" + for logger in self.loggers: + logger.flush() + + def close(self) -> None: + """Close all loggers.""" + for logger in self.loggers: + logger.close() + + +if __name__ == '__main__': + """ + Test code for Logger module. + """ + import tempfile + import shutil + + print("=" * 60) + print("Testing Logger Module") + print("=" * 60) + + # Test 1: MetricTracker + print("\n1. Testing MetricTracker...") + tracker = MetricTracker(window_size=5) + + for i in range(10): + tracker.add("loss", 1.0 - i * 0.1) + tracker.add("reward", i * 10) + + print(f" Loss stats: {tracker.get_stats('loss')}") + print(f" Reward stats: {tracker.get_stats('reward')}") + print(f" Tracked keys: {tracker.keys()}") + + # Test 2: ConsoleLogger + print("\n2. Testing ConsoleLogger...") + with tempfile.TemporaryDirectory() as tmpdir: + console_logger = ConsoleLogger( + log_dir=tmpdir, + name="test_console", + log_interval=2, + verbose=1 + ) + + console_logger.log_hyperparams({ + 'learning_rate': 0.001, + 'batch_size': 32, + 'algorithm': 'PPO' + }) + + for step in range(5): + console_logger.log_scalars({ + 'loss': 1.0 - step * 0.1, + 'reward': step * 10 + }, step=step) + console_logger.increment_step() + + print(f" Elapsed time: {console_logger.get_elapsed_time():.2f}s") + console_logger.close() + + # Test 3: TensorBoardLogger (if tensorboard is installed) + print("\n3. Testing TensorBoardLogger...") + try: + with tempfile.TemporaryDirectory() as tmpdir: + tb_logger = TensorBoardLogger( + log_dir=tmpdir, + name="test_tb" + ) + + for step in range(10): + tb_logger.log_scalars({ + 'loss': 1.0 - step * 0.05, + 'reward': step * 5 + }, step=step) + + # Test histogram + tb_logger.log_histogram("weights", np.random.randn(1000), step=0) + + tb_logger.flush() + tb_logger.close() + print(" TensorBoard logger test passed!") + except ImportError: + print(" TensorBoard not installed, skipping...") + + # Test 4: CompositeLogger + print("\n4. Testing CompositeLogger...") + with tempfile.TemporaryDirectory() as tmpdir: + composite_logger = CompositeLogger( + loggers=[ + ConsoleLogger(log_interval=5, verbose=1), + ], + log_dir=tmpdir, + name="test_composite" + ) + + for step in range(10): + composite_logger.log_scalar("test_metric", step * 0.5, step=step) + + composite_logger.close() + print(" Composite logger test passed!") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/infra/seed_manager.py b/rl/infra/seed_manager.py new file mode 100644 index 00000000..f3654b61 --- /dev/null +++ b/rl/infra/seed_manager.py @@ -0,0 +1,357 @@ +""" +Seed Manager for Reproducibility + +This module provides unified random seed management to ensure reproducibility +across different libraries (NumPy, PyTorch, Python random, etc.). + +Features: +- Unified seed setting for all random number generators +- Support for deterministic operations in PyTorch +- Worker seed management for parallel data loading +- Global seed state tracking +""" + +import os +import random +import numpy as np +import torch +from typing import Optional, Callable +from dataclasses import dataclass + + +# Global seed state +_GLOBAL_SEED: Optional[int] = None + + +@dataclass +class SeedState: + """Container for random state from different libraries.""" + python_state: tuple + numpy_state: dict + torch_state: torch.Tensor + torch_cuda_state: Optional[list] = None + + +class SeedManager: + """ + Unified random seed manager for reproducibility. + + This class manages random seeds for: + - Python's random module + - NumPy's random number generator + - PyTorch's CPU and GPU random number generators + - CUDA deterministic operations + + Usage: + # Simple usage + SeedManager.set_seed(42) + + # With deterministic mode (slower but fully reproducible) + SeedManager.set_seed(42, deterministic=True) + + # Save and restore state + state = SeedManager.get_state() + # ... do something ... + SeedManager.set_state(state) + """ + + @staticmethod + def set_seed( + seed: int, + deterministic: bool = False, + benchmark: bool = True, + warn_only: bool = False + ) -> None: + """ + Set random seed for all libraries. + + Args: + seed: Random seed value + deterministic: If True, enable deterministic algorithms in PyTorch + (may reduce performance but ensures reproducibility) + benchmark: If True, enable cuDNN benchmark mode for faster training + (disable when deterministic=True for full reproducibility) + warn_only: If True, only warn instead of error when deterministic + operations are not available + """ + global _GLOBAL_SEED + _GLOBAL_SEED = seed + + # Python random + random.seed(seed) + + # NumPy + np.random.seed(seed) + + # PyTorch + torch.manual_seed(seed) + + # PyTorch CUDA + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # For multi-GPU + + # Deterministic mode + if deterministic: + # Disable benchmark for determinism + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # Use deterministic algorithms + if hasattr(torch, 'use_deterministic_algorithms'): + torch.use_deterministic_algorithms(True, warn_only=warn_only) + elif hasattr(torch, 'set_deterministic'): + torch.set_deterministic(True) + else: + # Enable benchmark for performance + torch.backends.cudnn.benchmark = benchmark + torch.backends.cudnn.deterministic = False + + # Environment variable for hash seed + os.environ['PYTHONHASHSEED'] = str(seed) + + @staticmethod + def get_seed() -> Optional[int]: + """Get the current global seed.""" + return _GLOBAL_SEED + + @staticmethod + def get_state() -> SeedState: + """ + Get current random state from all libraries. + + Returns: + SeedState containing states from all RNGs + """ + cuda_state = None + if torch.cuda.is_available(): + cuda_state = [torch.cuda.get_rng_state(i) for i in range(torch.cuda.device_count())] + + return SeedState( + python_state=random.getstate(), + numpy_state=np.random.get_state(), + torch_state=torch.get_rng_state(), + torch_cuda_state=cuda_state + ) + + @staticmethod + def set_state(state: SeedState) -> None: + """ + Restore random state for all libraries. + + Args: + state: SeedState to restore + """ + random.setstate(state.python_state) + np.random.set_state(state.numpy_state) + torch.set_rng_state(state.torch_state) + + if state.torch_cuda_state is not None and torch.cuda.is_available(): + for i, cuda_state in enumerate(state.torch_cuda_state): + if i < torch.cuda.device_count(): + torch.cuda.set_rng_state(cuda_state, i) + + @staticmethod + def worker_init_fn(worker_id: int) -> None: + """ + Worker initialization function for PyTorch DataLoader. + + Use this as the `worker_init_fn` argument in DataLoader to ensure + each worker has a different but reproducible random seed. + + Args: + worker_id: Worker ID (0, 1, 2, ...) + + Usage: + DataLoader(..., worker_init_fn=SeedManager.worker_init_fn) + """ + global _GLOBAL_SEED + if _GLOBAL_SEED is not None: + worker_seed = _GLOBAL_SEED + worker_id + else: + worker_seed = torch.initial_seed() % 2**32 + + random.seed(worker_seed) + np.random.seed(worker_seed) + + @staticmethod + def fork_rng(devices: Optional[list] = None, enabled: bool = True) -> 'RNGForkContext': + """ + Fork the RNG state for a context (useful for dropout in eval mode). + + Args: + devices: List of CUDA devices to fork RNG state for + enabled: Whether to actually fork (useful for conditional forking) + + Returns: + Context manager that restores RNG state on exit + + Usage: + with SeedManager.fork_rng(): + # Operations here won't affect global RNG state + pass + """ + return RNGForkContext(devices=devices, enabled=enabled) + + +class RNGForkContext: + """Context manager for forking RNG state.""" + + def __init__(self, devices: Optional[list] = None, enabled: bool = True): + self.devices = devices + self.enabled = enabled + self._state: Optional[SeedState] = None + + def __enter__(self): + if self.enabled: + self._state = SeedManager.get_state() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enabled and self._state is not None: + SeedManager.set_state(self._state) + return False + + +# Convenience functions +def set_global_seed( + seed: int, + deterministic: bool = False, + benchmark: bool = True +) -> None: + """ + Set global random seed for reproducibility. + + Convenience function that calls SeedManager.set_seed(). + + Args: + seed: Random seed value + deterministic: Enable deterministic mode (slower but fully reproducible) + benchmark: Enable cuDNN benchmark (disable when deterministic=True) + """ + SeedManager.set_seed(seed, deterministic=deterministic, benchmark=benchmark) + + +def get_global_seed() -> Optional[int]: + """Get the current global seed.""" + return SeedManager.get_seed() + + +if __name__ == '__main__': + """ + Test code for SeedManager. + """ + print("=" * 60) + print("Testing SeedManager") + print("=" * 60) + + # Test 1: Basic seed setting + print("\n1. Testing basic seed setting...") + SeedManager.set_seed(42) + + # Generate some random numbers + py_random1 = [random.random() for _ in range(5)] + np_random1 = np.random.rand(5).tolist() + torch_random1 = torch.rand(5).tolist() + + # Reset and generate again + SeedManager.set_seed(42) + py_random2 = [random.random() for _ in range(5)] + np_random2 = np.random.rand(5).tolist() + torch_random2 = torch.rand(5).tolist() + + print(f" Python random match: {py_random1 == py_random2}") + print(f" NumPy random match: {np_random1 == np_random2}") + print(f" PyTorch random match: {torch_random1 == torch_random2}") + + assert py_random1 == py_random2, "Python random not reproducible" + assert np_random1 == np_random2, "NumPy random not reproducible" + assert torch_random1 == torch_random2, "PyTorch random not reproducible" + + # Test 2: Get/Set state + print("\n2. Testing state save/restore...") + SeedManager.set_seed(123) + + # Generate some numbers + _ = random.random() + _ = np.random.rand() + _ = torch.rand(1) + + # Save state + state = SeedManager.get_state() + + # Generate more numbers + val1 = random.random() + val2 = np.random.rand() + val3 = torch.rand(1).item() + + # Restore state + SeedManager.set_state(state) + + # Generate again - should match + val1_restored = random.random() + val2_restored = np.random.rand() + val3_restored = torch.rand(1).item() + + print(f" Python state restored: {val1 == val1_restored}") + print(f" NumPy state restored: {val2 == val2_restored}") + print(f" PyTorch state restored: {val3 == val3_restored}") + + assert val1 == val1_restored + assert val2 == val2_restored + assert val3 == val3_restored + + # Test 3: Fork RNG context + print("\n3. Testing RNG fork context...") + SeedManager.set_seed(456) + + before = random.random() + + # Save current position + SeedManager.set_seed(456) + _ = random.random() # Advance to same position + + with SeedManager.fork_rng(): + # This should not affect the outer state + for _ in range(10): + random.random() + + after = random.random() + + # Reset and check + SeedManager.set_seed(456) + _ = random.random() + expected_after = random.random() + + print(f" Fork context preserved state: {after == expected_after}") + + # Test 4: Global seed tracking + print("\n4. Testing global seed tracking...") + SeedManager.set_seed(789) + print(f" Global seed: {SeedManager.get_seed()}") + assert SeedManager.get_seed() == 789 + + # Test convenience functions + set_global_seed(999) + print(f" After set_global_seed: {get_global_seed()}") + assert get_global_seed() == 999 + + # Test 5: Deterministic mode + print("\n5. Testing deterministic mode...") + SeedManager.set_seed(42, deterministic=True) + print(" Deterministic mode enabled (slower but fully reproducible)") + + # Test 6: Worker init function + print("\n6. Testing worker init function...") + SeedManager.set_seed(42) + + # Simulate worker initialization + for worker_id in range(3): + SeedManager.worker_init_fn(worker_id) + worker_val = random.random() + print(f" Worker {worker_id} random value: {worker_val:.6f}") + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/rewards/__init__.py b/rl/rewards/__init__.py new file mode 100644 index 00000000..7aca62ae --- /dev/null +++ b/rl/rewards/__init__.py @@ -0,0 +1,86 @@ +""" +Reward Functions Module + +This module provides modular reward functions for RL training. + +Design Philosophy: +- Modular: Reward functions are independent modules, easy to replace and extend +- Composable: Support combining multiple reward functions +- Language-conditioned: Support VLA's language-conditioned rewards + +Available reward functions: +- SparseReward: Only give reward when task is completed +- DenseReward: Distance-based reward +- LearnedReward: Learned reward model +- LanguageReward: Language-conditioned reward (for VLA) +- CompositeReward: Combine multiple reward functions + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating reward functions. +""" + +from typing import Type, Dict, Any + +from .base_reward import BaseReward + +# Registry for reward classes +_REWARD_REGISTRY: Dict[str, Type] = {} + + +def register_reward(name: str, reward_class: Type) -> None: + """ + Register a reward function class. + + Args: + name: Reward function name (e.g., 'sparse', 'dense') + reward_class: Reward function class to register + """ + _REWARD_REGISTRY[name.lower()] = reward_class + + +def get_reward_class(name_or_type: str) -> Type: + """ + Get reward function class by name or type string. + + Args: + name_or_type: Reward name (e.g., 'sparse') or full type path + (e.g., 'rl.rewards.sparse_reward.SparseReward') + + Returns: + Reward function class + + Raises: + ValueError: If reward function not found + """ + # First check registry + if name_or_type.lower() in _REWARD_REGISTRY: + return _REWARD_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import reward function from '{name_or_type}': {e}") + + raise ValueError(f"Unknown reward function: '{name_or_type}'. Available: {list(_REWARD_REGISTRY.keys())}") + + +def list_rewards() -> list: + """List all registered reward functions.""" + return list(_REWARD_REGISTRY.keys()) + + +__all__ = [ + 'BaseReward', + 'register_reward', + 'get_reward_class', + 'list_rewards', +] + diff --git a/rl/rewards/base_reward.py b/rl/rewards/base_reward.py new file mode 100644 index 00000000..d204dd48 --- /dev/null +++ b/rl/rewards/base_reward.py @@ -0,0 +1,346 @@ +""" +Base Reward Function Class + +This module defines the base class for all reward functions in the RL framework. + +Design Philosophy: +- Modular: Reward functions are independent modules, easy to replace and extend +- Composable: Support combining multiple reward functions +- Language-conditioned: Support VLA's language-conditioned rewards +""" + +import numpy as np +from typing import Dict, Any, Optional +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction + + +class BaseReward(ABC): + """ + Base class for reward functions. + + This class defines the interface for all reward functions in the RL framework. + Reward functions compute custom rewards based on state, action, and other information. + + Note: The reward function is used in the Trainer during training time. + The replay buffer stores raw environment rewards, and the reward function + is applied during the update step for computing the training reward. + """ + + def __init__(self, **kwargs): + """ + Initialize the reward function. + + Args: + **kwargs: Reward function specific parameters + """ + self._kwargs = kwargs + + @abstractmethod + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Compute the reward. + + Args: + state: Current state (MetaObs) + action: Action (MetaAction) + next_state: Next state (MetaObs) + env_reward: Environment's raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + raise NotImplementedError + + def reset(self, **kwargs) -> None: + """ + Reset reward function state (if needed). + + Some reward functions may have internal state (e.g., running statistics, + episode counters) that need to be reset at the beginning of an episode. + + Args: + **kwargs: Reset parameters + """ + pass + + def __call__( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Callable interface for computing reward. + + This allows the reward function to be used as a callable: + reward = reward_fn(state, action, next_state, env_reward, info) + + Args: + state: Current state (MetaObs) + action: Action (MetaAction) + next_state: Next state (MetaObs) + env_reward: Environment's raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + return self.compute(state, action, next_state, env_reward, info) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._kwargs})" + + +class IdentityReward(BaseReward): + """ + Identity reward function - returns the environment reward unchanged. + + This is the default reward function when no custom reward is specified. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Return the environment reward unchanged.""" + return env_reward + + +class ScaledReward(BaseReward): + """ + Scaled reward function - scales the environment reward by a factor. + """ + + def __init__(self, scale: float = 1.0, offset: float = 0.0, **kwargs): + """ + Initialize scaled reward. + + Args: + scale: Scaling factor for the reward + offset: Offset to add to the scaled reward + **kwargs: Additional parameters + """ + super().__init__(**kwargs) + self.scale = scale + self.offset = offset + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Return scaled and offset reward.""" + return env_reward * self.scale + self.offset + + +class CompositeReward(BaseReward): + """ + Composite reward function - combines multiple reward functions. + + This allows combining different reward signals with weights. + """ + + def __init__( + self, + reward_fns: list, + weights: Optional[list] = None, + **kwargs + ): + """ + Initialize composite reward. + + Args: + reward_fns: List of reward function instances + weights: Optional list of weights for each reward function. + If None, uses uniform weights. + **kwargs: Additional parameters + """ + super().__init__(**kwargs) + self.reward_fns = reward_fns + if weights is None: + weights = [1.0] * len(reward_fns) + assert len(weights) == len(reward_fns), "Number of weights must match number of reward functions" + self.weights = weights + + def compute( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """Compute weighted sum of all reward functions.""" + total_reward = 0.0 + for reward_fn, weight in zip(self.reward_fns, self.weights): + total_reward += weight * reward_fn.compute( + state, action, next_state, env_reward, info + ) + return total_reward + + def reset(self, **kwargs) -> None: + """Reset all component reward functions.""" + for reward_fn in self.reward_fns: + reward_fn.reset(**kwargs) + + +if __name__ == '__main__': + """ + Test code for BaseReward class and its implementations. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction + + # Test IdentityReward + print("=" * 60) + print("Testing BaseReward and implementations") + print("=" * 60) + + # Create sample states and actions + state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + action = MetaAction( + action=np.random.randn(7).astype(np.float32), + ctrl_space='ee', + ctrl_type='delta' + ) + next_state = MetaObs( + state=np.random.randn(10).astype(np.float32), + state_ee=np.random.randn(7).astype(np.float32), + raw_lang="pick up the red block" + ) + + # Test 1: IdentityReward + print("\n1. Testing IdentityReward...") + identity_reward = IdentityReward() + env_reward = 1.5 + reward = identity_reward.compute(state, action, next_state, env_reward, {}) + print(f" IdentityReward: {identity_reward}") + print(f" Env reward: {env_reward}, Computed reward: {reward}") + assert reward == env_reward, "IdentityReward should return env_reward unchanged" + + # Test callable interface + reward_callable = identity_reward(state, action, next_state, env_reward, {}) + print(f" Callable interface result: {reward_callable}") + assert reward_callable == env_reward, "Callable interface should work the same" + + # Test 2: ScaledReward + print("\n2. Testing ScaledReward...") + scaled_reward = ScaledReward(scale=2.0, offset=0.5) + env_reward = 1.0 + reward = scaled_reward.compute(state, action, next_state, env_reward, {}) + expected = 1.0 * 2.0 + 0.5 # 2.5 + print(f" ScaledReward: {scaled_reward}") + print(f" Env reward: {env_reward}, Scale: 2.0, Offset: 0.5") + print(f" Computed reward: {reward}, Expected: {expected}") + assert abs(reward - expected) < 1e-6, "ScaledReward should scale and offset correctly" + + # Test 3: CompositeReward + print("\n3. Testing CompositeReward...") + reward_fn1 = IdentityReward() + reward_fn2 = ScaledReward(scale=0.5, offset=0.0) + composite_reward = CompositeReward( + reward_fns=[reward_fn1, reward_fn2], + weights=[0.6, 0.4] + ) + env_reward = 2.0 + reward = composite_reward.compute(state, action, next_state, env_reward, {}) + # Expected: 0.6 * 2.0 + 0.4 * (2.0 * 0.5) = 1.2 + 0.4 = 1.6 + expected = 0.6 * 2.0 + 0.4 * (2.0 * 0.5) + print(f" CompositeReward: {composite_reward}") + print(f" Env reward: {env_reward}") + print(f" Component 1 (IdentityReward, weight=0.6): {reward_fn1.compute(state, action, next_state, env_reward, {})}") + print(f" Component 2 (ScaledReward*0.5, weight=0.4): {reward_fn2.compute(state, action, next_state, env_reward, {})}") + print(f" Computed reward: {reward}, Expected: {expected}") + assert abs(reward - expected) < 1e-6, "CompositeReward should compute weighted sum correctly" + + # Test 4: Custom reward function (abstract class implementation) + print("\n4. Testing custom reward function...") + + class SuccessBonus(BaseReward): + """Give bonus reward when task is successful.""" + + def __init__(self, bonus: float = 10.0, **kwargs): + super().__init__(**kwargs) + self.bonus = bonus + + def compute(self, state, action, next_state, env_reward, info): + if info and info.get('success', False): + return env_reward + self.bonus + return env_reward + + success_bonus = SuccessBonus(bonus=5.0) + + # Without success + info_no_success = {'success': False} + reward_no_success = success_bonus.compute(state, action, next_state, 1.0, info_no_success) + print(f" SuccessBonus (no success): reward = {reward_no_success}") + assert reward_no_success == 1.0 + + # With success + info_success = {'success': True} + reward_success = success_bonus.compute(state, action, next_state, 1.0, info_success) + print(f" SuccessBonus (success): reward = {reward_success}") + assert reward_success == 6.0 # 1.0 + 5.0 + + # Test 5: Reset functionality + print("\n5. Testing reset functionality...") + + class StatefulReward(BaseReward): + """Reward function with internal state.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.step_count = 0 + + def compute(self, state, action, next_state, env_reward, info): + self.step_count += 1 + return env_reward + self.step_count * 0.01 + + def reset(self, **kwargs): + self.step_count = 0 + + stateful_reward = StatefulReward() + + # Compute a few rewards + for i in range(5): + r = stateful_reward.compute(state, action, next_state, 1.0, {}) + print(f" After 5 steps, step_count = {stateful_reward.step_count}") + + # Reset + stateful_reward.reset() + print(f" After reset, step_count = {stateful_reward.step_count}") + assert stateful_reward.step_count == 0, "Reset should clear step_count" + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/trainers/__init__.py b/rl/trainers/__init__.py new file mode 100644 index 00000000..4dd6dff3 --- /dev/null +++ b/rl/trainers/__init__.py @@ -0,0 +1,85 @@ +""" +Trainers Module + +This module provides trainers for coordinating RL training. + +Design Philosophy: +- Coordinate environment, policy, and algorithm for training loop +- Support single algorithm and multiple algorithms training +- Support custom reward functions (applied during training, not data collection) +- Support evaluation during training + +Available trainers: +- SimpleTrainer: Single machine trainer +- ParallelTrainer: Parallel environment trainer +- DistributedTrainer: Distributed training + +Note: Implementations are provided in separate files. +This __init__.py provides factory functions for creating trainers. +""" + +from typing import Type, Dict, Any + +from .base_trainer import BaseTrainer + +# Registry for trainer classes +_TRAINER_REGISTRY: Dict[str, Type] = {} + + +def register_trainer(name: str, trainer_class: Type) -> None: + """ + Register a trainer class. + + Args: + name: Trainer name (e.g., 'simple', 'parallel') + trainer_class: Trainer class to register + """ + _TRAINER_REGISTRY[name.lower()] = trainer_class + + +def get_trainer_class(name_or_type: str) -> Type: + """ + Get trainer class by name or type string. + + Args: + name_or_type: Trainer name (e.g., 'simple') or full type path + (e.g., 'rl.trainers.simple_trainer.SimpleTrainer') + + Returns: + Trainer class + + Raises: + ValueError: If trainer not found + """ + # First check registry + if name_or_type.lower() in _TRAINER_REGISTRY: + return _TRAINER_REGISTRY[name_or_type.lower()] + + # Try to import from type path + if '.' in name_or_type: + try: + parts = name_or_type.rsplit('.', 1) + module_path = parts[0] + class_name = parts[1] + + import importlib + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ImportError, AttributeError) as e: + raise ValueError(f"Cannot import trainer from '{name_or_type}': {e}") + + raise ValueError(f"Unknown trainer: '{name_or_type}'. Available: {list(_TRAINER_REGISTRY.keys())}") + + +def list_trainers() -> list: + """List all registered trainers.""" + return list(_TRAINER_REGISTRY.keys()) + + +__all__ = [ + 'BaseTrainer', + 'register_trainer', + 'get_trainer_class', + 'list_trainers', +] + diff --git a/rl/trainers/base_trainer.py b/rl/trainers/base_trainer.py new file mode 100644 index 00000000..a0ef2c31 --- /dev/null +++ b/rl/trainers/base_trainer.py @@ -0,0 +1,587 @@ +""" +Base Trainer Class + +This module defines the base class for all RL trainers in the framework. + +Design Philosophy: +- Coordinate environment, policy, and algorithm for executing training loop +- Support single algorithm and multiple algorithms training scenarios +- Support custom reward functions (applied during training, not data collection) +- Support evaluation during training +""" + +import numpy as np +import torch +from typing import Dict, Any, Optional, Union, List, Callable +from abc import ABC, abstractmethod + +# Type hints for Meta classes +from benchmark.base import MetaObs, MetaAction, MetaEnv + + +class BaseTrainer(ABC): + """ + Base class for RL trainers. + + This class defines the interface for all trainers in the RL framework. + Trainers coordinate the training loop by managing: + - Environments (meta_envs) + - Algorithms (with their policies and replay buffers) + - Collectors (for data collection) + - Reward functions (applied during training) + + Attributes: + meta_envs: Environment(s) for training + algorithm: Algorithm(s) for training + collector: Data collector(s) + reward_fn: Optional custom reward function + """ + + def __init__( + self, + meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + algorithm: Union['BaseAlgorithm', List['BaseAlgorithm']], + collector: Optional[Union['BaseCollector', List['BaseCollector']]] = None, + reward_fn: Optional[Union['BaseReward', Callable]] = None, + **kwargs + ): + """ + Initialize the trainer. + + Args: + meta_envs: Supports multiple formats: + - MetaEnv instance: Single environment + - List[MetaEnv]: Environment list (same type environments) + - Callable: Environment factory function + - Dict[str, Any]: Multi-environment config dict (different env types) + algorithm: BaseAlgorithm instance or list of BaseAlgorithm (required) + - Single algorithm: Single agent training + - Algorithm list: Multiple algorithms training independently in same env + (each algorithm has its own replay buffer) + collector: Optional data collector (if None, trainer creates default collector) + - Single algorithm: Single collector + - Multiple algorithms: Can be collector list, each for one algorithm + - If None, trainer creates default collector for each algorithm + reward_fn: Optional reward function (if None, use raw environment reward) + - Can be BaseReward instance or Callable function + - Applied during training for algorithm updates + - Note: Replay buffer stores raw rewards, reward function only applied during training + **kwargs: Trainer-specific parameters + """ + self.meta_envs = meta_envs + self.algorithm = algorithm + self.reward_fn = reward_fn + self._kwargs = kwargs + + # Handle collector initialization + self.collector = collector + # Note: Actual collector creation should be done in subclasses + # since it may require specific collector types + + @abstractmethod + def train(self, **kwargs) -> None: + """ + Execute training loop. + + Args: + **kwargs: Training parameters, can include: + - total_steps: Total training steps (optional) + - total_episodes: Total episodes (optional) + - max_time: Maximum training time (optional) + - log_interval: Logging interval (optional) + - save_interval: Model save interval (optional) + - eval_interval: Evaluation interval (optional) + """ + raise NotImplementedError + + def compute_reward( + self, + state: MetaObs, + action: MetaAction, + next_state: MetaObs, + env_reward: float, + info: Optional[Dict[str, Any]] = None + ) -> float: + """ + Compute reward (support custom reward function). + + Used during training for algorithm update reward computation. + Replay buffer stores raw rewards, reward function only applied during training. + + Args: + state: Current state + action: Action + next_state: Next state + env_reward: Environment raw reward + info: Additional information dictionary + + Returns: + Computed reward value + """ + if self.reward_fn is not None: + # Import here to avoid circular imports + from rl.rewards.base_reward import BaseReward + + if isinstance(self.reward_fn, BaseReward): + return self.reward_fn.compute(state, action, next_state, env_reward, info) + else: + # Assume it's a callable + return self.reward_fn(state, action, next_state, env_reward, info) + return env_reward + + def collect_rollout( + self, + n_steps: int, + env_type: Optional[str] = None + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """ + Collect rollout data (using collector). + + Args: + n_steps: Number of steps to collect + env_type: Optional environment type identifier (for multi-environment scenarios) + - Used to support a single algorithm storing data from multiple different environments + + Returns: + - Single algorithm: Rollout statistics dictionary + - Multiple algorithms: List of rollout statistics dictionaries + """ + if self.collector is None: + raise ValueError("Collector is not initialized. Please provide or create a collector.") + + if isinstance(self.collector, list): + # Multiple algorithms: Each algorithm collects independently + return [col.collect(n_steps, env_type=env_type) for col in self.collector] + else: + # Single algorithm + return self.collector.collect(n_steps, env_type=env_type) + + @abstractmethod + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Evaluate policy performance. + + Args: + n_episodes: Number of episodes to evaluate + render: Whether to render environment (optional) + env_type: Optional, specify evaluation environment type + **kwargs: Other evaluation parameters + + Returns: + Dictionary containing evaluation metrics + """ + raise NotImplementedError + + @abstractmethod + def save(self, path: str) -> None: + """Save model and training state.""" + raise NotImplementedError + + @abstractmethod + def load(self, path: str) -> None: + """Load model and training state.""" + raise NotImplementedError + + def get_algorithm(self) -> Union['BaseAlgorithm', List['BaseAlgorithm']]: + """Get the algorithm(s).""" + return self.algorithm + + def get_collector(self) -> Optional[Union['BaseCollector', List['BaseCollector']]]: + """Get the collector(s).""" + return self.collector + + def set_reward_fn(self, reward_fn: Union['BaseReward', Callable]) -> None: + """ + Set or update the reward function. + + Args: + reward_fn: New reward function to use + """ + self.reward_fn = reward_fn + + def __repr__(self) -> str: + algo_info = self.algorithm.__class__.__name__ if not isinstance(self.algorithm, list) else f"List[{len(self.algorithm)}]" + collector_info = "None" if self.collector is None else ( + self.collector.__class__.__name__ if not isinstance(self.collector, list) else f"List[{len(self.collector)}]" + ) + reward_info = "None" if self.reward_fn is None else self.reward_fn.__class__.__name__ + return f"{self.__class__.__name__}(algorithm={algo_info}, collector={collector_info}, reward_fn={reward_info})" + + +if __name__ == '__main__': + """ + Test code for BaseTrainer class. + + Since BaseTrainer is abstract, we create a simple concrete implementation for testing. + """ + import sys + sys.path.insert(0, '/home/zhang/robot/126/ILStudio') + + from benchmark.base import MetaObs, MetaAction, MetaEnv, MetaPolicy + from rl.buffer.base_replay import BaseReplay + from rl.base import BaseAlgorithm + from rl.rewards.base_reward import BaseReward, IdentityReward, ScaledReward + from rl.collectors.base_collector import BaseCollector + from dataclasses import asdict + + # Simple replay buffer for testing + class SimpleReplay(BaseReplay): + def __init__(self, capacity=1000, device='cpu', **kwargs): + super().__init__(capacity=capacity, device=device, **kwargs) + self._storage = [] + + def add(self, transition): + if self._size < self.capacity: + self._storage.append(transition) + self._size += 1 + else: + self._storage[self._position] = transition + self._position = (self._position + 1) % self.capacity + + def sample(self, batch_size): + if self._size == 0: + return {} + indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) + return { + 'states': [self._storage[i]['state'] for i in indices], + 'actions': [self._storage[i]['action'] for i in indices], + 'rewards': np.array([self._storage[i]['reward'] for i in indices]), + } + + def clear(self): + self._storage = [] + self._size = 0 + self._position = 0 + + def save(self, path, **kwargs): + pass + + def load(self, path, **kwargs): + pass + + # Simple dummy environment for testing + class DummyEnv: + def __init__(self, state_dim=10, action_dim=7, max_steps=100): + self.state_dim = state_dim + self.action_dim = action_dim + self._step_count = 0 + self._max_steps = max_steps + + def reset(self): + self._step_count = 0 + return {'state': np.random.randn(self.state_dim).astype(np.float32)} + + def step(self, action): + self._step_count += 1 + obs = {'state': np.random.randn(self.state_dim).astype(np.float32)} + reward = np.random.randn() + done = self._step_count >= self._max_steps + info = {'step': self._step_count} + return obs, reward, done, info + + def close(self): + pass + + # Simple MetaEnv wrapper for testing + class DummyMetaEnv(MetaEnv): + def __init__(self, state_dim=10, action_dim=7, max_steps=100): + self.env = DummyEnv(state_dim=state_dim, action_dim=action_dim, max_steps=max_steps) + self.prev_obs = None + + def obs2meta(self, raw_obs): + return MetaObs(state=raw_obs['state'], raw_lang="test") + + def meta2act(self, action, *args): + if hasattr(action, 'action'): + return action.action + return action + + def reset(self): + init_obs = self.env.reset() + self.prev_obs = self.obs2meta(init_obs) + return self.prev_obs + + def step(self, action): + act = self.meta2act(action) + obs, reward, done, info = self.env.step(act) + self.prev_obs = self.obs2meta(obs) + return self.prev_obs, reward, done, info + + # Simple policy for testing + class DummyPolicy: + def __init__(self, action_dim=7): + self.action_dim = action_dim + + def select_action(self, obs): + return MetaAction(action=np.random.randn(self.action_dim).astype(np.float32)) + + def train(self): + pass + + def eval(self): + pass + + # Simple MetaPolicy wrapper for testing + class DummyMetaPolicy(MetaPolicy): + def __init__(self, action_dim=7): + self.policy = DummyPolicy(action_dim=action_dim) + self.chunk_size = 1 + self.ctrl_space = 'ee' + self.ctrl_type = 'delta' + self.action_queue = [] + self.action_normalizer = None + self.state_normalizer = None + + def select_action(self, mobs, t=0, **kwargs): + return self.policy.select_action(mobs) + + # Simple algorithm for testing + class DummyAlgorithm(BaseAlgorithm): + def __init__(self, meta_policy, replay=None, **kwargs): + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + self._timestep = 0 + self._update_count = 0 + + def update(self, batch=None, **kwargs): + self._update_count += 1 + batch_size = kwargs.get('batch_size', 32) + if batch is None and self.replay is not None: + batch = self.replay.sample(batch_size) + return {'loss': np.random.randn(), 'update_count': self._update_count} + + def select_action(self, obs, **kwargs): + return self.meta_policy.select_action(obs, t=self._timestep) + + # Simple collector for testing + class DummyCollector(BaseCollector): + def __init__(self, meta_envs, algorithm, **kwargs): + super().__init__(meta_envs, algorithm, **kwargs) + if isinstance(meta_envs, list): + self.envs = meta_envs + else: + self.envs = [meta_envs] + self._last_obs = None + + def reset(self, **kwargs): + self._last_obs = [env.reset() for env in self.envs] + + def collect(self, n_steps, env_type=None): + if self._last_obs is None: + self.reset() + + stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0} + episode_reward = 0.0 + episode_length = 0 + + for step in range(n_steps): + for i, (env, obs) in enumerate(zip(self.envs, self._last_obs)): + action = self.algorithm.select_action(obs) + new_obs, reward, done, info = env.step(action) + + # Record transition + kwargs_trans = {'env_type': env_type} if env_type else {} + self.algorithm.record_transition( + state=obs, action=action, reward=reward, + next_state=new_obs, done=done, info=info, **kwargs_trans + ) + + episode_reward += reward + episode_length += 1 + stats['total_steps'] += 1 + + if done: + stats['episode_rewards'].append(episode_reward) + stats['episode_lengths'].append(episode_length) + episode_reward = 0.0 + episode_length = 0 + new_obs = env.reset() + + self._last_obs[i] = new_obs + + return stats + + # Simple trainer implementation for testing + class SimpleTrainer(BaseTrainer): + """Simple trainer for testing.""" + + def __init__( + self, + meta_envs, + algorithm, + collector=None, + reward_fn=None, + **kwargs + ): + super().__init__(meta_envs, algorithm, collector, reward_fn, **kwargs) + + # Create default collector if not provided + if self.collector is None: + if isinstance(algorithm, list): + self.collector = [ + DummyCollector(meta_envs=meta_envs, algorithm=alg) + for alg in algorithm + ] + else: + self.collector = DummyCollector( + meta_envs=meta_envs, + algorithm=algorithm + ) + + self._total_steps = 0 + self._training_logs = [] + + def train(self, **kwargs): + total_steps = kwargs.get('total_steps', 1000) + log_interval = kwargs.get('log_interval', 100) + update_interval = kwargs.get('update_interval', 50) + batch_size = kwargs.get('batch_size', 32) + + print(f"Starting training for {total_steps} steps...") + + while self._total_steps < total_steps: + # Collect data + stats = self.collect_rollout(n_steps=update_interval) + self._total_steps += stats['total_steps'] if isinstance(stats, dict) else sum(s['total_steps'] for s in stats) + + # Update algorithm + if isinstance(self.algorithm, list): + for alg in self.algorithm: + result = alg.update(batch_size=batch_size) + else: + result = self.algorithm.update(batch_size=batch_size) + + # Log + if self._total_steps % log_interval == 0: + print(f" Step {self._total_steps}: loss={result.get('loss', 0.0):.4f}") + self._training_logs.append({ + 'step': self._total_steps, + 'loss': result.get('loss', 0.0) + }) + + print(f"Training completed. Total steps: {self._total_steps}") + + def evaluate(self, n_episodes=10, render=False, env_type=None, **kwargs): + print(f"Evaluating for {n_episodes} episodes...") + + # Create eval env + if isinstance(self.meta_envs, dict): + env = list(self.meta_envs.values())[0] if env_type is None else self.meta_envs[env_type] + elif isinstance(self.meta_envs, list): + env = self.meta_envs[0] + else: + env = self.meta_envs + + alg = self.algorithm[0] if isinstance(self.algorithm, list) else self.algorithm + + episode_rewards = [] + episode_lengths = [] + + for ep in range(n_episodes): + obs = env.reset() + episode_reward = 0.0 + episode_length = 0 + done = False + + while not done: + action = alg.select_action(obs) + obs, reward, done, info = env.step(action) + episode_reward += reward + episode_length += 1 + + episode_rewards.append(episode_reward) + episode_lengths.append(episode_length) + + results = { + 'mean_reward': np.mean(episode_rewards), + 'std_reward': np.std(episode_rewards), + 'mean_length': np.mean(episode_lengths), + 'episode_rewards': episode_rewards, + 'episode_lengths': episode_lengths + } + print(f"Evaluation: mean_reward={results['mean_reward']:.2f} ± {results['std_reward']:.2f}") + return results + + def save(self, path): + print(f"Saving to {path} (mock)") + + def load(self, path): + print(f"Loading from {path} (mock)") + + # Test the implementation + print("=" * 60) + print("Testing BaseTrainer (SimpleTrainer implementation)") + print("=" * 60) + + # Create components + print("\n1. Creating components...") + env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + meta_policy = DummyMetaPolicy(action_dim=7) + replay = SimpleReplay(capacity=10000) + algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) + + # Test 1: Basic trainer + print("\n2. Testing basic trainer...") + trainer = SimpleTrainer( + meta_envs=env, + algorithm=algorithm, + reward_fn=None + ) + print(f" Trainer: {trainer}") + + # Test compute_reward without reward_fn + print("\n3. Testing compute_reward without custom reward_fn...") + state = MetaObs(state=np.random.randn(10).astype(np.float32)) + action = MetaAction(action=np.random.randn(7).astype(np.float32)) + next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) + env_reward = 1.5 + computed_reward = trainer.compute_reward(state, action, next_state, env_reward, {}) + print(f" Env reward: {env_reward}, Computed reward: {computed_reward}") + assert computed_reward == env_reward + + # Test compute_reward with custom reward_fn + print("\n4. Testing compute_reward with custom reward_fn...") + trainer.set_reward_fn(ScaledReward(scale=2.0, offset=0.5)) + computed_reward = trainer.compute_reward(state, action, next_state, env_reward, {}) + expected = env_reward * 2.0 + 0.5 + print(f" Env reward: {env_reward}, Computed reward: {computed_reward}, Expected: {expected}") + assert abs(computed_reward - expected) < 1e-6 + + # Test training + print("\n5. Testing training...") + trainer.train(total_steps=200, log_interval=50, update_interval=20, batch_size=16) + print(f" Replay buffer size: {len(algorithm.replay)}") + + # Test evaluation + print("\n6. Testing evaluation...") + eval_results = trainer.evaluate(n_episodes=3) + print(f" Evaluation results: {eval_results}") + + # Test collect_rollout + print("\n7. Testing collect_rollout...") + rollout_stats = trainer.collect_rollout(n_steps=50) + print(f" Rollout stats: {rollout_stats}") + + # Test with multiple algorithms + print("\n8. Testing with multiple algorithms...") + algorithms = [ + DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)), + DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)) + ] + multi_trainer = SimpleTrainer( + meta_envs=env, + algorithm=algorithms, + reward_fn=IdentityReward() + ) + print(f" Multi-algorithm trainer: {multi_trainer}") + multi_trainer.train(total_steps=100, log_interval=50, update_interval=20, batch_size=8) + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + diff --git a/rl/utils/__init__.py b/rl/utils/__init__.py new file mode 100644 index 00000000..6388a62e --- /dev/null +++ b/rl/utils/__init__.py @@ -0,0 +1,325 @@ +""" +RL Utilities Module + +This module provides utility functions for the RL framework. + +Available utilities: +- DataProcessor: Data processor for aligning with ILStudio pipeline +- Running statistics (mean, variance) for normalization +- Advantage computation (GAE) +- Discount reward computation +""" + +import numpy as np +import torch +from typing import Dict, Any, Optional, List, Union + + +def compute_gae( + rewards: np.ndarray, + values: np.ndarray, + dones: np.ndarray, + next_value: float, + gamma: float = 0.99, + gae_lambda: float = 0.95 +) -> np.ndarray: + """ + Compute Generalized Advantage Estimation (GAE). + + Args: + rewards: Array of rewards [T] + values: Array of value estimates [T] + dones: Array of done flags [T] + next_value: Value estimate for the next state + gamma: Discount factor + gae_lambda: GAE lambda parameter + + Returns: + Array of advantages [T] + """ + T = len(rewards) + advantages = np.zeros(T, dtype=np.float32) + last_gae = 0.0 + + for t in reversed(range(T)): + if t == T - 1: + next_non_terminal = 1.0 - dones[t] + next_val = next_value + else: + next_non_terminal = 1.0 - dones[t] + next_val = values[t + 1] + + delta = rewards[t] + gamma * next_val * next_non_terminal - values[t] + advantages[t] = last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae + + return advantages + + +def compute_returns( + rewards: np.ndarray, + dones: np.ndarray, + next_value: float = 0.0, + gamma: float = 0.99 +) -> np.ndarray: + """ + Compute discounted returns. + + Args: + rewards: Array of rewards [T] + dones: Array of done flags [T] + next_value: Value estimate for the next state + gamma: Discount factor + + Returns: + Array of discounted returns [T] + """ + T = len(rewards) + returns = np.zeros(T, dtype=np.float32) + running_return = next_value + + for t in reversed(range(T)): + running_return = rewards[t] + gamma * running_return * (1.0 - dones[t]) + returns[t] = running_return + + return returns + + +class RunningMeanStd: + """ + Running mean and standard deviation tracker. + + Useful for observation normalization in RL. + """ + + def __init__(self, shape: tuple = (), epsilon: float = 1e-8): + """ + Initialize running statistics. + + Args: + shape: Shape of the data to track + epsilon: Small value for numerical stability + """ + self.mean = np.zeros(shape, dtype=np.float64) + self.var = np.ones(shape, dtype=np.float64) + self.count = epsilon + self.epsilon = epsilon + + def update(self, x: np.ndarray) -> None: + """ + Update running statistics with new data. + + Args: + x: New data batch [batch_size, *shape] + """ + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + + self._update_from_moments(batch_mean, batch_var, batch_count) + + def _update_from_moments( + self, + batch_mean: np.ndarray, + batch_var: np.ndarray, + batch_count: int + ) -> None: + """Update from batch moments.""" + delta = batch_mean - self.mean + total_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / total_count + m_a = self.var * self.count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * self.count * batch_count / total_count + new_var = M2 / total_count + + self.mean = new_mean + self.var = new_var + self.count = total_count + + def normalize(self, x: np.ndarray) -> np.ndarray: + """ + Normalize data using running statistics. + + Args: + x: Data to normalize + + Returns: + Normalized data + """ + return (x - self.mean) / np.sqrt(self.var + self.epsilon) + + def denormalize(self, x: np.ndarray) -> np.ndarray: + """ + Denormalize data using running statistics. + + Args: + x: Normalized data + + Returns: + Denormalized data + """ + return x * np.sqrt(self.var + self.epsilon) + self.mean + + +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: + """ + Compute explained variance. + + Args: + y_pred: Predicted values + y_true: True values + + Returns: + Explained variance (1.0 is perfect prediction) + """ + var_y = np.var(y_true) + if var_y == 0: + return np.nan + return 1.0 - np.var(y_true - y_pred) / var_y + + +def polyak_update( + source_params: List[torch.nn.Parameter], + target_params: List[torch.nn.Parameter], + tau: float = 0.005 +) -> None: + """ + Perform Polyak (soft) update of target network parameters. + + target = tau * source + (1 - tau) * target + + Args: + source_params: Source network parameters + target_params: Target network parameters + tau: Interpolation factor (0.0 = no update, 1.0 = full copy) + """ + with torch.no_grad(): + for source_param, target_param in zip(source_params, target_params): + target_param.data.mul_(1.0 - tau) + target_param.data.add_(tau * source_param.data) + + +def hard_update( + source_params: List[torch.nn.Parameter], + target_params: List[torch.nn.Parameter] +) -> None: + """ + Perform hard update of target network parameters (full copy). + + Args: + source_params: Source network parameters + target_params: Target network parameters + """ + polyak_update(source_params, target_params, tau=1.0) + + +__all__ = [ + 'compute_gae', + 'compute_returns', + 'RunningMeanStd', + 'explained_variance', + 'polyak_update', + 'hard_update', +] + + +if __name__ == '__main__': + """ + Test code for RL utilities. + """ + print("=" * 60) + print("Testing RL Utilities") + print("=" * 60) + + # Test compute_gae + print("\n1. Testing compute_gae...") + rewards = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + values = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32) + dones = np.array([0, 0, 0, 0, 1], dtype=np.float32) + next_value = 0.0 + + advantages = compute_gae(rewards, values, dones, next_value, gamma=0.99, gae_lambda=0.95) + print(f" Rewards: {rewards}") + print(f" Values: {values}") + print(f" Advantages: {advantages}") + print(f" Advantages shape: {advantages.shape}") + + # Test compute_returns + print("\n2. Testing compute_returns...") + returns = compute_returns(rewards, dones, next_value=0.0, gamma=0.99) + print(f" Returns: {returns}") + print(f" Returns shape: {returns.shape}") + + # Test RunningMeanStd + print("\n3. Testing RunningMeanStd...") + rms = RunningMeanStd(shape=(5,)) + + # Update with random data + for i in range(10): + data = np.random.randn(32, 5) + rms.update(data) + + print(f" Mean: {rms.mean}") + print(f" Var: {rms.var}") + print(f" Count: {rms.count}") + + # Test normalization + test_data = np.random.randn(10, 5) + normalized = rms.normalize(test_data) + denormalized = rms.denormalize(normalized) + print(f" Normalization error: {np.max(np.abs(test_data - denormalized)):.2e}") + + # Test explained_variance + print("\n4. Testing explained_variance...") + y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + y_pred = np.array([1.1, 1.9, 3.1, 4.1, 4.9]) + ev = explained_variance(y_pred, y_true) + print(f" y_true: {y_true}") + print(f" y_pred: {y_pred}") + print(f" Explained variance: {ev:.4f}") + + # Test polyak_update + print("\n5. Testing polyak_update...") + + class SimpleNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(10, 5) + + source_net = SimpleNet() + target_net = SimpleNet() + + # Initialize target differently + with torch.no_grad(): + target_net.fc.weight.fill_(0.0) + target_net.fc.bias.fill_(0.0) + + print(f" Source weight mean: {source_net.fc.weight.mean().item():.4f}") + print(f" Target weight mean before update: {target_net.fc.weight.mean().item():.4f}") + + polyak_update( + list(source_net.parameters()), + list(target_net.parameters()), + tau=0.5 + ) + print(f" Target weight mean after polyak update (tau=0.5): {target_net.fc.weight.mean().item():.4f}") + + # Test hard_update + print("\n6. Testing hard_update...") + with torch.no_grad(): + target_net.fc.weight.fill_(0.0) + + hard_update( + list(source_net.parameters()), + list(target_net.parameters()) + ) + + diff = (source_net.fc.weight - target_net.fc.weight).abs().max().item() + print(f" Weight difference after hard update: {diff:.2e}") + assert diff < 1e-6, "Hard update should copy exactly" + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + From a92753d4fff6814f427d3a1a7c35beb378cb1135 Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Thu, 29 Jan 2026 00:26:59 +0800 Subject: [PATCH 2/6] [feat] add VectorEnvProtocol to support custom VectorEnv --- rl/collectors/base_collector.py | 300 +++++++++++++++++++++++--------- rl/envs/__init__.py | 19 ++ rl/envs/protocols.py | 77 ++++++++ rl/envs/utils.py | 99 +++++++++++ rl/trainers/base_trainer.py | 271 ++++++++++++++++++++++------- 5 files changed, 614 insertions(+), 152 deletions(-) create mode 100644 rl/envs/__init__.py create mode 100644 rl/envs/protocols.py create mode 100644 rl/envs/utils.py diff --git a/rl/collectors/base_collector.py b/rl/collectors/base_collector.py index f1f05e32..177371db 100644 --- a/rl/collectors/base_collector.py +++ b/rl/collectors/base_collector.py @@ -6,18 +6,21 @@ Design Philosophy: - Responsibility separation: Separate data collection logic from trainer so that trainer can focus on training loop coordination -- Environment abstraction: Support single environment, parallel environments, multiple environment types +- Environment abstraction: Support vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) - Raw data storage: Only store raw environment rewards, no reward function computation, ensuring data integrity - Statistics: Collect and return episode statistics """ import numpy as np -from typing import Dict, Any, Optional, Union, List, Callable +from typing import Dict, Any, Optional, Union, List from abc import ABC, abstractmethod # Type hints for Meta classes -from benchmark.base import MetaObs, MetaAction, MetaEnv +from benchmark.base import MetaObs, MetaAction + +# Vector environment protocol +from rl.envs import VectorEnvProtocol, EnvsType class BaseCollector(ABC): @@ -25,11 +28,11 @@ class BaseCollector(ABC): Base class for data collectors. This class defines the interface for all data collectors in the RL framework. - Collectors gather experience data from environments by interacting with them + Collectors gather experience data from vectorized environments by interacting with them using the algorithm's policy. Attributes: - meta_envs: Environment(s) to collect data from + envs: Vectorized environment(s) to collect data from algorithm: Algorithm instance for action selection and transition recording Note: Collector only stores raw environment rewards, no reward function computation. @@ -38,7 +41,7 @@ class BaseCollector(ABC): def __init__( self, - meta_envs: Union['MetaEnv', List['MetaEnv'], Callable, Dict[str, Any]], + envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], algorithm: 'BaseAlgorithm', **kwargs ): @@ -46,11 +49,11 @@ def __init__( Initialize the collector. Args: - meta_envs: Supports multiple formats: - - MetaEnv instance: Single environment - - List[MetaEnv]: Environment list (same type environments) - - Callable: Environment factory function - - Dict[str, Any]: Multi-environment config dict (supports different env types) + envs: Vectorized environment(s), supports: + - VectorEnvProtocol: Single vectorized environment + (SequentialVectorEnv, SubprocVectorEnv, DummyVectorEnv, etc.) + - Dict[str, VectorEnvProtocol]: Multi-environment dict for different env types + e.g., {'sim': sim_vec_env, 'real': real_vec_env} algorithm: BaseAlgorithm instance (required) - Used for action selection and transition recording **kwargs: Collector-specific parameters @@ -58,9 +61,17 @@ def __init__( Note: Collector only stores raw environment rewards, no reward function computation. Reward functions are applied in trainer during training time. """ - self.meta_envs = meta_envs + self.envs = envs self.algorithm = algorithm self._kwargs = kwargs + + # Normalize environment storage: always use dict internally + if isinstance(envs, dict): + self._envs_dict: Dict[str, VectorEnvProtocol] = envs + self._is_multi_env = True + else: + self._envs_dict = {'default': envs} + self._is_multi_env = False @abstractmethod def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: @@ -92,16 +103,68 @@ def reset(self, **kwargs) -> None: """ raise NotImplementedError - def get_envs(self) -> Union['MetaEnv', List['MetaEnv'], Dict[str, Any]]: + def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + """ + Get the vectorized environment by type. + + Args: + env_type: Environment type identifier. If None, returns 'default' env + or the first env if 'default' doesn't exist. + + Returns: + The vectorized environment + + Raises: + KeyError: If specified env_type is not found + """ + if env_type is not None: + return self._envs_dict[env_type] + + # Return 'default' if exists, otherwise return first env + if 'default' in self._envs_dict: + return self._envs_dict['default'] + return list(self._envs_dict.values())[0] + + def get_total_env_num(self) -> int: + """ + Get total number of parallel environments across all types. + + Returns: + Total count of parallel environments + """ + return sum(len(env) for env in self._envs_dict.values()) + + def get_env_types(self) -> List[str]: + """ + Get all environment type identifiers. + + Returns: + List of environment type strings + """ + return list(self._envs_dict.keys()) + + def get_envs(self) -> Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]]: """Get the underlying environment(s).""" - return self.meta_envs + return self.envs def get_algorithm(self) -> 'BaseAlgorithm': """Get the algorithm instance.""" return self.algorithm + @property + def env_num(self) -> int: + """ + Number of environments in the default (or first) environment. + + Returns: + Number of parallel environments + """ + if 'default' in self._envs_dict: + return len(self._envs_dict['default']) + return len(list(self._envs_dict.values())[0]) + def __repr__(self) -> str: - env_info = type(self.meta_envs).__name__ if not isinstance(self.meta_envs, (list, dict)) else f"List[{len(self.meta_envs)}]" if isinstance(self.meta_envs, list) else f"Dict[{len(self.meta_envs)}]" + env_info = f"{self.get_total_env_num()} envs" if self._is_multi_env else f"{self.env_num} envs" return f"{self.__class__.__name__}(envs={env_info}, algorithm={self.algorithm.__class__.__name__})" @@ -161,7 +224,7 @@ def __init__(self, state_dim=10, action_dim=7): self.state_dim = state_dim self.action_dim = action_dim self._step_count = 0 - self._max_steps = 100 + self._max_steps = 5 def reset(self): self._step_count = 0 @@ -253,43 +316,76 @@ def update(self, batch=None, **kwargs): def select_action(self, obs, **kwargs): return self.meta_policy.select_action(obs, t=self._timestep) + # Simple VectorEnv for testing (similar to SequentialVectorEnv) + class DummyVectorEnv: + """Simple sequential vector environment for testing.""" + def __init__(self, env_fns): + self.envs = [fn() for fn in env_fns] + self.env_num = len(self.envs) + + def reset(self, id=None): + if id is None: + obs_list = [env.reset() for env in self.envs] + return np.array(obs_list, dtype=object) + else: + if np.isscalar(id): + return self.envs[id].reset() + else: + obs_list = [self.envs[i].reset() for i in id] + return np.array(obs_list, dtype=object) + + def step(self, action, id=None): + if id is None: + results = [env.step(act) for env, act in zip(self.envs, action)] + obs = np.array([r[0] for r in results], dtype=object) + rew = np.array([r[1] for r in results]) + done = np.array([r[2] for r in results]) + info = [r[3] for r in results] + return obs, rew, done, info + else: + if np.isscalar(id): + return self.envs[id].step(action) + else: + results = [self.envs[i].step(act) for i, act in zip(id, action)] + obs = np.array([r[0] for r in results], dtype=object) + rew = np.array([r[1] for r in results]) + done = np.array([r[2] for r in results]) + info = [r[3] for r in results] + return obs, rew, done, info + + def close(self): + for env in self.envs: + if hasattr(env, 'close'): + env.close() + + def __len__(self): + return self.env_num + # Simple collector implementation for testing class SimpleCollector(BaseCollector): - """Simple collector for testing.""" + """Simple collector for testing with vectorized environments.""" def __init__( self, - meta_envs: Union[MetaEnv, List[MetaEnv]], + envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], algorithm: BaseAlgorithm, **kwargs ): - super().__init__(meta_envs, algorithm, **kwargs) - - # Initialize environments - if isinstance(meta_envs, list): - self.envs = meta_envs - elif isinstance(meta_envs, MetaEnv): - self.envs = [meta_envs] - else: - raise ValueError(f"Unsupported env type: {type(meta_envs)}") + super().__init__(envs, algorithm, **kwargs) + # Get default environment + self.vec_env = self.get_env() self._last_obs = None self._last_dones = None - self._episode_rewards = [] - self._episode_lengths = [] - self._current_episode_reward = [0.0] * len(self.envs) - self._current_episode_length = [0] * len(self.envs) + self._current_episode_reward = None + self._current_episode_length = None def reset(self, **kwargs) -> None: """Reset all environments.""" - self._last_obs = [] - self._last_dones = [] - for env in self.envs: - obs = env.reset() - self._last_obs.append(obs) - self._last_dones.append(False) - self._current_episode_reward = [0.0] * len(self.envs) - self._current_episode_length = [0] * len(self.envs) + self._last_obs = self.vec_env.reset() + self._last_dones = np.zeros(len(self.vec_env), dtype=bool) + self._current_episode_reward = np.zeros(len(self.vec_env), dtype=np.float32) + self._current_episode_length = np.zeros(len(self.vec_env), dtype=int) def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: """ @@ -311,76 +407,97 @@ def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any 'total_steps': 0 } - for step in range(n_steps): - for i, (env, obs) in enumerate(zip(self.envs, self._last_obs)): - if self._last_dones[i]: - continue - - # Get action + steps_executed = 0 + + while steps_executed < n_steps: + # Get actions for all environments + actions = [] + for obs in self._last_obs: with torch.no_grad(): action = self.algorithm.select_action(obs) - - # Environment interaction - new_obs, reward, done, info = env.step(action) - + actions.append(action) + + # Step all environments + new_obs, rewards, dones, infos = self.vec_env.step(actions) + + # Record transitions and update stats + for i in range(len(self.vec_env)): # Record transition (only store raw reward) transition_kwargs = {} if env_type is not None: transition_kwargs['env_type'] = env_type self.algorithm.record_transition( - state=obs, - action=action, - reward=reward, # Store raw reward - next_state=new_obs, - done=done, - info=info, + state=self._last_obs[i], + action=actions[i], + reward=rewards[i], # Store raw reward + next_state=new_obs[i], + done=dones[i], + info=infos[i], **transition_kwargs ) # Update episode statistics - self._current_episode_reward[i] += reward + self._current_episode_reward[i] += rewards[i] self._current_episode_length[i] += 1 + stats['total_steps'] += 1 + steps_executed += 1 + + # Reset done environments + done_indices = np.where(dones)[0] + if len(done_indices) > 0: + for idx in done_indices: + stats['episode_rewards'].append(self._current_episode_reward[idx]) + stats['episode_lengths'].append(self._current_episode_length[idx]) + self._current_episode_reward[idx] = 0.0 + self._current_episode_length[idx] = 0 - # If episode ended - if done: - stats['episode_rewards'].append(self._current_episode_reward[i]) - stats['episode_lengths'].append(self._current_episode_length[i]) - - # Reset environment and episode stats - new_obs = env.reset() - self._current_episode_reward[i] = 0.0 - self._current_episode_length[i] = 0 + # Reset done environments + reset_obs = self.vec_env.reset(id=done_indices.tolist()) + if len(done_indices) == 1: + new_obs[done_indices[0]] = reset_obs + else: + for idx, reset_idx in enumerate(done_indices): + new_obs[reset_idx] = reset_obs[idx] - self._last_obs[i] = new_obs - self._last_dones[i] = done if not done else False # Reset done flag after reset - stats['total_steps'] += 1 + # Mark reset environments as not done + dones[done_indices] = False + + self._last_obs = new_obs + self._last_dones = dones return stats # Test the implementation print("=" * 60) - print("Testing BaseCollector (SimpleCollector implementation)") + print("Testing BaseCollector with Vectorized Environments") print("=" * 60) # Create environment, policy, algorithm print("\n1. Creating components...") - env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + # Create vectorized environment + # Use max_steps=5 to match DummyEnv's _max_steps=5 (user changed it) + env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=5) + env_fns = [env_fn for _ in range(4)] + vec_env = DummyVectorEnv(env_fns) + meta_policy = DummyMetaPolicy(action_dim=7) replay = SimpleReplay(capacity=1000) algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) - print(f" Environment: {env}") + print(f" Vectorized Environment: {len(vec_env)} parallel envs") print(f" MetaPolicy: {meta_policy}") print(f" Algorithm: {algorithm}") # Create collector print("\n2. Creating collector...") collector = SimpleCollector( - meta_envs=env, + envs=vec_env, algorithm=algorithm ) print(f" Collector: {collector}") + print(f" env_num: {collector.env_num}") + print(f" total_env_num: {collector.get_total_env_num()}") # Test reset print("\n3. Testing reset...") @@ -391,8 +508,9 @@ def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any print("\n4. Testing collect...") stats = collector.collect(n_steps=50) print(f" Collected {stats['total_steps']} steps") - print(f" Episode rewards: {stats['episode_rewards']}") - print(f" Episode lengths: {stats['episode_lengths']}") + print(f" Episodes completed: {len(stats['episode_rewards'])}") + print(f" Episode rewards: {stats['episode_rewards'][:5]}...") # Show first 5 + print(f" Episode lengths: {stats['episode_lengths'][:5]}...") # Show first 5 print(f" Replay buffer size: {len(algorithm.replay)}") # Test with env_type @@ -404,9 +522,10 @@ def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any } multi_algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=multi_replay) - indoor_env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + indoor_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + indoor_vec_env = DummyVectorEnv([indoor_env_fn for _ in range(2)]) indoor_collector = SimpleCollector( - meta_envs=indoor_env, + envs=indoor_vec_env, algorithm=multi_algorithm ) @@ -417,20 +536,29 @@ def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") - # Test with multiple environments - print("\n6. Testing with multiple environments...") - envs = [ - DummyMetaEnv(state_dim=10, action_dim=7, max_steps=15), - DummyMetaEnv(state_dim=10, action_dim=7, max_steps=15), - ] + # Test with multiple environment types (dict) + print("\n6. Testing with multiple environment types (dict)...") + sim_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) + real_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) + + multi_envs = { + 'sim': DummyVectorEnv([sim_env_fn for _ in range(4)]), + 'real': DummyVectorEnv([real_env_fn for _ in range(2)]), + } + multi_env_collector = SimpleCollector( - meta_envs=envs, + envs=multi_envs, algorithm=DummyAlgorithm(meta_policy=meta_policy, replay=SimpleReplay(capacity=1000)) ) + print(f" Multi-env collector: {multi_env_collector}") + print(f" env_types: {multi_env_collector.get_env_types()}") + print(f" total_env_num: {multi_env_collector.get_total_env_num()}") + print(f" sim env_num: {len(multi_env_collector.get_env('sim'))}") + print(f" real env_num: {len(multi_env_collector.get_env('real'))}") + multi_env_collector.reset() - multi_env_stats = multi_env_collector.collect(n_steps=50) - print(f" Collected from {len(envs)} environments") - print(f" Total steps: {multi_env_stats['total_steps']}") + multi_env_stats = multi_env_collector.collect(n_steps=1000) + print(f" Collected {multi_env_stats['total_steps']} steps") print(f" Episodes completed: {len(multi_env_stats['episode_rewards'])}") print("\n" + "=" * 60) diff --git a/rl/envs/__init__.py b/rl/envs/__init__.py new file mode 100644 index 00000000..6d8cacbd --- /dev/null +++ b/rl/envs/__init__.py @@ -0,0 +1,19 @@ +""" +RL Environments Module + +Provides protocols and utilities for vectorized environments. +""" + +from .protocols import VectorEnvProtocol, VectorEnv, EnvsType +from .utils import make_vector_env, get_env_info + +__all__ = [ + # Protocols + 'VectorEnvProtocol', + 'VectorEnv', + 'EnvsType', + # Utilities + 'make_vector_env', + 'get_env_info', +] + diff --git a/rl/envs/protocols.py b/rl/envs/protocols.py new file mode 100644 index 00000000..35ddedc1 --- /dev/null +++ b/rl/envs/protocols.py @@ -0,0 +1,77 @@ +""" +Vector Environment Protocols + +Defines the protocol (interface) for vectorized environments. +Supports SequentialVectorEnv, SubprocVectorEnv, and custom implementations. +""" + +from typing import Protocol, runtime_checkable, Any, Optional, Union, List, Dict +import numpy as np + + +@runtime_checkable +class VectorEnvProtocol(Protocol): + """ + Vectorized environment protocol. + + Any class implementing this protocol can be used as a vectorized environment + in the RL framework. This includes: + - benchmark.utils.SequentialVectorEnv + - tianshou.env.SubprocVectorEnv + - tianshou.env.DummyVectorEnv + - tianshou.env.ShmemVectorEnv + - Any custom implementation satisfying this interface + + Attributes: + env_num: Number of parallel environments + """ + env_num: int + + def reset( + self, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> Any: + """ + Reset environment(s). + + Args: + id: Optional environment index(es) to reset. + - None: Reset all environments + - int: Reset single environment + - List[int] or np.ndarray: Reset specific environments + + Returns: + Observations from reset environment(s) + """ + ... + + def step( + self, + action: Any, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> tuple: + """ + Execute action(s) in environment(s). + + Args: + action: Action(s) to execute + id: Optional environment index(es) to step + + Returns: + Tuple of (obs, reward, done, info) + """ + ... + + def close(self) -> None: + """Close all environments and release resources.""" + ... + + def __len__(self) -> int: + """Return number of environments.""" + ... + + +# Type aliases for convenience +VectorEnv = VectorEnvProtocol +EnvsType = Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]] + diff --git a/rl/envs/utils.py b/rl/envs/utils.py new file mode 100644 index 00000000..f3c972fa --- /dev/null +++ b/rl/envs/utils.py @@ -0,0 +1,99 @@ +""" +Environment Utilities + +Provides helper functions for creating and managing vectorized environments. +""" + +from typing import Callable, Any, Optional +from .protocols import VectorEnvProtocol + + +def make_vector_env( + env_fn: Callable[[], Any], + num_envs: int = 1, + vector_type: str = 'sequential', + **kwargs +) -> VectorEnvProtocol: + """ + Create a vectorized environment from an environment factory function. + + This is a convenience function for creating vectorized environments. + For more control, create the vectorized environment directly. + + Args: + env_fn: Factory function that returns a single environment (e.g., MetaEnv). + This function will be called num_envs times to create parallel envs. + num_envs: Number of parallel environments to create. Default is 1. + vector_type: Type of vectorization to use: + - 'sequential': SequentialVectorEnv (no multiprocessing, safe for daemon processes) + - 'subproc': SubprocVectorEnv (multiprocessing, faster for CPU-bound envs) + - 'dummy': DummyVectorEnv (tianshou's sequential implementation) + - 'shmem': ShmemVectorEnv (shared memory, fastest for large observations) + **kwargs: Additional arguments passed to the vector environment constructor + + Returns: + VectorEnvProtocol: A vectorized environment instance + + Examples: + >>> from benchmark.aloha import create_env + >>> + >>> # Create 4 parallel environments using sequential vectorization + >>> vec_env = make_vector_env( + ... env_fn=lambda: create_env(config), + ... num_envs=4, + ... vector_type='sequential' + ... ) + >>> + >>> # Create 8 parallel environments using subprocesses + >>> vec_env = make_vector_env( + ... env_fn=lambda: create_env(config), + ... num_envs=8, + ... vector_type='subproc' + ... ) + + Raises: + ValueError: If vector_type is not recognized + ImportError: If required module for vector_type is not available + """ + env_fns = [env_fn for _ in range(num_envs)] + + if vector_type == 'sequential': + from benchmark.utils import SequentialVectorEnv + return SequentialVectorEnv(env_fns) + + elif vector_type == 'subproc': + from tianshou.env import SubprocVectorEnv + return SubprocVectorEnv(env_fns, **kwargs) + + elif vector_type == 'dummy': + from tianshou.env import DummyVectorEnv + return DummyVectorEnv(env_fns) + + elif vector_type == 'shmem': + from tianshou.env import ShmemVectorEnv + return ShmemVectorEnv(env_fns, **kwargs) + + else: + raise ValueError( + f"Unknown vector_type: {vector_type}. " + f"Supported types: 'sequential', 'subproc', 'dummy', 'shmem'" + ) + + +def get_env_info(envs: VectorEnvProtocol) -> dict: + """ + Get information about a vectorized environment. + + Args: + envs: A vectorized environment + + Returns: + Dictionary containing environment information: + - env_num: Number of parallel environments + - type: Type name of the vectorized environment + """ + return { + 'env_num': envs.env_num if hasattr(envs, 'env_num') else len(envs), + 'type': type(envs).__name__, + } + diff --git a/rl/trainers/base_trainer.py b/rl/trainers/base_trainer.py index a0ef2c31..049ce113 100644 --- a/rl/trainers/base_trainer.py +++ b/rl/trainers/base_trainer.py @@ -8,6 +8,7 @@ - Support single algorithm and multiple algorithms training scenarios - Support custom reward functions (applied during training, not data collection) - Support evaluation during training +- Support vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) """ import numpy as np @@ -16,7 +17,10 @@ from abc import ABC, abstractmethod # Type hints for Meta classes -from benchmark.base import MetaObs, MetaAction, MetaEnv +from benchmark.base import MetaObs, MetaAction + +# Vector environment protocol +from rl.envs import VectorEnvProtocol, EnvsType class BaseTrainer(ABC): @@ -25,13 +29,13 @@ class BaseTrainer(ABC): This class defines the interface for all trainers in the RL framework. Trainers coordinate the training loop by managing: - - Environments (meta_envs) + - Vectorized environments (envs) - Algorithms (with their policies and replay buffers) - Collectors (for data collection) - Reward functions (applied during training) Attributes: - meta_envs: Environment(s) for training + envs: Vectorized environment(s) for training algorithm: Algorithm(s) for training collector: Data collector(s) reward_fn: Optional custom reward function @@ -39,7 +43,7 @@ class BaseTrainer(ABC): def __init__( self, - meta_envs: Union[MetaEnv, List[MetaEnv], Callable, Dict[str, Any]], + envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], algorithm: Union['BaseAlgorithm', List['BaseAlgorithm']], collector: Optional[Union['BaseCollector', List['BaseCollector']]] = None, reward_fn: Optional[Union['BaseReward', Callable]] = None, @@ -49,11 +53,11 @@ def __init__( Initialize the trainer. Args: - meta_envs: Supports multiple formats: - - MetaEnv instance: Single environment - - List[MetaEnv]: Environment list (same type environments) - - Callable: Environment factory function - - Dict[str, Any]: Multi-environment config dict (different env types) + envs: Vectorized environment(s), supports: + - VectorEnvProtocol: Single vectorized environment + (SequentialVectorEnv, SubprocVectorEnv, DummyVectorEnv, etc.) + - Dict[str, VectorEnvProtocol]: Multi-environment dict for different env types + e.g., {'sim': sim_vec_env, 'real': real_vec_env} algorithm: BaseAlgorithm instance or list of BaseAlgorithm (required) - Single algorithm: Single agent training - Algorithm list: Multiple algorithms training independently in same env @@ -68,16 +72,76 @@ def __init__( - Note: Replay buffer stores raw rewards, reward function only applied during training **kwargs: Trainer-specific parameters """ - self.meta_envs = meta_envs + self.envs = envs self.algorithm = algorithm self.reward_fn = reward_fn self._kwargs = kwargs + # Normalize environment storage: always use dict internally + if isinstance(envs, dict): + self._envs_dict: Dict[str, VectorEnvProtocol] = envs + self._is_multi_env = True + else: + self._envs_dict = {'default': envs} + self._is_multi_env = False + # Handle collector initialization self.collector = collector # Note: Actual collector creation should be done in subclasses # since it may require specific collector types + def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + """ + Get the vectorized environment by type. + + Args: + env_type: Environment type identifier. If None, returns 'default' env + or the first env if 'default' doesn't exist. + + Returns: + The vectorized environment + + Raises: + KeyError: If specified env_type is not found + """ + if env_type is not None: + return self._envs_dict[env_type] + + # Return 'default' if exists, otherwise return first env + if 'default' in self._envs_dict: + return self._envs_dict['default'] + return list(self._envs_dict.values())[0] + + def get_total_env_num(self) -> int: + """ + Get total number of parallel environments across all types. + + Returns: + Total count of parallel environments + """ + return sum(len(env) for env in self._envs_dict.values()) + + def get_env_types(self) -> List[str]: + """ + Get all environment type identifiers. + + Returns: + List of environment type strings + """ + return list(self._envs_dict.keys()) + + @property + def env_num(self) -> int: + """ + Number of environments in the default (or first) environment. + + Returns: + Number of parallel environments + """ + if 'default' in self._envs_dict: + return len(self._envs_dict['default']) + return len(list(self._envs_dict.values())[0]) + @abstractmethod def train(self, **kwargs) -> None: """ @@ -211,7 +275,8 @@ def __repr__(self) -> str: self.collector.__class__.__name__ if not isinstance(self.collector, list) else f"List[{len(self.collector)}]" ) reward_info = "None" if self.reward_fn is None else self.reward_fn.__class__.__name__ - return f"{self.__class__.__name__}(algorithm={algo_info}, collector={collector_info}, reward_fn={reward_info})" + env_info = f"{self.get_total_env_num()} envs" if self._is_multi_env else f"{self.env_num} envs" + return f"{self.__class__.__name__}(envs={env_info}, algorithm={algo_info}, collector={collector_info}, reward_fn={reward_info})" if __name__ == '__main__': @@ -219,6 +284,7 @@ def __repr__(self) -> str: Test code for BaseTrainer class. Since BaseTrainer is abstract, we create a simple concrete implementation for testing. + Tests now use vectorized environments (VectorEnvProtocol). """ import sys sys.path.insert(0, '/home/zhang/robot/126/ILStudio') @@ -228,6 +294,7 @@ def __repr__(self) -> str: from rl.base import BaseAlgorithm from rl.rewards.base_reward import BaseReward, IdentityReward, ScaledReward from rl.collectors.base_collector import BaseCollector + from rl.envs import VectorEnvProtocol from dataclasses import asdict # Simple replay buffer for testing @@ -313,6 +380,51 @@ def step(self, action): self.prev_obs = self.obs2meta(obs) return self.prev_obs, reward, done, info + # Simple VectorEnv for testing (similar to SequentialVectorEnv) + class DummyVectorEnv: + """Simple sequential vector environment for testing.""" + def __init__(self, env_fns): + self.envs = [fn() for fn in env_fns] + self.env_num = len(self.envs) + + def reset(self, id=None): + if id is None: + obs_list = [env.reset() for env in self.envs] + return np.array(obs_list, dtype=object) + else: + if np.isscalar(id): + return self.envs[id].reset() + else: + obs_list = [self.envs[i].reset() for i in id] + return np.array(obs_list, dtype=object) + + def step(self, action, id=None): + if id is None: + results = [env.step(act) for env, act in zip(self.envs, action)] + obs = np.array([r[0] for r in results], dtype=object) + rew = np.array([r[1] for r in results]) + done = np.array([r[2] for r in results]) + info = [r[3] for r in results] + return obs, rew, done, info + else: + if np.isscalar(id): + return self.envs[id].step(action) + else: + results = [self.envs[i].step(act) for i, act in zip(id, action)] + obs = np.array([r[0] for r in results], dtype=object) + rew = np.array([r[1] for r in results]) + done = np.array([r[2] for r in results]) + info = [r[3] for r in results] + return obs, rew, done, info + + def close(self): + for env in self.envs: + if hasattr(env, 'close'): + env.close() + + def __len__(self): + return self.env_num + # Simple policy for testing class DummyPolicy: def __init__(self, action_dim=7): @@ -358,78 +470,83 @@ def update(self, batch=None, **kwargs): def select_action(self, obs, **kwargs): return self.meta_policy.select_action(obs, t=self._timestep) - # Simple collector for testing + # Simple collector for testing (works with VectorEnv) class DummyCollector(BaseCollector): - def __init__(self, meta_envs, algorithm, **kwargs): - super().__init__(meta_envs, algorithm, **kwargs) - if isinstance(meta_envs, list): - self.envs = meta_envs - else: - self.envs = [meta_envs] + def __init__(self, envs, algorithm, **kwargs): + super().__init__(envs, algorithm, **kwargs) + self.vec_env = envs self._last_obs = None def reset(self, **kwargs): - self._last_obs = [env.reset() for env in self.envs] + self._last_obs = self.vec_env.reset() def collect(self, n_steps, env_type=None): if self._last_obs is None: self.reset() stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0} - episode_reward = 0.0 - episode_length = 0 + episode_rewards = np.zeros(len(self.vec_env)) + episode_lengths = np.zeros(len(self.vec_env), dtype=int) for step in range(n_steps): - for i, (env, obs) in enumerate(zip(self.envs, self._last_obs)): + # Get actions for all environments + actions = [] + for obs in self._last_obs: action = self.algorithm.select_action(obs) - new_obs, reward, done, info = env.step(action) - - # Record transition + actions.append(action) + + # Step all environments + new_obs, rewards, dones, infos = self.vec_env.step(actions) + + # Record transitions and update stats + for i in range(len(self.vec_env)): kwargs_trans = {'env_type': env_type} if env_type else {} self.algorithm.record_transition( - state=obs, action=action, reward=reward, - next_state=new_obs, done=done, info=info, **kwargs_trans + state=self._last_obs[i], action=actions[i], reward=rewards[i], + next_state=new_obs[i], done=dones[i], info=infos[i], **kwargs_trans ) - episode_reward += reward - episode_length += 1 + episode_rewards[i] += rewards[i] + episode_lengths[i] += 1 stats['total_steps'] += 1 - if done: - stats['episode_rewards'].append(episode_reward) - stats['episode_lengths'].append(episode_length) - episode_reward = 0.0 - episode_length = 0 - new_obs = env.reset() - - self._last_obs[i] = new_obs + if dones[i]: + stats['episode_rewards'].append(episode_rewards[i]) + stats['episode_lengths'].append(episode_lengths[i]) + episode_rewards[i] = 0.0 + episode_lengths[i] = 0 + # Reset this env + new_obs[i] = self.vec_env.reset(id=i) + + self._last_obs = new_obs return stats # Simple trainer implementation for testing class SimpleTrainer(BaseTrainer): - """Simple trainer for testing.""" + """Simple trainer for testing with vectorized environments.""" def __init__( self, - meta_envs, + envs, algorithm, collector=None, reward_fn=None, **kwargs ): - super().__init__(meta_envs, algorithm, collector, reward_fn, **kwargs) + super().__init__(envs, algorithm, collector, reward_fn, **kwargs) # Create default collector if not provided if self.collector is None: + default_env = self.get_env() if isinstance(algorithm, list): self.collector = [ - DummyCollector(meta_envs=meta_envs, algorithm=alg) + DummyCollector(envs=default_env, algorithm=alg) for alg in algorithm ] else: self.collector = DummyCollector( - meta_envs=meta_envs, + envs=default_env, algorithm=algorithm ) @@ -469,28 +586,22 @@ def train(self, **kwargs): def evaluate(self, n_episodes=10, render=False, env_type=None, **kwargs): print(f"Evaluating for {n_episodes} episodes...") - # Create eval env - if isinstance(self.meta_envs, dict): - env = list(self.meta_envs.values())[0] if env_type is None else self.meta_envs[env_type] - elif isinstance(self.meta_envs, list): - env = self.meta_envs[0] - else: - env = self.meta_envs - + # Get eval env + vec_env = self.get_env(env_type) if env_type else self.get_env() alg = self.algorithm[0] if isinstance(self.algorithm, list) else self.algorithm episode_rewards = [] episode_lengths = [] for ep in range(n_episodes): - obs = env.reset() + obs = vec_env.reset(id=0) # Reset first env only for single episode eval episode_reward = 0.0 episode_length = 0 done = False while not done: action = alg.select_action(obs) - obs, reward, done, info = env.step(action) + obs, reward, done, info = vec_env.step(action, id=0) episode_reward += reward episode_length += 1 @@ -515,24 +626,31 @@ def load(self, path): # Test the implementation print("=" * 60) - print("Testing BaseTrainer (SimpleTrainer implementation)") + print("Testing BaseTrainer with Vectorized Environments") print("=" * 60) # Create components print("\n1. Creating components...") - env = DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) + # Create vectorized environment with 4 parallel envs + env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] + vec_env = DummyVectorEnv(env_fns) + print(f" Created DummyVectorEnv with {vec_env.env_num} parallel environments") + meta_policy = DummyMetaPolicy(action_dim=7) replay = SimpleReplay(capacity=10000) algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) - # Test 1: Basic trainer - print("\n2. Testing basic trainer...") + # Test 1: Basic trainer with vectorized env + print("\n2. Testing basic trainer with vectorized env...") trainer = SimpleTrainer( - meta_envs=env, + envs=vec_env, algorithm=algorithm, reward_fn=None ) print(f" Trainer: {trainer}") + print(f" env_num: {trainer.env_num}") + print(f" total_env_num: {trainer.get_total_env_num()}") + print(f" env_types: {trainer.get_env_types()}") # Test compute_reward without reward_fn print("\n3. Testing compute_reward without custom reward_fn...") @@ -553,33 +671,54 @@ def load(self, path): assert abs(computed_reward - expected) < 1e-6 # Test training - print("\n5. Testing training...") - trainer.train(total_steps=200, log_interval=50, update_interval=20, batch_size=16) + print("\n5. Testing training with vectorized env...") + trainer.train(total_steps=400, log_interval=100, update_interval=20, batch_size=16) print(f" Replay buffer size: {len(algorithm.replay)}") # Test evaluation print("\n6. Testing evaluation...") eval_results = trainer.evaluate(n_episodes=3) - print(f" Evaluation results: {eval_results}") + print(f" Evaluation results: mean_reward={eval_results['mean_reward']:.2f}") # Test collect_rollout print("\n7. Testing collect_rollout...") rollout_stats = trainer.collect_rollout(n_steps=50) - print(f" Rollout stats: {rollout_stats}") + print(f" Rollout stats: total_steps={rollout_stats['total_steps']}") + + # Test with multiple environment types (dict) + print("\n8. Testing with multiple environment types (dict)...") + sim_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] + real_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) for _ in range(2)] + + multi_envs = { + 'sim': DummyVectorEnv(sim_env_fns), + 'real': DummyVectorEnv(real_env_fns), + } + + multi_env_trainer = SimpleTrainer( + envs=multi_envs, + algorithm=DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)), + reward_fn=IdentityReward() + ) + print(f" Multi-env trainer: {multi_env_trainer}") + print(f" env_types: {multi_env_trainer.get_env_types()}") + print(f" total_env_num: {multi_env_trainer.get_total_env_num()}") + print(f" sim env_num: {len(multi_env_trainer.get_env('sim'))}") + print(f" real env_num: {len(multi_env_trainer.get_env('real'))}") # Test with multiple algorithms - print("\n8. Testing with multiple algorithms...") + print("\n9. Testing with multiple algorithms...") algorithms = [ DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)), DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)) ] - multi_trainer = SimpleTrainer( - meta_envs=env, + multi_algo_trainer = SimpleTrainer( + envs=vec_env, algorithm=algorithms, reward_fn=IdentityReward() ) - print(f" Multi-algorithm trainer: {multi_trainer}") - multi_trainer.train(total_steps=100, log_interval=50, update_interval=20, batch_size=8) + print(f" Multi-algorithm trainer: {multi_algo_trainer}") + multi_algo_trainer.train(total_steps=100, log_interval=50, update_interval=20, batch_size=8) print("\n" + "=" * 60) print("All tests passed!") From c27b8444809f0037358b9e81caa5dc915da412d0 Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Wed, 4 Feb 2026 00:29:02 +0800 Subject: [PATCH 3/6] [feat] Update rl framework and add an off-policy algo TD3 --- benchmark/base.py | 2 + benchmark/gymnasium/__init__.py | 167 ++++++++ rl/__init__.py | 5 +- rl/algorithms/__init__.py | 16 + rl/{ => algorithms}/base.py | 215 +++++++--- rl/algorithms/td3/__init__.py | 7 + rl/algorithms/td3/td3.py | 303 +++++++++++++ rl/buffer/__init__.py | 5 + rl/buffer/base_replay.py | 431 +++++++++++++++---- rl/buffer/meta_replay.py | 726 ++++++++++++++++++++++++++++++++ rl/buffer/transition.py | 38 ++ rl/collectors/__init__.py | 3 +- rl/collectors/base_collector.py | 658 ++++++++++++----------------- rl/trainers/base_trainer.py | 602 +++++++++++++------------- rl/utils/__init__.py | 3 + rl/utils/action_utils.py | 80 ++++ utils/__init__.py | 23 +- 17 files changed, 2430 insertions(+), 854 deletions(-) create mode 100644 benchmark/gymnasium/__init__.py rename rl/{ => algorithms}/base.py (68%) create mode 100644 rl/algorithms/td3/__init__.py create mode 100644 rl/algorithms/td3/td3.py create mode 100644 rl/buffer/meta_replay.py create mode 100644 rl/buffer/transition.py create mode 100644 rl/utils/action_utils.py diff --git a/benchmark/base.py b/benchmark/base.py index b650aaca..6333512f 100644 --- a/benchmark/base.py +++ b/benchmark/base.py @@ -218,6 +218,8 @@ def normed_mobs_to_samples(self, normed_mobs: MetaObs): sample['image'] = normed_mobs.image # (cameras, C, H, W) else: sample['image'] = normed_mobs.image + else: + sample['image'] = None # Explicitly set to None when no image # State: use specified control space (already normalized) sample['state'] = normed_mobs.state[i] if len(normed_mobs.state.shape) > 1 else normed_mobs.state diff --git a/benchmark/gymnasium/__init__.py b/benchmark/gymnasium/__init__.py new file mode 100644 index 00000000..e4b9b26b --- /dev/null +++ b/benchmark/gymnasium/__init__.py @@ -0,0 +1,167 @@ +""" +Gymnasium Environment Adapter + +Provides a MetaEnv wrapper for standard Gymnasium environments, +including MuJoCo environments like HalfCheetah, Ant, Humanoid, etc. + +Usage: + # Via config: + python train_rl.py -a td3 -e gymnasium/halfcheetah + + # Programmatically: + from benchmark.gymnasium import create_env + env = create_env(config) +""" + +import numpy as np +import gymnasium as gym +from ..base import MetaEnv, MetaObs, MetaAction + + +def create_env(config): + """Factory function for creating GymnasiumEnv.""" + return GymnasiumEnv(config) + + +class GymnasiumEnv(MetaEnv): + """ + MetaEnv wrapper for standard Gymnasium environments. + + Supports continuous control environments like: + - MuJoCo: HalfCheetah-v5, Ant-v5, Humanoid-v5, Walker2d-v5, Hopper-v5, etc. + - Classic Control: Pendulum-v1, MountainCarContinuous-v0 + - Box2D: LunarLanderContinuous-v3, BipedalWalker-v3 + + Config attributes: + task (str): Gymnasium environment ID (e.g., 'HalfCheetah-v5') + render_mode (str): Render mode ('human', 'rgb_array', None). Default: None + use_camera (bool): Whether to capture images. Default: False + max_episode_steps (int): Override max steps. Default: use env default + raw_lang (str): Language instruction. Default: task name + """ + + def __init__(self, config, *args): + self.config = config + self.task = config.task + self.render_mode = getattr(config, 'render_mode', None) + self.use_camera = getattr(config, 'use_camera', False) + self.max_episode_steps = getattr(config, 'max_episode_steps', None) + self.raw_lang = getattr(config, 'raw_lang', self.task) + + # Control space info (most MuJoCo envs use normalized actions) + self.ctrl_space = getattr(config, 'ctrl_space', 'joint') + self.ctrl_type = getattr(config, 'ctrl_type', 'abs') + + env = self._create_env() + super().__init__(env) + + # Store action space bounds for reference + self.action_low = self.env.action_space.low + self.action_high = self.env.action_space.high + self.action_dim = self.env.action_space.shape[0] + self.state_dim = self.env.observation_space.shape[0] + self.debug_step=0 + + def _create_env(self): + """Create the underlying Gymnasium environment.""" + kwargs = {} + if self.render_mode: + kwargs['render_mode'] = self.render_mode + if self.max_episode_steps: + kwargs['max_episode_steps'] = self.max_episode_steps + + env = gym.make(self.task, **kwargs) + return env + + @property + def action_space(self): + """Return the action space of the environment.""" + return self.env.action_space + + @property + def observation_space(self): + """Return the observation space of the environment.""" + return self.env.observation_space + + def meta2act(self, maction: MetaAction): + """Convert MetaAction to Gymnasium action.""" + actions = maction['action'] # (action_dim, ) + return actions + + def obs2meta(self, obs) -> MetaObs: + """Convert Gymnasium observation to MetaObs.""" + # Handle different observation types + if isinstance(obs, dict): + # For goal-conditioned environments + state = obs.get('observation', obs.get('state', None)) + if state is None: + state = np.concatenate([v.flatten() for v in obs.values()]) + else: + state = obs + + state = np.asarray(state, dtype=np.float32) + + # Capture image if requested + image = None + if self.use_camera and self.render_mode == 'rgb_array': + image = self.env.render() + if image is not None: + # Convert to (N, C, H, W) format + if len(image.shape) == 3: # (H, W, C) + image = image[np.newaxis, ...] # -> (1, H, W, C) + if image.shape[-1] == 3: # (N, H, W, C) -> (N, C, H, W) + image = image.transpose(0, 3, 1, 2) + + return MetaObs(state=state, image=image, raw_lang=self.raw_lang) + + def step(self, maction): + """Execute action and return (obs, reward, done, info).""" + action = self.meta2act(maction) + observation, reward, terminated, truncated, info = self.env.step(action) + + obs = self.obs2meta(observation) + + # Combine terminated and truncated into done + # Store both for proper handling in RL algorithms + done = terminated or truncated + info['terminated'] = terminated + info['truncated'] = truncated + info['TimeLimit.truncated'] = truncated and not terminated + self.debug_step += 1 + + return obs, reward, done, info + + def reset(self, **kwargs): + """Reset environment and return initial observation.""" + obs, info = self.env.reset(**kwargs) + self.debug_step=0 + return self.obs2meta(obs) + + def render(self): + """Render the environment.""" + return self.env.render() + + def close(self): + """Close the environment.""" + self.env.close() + + +# Common MuJoCo environment configurations +MUJOCO_ENVS = { + 'halfcheetah': 'HalfCheetah-v5', + 'ant': 'Ant-v5', + 'humanoid': 'Humanoid-v5', + 'walker2d': 'Walker2d-v5', + 'hopper': 'Hopper-v5', + 'swimmer': 'Swimmer-v5', + 'reacher': 'Reacher-v5', + 'pusher': 'Pusher-v5', + 'inverted_pendulum': 'InvertedPendulum-v5', + 'inverted_double_pendulum': 'InvertedDoublePendulum-v5', +} + + +def get_env_id(name: str) -> str: + """Get full environment ID from short name.""" + return MUJOCO_ENVS.get(name.lower(), name) + diff --git a/rl/__init__.py b/rl/__init__.py index 7ffb8010..bc650612 100644 --- a/rl/__init__.py +++ b/rl/__init__.py @@ -21,8 +21,9 @@ Directory Structure: rl/ ├── __init__.py # This file - ├── base.py # BaseAlgorithm class ├── algorithms/ # RL algorithm implementations + │ ├── __init__.py # Algorithm registry + │ └── base.py # BaseAlgorithm class │ └── __init__.py # Algorithm registry ├── buffer/ # Replay buffer implementations │ ├── __init__.py @@ -48,7 +49,7 @@ """ # Base classes -from .base import BaseAlgorithm +from .algorithms.base import BaseAlgorithm from .buffer import BaseReplay from .rewards import BaseReward from .collectors import BaseCollector diff --git a/rl/algorithms/__init__.py b/rl/algorithms/__init__.py index 0c7d3b3e..6c83330b 100644 --- a/rl/algorithms/__init__.py +++ b/rl/algorithms/__init__.py @@ -77,3 +77,19 @@ def list_algorithms() -> list: 'list_algorithms', ] + +# Auto-register built-in algorithms when this module is imported +def _register_builtin_algorithms(): + """Import algorithm submodules to trigger their registration.""" + try: + from . import td3 # noqa: F401 + except ImportError: + pass + # Add more algorithms here as they are implemented + # try: + # from . import sac + # except ImportError: + # pass + + +_register_builtin_algorithms() diff --git a/rl/base.py b/rl/algorithms/base.py similarity index 68% rename from rl/base.py rename to rl/algorithms/base.py index 8c9997dd..2a4bcb68 100644 --- a/rl/base.py +++ b/rl/algorithms/base.py @@ -14,10 +14,11 @@ import numpy as np from typing import Dict, Any, Optional, Union, List, Callable from abc import ABC, abstractmethod +from dataclasses import asdict, fields # Type hints for Meta classes (imported at runtime to avoid circular imports) from benchmark.base import MetaObs, MetaAction, MetaPolicy - +from rl.buffer.transition import RLTransition class BaseAlgorithm(ABC): """ @@ -100,76 +101,151 @@ def compute_loss(self, batch: Dict[str, Any]) -> torch.Tensor: """ raise NotImplementedError("Subclass should implement compute_loss if needed") - def select_action(self, obs: MetaObs, **kwargs) -> MetaAction: + def select_action( + self, + obs: Union[MetaObs, List[MetaObs], np.ndarray], + **kwargs + ) -> Union[MetaAction, List[MetaAction]]: """ - Select action (optional, some algorithms may need this). + Select action(s) for given observation(s). + + Supports both single and batched inputs for vectorized environments. + The recommended way for vectorized envs is to use organized batch obs + (a single MetaObs with batched fields like (n_envs, state_dim)). Args: - obs: MetaObs format observation - **kwargs: Other parameters (e.g., exploration settings) + obs: Observation(s), can be: + - MetaObs: Single observation OR organized batch obs + (with batched fields like state: (n_envs, state_dim)) + - List[MetaObs] or np.ndarray: List of individual observations + (will be organized into batch internally) + **kwargs: Other parameters (e.g., t for timestep) Returns: - MetaAction format action + MetaAction: Single action or batched action + (with action field shape (n_envs, action_dim) for batch) + + Note: + For vectorized environments, it's recommended to: + 1. Use organize_obs() to convert List[MetaObs] -> batched MetaObs + 2. Pass the organized obs to this method + 3. The policy handles batched inference internally """ - # Default implementation uses meta_policy's select_action + # Check if input is a list of observations that needs organizing + if isinstance(obs, (list, np.ndarray)) and len(obs) > 0: + first = obs[0] if isinstance(obs, list) else obs.flat[0] + if hasattr(first, '__dataclass_fields__'): # List of MetaObs + # Organize into batched MetaObs + obs = self._organize_obs(obs) + + # Pass to meta_policy (handles both single and batch) return self.meta_policy.select_action(obs, **kwargs) + def _organize_obs(self, obs_list: Union[List[MetaObs], np.ndarray]) -> MetaObs: + """ + Organize list of MetaObs into a single batched MetaObs. + + Converts List[MetaObs] -> MetaObs with batched fields. + e.g., [MetaObs(state=(10,)), MetaObs(state=(10,))] -> MetaObs(state=(2, 10)) + + Args: + obs_list: List or array of MetaObs objects + + Returns: + MetaObs with batched fields + """ + if len(obs_list) == 0: + return None + + # Convert to list of dicts + obs_dicts = [] + for o in obs_list: + if hasattr(o, '__dataclass_fields__'): + obs_dicts.append(asdict(o)) + elif isinstance(o, dict): + obs_dicts.append(o) + else: + obs_dicts.append(vars(o) if hasattr(o, '__dict__') else {}) + + # Stack each field + all_keys = list(obs_dicts[0].keys()) + batched = {} + for k in all_keys: + values = [d[k] for d in obs_dicts] + if values[0] is None: + batched[k] = None + elif isinstance(values[0], np.ndarray): + batched[k] = np.stack(values) + elif isinstance(values[0], (int, float)): + batched[k] = np.array(values) + elif isinstance(values[0], str): + batched[k] = values # Keep as list for strings + else: + batched[k] = values # Keep as list for other types + + # Convert back to MetaObs + return MetaObs(**{k: v for k, v in batched.items() if k in [f.name for f in fields(MetaObs)]}) + def record_transition( self, - state: MetaObs, - action: MetaAction, - reward: float, - next_state: MetaObs, - done: bool, - info: Optional[Dict[str, Any]] = None, + transition: 'RLTransition', **kwargs ) -> None: """ Record transition to replay buffer (if exists). - - Supports storing complete MetaObs and MetaAction information, plus additional custom fields. - If using multiple replay buffers (by environment type), selects corresponding replay - based on env_type in kwargs. + + Args: + transition: RLTransition created by collector + **kwargs: Additional fields: + - env_type: Environment type (required for multi-buffer) + """ + if self.replay is None: + raise ValueError("Replay buffer is not set") + + from rl.buffer.transition import RLTransition + if not isinstance(transition, RLTransition): + raise TypeError("record_transition expects RLTransition") + + env_type = kwargs.get('env_type', None) + if isinstance(self.replay, dict): + if env_type is None: + raise ValueError("env_type must be provided when using multiple replay buffers") + if env_type not in self.replay: + raise ValueError(f"env_type '{env_type}' not found in replay buffers") + target_replay = self.replay[env_type] + else: + target_replay = self.replay + + target_replay.add(transition) + + def _stack_dicts(self, dict_list: List[Dict]) -> Dict: + """ + Stack a list of dicts into a single dict with batched values. Args: - state: Current state (MetaObs, including all fields) - action: Action (MetaAction, including all fields) - reward: Reward - next_state: Next state (MetaObs, including all fields) - done: Whether episode ended - info: Additional information dictionary - **kwargs: Other custom fields, can store any additional information - - env_type: Environment type identifier (if replay is Dict[str, BaseReplay]) - - e.g., value, log_prob, advantage, trajectory_id, etc. + dict_list: List of dicts with same keys + + Returns: + Dict with stacked values (numpy arrays where applicable) """ - if self.replay is not None: - from dataclasses import asdict - - # Convert MetaObs and MetaAction to dict if they are dataclass instances - state_dict = asdict(state) if hasattr(state, '__dataclass_fields__') else state - action_dict = asdict(action) if hasattr(action, '__dataclass_fields__') else action - next_state_dict = asdict(next_state) if hasattr(next_state, '__dataclass_fields__') else next_state - - transition = { - 'state': state_dict, - 'action': action_dict, - 'reward': reward, - 'next_state': next_state_dict, - 'done': done, - 'info': info or {}, - **kwargs - } - - env_type = kwargs.get('env_type', None) - - if isinstance(self.replay, dict): - if env_type is None: - raise ValueError("env_type must be provided when using multiple replay buffers") - if env_type not in self.replay: - raise ValueError(f"env_type '{env_type}' not found in replay buffers") - self.replay[env_type].add(transition) + if not dict_list: + return {} + + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values # Keep strings as list else: - self.replay.add(transition) + result[key] = values # Keep other types as list + + return result def get_policy(self) -> MetaPolicy: """Get the underlying MetaPolicy.""" @@ -226,6 +302,7 @@ def __repr__(self) -> str: from benchmark.base import MetaObs, MetaAction, MetaPolicy from rl.buffer.base_replay import BaseReplay + from rl.algorithms.base import BaseAlgorithm from dataclasses import asdict # Simple replay buffer for testing @@ -360,15 +437,15 @@ def compute_loss(self, batch): raw_lang="test instruction" ) - algorithm.record_transition( - state=state, + transition = RLTransition( + obs=state, action=action, + next_obs=next_state, reward=np.random.randn(), - next_state=next_state, done=(i == 9), - info={'step': i}, - value=np.random.randn(), # Custom field + info={'step': i, 'value': np.random.randn()}, ) + algorithm.record_transition(transition) print(f" Replay buffer size: {len(algorithm.replay)}") # Update algorithm @@ -390,16 +467,22 @@ def compute_loss(self, batch): action = MetaAction(action=np.random.randn(7).astype(np.float32)) next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) - multi_algorithm.record_transition( - state=state, action=action, reward=1.0, - next_state=next_state, done=False, - env_type='indoor' + transition = RLTransition( + obs=state, + action=action, + next_obs=next_state, + reward=1.0, + done=False, ) - multi_algorithm.record_transition( - state=state, action=action, reward=0.5, - next_state=next_state, done=False, - env_type='outdoor' + multi_algorithm.record_transition(transition, env_type='indoor') + transition = RLTransition( + obs=state, + action=action, + next_obs=next_state, + reward=0.5, + done=False, ) + multi_algorithm.record_transition(transition, env_type='outdoor') print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") diff --git a/rl/algorithms/td3/__init__.py b/rl/algorithms/td3/__init__.py new file mode 100644 index 00000000..c42799bb --- /dev/null +++ b/rl/algorithms/td3/__init__.py @@ -0,0 +1,7 @@ +from .td3 import TD3Algorithm, TD3Config +from .. import register_algorithm + +register_algorithm("td3", TD3Algorithm) + +__all__ = ["TD3Algorithm", "TD3Config"] + diff --git a/rl/algorithms/td3/td3.py b/rl/algorithms/td3/td3.py new file mode 100644 index 00000000..6f10df02 --- /dev/null +++ b/rl/algorithms/td3/td3.py @@ -0,0 +1,303 @@ +""" +TD3 algorithm implementation based on the official reference: +https://arxiv.org/abs/1802.09477 +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from benchmark.base import MetaAction, MetaObs, MetaPolicy +from policy.mlp.mlp import MLPPolicy, MLPPolicyConfig +from rl.utils import polyak_update +from rl.utils.action_utils import ensure_action + +from ..base import BaseAlgorithm + + +@dataclass +class TD3Config: + state_dim: int + action_dim: int + discount: float = 0.99 + tau: float = 0.005 + policy_noise: float = 0.2 + noise_clip: float = 0.5 + policy_freq: int = 2 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + device: str = "cpu" + state_key: str = "state" + next_state_key: str = "next_state" + action_key: str = "action" + + +class _IdentityStateNormalizer: + def normalize_metaobs(self, mobs: MetaObs, ctrl_space: str): + return mobs + + +class _IdentityActionNormalizer: + def denormalize_metaact(self, mact: MetaAction): + return mact + + +class Critic(nn.Module): + def __init__(self, state_dim: int, action_dim: int): + super().__init__() + # Q1 architecture + self.l1 = nn.Linear(state_dim + action_dim, 256) + self.l2 = nn.Linear(256, 256) + self.l3 = nn.Linear(256, 1) + + # Q2 architecture + self.l4 = nn.Linear(state_dim + action_dim, 256) + self.l5 = nn.Linear(256, 256) + self.l6 = nn.Linear(256, 1) + + def forward(self, state: torch.Tensor, action: torch.Tensor): + sa = torch.cat([state, action], dim=1) + + q1 = F.relu(self.l1(sa)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + + q2 = F.relu(self.l4(sa)) + q2 = F.relu(self.l5(q2)) + q2 = self.l6(q2) + return q1, q2 + + def Q1(self, state: torch.Tensor, action: torch.Tensor): + sa = torch.cat([state, action], dim=1) + q1 = F.relu(self.l1(sa)) + q1 = F.relu(self.l2(q1)) + q1 = self.l3(q1) + return q1 + + +class TD3Algorithm(BaseAlgorithm): + """ + TD3 algorithm using MLPPolicy as actor and a twin-critic network. + """ + + def __init__( + self, + replay: Optional[Any], + config: TD3Config, + actor_config: Optional[MLPPolicyConfig] = None, + meta_policy: Optional[MetaPolicy] = None, + ensure_refine_fn=None, + ctrl_space: str = "ee", + ctrl_type: str = "delta", + gripper_continuous: bool = False, + **kwargs, + ): + if actor_config is None: + actor_config = MLPPolicyConfig( + state_dim=config.state_dim, + action_dim=config.action_dim, + chunk_size=1, + ) + else: + actor_config.chunk_size = 1 + + self.device = torch.device(config.device) + self._env = None + self.config = config + self.ensure_refine_fn = ensure_refine_fn + self.ctrl_space = ctrl_space + self.ctrl_type = ctrl_type + self.gripper_continuous = gripper_continuous + + self.actor = MLPPolicy(actor_config).to(self.device) + self.actor_target = copy.deepcopy(self.actor) + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.actor_lr) + + self.critic = Critic(config.state_dim, config.action_dim).to(self.device) + self.critic_target = copy.deepcopy(self.critic) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.critic_lr) + + self.total_it = 0 + + if meta_policy is None: + meta_policy = MetaPolicy( + self.actor, + chunk_size=1, + action_normalizer=_IdentityActionNormalizer(), + state_normalizer=_IdentityStateNormalizer(), + ctrl_space=ctrl_space, + ctrl_type=ctrl_type, + ) + + super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) + + def set_env(self, env: Any) -> None: + """Attach an environment for action post-processing.""" + self._env = env + + def _actor_forward(self, state: torch.Tensor, model: Optional[MLPPolicy] = None) -> torch.Tensor: + model = self.actor if model is None else model + output = model(state) + if isinstance(output, dict): + action = output.get("action") + else: + action = output + if action is None: + raise ValueError("Actor output does not contain 'action'") + if action.dim() == 3: + action = action[:, 0, :] + return action + + def _to_tensor(self, value, dtype=torch.float32) -> torch.Tensor: + if torch.is_tensor(value): + return value.to(device=self.device, dtype=dtype) + return torch.as_tensor(value, device=self.device, dtype=dtype) + + def _extract_batch_action(self, batch: Dict[str, Any]) -> Any: + action = batch.get(self.config.action_key, None) + if action is None: + raise KeyError(f"Missing '{self.config.action_key}' in batch") + if hasattr(action, "action"): + return action.action + if isinstance(action, dict) and "action" in action: + return action["action"] + return action + + def _extract_batch_state(self, batch: Dict[str, Any], key: str) -> Any: + state = batch.get(key, None) + if state is None: + raise KeyError(f"Missing '{key}' in batch") + return state + + def select_action( + self, + obs: Any, + noise_scale: float = 0.0, + env: Optional[Any] = None, + **kwargs, + ) -> MetaAction: + if isinstance(obs, (list, np.ndarray)) and len(obs) > 0: + first = obs[0] if isinstance(obs, list) else obs.flat[0] + if hasattr(first, "__dataclass_fields__"): + obs = self._organize_obs(obs) + + if hasattr(obs, self.config.state_key): + state = getattr(obs, self.config.state_key) + elif hasattr(obs, "state"): + state = obs.state + elif isinstance(obs, dict) and self.config.state_key in obs: + state = obs[self.config.state_key] + else: + state = obs + state_t = self._to_tensor(state) + if state_t.dim() == 1: + state_t = state_t.unsqueeze(0) + + with torch.no_grad(): + action = self._actor_forward(state_t) + if noise_scale > 0: + action = action + noise_scale * torch.randn_like(action) + action = ensure_action(env or self._env, action, refine_fn=self.ensure_refine_fn) + + action_np = action.detach().cpu().numpy() + return MetaAction( + action=action_np, + ctrl_space=self.ctrl_space, + ctrl_type=self.ctrl_type, + gripper_continuous=self.gripper_continuous, + ) + + def update( + self, + batch: Optional[Dict[str, Any]] = None, + env: Optional[Any] = None, + **kwargs, + ) -> Dict[str, Any]: + if batch is None: + if self.replay is None: + raise ValueError("Replay buffer is not set") + batch_size = kwargs.get("batch_size", 256) + if hasattr(self.replay, "sample_for_training"): + batch = self.replay.sample_for_training(batch_size) + else: + raise ValueError("Replay buffer does not support sample_for_training") + batch = self.replay.sample(batch_size) + + state = self._to_tensor(self._extract_batch_state(batch, self.config.state_key)) + next_state = self._to_tensor(self._extract_batch_state(batch, self.config.next_state_key)) + action = self._to_tensor(self._extract_batch_action(batch)) + reward = self._to_tensor(batch.get("reward"), dtype=torch.float32).unsqueeze(-1) + done = self._to_tensor(batch.get("done"), dtype=torch.float32).unsqueeze(-1) + truncated_raw = batch.get("truncated") + if truncated_raw is None: + truncated_raw = np.zeros_like(batch.get("done")) + truncated = self._to_tensor(truncated_raw, dtype=torch.float32).unsqueeze(-1) + + terminal = done * (1.0 - truncated) + not_done = 1.0 - terminal + + self.total_it += 1 + + with torch.no_grad(): + noise = ( + torch.randn_like(action) * self.config.policy_noise + ).clamp(-self.config.noise_clip, self.config.noise_clip) + next_action = self._actor_forward(next_state, model=self.actor_target) + next_action = next_action + noise + next_action = ensure_action(env or self._env, next_action, refine_fn=self.ensure_refine_fn) + target_q1, target_q2 = self.critic_target(next_state, next_action) + target_q = torch.min(target_q1, target_q2) + target_q = reward + not_done * self.config.discount * target_q + + current_q1, current_q2 = self.critic(state, action) + critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + actor_loss = None + if self.total_it % self.config.policy_freq == 0: + actor_action = self._actor_forward(state) + actor_action = ensure_action(env or self._env, actor_action, refine_fn=self.ensure_refine_fn) + actor_loss = -self.critic.Q1(state, actor_action).mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.config.tau) + polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.config.tau) + + return { + "critic_loss": critic_loss.item(), + "actor_loss": None if actor_loss is None else actor_loss.item(), + "update_step": self.total_it, + } + + def save(self, path: str, **kwargs) -> None: + payload = { + "actor": self.actor.state_dict(), + "critic": self.critic.state_dict(), + "actor_optimizer": self.actor_optimizer.state_dict(), + "critic_optimizer": self.critic_optimizer.state_dict(), + "config": self.config, + } + torch.save(payload, path) + + def load(self, path: str, **kwargs) -> None: + payload = torch.load(path, map_location=self.device) + self.actor.load_state_dict(payload["actor"]) + self.critic.load_state_dict(payload["critic"]) + self.actor_optimizer.load_state_dict(payload["actor_optimizer"]) + self.critic_optimizer.load_state_dict(payload["critic_optimizer"]) + self.actor_target = copy.deepcopy(self.actor) + self.critic_target = copy.deepcopy(self.critic) + diff --git a/rl/buffer/__init__.py b/rl/buffer/__init__.py index 8a2bf6e1..d23f69b9 100644 --- a/rl/buffer/__init__.py +++ b/rl/buffer/__init__.py @@ -5,11 +5,16 @@ Classes: BaseReplay: Base class for all replay buffers + MetaReplay: Replay buffer for MetaObs and MetaAction fields """ from .base_replay import BaseReplay +from .meta_replay import MetaReplay +from .transition import RLTransition __all__ = [ 'BaseReplay', + 'MetaReplay', + 'RLTransition', ] diff --git a/rl/buffer/base_replay.py b/rl/buffer/base_replay.py index 40ddcf88..8a1beb13 100644 --- a/rl/buffer/base_replay.py +++ b/rl/buffer/base_replay.py @@ -9,11 +9,16 @@ - Extensibility: Support storing additional custom fields (value, log_prob, advantage, trajectory_id, etc.) - Conversion during sampling: Convert to ILStudio data pipeline format (normalization, etc.) during sampling - Compatibility: Maintain data integrity while being compatible with ILStudio's normalization pipeline +- Vectorized environment support: Each buffer corresponds to one environment type with multiple parallel envs + - Storage shape: (capacity, n_envs, ...) where capacity is number of time steps + - Each add() stores data from all parallel envs at one time step + - Sample supports selecting specific env indices """ import torch import numpy as np -from typing import Dict, Any, Optional, Union, Callable +import pickle +from typing import Dict, Any, Optional, Union, Callable, List, Set from abc import ABC, abstractmethod @@ -24,68 +29,120 @@ class BaseReplay(ABC): This class defines the interface for all replay buffers in the RL framework. Replay buffers store experience data (transitions) for training RL algorithms. + Supports vectorized environments where one buffer stores data from multiple parallel + environments of the same type. Storage shape is (capacity, n_envs, ...). + Attributes: - capacity: Maximum number of transitions to store + capacity: Maximum number of time steps to store (not total transitions) device: Device to store data on ('cpu' or 'cuda') + env_type: Environment type identifier for this buffer + n_envs: Number of parallel environments (default 1 for non-vectorized) """ def __init__( self, capacity: int = 1000000, device: Union[str, torch.device] = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, **kwargs ): """ Initialize the Replay Buffer. Args: - capacity: Buffer capacity (maximum number of transitions to store) + capacity: Buffer capacity (maximum number of time steps to store) + - Total transitions stored = capacity * n_envs device: Data storage device ('cpu' or 'cuda', default 'cpu') - 'cpu': Store data in CPU memory - 'cuda' or 'cuda:0': Store data in GPU memory + env_type: Environment type identifier (e.g., 'sim', 'real', 'indoor', 'outdoor') + - Used to distinguish buffers for different environment types + - If None, defaults to 'default' + n_envs: Number of parallel environments for this buffer + - Storage shape will be (capacity, n_envs, ...) + - Default 1 for non-vectorized single environment **kwargs: Other initialization parameters """ self.capacity = capacity self.device = torch.device(device) if isinstance(device, str) else device + self.env_type = env_type if env_type is not None else 'default' + self.n_envs = n_envs self._size = 0 self._position = 0 @abstractmethod def add(self, transition: Dict[str, Any]) -> None: """ - Add a transition to the buffer. + Add a transition (one time step from all parallel envs) to the buffer. Stores raw Meta data (MetaObs, MetaAction) without any normalization. Supports storing all fields of MetaObs and MetaAction, plus additional custom information. + For vectorized environments (n_envs > 1): + - Each field should have shape (n_envs, ...) + - e.g., state: (n_envs, state_dim), action: (n_envs, action_dim) + - reward: (n_envs,), done: (n_envs,), truncated: (n_envs,) + + For single environment (n_envs = 1): + - Fields can be either (1, ...) or (...) shape + - Will be automatically expanded to (1, ...) if needed + + Note on done vs truncated (Gymnasium API): + - done (terminated): True if episode ended naturally (goal reached, failure, etc.) + - truncated: True if episode was cut off due to time limit or other external reasons + - For bootstrap value calculation: + - If done=True: V(next_state) = 0 (true terminal state) + - If truncated=True: V(next_state) should be bootstrapped (not a true terminal) + Args: transition: Dictionary containing the following fields: - - state: MetaObs format current state (raw data, including all fields) - - action: MetaAction format action (raw data, including all fields) - - reward: float reward - - next_state: MetaObs format next state (raw data) - - done: bool whether episode ended - - info: Optional, additional information dictionary + - state: Current state, shape (n_envs, state_dim) or dict with MetaObs fields + - action: Action, shape (n_envs, action_dim) or dict with MetaAction fields + - reward: Reward, shape (n_envs,) + - next_state: Next state, shape (n_envs, state_dim) or dict with MetaObs fields + - done: Terminated flag, shape (n_envs,) - True if episode ended naturally + - truncated: Truncated flag, shape (n_envs,) - True if episode was cut off + - info: Optional, additional information (list of dicts, one per env) - **other custom fields**: Can store any additional information """ raise NotImplementedError @abstractmethod - def sample(self, batch_size: int) -> Dict[str, Any]: + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None + ) -> Dict[str, Any]: """ Sample a batch from the buffer (raw data). + Sampling is done in two steps: + 1. Select which parallel environments to sample from (env_indices) + 2. Sample batch_size transitions from (time_idx, env_idx) combinations + Args: - batch_size: Batch size + batch_size: Number of transitions to sample + env_indices: Optional, indices of parallel environments to sample from + - If None, sample from all environments uniformly + - e.g., [0, 2] means only sample from env 0 and env 2 + keys: Optional set of keys to sample. If None, uses default keys + (typically state, action, next_state, reward, done, truncated). + Subclasses may define their own default keys and available keys. Returns: Dictionary containing raw Meta data (without normalization) + - All fields have shape (batch_size, ...) with env dimension flattened + - e.g., state: (batch_size, state_dim), action: (batch_size, action_dim) """ raise NotImplementedError def sample_for_training( self, batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, data_processor: Optional[Callable] = None ) -> Dict[str, Any]: """ @@ -93,6 +150,8 @@ def sample_for_training( Args: batch_size: Batch size + env_indices: Optional, indices of parallel environments to sample from + keys: Optional set of keys to sample data_processor: Optional data processing function to align with ILStudio pipeline - If None, return raw data - If provided, should be a function: batch -> processed_batch @@ -100,15 +159,33 @@ def sample_for_training( Returns: Processed batch data (conforming to ILStudio training format) """ - batch = self.sample(batch_size) + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) if data_processor is not None: batch = data_processor(batch) return batch def __len__(self) -> int: - """Return current buffer size.""" + """Return current buffer size (number of time steps stored).""" return self._size + @property + def total_transitions(self) -> int: + """ + Total number of transitions stored (size * n_envs). + + Returns: + Total transition count across all parallel environments + """ + return self._size * self.n_envs + + def get_env_type(self) -> str: + """Get the environment type identifier.""" + return self.env_type + + def get_n_envs(self) -> int: + """Get the number of parallel environments.""" + return self.n_envs + @abstractmethod def clear(self) -> None: """Clear the buffer.""" @@ -141,20 +218,29 @@ def load(self, path: str, **kwargs) -> None: raise NotImplementedError def is_full(self) -> bool: - """Check if buffer is full.""" + """Check if buffer is full (time steps, not total transitions).""" return self._size >= self.capacity - def get_all(self) -> Dict[str, Any]: + def get_all( + self, + env_indices: Optional[Union[List[int], np.ndarray]] = None + ) -> Dict[str, Any]: """ Get all data in the buffer. + Args: + env_indices: Optional, indices of parallel environments to get data from + - If None, get data from all environments + Returns: Dictionary containing all stored data """ - return self.sample(self._size) if self._size > 0 else {} + return self.sample(self.total_transitions, env_indices=env_indices) if self._size > 0 else {} def __repr__(self) -> str: - return f"{self.__class__.__name__}(capacity={self.capacity}, size={self._size}, device={self.device})" + return (f"{self.__class__.__name__}(capacity={self.capacity}, size={self._size}, " + f"n_envs={self.n_envs}, total_transitions={self.total_transitions}, " + f"env_type='{self.env_type}', device={self.device})") if __name__ == '__main__': @@ -162,6 +248,7 @@ def __repr__(self) -> str: Test code for BaseReplay class. Since BaseReplay is abstract, we create a simple concrete implementation for testing. + Tests include both single-env and vectorized (multi-env) scenarios. """ import sys sys.path.insert(0, '/home/zhang/robot/126/ILStudio') @@ -169,134 +256,296 @@ def __repr__(self) -> str: from benchmark.base import MetaObs, MetaAction from dataclasses import asdict - # Simple concrete implementation for testing + # Simple concrete implementation for testing (supports vectorized envs) class SimpleReplay(BaseReplay): - """Simple in-memory replay buffer for testing.""" + """Simple in-memory replay buffer for testing with vectorized env support.""" - def __init__(self, capacity: int = 1000, device: str = 'cpu', **kwargs): - super().__init__(capacity=capacity, device=device, **kwargs) - self._storage = [] + def __init__( + self, + capacity: int = 1000, + device: str = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, + state_dim: int = 10, + action_dim: int = 7, + **kwargs + ): + super().__init__( + capacity=capacity, + device=device, + env_type=env_type, + n_envs=n_envs, + **kwargs + ) + self.state_dim = state_dim + self.action_dim = action_dim + + # Pre-allocate storage: (capacity, n_envs, dim) + self._state = np.zeros((capacity, n_envs, state_dim), dtype=np.float32) + self._action = np.zeros((capacity, n_envs, action_dim), dtype=np.float32) + self._reward = np.zeros((capacity, n_envs), dtype=np.float32) + self._next_state = np.zeros((capacity, n_envs, state_dim), dtype=np.float32) + self._done = np.zeros((capacity, n_envs), dtype=np.bool_) + self._truncated = np.zeros((capacity, n_envs), dtype=np.bool_) def add(self, transition: Dict[str, Any]) -> None: - if self._size < self.capacity: - self._storage.append(transition) - self._size += 1 - else: - self._storage[self._position] = transition + """Add transition from all parallel envs at one time step.""" + idx = self._position + + # Store data (expecting shape (n_envs, dim)) + self._state[idx] = transition['state'] + self._action[idx] = transition['action'] + self._reward[idx] = transition['reward'] + self._next_state[idx] = transition['next_state'] + self._done[idx] = transition['done'] + self._truncated[idx] = transition.get('truncated', np.zeros(self.n_envs, dtype=bool)) + self._position = (self._position + 1) % self.capacity + self._size = min(self._size + 1, self.capacity) + + # Default sample keys + DEFAULT_SAMPLE_KEYS = {'state', 'action', 'next_state', 'reward', 'done', 'truncated'} - def sample(self, batch_size: int) -> Dict[str, Any]: + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Sample batch_size transitions from buffer. + + Args: + batch_size: Number of transitions to sample + env_indices: Optional, indices of parallel environments to sample from + keys: Optional set of keys to sample. If None, uses DEFAULT_SAMPLE_KEYS + """ if self._size == 0: return {} - indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) - batch = { - 'states': [self._storage[i]['state'] for i in indices], - 'actions': [self._storage[i]['action'] for i in indices], - 'rewards': np.array([self._storage[i]['reward'] for i in indices]), - 'next_states': [self._storage[i]['next_state'] for i in indices], - 'dones': np.array([self._storage[i]['done'] for i in indices]), - } + + if keys is None: + keys = self.DEFAULT_SAMPLE_KEYS + + # Determine which envs to sample from + if env_indices is None: + env_indices = np.arange(self.n_envs) + env_indices = np.asarray(env_indices) + + # Sample (time_idx, env_idx) pairs + time_indices = np.random.randint(0, self._size, size=batch_size) + env_sample_indices = np.random.choice(env_indices, size=batch_size) + + batch = {} + if 'state' in keys: + batch['state'] = self._state[time_indices, env_sample_indices] # (batch_size, state_dim) + if 'action' in keys: + batch['action'] = self._action[time_indices, env_sample_indices] # (batch_size, action_dim) + if 'reward' in keys: + batch['reward'] = self._reward[time_indices, env_sample_indices] # (batch_size,) + if 'next_state' in keys: + batch['next_state'] = self._next_state[time_indices, env_sample_indices] + if 'done' in keys: + batch['done'] = self._done[time_indices, env_sample_indices] + if 'truncated' in keys: + batch['truncated'] = self._truncated[time_indices, env_sample_indices] + + # Always include indices for debugging + batch['time_indices'] = time_indices + batch['env_indices'] = env_sample_indices + return batch def clear(self) -> None: - self._storage = [] + self._state.fill(0) + self._action.fill(0) + self._reward.fill(0) + self._next_state.fill(0) + self._done.fill(False) + self._truncated.fill(False) self._size = 0 self._position = 0 def save(self, path: str, **kwargs) -> None: - import pickle + data = { + 'state': self._state[:self._size], + 'action': self._action[:self._size], + 'reward': self._reward[:self._size], + 'next_state': self._next_state[:self._size], + 'done': self._done[:self._size], + 'truncated': self._truncated[:self._size], + 'size': self._size, + 'env_type': self.env_type, + 'n_envs': self.n_envs, + } with open(path, 'wb') as f: - pickle.dump(self._storage[:self._size], f) - print(f"Saved {self._size} transitions to {path}") + pickle.dump(data, f) + print(f"Saved {self._size} time steps ({self.total_transitions} transitions) to {path}") def load(self, path: str, **kwargs) -> None: - import pickle append = kwargs.get('append', False) if not append: self.clear() with open(path, 'rb') as f: data = pickle.load(f) - for transition in data: - self.add(transition) - print(f"Loaded {len(data)} transitions from {path}") + + size = data['size'] + for i in range(size): + self.add({ + 'state': data['state'][i], + 'action': data['action'][i], + 'reward': data['reward'][i], + 'next_state': data['next_state'][i], + 'done': data['done'][i], + 'truncated': data.get('truncated', np.zeros(self.n_envs, dtype=bool))[i] if 'truncated' in data else np.zeros(self.n_envs, dtype=bool), + }) + print(f"Loaded {self._size} time steps ({self.total_transitions} transitions) from {path}") # Test the implementation print("=" * 60) - print("Testing BaseReplay (SimpleReplay implementation)") + print("Testing BaseReplay (SimpleReplay with Vectorized Env Support)") print("=" * 60) - # Create buffer - buffer = SimpleReplay(capacity=100, device='cpu') + # Test 1: Single environment (n_envs=1) + print("\n" + "-" * 40) + print("Test 1: Single Environment (n_envs=1)") + print("-" * 40) + + buffer = SimpleReplay(capacity=100, device='cpu', env_type='single', n_envs=1) print(f"\n1. Created buffer: {buffer}") + print(f" env_type: {buffer.get_env_type()}") + print(f" n_envs: {buffer.get_n_envs()}") - # Create sample transitions + # Add transitions (n_envs=1, so shapes are (1, dim)) print("\n2. Adding transitions...") for i in range(10): - state = MetaObs( - state=np.random.randn(10).astype(np.float32), - state_ee=np.random.randn(7).astype(np.float32), - raw_lang="pick up the red block" - ) - action = MetaAction( - action=np.random.randn(7).astype(np.float32), - ctrl_space='ee', - ctrl_type='delta' - ) - next_state = MetaObs( - state=np.random.randn(10).astype(np.float32), - state_ee=np.random.randn(7).astype(np.float32), - raw_lang="pick up the red block" - ) - transition = { - 'state': asdict(state), - 'action': asdict(action), - 'reward': np.random.randn(), - 'next_state': asdict(next_state), - 'done': i == 9, - 'info': {'step': i}, - 'value': np.random.randn(), # Custom field - 'log_prob': np.random.randn(), # Custom field + 'state': np.random.randn(1, 10).astype(np.float32), + 'action': np.random.randn(1, 7).astype(np.float32), + 'reward': np.array([np.random.randn()]), + 'next_state': np.random.randn(1, 10).astype(np.float32), + 'done': np.array([i == 9]), # True terminal state + 'truncated': np.array([False]), # Not truncated } buffer.add(transition) - print(f" Buffer size after adding: {len(buffer)}") - print(f" Buffer is full: {buffer.is_full()}") + print(f" Buffer size (time steps): {len(buffer)}") + print(f" Total transitions: {buffer.total_transitions}") - # Sample from buffer + # Sample print("\n3. Sampling from buffer...") batch = buffer.sample(batch_size=5) print(f" Batch keys: {batch.keys()}") - print(f" Batch rewards shape: {batch['rewards'].shape}") - print(f" Number of states in batch: {len(batch['states'])}") + print(f" State shape: {batch['state'].shape}") + print(f" Reward shape: {batch['reward'].shape}") + print(f" Done shape: {batch['done'].shape}") + print(f" Truncated shape: {batch['truncated'].shape}") + + # Test 2: Vectorized environment (n_envs=4) + print("\n" + "-" * 40) + print("Test 2: Vectorized Environment (n_envs=4)") + print("-" * 40) + + vec_buffer = SimpleReplay( + capacity=100, + device='cpu', + env_type='sim', + n_envs=4, + state_dim=10, + action_dim=7 + ) + print(f"\n1. Created buffer: {vec_buffer}") - # Test sample_for_training with processor - print("\n4. Testing sample_for_training with data processor...") + # Add transitions from all 4 envs at once + print("\n2. Adding transitions from 4 parallel envs...") + for i in range(20): + transition = { + 'state': np.random.randn(4, 10).astype(np.float32), # (n_envs, state_dim) + 'action': np.random.randn(4, 7).astype(np.float32), # (n_envs, action_dim) + 'reward': np.random.randn(4).astype(np.float32), # (n_envs,) + 'next_state': np.random.randn(4, 10).astype(np.float32), + 'done': np.random.choice([True, False], size=4), # terminated + 'truncated': np.random.choice([True, False], size=4, p=[0.1, 0.9]), # truncated (10% chance) + } + vec_buffer.add(transition) + + print(f" Buffer size (time steps): {len(vec_buffer)}") + print(f" Total transitions: {vec_buffer.total_transitions}") + + # Sample from all envs + print("\n3. Sampling from all envs...") + batch = vec_buffer.sample(batch_size=16) + print(f" State shape: {batch['state'].shape}") # (16, 10) + print(f" Sampled from envs: {np.unique(batch['env_indices'])}") + + # Sample from specific envs only + print("\n4. Sampling from specific envs [0, 2] only...") + batch = vec_buffer.sample(batch_size=16, env_indices=[0, 2]) + print(f" State shape: {batch['state'].shape}") + print(f" Sampled from envs: {np.unique(batch['env_indices'])}") + assert all(e in [0, 2] for e in batch['env_indices']), "Should only sample from envs 0 and 2" + + # Test sample_for_training with env_indices + print("\n5. Testing sample_for_training with env_indices and processor...") def simple_processor(batch): - """Simple processor that adds a 'processed' flag.""" batch['processed'] = True return batch - processed_batch = buffer.sample_for_training(batch_size=5, data_processor=simple_processor) - print(f" Processed batch has 'processed' key: {'processed' in processed_batch}") + processed_batch = vec_buffer.sample_for_training( + batch_size=8, + env_indices=[1, 3], + data_processor=simple_processor + ) + print(f" Processed: {processed_batch.get('processed', False)}") + print(f" Sampled from envs: {np.unique(processed_batch['env_indices'])}") # Test save and load - print("\n5. Testing save and load...") + print("\n6. Testing save and load...") import tempfile import os with tempfile.TemporaryDirectory() as tmpdir: - save_path = os.path.join(tmpdir, 'buffer.pkl') - buffer.save(save_path) + save_path = os.path.join(tmpdir, 'vec_buffer.pkl') + vec_buffer.save(save_path) - # Create new buffer and load - new_buffer = SimpleReplay(capacity=100) + new_buffer = SimpleReplay(capacity=100, n_envs=4, state_dim=10, action_dim=7) new_buffer.load(save_path) print(f" Loaded buffer size: {len(new_buffer)}") + print(f" Loaded total transitions: {new_buffer.total_transitions}") + + # Test 3: Multiple buffers for different env types + print("\n" + "-" * 40) + print("Test 3: Multiple Buffers for Different Env Types") + print("-" * 40) + + buffers = { + 'indoor': SimpleReplay(capacity=50, env_type='indoor', n_envs=2), + 'outdoor': SimpleReplay(capacity=50, env_type='outdoor', n_envs=4), + 'sim': SimpleReplay(capacity=100, env_type='sim', n_envs=8), + } + + print("\nCreated buffers for different env types:") + for name, buf in buffers.items(): + print(f" {name}: env_type='{buf.env_type}', n_envs={buf.n_envs}, capacity={buf.capacity}") + + # Add some data to each + for name, buf in buffers.items(): + for _ in range(10): + buf.add({ + 'state': np.random.randn(buf.n_envs, 10).astype(np.float32), + 'action': np.random.randn(buf.n_envs, 7).astype(np.float32), + 'reward': np.random.randn(buf.n_envs).astype(np.float32), + 'next_state': np.random.randn(buf.n_envs, 10).astype(np.float32), + 'done': np.zeros(buf.n_envs, dtype=bool), + 'truncated': np.zeros(buf.n_envs, dtype=bool), + }) + + print("\nBuffer statistics:") + for name, buf in buffers.items(): + print(f" {name}: {len(buf)} time steps, {buf.total_transitions} total transitions") # Test clear - print("\n6. Testing clear...") - buffer.clear() - print(f" Buffer size after clear: {len(buffer)}") + print("\n7. Testing clear...") + vec_buffer.clear() + print(f" Buffer size after clear: {len(vec_buffer)}") + print(f" Total transitions after clear: {vec_buffer.total_transitions}") print("\n" + "=" * 60) print("All tests passed!") diff --git a/rl/buffer/meta_replay.py b/rl/buffer/meta_replay.py new file mode 100644 index 00000000..7994b753 --- /dev/null +++ b/rl/buffer/meta_replay.py @@ -0,0 +1,726 @@ +""" +Meta Replay Buffer (rewritten). + +Stores raw MetaObs/MetaAction/MetaObs(next) in env-first layout: + (n_envs, capacity, ...) + +Sampling for training returns normalized and processor-aligned data. +""" + +from __future__ import annotations + +import pickle +from dataclasses import asdict +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import numpy as np +import torch + +from benchmark.base import MetaObs, MetaAction, dict2meta +from rl.utils import RunningMeanStd +from .base_replay import BaseReplay +from .transition import RLTransition + + +CTRL_SPACE_MAP = {'ee': 0, 'joint': 1} +CTRL_SPACE_INV_MAP = {v: k for k, v in CTRL_SPACE_MAP.items()} + +CTRL_TYPE_MAP = {'delta': 0, 'absolute': 1, 'relative': 2} +CTRL_TYPE_INV_MAP = {v: k for k, v in CTRL_TYPE_MAP.items()} + +GRIPPER_CONTINUOUS_MAP = {False: 0, True: 1} +GRIPPER_CONTINUOUS_INV_MAP = {v: k for k, v in GRIPPER_CONTINUOUS_MAP.items()} + + +class _MetaObsStorage: + def __init__( + self, + n_envs: int, + capacity: int, + state_dim: Optional[int], + state_ee_dim: Optional[int], + state_joint_dim: Optional[int], + state_obj_dim: Optional[int], + image_shape: Optional[tuple], + depth_shape: Optional[tuple], + pc_shape: Optional[tuple], + store_images: bool, + store_depth: bool, + store_pc: bool, + store_lang: bool, + ) -> None: + self.n_envs = n_envs + self.capacity = capacity + self.store_images = store_images + self.store_depth = store_depth + self.store_pc = store_pc + self.store_lang = store_lang + + self.state = np.zeros((n_envs, capacity, state_dim), dtype=np.float32) if state_dim else None + self.state_ee = np.zeros((n_envs, capacity, state_ee_dim), dtype=np.float32) if state_ee_dim else None + self.state_joint = np.zeros((n_envs, capacity, state_joint_dim), dtype=np.float32) if state_joint_dim else None + self.state_obj = np.zeros((n_envs, capacity, state_obj_dim), dtype=np.float32) if state_obj_dim else None + self.image = np.zeros((n_envs, capacity, *image_shape), dtype=np.uint8) if (store_images and image_shape) else None + self.depth = np.zeros((n_envs, capacity, *depth_shape), dtype=np.float32) if (store_depth and depth_shape) else None + self.pc = np.zeros((n_envs, capacity, *pc_shape), dtype=np.float32) if (store_pc and pc_shape) else None + self.raw_lang: Optional[List[List[str]]] = [['' for _ in range(capacity)] for _ in range(n_envs)] if store_lang else None + self.timestep = np.zeros((n_envs, capacity), dtype=np.int32) + + def set(self, idx: int, obs: Union[MetaObs, Dict[str, Any], List[Any]]) -> None: + obs_dict = self._coerce_obs(obs) + if obs_dict.get('state') is not None and self.state is not None: + self.state[:, idx] = self._ensure_env_batch(obs_dict['state']) + if obs_dict.get('state_ee') is not None and self.state_ee is not None: + self.state_ee[:, idx] = self._ensure_env_batch(obs_dict['state_ee']) + if obs_dict.get('state_joint') is not None and self.state_joint is not None: + self.state_joint[:, idx] = self._ensure_env_batch(obs_dict['state_joint']) + if obs_dict.get('state_obj') is not None and self.state_obj is not None: + self.state_obj[:, idx] = self._ensure_env_batch(obs_dict['state_obj']) + if obs_dict.get('image') is not None and self.image is not None: + self.image[:, idx] = self._ensure_env_batch(obs_dict['image']) + if obs_dict.get('depth') is not None and self.depth is not None: + self.depth[:, idx] = self._ensure_env_batch(obs_dict['depth']) + if obs_dict.get('pc') is not None and self.pc is not None: + self.pc[:, idx] = self._ensure_env_batch(obs_dict['pc']) + + if self.store_lang and self.raw_lang is not None and obs_dict.get('raw_lang') is not None: + raw_lang = obs_dict['raw_lang'] + if isinstance(raw_lang, str): + raw_lang = [raw_lang] * self.n_envs + for env_i, value in enumerate(raw_lang): + self.raw_lang[env_i][idx] = value + + if obs_dict.get('timestep') is not None: + timestep = np.asarray(obs_dict['timestep']) + if timestep.ndim == 0: + timestep = np.array([timestep]) + self.timestep[:, idx] = timestep + + def get(self, time_indices: np.ndarray, env_indices: np.ndarray, keys: Set[str]) -> Dict[str, Any]: + batch: Dict[str, Any] = {} + if 'state' in keys and self.state is not None: + batch['state'] = self.state[env_indices, time_indices].copy() + if 'state_ee' in keys and self.state_ee is not None: + batch['state_ee'] = self.state_ee[env_indices, time_indices].copy() + if 'state_joint' in keys and self.state_joint is not None: + batch['state_joint'] = self.state_joint[env_indices, time_indices].copy() + if 'state_obj' in keys and self.state_obj is not None: + batch['state_obj'] = self.state_obj[env_indices, time_indices].copy() + if 'image' in keys and self.image is not None: + batch['image'] = self.image[env_indices, time_indices].copy() + if 'depth' in keys and self.depth is not None: + batch['depth'] = self.depth[env_indices, time_indices].copy() + if 'pc' in keys and self.pc is not None: + batch['pc'] = self.pc[env_indices, time_indices].copy() + if 'raw_lang' in keys and self.raw_lang is not None: + batch['raw_lang'] = [self.raw_lang[e][t] for t, e in zip(time_indices, env_indices)] + if 'timestep' in keys: + batch['timestep'] = self.timestep[env_indices, time_indices].copy() + return batch + + def _coerce_obs(self, obs: Union[MetaObs, Dict[str, Any], List[Any]]) -> Dict[str, Any]: + if isinstance(obs, list): + obs_dicts = [self._coerce_obs(o) for o in obs] + return self._stack_dicts(obs_dicts) + if hasattr(obs, '__dataclass_fields__'): + return asdict(obs) + if isinstance(obs, dict): + return obs + if hasattr(obs, '__dict__'): + return vars(obs) + return {} + + def _stack_dicts(self, dict_list: List[Dict[str, Any]]) -> Dict[str, Any]: + if not dict_list: + return {} + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values + else: + result[key] = values + return result + + def _ensure_env_batch(self, data: Any) -> np.ndarray: + arr = np.asarray(data) + if self.n_envs == 1 and arr.ndim >= 1 and arr.shape[0] != 1: + arr = arr[np.newaxis, ...] + return arr + + +class _MetaActionStorage: + def __init__(self, n_envs: int, capacity: int, action_dim: Optional[int]) -> None: + self.n_envs = n_envs + self.capacity = capacity + self.action = np.zeros((n_envs, capacity, action_dim), dtype=np.float32) if action_dim else None + self.ctrl_space = np.zeros((n_envs, capacity), dtype=np.int8) + self.ctrl_type = np.zeros((n_envs, capacity), dtype=np.int8) + self.gripper_continuous = np.zeros((n_envs, capacity), dtype=np.int8) + + def set(self, idx: int, action: Union[MetaAction, Dict[str, Any], List[Any]]) -> None: + action_dict = self._coerce_action(action) + if action_dict.get('action') is not None and self.action is not None: + self.action[:, idx] = self._ensure_env_batch(action_dict['action']) + + ctrl_space = action_dict.get('ctrl_space', 'ee') + if isinstance(ctrl_space, str): + ctrl_space = [ctrl_space] * self.n_envs + self.ctrl_space[:, idx] = np.array([CTRL_SPACE_MAP.get(cs, 0) for cs in ctrl_space], dtype=np.int8) + + ctrl_type = action_dict.get('ctrl_type', 'delta') + if isinstance(ctrl_type, str): + ctrl_type = [ctrl_type] * self.n_envs + self.ctrl_type[:, idx] = np.array([CTRL_TYPE_MAP.get(ct, 0) for ct in ctrl_type], dtype=np.int8) + + gripper_continuous = action_dict.get('gripper_continuous', False) + if isinstance(gripper_continuous, bool): + gripper_continuous = [gripper_continuous] * self.n_envs + self.gripper_continuous[:, idx] = np.array([GRIPPER_CONTINUOUS_MAP.get(gc, 0) for gc in gripper_continuous], dtype=np.int8) + + def get(self, time_indices: np.ndarray, env_indices: np.ndarray, keys: Set[str]) -> Dict[str, Any]: + batch: Dict[str, Any] = {} + if 'action' in keys and self.action is not None: + batch['action'] = self.action[env_indices, time_indices].copy() + if 'ctrl_space' in keys: + batch['ctrl_space'] = self.ctrl_space[env_indices, time_indices].copy() + batch['ctrl_space_str'] = [CTRL_SPACE_INV_MAP[v] for v in batch['ctrl_space']] + if 'ctrl_type' in keys: + batch['ctrl_type'] = self.ctrl_type[env_indices, time_indices].copy() + batch['ctrl_type_str'] = [CTRL_TYPE_INV_MAP[v] for v in batch['ctrl_type']] + if 'gripper_continuous' in keys: + batch['gripper_continuous'] = self.gripper_continuous[env_indices, time_indices].copy() + batch['gripper_continuous_bool'] = [GRIPPER_CONTINUOUS_INV_MAP[v] for v in batch['gripper_continuous']] + return batch + + def _coerce_action(self, action: Union[MetaAction, Dict[str, Any], List[Any]]) -> Dict[str, Any]: + if isinstance(action, list): + action_dicts = [self._coerce_action(a) for a in action] + return self._stack_dicts(action_dicts) + if hasattr(action, '__dataclass_fields__'): + return asdict(action) + if isinstance(action, dict): + return action + if hasattr(action, '__dict__'): + return vars(action) + return {} + + def _stack_dicts(self, dict_list: List[Dict[str, Any]]) -> Dict[str, Any]: + if not dict_list: + return {} + result = {} + for key in dict_list[0].keys(): + values = [d.get(key) for d in dict_list] + if values[0] is None: + result[key] = None + elif isinstance(values[0], np.ndarray): + result[key] = np.stack(values) + elif isinstance(values[0], (int, float, bool)): + result[key] = np.array(values) + elif isinstance(values[0], str): + result[key] = values + else: + result[key] = values + return result + + def _ensure_env_batch(self, data: Any) -> np.ndarray: + arr = np.asarray(data) + if self.n_envs == 1 and arr.ndim >= 1 and arr.shape[0] != 1: + arr = arr[np.newaxis, ...] + return arr + + +class MetaReplay(BaseReplay): + """ + Replay buffer for MetaObs/MetaAction with env-first storage. + """ + + DEFAULT_SAMPLE_KEYS = {'state', 'action', 'next_state', 'reward', 'done', 'truncated'} + ALL_SAMPLE_KEYS = { + 'state', 'state_ee', 'state_joint', 'state_obj', 'image', 'depth', 'pc', 'raw_lang', 'timestep', + 'next_state', 'next_state_ee', 'next_state_joint', 'next_state_obj', 'next_image', 'next_depth', + 'next_pc', 'next_raw_lang', 'next_timestep', + 'action', 'ctrl_space', 'ctrl_type', 'gripper_continuous', + 'reward', 'done', 'truncated', 'trajectory_id', + } + + def __init__( + self, + capacity: int = 100000, + device: Union[str, torch.device] = 'cpu', + env_type: Optional[str] = None, + n_envs: int = 1, + state_dim: Optional[int] = None, + state_ee_dim: Optional[int] = None, + state_joint_dim: Optional[int] = None, + state_obj_dim: Optional[int] = None, + image_shape: Optional[tuple] = None, + depth_shape: Optional[tuple] = None, + pc_shape: Optional[tuple] = None, + action_dim: Optional[int] = None, + store_images: bool = True, + store_depth: bool = False, + store_pc: bool = False, + store_lang: bool = False, + data_processor: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + data_collator: Optional[Callable[[List[Dict[str, Any]]], Dict[str, Any]]] = None, + state_normalizer: Optional[RunningMeanStd] = None, + action_normalizer: Optional[RunningMeanStd] = None, + update_normalizers: bool = True, + **kwargs, + ) -> None: + super().__init__(capacity=capacity, device=device, env_type=env_type, n_envs=n_envs, **kwargs) + + self.state_dim = state_dim + self.state_ee_dim = state_ee_dim + self.state_joint_dim = state_joint_dim + self.state_obj_dim = state_obj_dim + self.image_shape = image_shape + self.depth_shape = depth_shape + self.pc_shape = pc_shape + self.action_dim = action_dim + self.store_images = store_images + self.store_depth = store_depth + self.store_pc = store_pc + self.store_lang = store_lang + + self.data_processor = data_processor + self.data_collator = data_collator + self.update_normalizers = update_normalizers + self.state_normalizer = state_normalizer or (RunningMeanStd((state_dim,)) if state_dim else None) + self.action_normalizer = action_normalizer or (RunningMeanStd((action_dim,)) if action_dim else None) + + self._obs = _MetaObsStorage( + n_envs=n_envs, + capacity=capacity, + state_dim=state_dim, + state_ee_dim=state_ee_dim, + state_joint_dim=state_joint_dim, + state_obj_dim=state_obj_dim, + image_shape=image_shape, + depth_shape=depth_shape, + pc_shape=pc_shape, + store_images=store_images, + store_depth=store_depth, + store_pc=store_pc, + store_lang=store_lang, + ) + self._next_obs = _MetaObsStorage( + n_envs=n_envs, + capacity=capacity, + state_dim=state_dim, + state_ee_dim=state_ee_dim, + state_joint_dim=state_joint_dim, + state_obj_dim=state_obj_dim, + image_shape=image_shape, + depth_shape=depth_shape, + pc_shape=pc_shape, + store_images=store_images, + store_depth=store_depth, + store_pc=store_pc, + store_lang=store_lang, + ) + self._action = _MetaActionStorage(n_envs=n_envs, capacity=capacity, action_dim=action_dim) + self._reward = np.zeros((n_envs, capacity), dtype=np.float32) + self._done = np.zeros((n_envs, capacity), dtype=np.bool_) + self._truncated = np.zeros((n_envs, capacity), dtype=np.bool_) + self._trajectory_id = np.zeros((n_envs, capacity), dtype=np.int32) + + def add(self, transition: Union[RLTransition, Dict[str, Any]]) -> None: + if not isinstance(transition, RLTransition): + transition = self._coerce_transition(transition) + + idx = self._position + + + self._obs.set(idx, transition.obs) + self._action.set(idx, transition.action) + self._next_obs.set(idx, transition.next_obs) + + reward = np.asarray(transition.reward) + done = np.asarray(transition.done) + truncated = np.asarray(transition.truncated) if transition.truncated is not None else np.zeros(self.n_envs, dtype=bool) + if self.n_envs == 1: + reward = np.atleast_1d(reward) + done = np.atleast_1d(done) + truncated = np.atleast_1d(truncated) + + self._reward[:, idx] = reward + self._done[:, idx] = done + self._truncated[:, idx] = truncated + + if isinstance(transition.info, dict) and 'trajectory_id' in transition.info: + traj_id = transition.info['trajectory_id'] + if self.n_envs == 1: + traj_id = np.atleast_1d(traj_id) + self._trajectory_id[:, idx] = traj_id + + self._position = (self._position + 1) % self.capacity + self._size = min(self._size + 1, self.capacity) + + def sample( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + if self._size == 0: + return {} + + if keys is None: + keys = self.DEFAULT_SAMPLE_KEYS + + if env_indices is None: + env_indices = np.arange(self.n_envs) + env_indices = np.asarray(env_indices) + + time_indices = np.random.randint(0, self._size, size=batch_size) + env_sample_indices = np.random.choice(env_indices, size=batch_size) + + batch = {} + obs_keys = {k for k in keys if not k.startswith('next_')} + next_obs_keys = {k[5:] for k in keys if k.startswith('next_')} + + batch.update(self._obs.get(time_indices, env_sample_indices, obs_keys)) + next_obs_batch = self._next_obs.get(time_indices, env_sample_indices, next_obs_keys) + for key, value in next_obs_batch.items(): + batch[f"next_{key}"] = value + batch.update(self._action.get(time_indices, env_sample_indices, keys)) + + if 'reward' in keys: + batch['reward'] = self._reward[env_sample_indices, time_indices].copy() + if 'done' in keys: + batch['done'] = self._done[env_sample_indices, time_indices].copy() + if 'truncated' in keys: + batch['truncated'] = self._truncated[env_sample_indices, time_indices].copy() + if 'trajectory_id' in keys: + batch['trajectory_id'] = self._trajectory_id[env_sample_indices, time_indices].copy() + + batch['time_indices'] = time_indices.copy() + batch['env_indices'] = env_sample_indices.copy() + return batch + + def sample_as_tensor( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + ) -> Dict[str, Any]: + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) + if not batch: + return {} + tensor_batch: Dict[str, Any] = {} + for key, value in batch.items(): + if isinstance(value, np.ndarray): + tensor_batch[key] = torch.from_numpy(value).to(self.device) + else: + tensor_batch[key] = value + return tensor_batch + + def sample_for_training( + self, + batch_size: int, + env_indices: Optional[Union[List[int], np.ndarray]] = None, + keys: Optional[Set[str]] = None, + data_processor: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + data_collator: Optional[Callable[[List[Dict[str, Any]]], Dict[str, Any]]] = None, + apply_normalization: bool = True, + ) -> Dict[str, Any]: + batch = self.sample(batch_size, env_indices=env_indices, keys=keys) + if not batch: + return {} + + if apply_normalization: + batch = self._apply_normalization(batch) + + processor = data_processor or self.data_processor + collator = data_collator or self.data_collator + if processor is not None or collator is not None: + obs_samples = self._build_samples(batch, prefix='') + next_obs_samples = self._build_samples(batch, prefix='next_') + + if processor is not None: + obs_samples = [processor(sample) for sample in obs_samples] + next_obs_samples = [processor(sample) for sample in next_obs_samples] + + if collator is not None: + batch['processed_obs'] = collator(obs_samples) + batch['processed_next_obs'] = collator(next_obs_samples) + else: + batch['processed_obs'] = self._default_collate(obs_samples) + batch['processed_next_obs'] = self._default_collate(next_obs_samples) + + return batch + + def clear(self) -> None: + self._size = 0 + self._position = 0 + self.__init__( + capacity=self.capacity, + device=self.device, + env_type=self.env_type, + n_envs=self.n_envs, + state_dim=self.state_dim, + state_ee_dim=self.state_ee_dim, + state_joint_dim=self.state_joint_dim, + state_obj_dim=self.state_obj_dim, + image_shape=self.image_shape, + depth_shape=self.depth_shape, + pc_shape=self.pc_shape, + action_dim=self.action_dim, + store_images=self.store_images, + store_depth=self.store_depth, + store_pc=self.store_pc, + store_lang=self.store_lang, + data_processor=self.data_processor, + data_collator=self.data_collator, + state_normalizer=self.state_normalizer, + action_normalizer=self.action_normalizer, + update_normalizers=self.update_normalizers, + ) + + def save(self, path: str, **kwargs) -> None: + fmt = kwargs.get('format', 'pkl') + if path.endswith('.npz'): + fmt = 'npz' + + data = { + 'size': self._size, + 'position': self._position, + 'capacity': self.capacity, + 'n_envs': self.n_envs, + 'env_type': self.env_type, + 'state_dim': self.state_dim, + 'state_ee_dim': self.state_ee_dim, + 'state_joint_dim': self.state_joint_dim, + 'state_obj_dim': self.state_obj_dim, + 'image_shape': self.image_shape, + 'depth_shape': self.depth_shape, + 'pc_shape': self.pc_shape, + 'action_dim': self.action_dim, + 'store_images': self.store_images, + 'store_depth': self.store_depth, + 'store_pc': self.store_pc, + 'store_lang': self.store_lang, + 'storage_layout': 'env_first', + } + + size = self._size + data.update({ + '_state': None if self._obs.state is None else self._obs.state[:, :size], + '_state_ee': None if self._obs.state_ee is None else self._obs.state_ee[:, :size], + '_state_joint': None if self._obs.state_joint is None else self._obs.state_joint[:, :size], + '_state_obj': None if self._obs.state_obj is None else self._obs.state_obj[:, :size], + '_image': None if self._obs.image is None else self._obs.image[:, :size], + '_depth': None if self._obs.depth is None else self._obs.depth[:, :size], + '_pc': None if self._obs.pc is None else self._obs.pc[:, :size], + '_timestep': self._obs.timestep[:, :size], + '_next_state': None if self._next_obs.state is None else self._next_obs.state[:, :size], + '_next_state_ee': None if self._next_obs.state_ee is None else self._next_obs.state_ee[:, :size], + '_next_state_joint': None if self._next_obs.state_joint is None else self._next_obs.state_joint[:, :size], + '_next_state_obj': None if self._next_obs.state_obj is None else self._next_obs.state_obj[:, :size], + '_next_image': None if self._next_obs.image is None else self._next_obs.image[:, :size], + '_next_depth': None if self._next_obs.depth is None else self._next_obs.depth[:, :size], + '_next_pc': None if self._next_obs.pc is None else self._next_obs.pc[:, :size], + '_next_timestep': self._next_obs.timestep[:, :size], + '_action': None if self._action.action is None else self._action.action[:, :size], + '_ctrl_space': self._action.ctrl_space[:, :size], + '_ctrl_type': self._action.ctrl_type[:, :size], + '_gripper_continuous': self._action.gripper_continuous[:, :size], + '_reward': self._reward[:, :size], + '_done': self._done[:, :size], + '_truncated': self._truncated[:, :size], + '_trajectory_id': self._trajectory_id[:, :size], + }) + if self.store_lang and self._obs.raw_lang is not None: + data['_raw_lang'] = [row[:size] for row in self._obs.raw_lang] + if self.store_lang and self._next_obs.raw_lang is not None: + data['_next_raw_lang'] = [row[:size] for row in self._next_obs.raw_lang] + + if fmt == 'npz': + np_data = {k: v for k, v in data.items() if isinstance(v, np.ndarray)} + if '_raw_lang' in data: + np_data['_raw_lang'] = np.array(data['_raw_lang'], dtype=object) + if '_next_raw_lang' in data: + np_data['_next_raw_lang'] = np.array(data['_next_raw_lang'], dtype=object) + np_data['_metadata'] = np.array([data['size'], data['position'], data['capacity'], data['n_envs']]) + np_data['_env_type'] = np.array([data['env_type']], dtype=object) + np.savez_compressed(path, **np_data) + else: + with open(path, 'wb') as f: + pickle.dump(data, f) + + def load(self, path: str, **kwargs) -> None: + append = kwargs.get('append', False) + if not append: + self.clear() + + if path.endswith('.npz'): + data = np.load(path, allow_pickle=True) + size = int(data['_metadata'][0]) + for i in range(size): + transition = self._extract_transition_from_npz(data, i) + self.add(transition) + return + + with open(path, 'rb') as f: + data = pickle.load(f) + size = data['size'] + for i in range(size): + transition = self._extract_transition_from_data(data, i) + self.add(transition) + + def _apply_normalization(self, batch: Dict[str, Any]) -> Dict[str, Any]: + if self.state_normalizer is not None: + if 'state' in batch: + batch['state'] = self.state_normalizer.normalize(batch['state']) + if 'next_state' in batch: + batch['next_state'] = self.state_normalizer.normalize(batch['next_state']) + if self.action_normalizer is not None and 'action' in batch: + batch['action'] = self.action_normalizer.normalize(batch['action']) + return batch + + def _build_samples(self, batch: Dict[str, Any], prefix: str = '') -> List[Dict[str, Any]]: + samples: List[Dict[str, Any]] = [] + state_key = f'{prefix}state' + image_key = f'{prefix}image' + raw_lang_key = f'{prefix}raw_lang' + timestep_key = f'{prefix}timestep' + + state_arr = batch.get(state_key, None) + batch_size = state_arr.shape[0] if isinstance(state_arr, np.ndarray) and state_arr.ndim > 1 else len(batch.get(raw_lang_key, [])) + if batch_size == 0: + return samples + + for i in range(batch_size): + sample = {} + if image_key in batch: + sample['image'] = batch[image_key][i] + if state_key in batch: + sample['state'] = batch[state_key][i] if isinstance(batch[state_key], np.ndarray) else batch[state_key] + if raw_lang_key in batch: + sample['raw_lang'] = batch[raw_lang_key][i] + if timestep_key in batch: + sample['timestamp'] = batch[timestep_key][i] + samples.append(sample) + return samples + + def _default_collate(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: + if not samples: + return {} + batch: Dict[str, Any] = {} + keys = samples[0].keys() + for key in keys: + values = [s[key] for s in samples] + if values[0] is None: + batch[key] = None + elif isinstance(values[0], np.ndarray): + batch[key] = torch.from_numpy(np.stack(values)) + elif isinstance(values[0], torch.Tensor): + batch[key] = torch.stack(values) + elif isinstance(values[0], str): + batch[key] = values + elif isinstance(values[0], (int, float)): + batch[key] = torch.tensor(values) + else: + batch[key] = values + return batch + + def _maybe_update_normalizers(self, obs: MetaObs, action: MetaAction) -> None: + if not self.update_normalizers: + return + if self.state_normalizer is not None and obs is not None and getattr(obs, 'state', None) is not None: + state_arr = np.asarray(getattr(obs, 'state')) + if self.n_envs == 1 and state_arr.ndim == 1: + state_arr = state_arr[np.newaxis, :] + if state_arr.ndim >= 2: + self.state_normalizer.update(state_arr) + if self.action_normalizer is not None and action is not None and getattr(action, 'action', None) is not None: + action_arr = np.asarray(getattr(action, 'action')) + if self.n_envs == 1 and action_arr.ndim == 1: + action_arr = action_arr[np.newaxis, :] + if action_arr.ndim >= 2: + self.action_normalizer.update(action_arr) + + def _coerce_transition(self, transition: Dict[str, Any]) -> RLTransition: + obs = transition.get('state') or transition.get('obs') + action = transition.get('action') + next_obs = transition.get('next_state') or transition.get('next_obs') + reward = transition.get('reward') + done = transition.get('done') + truncated = transition.get('truncated', None) + info = transition.get('info', None) + + if not isinstance(obs, MetaObs): + obs = dict2meta(obs or {}, mtype='obs') + if not isinstance(action, MetaAction): + action = dict2meta(action or {}, mtype='act') + if not isinstance(next_obs, MetaObs): + next_obs = dict2meta(next_obs or {}, mtype='obs') + + return RLTransition(obs=obs, action=action, next_obs=next_obs, reward=reward, done=done, truncated=truncated, info=info) + + def _extract_transition_from_data(self, data: Dict[str, Any], idx: int) -> RLTransition: + obs = MetaObs( + state=data.get('_state')[:, idx] if data.get('_state') is not None else None, + state_ee=data.get('_state_ee')[:, idx] if data.get('_state_ee') is not None else None, + state_joint=data.get('_state_joint')[:, idx] if data.get('_state_joint') is not None else None, + state_obj=data.get('_state_obj')[:, idx] if data.get('_state_obj') is not None else None, + image=data.get('_image')[:, idx] if data.get('_image') is not None else None, + depth=data.get('_depth')[:, idx] if data.get('_depth') is not None else None, + pc=data.get('_pc')[:, idx] if data.get('_pc') is not None else None, + raw_lang=[row[idx] for row in data.get('_raw_lang', [])] if data.get('_raw_lang') is not None else None, + timestep=data.get('_timestep')[:, idx] if data.get('_timestep') is not None else None, + ) + next_obs = MetaObs( + state=data.get('_next_state')[:, idx] if data.get('_next_state') is not None else None, + state_ee=data.get('_next_state_ee')[:, idx] if data.get('_next_state_ee') is not None else None, + state_joint=data.get('_next_state_joint')[:, idx] if data.get('_next_state_joint') is not None else None, + state_obj=data.get('_next_state_obj')[:, idx] if data.get('_next_state_obj') is not None else None, + image=data.get('_next_image')[:, idx] if data.get('_next_image') is not None else None, + depth=data.get('_next_depth')[:, idx] if data.get('_next_depth') is not None else None, + pc=data.get('_next_pc')[:, idx] if data.get('_next_pc') is not None else None, + raw_lang=[row[idx] for row in data.get('_next_raw_lang', [])] if data.get('_next_raw_lang') is not None else None, + timestep=data.get('_next_timestep')[:, idx] if data.get('_next_timestep') is not None else None, + ) + action = MetaAction( + action=data.get('_action')[:, idx] if data.get('_action') is not None else None, + ctrl_space=[CTRL_SPACE_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_ctrl_space')[:, idx])], + ctrl_type=[CTRL_TYPE_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_ctrl_type')[:, idx])], + gripper_continuous=[GRIPPER_CONTINUOUS_INV_MAP[int(v)] for v in np.atleast_1d(data.get('_gripper_continuous')[:, idx])], + ) + return RLTransition( + obs=obs, + action=action, + next_obs=next_obs, + reward=data.get('_reward')[:, idx], + done=data.get('_done')[:, idx], + truncated=data.get('_truncated')[:, idx] if data.get('_truncated') is not None else None, + ) + + def _extract_transition_from_npz(self, npz_data, idx: int) -> RLTransition: + data = {k: npz_data[k] for k in npz_data.files} + return self._extract_transition_from_data(data, idx) + + def get_trajectory(self, trajectory_id: int, env_idx: int = 0) -> Dict[str, Any]: + mask = self._trajectory_id[env_idx, :self._size] == trajectory_id + time_indices = np.where(mask)[0] + if len(time_indices) == 0: + return {} + env_indices = np.full(len(time_indices), env_idx, dtype=np.int64) + return self.sample(len(time_indices), env_indices=env_indices) + + def __repr__(self) -> str: + return (f"MetaReplay(capacity={self.capacity}, size={self._size}, " + f"n_envs={self.n_envs}, total_transitions={self.total_transitions}, " + f"env_type='{self.env_type}', state_dim={self.state_dim}, " + f"action_dim={self.action_dim}, store_images={self.store_images}, " + f"store_lang={self.store_lang})") + + diff --git a/rl/buffer/transition.py b/rl/buffer/transition.py new file mode 100644 index 00000000..ab708045 --- /dev/null +++ b/rl/buffer/transition.py @@ -0,0 +1,38 @@ +""" +RL transition data structures. +""" + +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np + +from benchmark.base import MetaObs, MetaAction + + +@dataclass +class RLTransition: + """ + One-step transition for RL replay (raw environment meta data). + """ + + obs: MetaObs + action: MetaAction + next_obs: MetaObs + reward: Any + done: Any + truncated: Optional[Any] = None + info: Optional[Any] = None + + def to_batch(self) -> "RLTransition": + """ + Ensure obs/action/next_obs are batched along env dimension. + """ + if hasattr(self.obs, "to_batch"): + self.obs.to_batch() + if hasattr(self.action, "to_batch"): + self.action.to_batch() + if hasattr(self.next_obs, "to_batch"): + self.next_obs.to_batch() + return self + diff --git a/rl/collectors/__init__.py b/rl/collectors/__init__.py index b7446018..b74a1ea8 100644 --- a/rl/collectors/__init__.py +++ b/rl/collectors/__init__.py @@ -19,7 +19,7 @@ from typing import Type, Dict, Any -from .base_collector import BaseCollector +from .base_collector import BaseCollector, DummyCollector # Registry for collector classes _COLLECTOR_REGISTRY: Dict[str, Type] = {} @@ -77,6 +77,7 @@ def list_collectors() -> list: __all__ = [ 'BaseCollector', + 'DummyCollector', 'register_collector', 'get_collector_class', 'list_collectors', diff --git a/rl/collectors/base_collector.py b/rl/collectors/base_collector.py index 177371db..d8e595b1 100644 --- a/rl/collectors/base_collector.py +++ b/rl/collectors/base_collector.py @@ -10,10 +10,11 @@ - Raw data storage: Only store raw environment rewards, no reward function computation, ensuring data integrity - Statistics: Collect and return episode statistics +- Exploration support: Support random exploration phase and noise-based exploration """ import numpy as np -from typing import Dict, Any, Optional, Union, List +from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING from abc import ABC, abstractmethod # Type hints for Meta classes @@ -22,6 +23,9 @@ # Vector environment protocol from rl.envs import VectorEnvProtocol, EnvsType +if TYPE_CHECKING: + from utils.exploration import ExplorationScheduler + class BaseCollector(ABC): """ @@ -34,6 +38,7 @@ class BaseCollector(ABC): Attributes: envs: Vectorized environment(s) to collect data from algorithm: Algorithm instance for action selection and transition recording + exploration: Optional exploration strategy for action exploration Note: Collector only stores raw environment rewards, no reward function computation. Reward functions are applied in the trainer during training time. @@ -43,6 +48,7 @@ def __init__( self, envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], algorithm: 'BaseAlgorithm', + exploration: Optional['ExplorationScheduler'] = None, **kwargs ): """ @@ -56,6 +62,16 @@ def __init__( e.g., {'sim': sim_vec_env, 'real': real_vec_env} algorithm: BaseAlgorithm instance (required) - Used for action selection and transition recording + exploration: Optional ExplorationScheduler for action exploration + - If None, no exploration is applied (use policy actions directly) + - Supports initial random exploration + noise-based exploration + - Example: + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) **kwargs: Collector-specific parameters Note: Collector only stores raw environment rewards, no reward function computation. @@ -63,8 +79,12 @@ def __init__( """ self.envs = envs self.algorithm = algorithm + self.exploration = exploration self._kwargs = kwargs + # Total steps counter for exploration scheduling + self._total_steps = 0 + # Normalize environment storage: always use dict internally if isinstance(envs, dict): self._envs_dict: Dict[str, VectorEnvProtocol] = envs @@ -147,6 +167,168 @@ def get_envs(self) -> Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]]: """Get the underlying environment(s).""" return self.envs + # ==================== Exploration Methods ==================== + + def set_exploration(self, exploration: Optional['ExplorationScheduler']) -> None: + """ + Set or update the exploration strategy. + + Args: + exploration: ExplorationScheduler instance or None to disable exploration + """ + self.exploration = exploration + + def apply_exploration( + self, + action: Any, + obs: Any = None, + **kwargs + ) -> Any: + """ + Apply exploration to action(s). + + Supports: + - MetaAction: Single action or batched (action.action shape: (action_dim,) or (n_envs, action_dim)) + - List[MetaAction]: List of individual actions + - np.ndarray: Raw action array + - dict with 'action' key + + Args: + action: Action(s) from policy + obs: Optional observation (for uncertainty-based exploration) + **kwargs: Additional arguments for exploration strategy + + Returns: + Explored action(s) (same type/structure as input) + """ + if self.exploration is None: + return action + + # Case 1: List of MetaAction objects + if isinstance(action, list) and len(action) > 0 and hasattr(action[0], 'action'): + # Stack actions, apply exploration, then update each + action_arrays = [a.action for a in action] + stacked = np.stack(action_arrays) # (n_envs, action_dim) + explored = self.exploration( + stacked, + step=self._total_steps, + obs=obs, + **kwargs + ) + # Update each MetaAction + for i, a in enumerate(action): + a.action = explored[i] + return action + + # Case 2: Single MetaAction (possibly with batched action field) + elif hasattr(action, 'action') and action.action is not None: + action_array = action.action + explored_array = self.exploration( + action_array, + step=self._total_steps, + obs=obs, + **kwargs + ) + action.action = explored_array + return action + + # Case 3: Dict with 'action' key + elif isinstance(action, dict) and 'action' in action: + action_array = action['action'] + explored_array = self.exploration( + action_array, + step=self._total_steps, + obs=obs, + **kwargs + ) + action['action'] = explored_array + return action + + # Case 4: Raw numpy array + elif isinstance(action, np.ndarray): + return self.exploration( + action, + step=self._total_steps, + obs=obs, + **kwargs + ) + + # Case 5: Unknown type, return as-is + else: + return action + + def reset_exploration(self, env_idx: Optional[int] = None) -> None: + """ + Reset exploration state (e.g., for OU noise). + + Args: + env_idx: Optional specific environment index to reset + """ + if self.exploration is not None: + self.exploration.reset(env_idx) + + @property + def total_steps(self) -> int: + """Get total steps collected so far.""" + return self._total_steps + + @property + def is_exploring_randomly(self) -> bool: + """Check if currently in random exploration phase.""" + if self.exploration is None: + return False + return self.exploration.is_random_phase + + def _unpack_actions(self, actions: Any, n_envs: int) -> List[Any]: + """ + Unpack batched actions into a list for vec_env.step. + + Handles: + - MetaAction with batched action field: (n_envs, action_dim) -> List[MetaAction] + - List[MetaAction]: Return as-is + - np.ndarray: (n_envs, action_dim) -> List[np.ndarray] + + Args: + actions: Batched actions from policy + n_envs: Number of environments + + Returns: + List of individual actions, one per environment + """ + # Case 1: Already a list + if isinstance(actions, list): + return actions + + # Case 2: MetaAction with batched action field + if hasattr(actions, 'action') and actions.action is not None: + action_array = actions.action + if isinstance(action_array, np.ndarray) and action_array.ndim >= 2: + # Batched: (n_envs, action_dim) or (n_envs, chunk_size, action_dim) + unpacked = [] + for i in range(min(n_envs, len(action_array))): + # Create new MetaAction for each env + single_action = MetaAction( + action=action_array[i], + ctrl_space=getattr(actions, 'ctrl_space', 'ee'), + ctrl_type=getattr(actions, 'ctrl_type', 'delta'), + gripper_continuous=getattr(actions, 'gripper_continuous', False), + ) + unpacked.append(single_action) + return unpacked + else: + # Single action, replicate for all envs + return [actions] * n_envs + + # Case 3: np.ndarray (batched) + if isinstance(actions, np.ndarray): + if actions.ndim >= 2: + return [actions[i] for i in range(min(n_envs, len(actions)))] + else: + return [actions] * n_envs + + # Case 4: Unknown, replicate for all envs + return [actions] * n_envs + def get_algorithm(self) -> 'BaseAlgorithm': """Get the algorithm instance.""" return self.algorithm @@ -168,400 +350,120 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(envs={env_info}, algorithm={self.algorithm.__class__.__name__})" -if __name__ == '__main__': +class DummyCollector(BaseCollector): """ - Test code for BaseCollector class. + Simple collector implementation for testing and basic use cases. - Since BaseCollector is abstract, we create a simple concrete implementation for testing. + This collector: + - Works with vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) + - Handles batched observations and actions + - Creates RLTransition objects for storage in replay buffer + - Tracks episode statistics """ - import sys - sys.path.insert(0, '/home/zhang/robot/126/ILStudio') - import torch - from benchmark.base import MetaObs, MetaAction, MetaEnv, MetaPolicy - from rl.buffer.base_replay import BaseReplay - from rl.base import BaseAlgorithm - from dataclasses import asdict + def __init__(self, envs, algorithm, **kwargs): + super().__init__(envs, algorithm, **kwargs) + self.vec_env = self.get_env() + self._last_obs = None - # Simple replay buffer for testing - class SimpleReplay(BaseReplay): - def __init__(self, capacity=1000, device='cpu', **kwargs): - super().__init__(capacity=capacity, device=device, **kwargs) - self._storage = [] - - def add(self, transition): - if self._size < self.capacity: - self._storage.append(transition) - self._size += 1 - else: - self._storage[self._position] = transition - self._position = (self._position + 1) % self.capacity - - def sample(self, batch_size): - if self._size == 0: - return {} - indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) - return { - 'states': [self._storage[i]['state'] for i in indices], - 'actions': [self._storage[i]['action'] for i in indices], - 'rewards': np.array([self._storage[i]['reward'] for i in indices]), - } - - def clear(self): - self._storage = [] - self._size = 0 - self._position = 0 - - def save(self, path, **kwargs): - pass - - def load(self, path, **kwargs): - pass - - # Simple dummy environment for testing - class DummyEnv: - def __init__(self, state_dim=10, action_dim=7): - self.state_dim = state_dim - self.action_dim = action_dim - self._step_count = 0 - self._max_steps = 5 - - def reset(self): - self._step_count = 0 - return {'state': np.random.randn(self.state_dim).astype(np.float32)} - - def step(self, action): - self._step_count += 1 - obs = {'state': np.random.randn(self.state_dim).astype(np.float32)} - reward = np.random.randn() - done = self._step_count >= self._max_steps - info = {'step': self._step_count} - if done: - info['episode'] = {'r': np.random.randn() * 10, 'l': self._step_count} - return obs, reward, done, info - - def close(self): - pass + def reset(self, **kwargs): + """Reset the collector and environments.""" + self.vec_env = self.get_env() + self._last_obs = self.vec_env.reset() - # Simple MetaEnv wrapper for testing - class DummyMetaEnv(MetaEnv): - def __init__(self, state_dim=10, action_dim=7, max_steps=100): - self.env = DummyEnv(state_dim=state_dim, action_dim=action_dim) - self.env._max_steps = max_steps - self.prev_obs = None - - def obs2meta(self, raw_obs): - return MetaObs( - state=raw_obs['state'], - raw_lang="test instruction" - ) - - def meta2act(self, action, *args): - if hasattr(action, 'action'): - return action.action - return action - - def reset(self): - init_obs = self.env.reset() - self.prev_obs = self.obs2meta(init_obs) - return self.prev_obs + def collect(self, n_steps, env_type=None): + """ + Collect n_steps of interaction data. - def step(self, action): - act = self.meta2act(action) - obs, reward, done, info = self.env.step(act) - self.prev_obs = self.obs2meta(obs) - return self.prev_obs, reward, done, info - - # Simple policy for testing - class DummyPolicy: - def __init__(self, action_dim=7): - self.action_dim = action_dim - - def select_action(self, obs): - return MetaAction( - action=np.random.randn(self.action_dim).astype(np.float32), - ctrl_space='ee', - ctrl_type='delta' - ) + Args: + n_steps: Number of steps to collect + env_type: Optional environment type identifier - def train(self): - pass + Returns: + Dictionary with statistics: + - episode_rewards: List of episode rewards + - episode_lengths: List of episode lengths + - total_steps: Total number of steps collected + - env_type: Environment type identifier + """ + if self._last_obs is None: + self.reset() - def eval(self): - pass - - # Simple MetaPolicy wrapper for testing - class DummyMetaPolicy(MetaPolicy): - def __init__(self, action_dim=7): - self.policy = DummyPolicy(action_dim=action_dim) - self.chunk_size = 1 - self.ctrl_space = 'ee' - self.ctrl_type = 'delta' - self.action_queue = [] - self.action_normalizer = None - self.state_normalizer = None - - def select_action(self, mobs, t=0, **kwargs): - return self.policy.select_action(mobs) - - # Simple algorithm for testing - class DummyAlgorithm(BaseAlgorithm): - def __init__(self, meta_policy, replay=None, **kwargs): - super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) - self._timestep = 0 + from benchmark.base import dict2meta, MetaObs, MetaAction + from benchmark.utils import organize_obs + from rl.buffer.transition import RLTransition - def update(self, batch=None, **kwargs): - return {'loss': 0.0} + stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0, 'env_type': env_type} + episode_rewards = np.zeros(len(self.vec_env)) + episode_lengths = np.zeros(len(self.vec_env), dtype=int) - def select_action(self, obs, **kwargs): - return self.meta_policy.select_action(obs, t=self._timestep) - - # Simple VectorEnv for testing (similar to SequentialVectorEnv) - class DummyVectorEnv: - """Simple sequential vector environment for testing.""" - def __init__(self, env_fns): - self.envs = [fn() for fn in env_fns] - self.env_num = len(self.envs) - - def reset(self, id=None): - if id is None: - obs_list = [env.reset() for env in self.envs] - return np.array(obs_list, dtype=object) - else: - if np.isscalar(id): - return self.envs[id].reset() - else: - obs_list = [self.envs[i].reset() for i in id] - return np.array(obs_list, dtype=object) - - def step(self, action, id=None): - if id is None: - results = [env.step(act) for env, act in zip(self.envs, action)] - obs = np.array([r[0] for r in results], dtype=object) - rew = np.array([r[1] for r in results]) - done = np.array([r[2] for r in results]) - info = [r[3] for r in results] - return obs, rew, done, info - else: - if np.isscalar(id): - return self.envs[id].step(action) - else: - results = [self.envs[i].step(act) for i, act in zip(id, action)] - obs = np.array([r[0] for r in results], dtype=object) - rew = np.array([r[1] for r in results]) - done = np.array([r[2] for r in results]) - info = [r[3] for r in results] - return obs, rew, done, info - - def close(self): - for env in self.envs: - if hasattr(env, 'close'): - env.close() - - def __len__(self): - return self.env_num - - # Simple collector implementation for testing - class SimpleCollector(BaseCollector): - """Simple collector for testing with vectorized environments.""" - - def __init__( - self, - envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], - algorithm: BaseAlgorithm, - **kwargs - ): - super().__init__(envs, algorithm, **kwargs) + for step in range(n_steps): + # Organize observations into batched MetaObs + batched_obs = organize_obs(self._last_obs) if not isinstance(self._last_obs, MetaObs) else self._last_obs - # Get default environment - self.vec_env = self.get_env() - self._last_obs = None - self._last_dones = None - self._current_episode_reward = None - self._current_episode_length = None - - def reset(self, **kwargs) -> None: - """Reset all environments.""" - self._last_obs = self.vec_env.reset() - self._last_dones = np.zeros(len(self.vec_env), dtype=bool) - self._current_episode_reward = np.zeros(len(self.vec_env), dtype=np.float32) - self._current_episode_length = np.zeros(len(self.vec_env), dtype=int) - - def collect(self, n_steps: int, env_type: Optional[str] = None) -> Dict[str, Any]: - """ - Collect n_steps of interaction data. + # Get batched actions for all environments at once + # actions is expected to be an object array of dicts (from MetaPolicy) + actions = self.algorithm.select_action(batched_obs) - Args: - n_steps: Number of steps to collect - env_type: Optional environment type identifier + # Apply exploration if configured + if self.exploration is not None: + if self.is_exploring_randomly: + stats['random_steps'] = stats.get('random_steps', 0) + len(self.vec_env) + actions = self.apply_exploration(actions, obs=batched_obs) - Returns: - Statistics dictionary - """ - if self._last_obs is None: - self.reset() + # Step all environments (expects array of dicts, one per env) + new_obs, rewards, dones, infos = self.vec_env.step(actions) - stats = { - 'episode_rewards': [], - 'episode_lengths': [], - 'total_steps': 0 - } + # Organize next observations into batched MetaObs + batched_next_obs = organize_obs(new_obs) if not isinstance(new_obs, MetaObs) else new_obs - steps_executed = 0 + # Reconstruct MetaAction for storage (since actions is now an object array) + # We need to extract the raw action arrays from the dicts + if isinstance(actions, np.ndarray) and actions.dtype == object: + # Extract 'action' field from each dict + raw_actions = np.stack([a['action'] for a in actions]) + # Create MetaAction + stored_actions = MetaAction(action=raw_actions) + elif isinstance(actions, MetaAction): + stored_actions = actions + else: + stored_actions = dict2meta(actions, mtype='act') - while steps_executed < n_steps: - # Get actions for all environments - actions = [] - for obs in self._last_obs: - with torch.no_grad(): - action = self.algorithm.select_action(obs) - actions.append(action) - - # Step all environments - new_obs, rewards, dones, infos = self.vec_env.step(actions) - - # Record transitions and update stats - for i in range(len(self.vec_env)): - # Record transition (only store raw reward) - transition_kwargs = {} - if env_type is not None: - transition_kwargs['env_type'] = env_type - - self.algorithm.record_transition( - state=self._last_obs[i], - action=actions[i], - reward=rewards[i], # Store raw reward - next_state=new_obs[i], - done=dones[i], - info=infos[i], - **transition_kwargs - ) - - # Update episode statistics - self._current_episode_reward[i] += rewards[i] - self._current_episode_length[i] += 1 - stats['total_steps'] += 1 - steps_executed += 1 - + # Create single RLTransition for all environments + truncated = np.array([infos[i].get('TimeLimit.truncated', False) for i in range(len(infos))]) if infos else np.zeros_like(dones) + + transition = RLTransition( + obs=batched_obs, + action=stored_actions, + next_obs=batched_next_obs, + reward=rewards, + done=dones, + truncated=truncated, + info=infos if infos else None, + ) + + kwargs_trans = {'env_type': env_type} if env_type else {} + self.algorithm.record_transition(transition, **kwargs_trans) + + # Update episode statistics + episode_rewards += rewards + episode_lengths += 1 + stats['total_steps'] += len(self.vec_env) + self._total_steps += len(self.vec_env) + + # Handle done environments + done_indices = np.where(dones)[0] + if len(done_indices) > 0: + stats['episode_rewards'].extend(episode_rewards[done_indices].tolist()) + stats['episode_lengths'].extend(episode_lengths[done_indices].tolist()) + episode_rewards[done_indices] = 0.0 + episode_lengths[done_indices] = 0 # Reset done environments - done_indices = np.where(dones)[0] - if len(done_indices) > 0: - for idx in done_indices: - stats['episode_rewards'].append(self._current_episode_reward[idx]) - stats['episode_lengths'].append(self._current_episode_length[idx]) - self._current_episode_reward[idx] = 0.0 - self._current_episode_length[idx] = 0 - - # Reset done environments - reset_obs = self.vec_env.reset(id=done_indices.tolist()) - if len(done_indices) == 1: - new_obs[done_indices[0]] = reset_obs - else: - for idx, reset_idx in enumerate(done_indices): - new_obs[reset_idx] = reset_obs[idx] - - # Mark reset environments as not done - dones[done_indices] = False - - self._last_obs = new_obs - self._last_dones = dones + for idx in done_indices: + new_obs[idx] = self.vec_env.reset(id=idx) - return stats - - # Test the implementation - print("=" * 60) - print("Testing BaseCollector with Vectorized Environments") - print("=" * 60) - - # Create environment, policy, algorithm - print("\n1. Creating components...") - # Create vectorized environment - # Use max_steps=5 to match DummyEnv's _max_steps=5 (user changed it) - env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=5) - env_fns = [env_fn for _ in range(4)] - vec_env = DummyVectorEnv(env_fns) - - meta_policy = DummyMetaPolicy(action_dim=7) - replay = SimpleReplay(capacity=1000) - algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) - - print(f" Vectorized Environment: {len(vec_env)} parallel envs") - print(f" MetaPolicy: {meta_policy}") - print(f" Algorithm: {algorithm}") - - # Create collector - print("\n2. Creating collector...") - collector = SimpleCollector( - envs=vec_env, - algorithm=algorithm - ) - print(f" Collector: {collector}") - print(f" env_num: {collector.env_num}") - print(f" total_env_num: {collector.get_total_env_num()}") - - # Test reset - print("\n3. Testing reset...") - collector.reset() - print(" Reset successful") - - # Test collect - print("\n4. Testing collect...") - stats = collector.collect(n_steps=50) - print(f" Collected {stats['total_steps']} steps") - print(f" Episodes completed: {len(stats['episode_rewards'])}") - print(f" Episode rewards: {stats['episode_rewards'][:5]}...") # Show first 5 - print(f" Episode lengths: {stats['episode_lengths'][:5]}...") # Show first 5 - print(f" Replay buffer size: {len(algorithm.replay)}") - - # Test with env_type - print("\n5. Testing collect with env_type...") - # Create algorithm with multi-replay - multi_replay = { - 'indoor': SimpleReplay(capacity=1000), - 'outdoor': SimpleReplay(capacity=1000) - } - multi_algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=multi_replay) - - indoor_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) - indoor_vec_env = DummyVectorEnv([indoor_env_fn for _ in range(2)]) - indoor_collector = SimpleCollector( - envs=indoor_vec_env, - algorithm=multi_algorithm - ) - - # Collect to indoor replay - indoor_collector.reset() - indoor_stats = indoor_collector.collect(n_steps=30, env_type='indoor') - print(f" Indoor collected {indoor_stats['total_steps']} steps") - print(f" Indoor replay size: {len(multi_algorithm.replay['indoor'])}") - print(f" Outdoor replay size: {len(multi_algorithm.replay['outdoor'])}") - - # Test with multiple environment types (dict) - print("\n6. Testing with multiple environment types (dict)...") - sim_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) - real_env_fn = lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) - - multi_envs = { - 'sim': DummyVectorEnv([sim_env_fn for _ in range(4)]), - 'real': DummyVectorEnv([real_env_fn for _ in range(2)]), - } - - multi_env_collector = SimpleCollector( - envs=multi_envs, - algorithm=DummyAlgorithm(meta_policy=meta_policy, replay=SimpleReplay(capacity=1000)) - ) - print(f" Multi-env collector: {multi_env_collector}") - print(f" env_types: {multi_env_collector.get_env_types()}") - print(f" total_env_num: {multi_env_collector.get_total_env_num()}") - print(f" sim env_num: {len(multi_env_collector.get_env('sim'))}") - print(f" real env_num: {len(multi_env_collector.get_env('real'))}") - - multi_env_collector.reset() - multi_env_stats = multi_env_collector.collect(n_steps=1000) - print(f" Collected {multi_env_stats['total_steps']} steps") - print(f" Episodes completed: {len(multi_env_stats['episode_rewards'])}") - - print("\n" + "=" * 60) - print("All tests passed!") - print("=" * 60) + # Reorganize observations for next iteration + self._last_obs = organize_obs(new_obs) if not isinstance(new_obs, MetaObs) else new_obs + + return stats diff --git a/rl/trainers/base_trainer.py b/rl/trainers/base_trainer.py index 049ce113..2be75ca0 100644 --- a/rl/trainers/base_trainer.py +++ b/rl/trainers/base_trainer.py @@ -4,16 +4,27 @@ This module defines the base class for all RL trainers in the framework. Design Philosophy: -- Coordinate environment, policy, and algorithm for executing training loop -- Support single algorithm and multiple algorithms training scenarios +- Coordinate algorithm and collectors for executing training loop +- Support a single algorithm collecting data from multiple different environment types +- Each environment type has its own Collector and Replay Buffer +- Collectors manage environments internally (Trainer doesn't need envs directly) - Support custom reward functions (applied during training, not data collection) - Support evaluation during training -- Support vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) + +Architecture: + Trainer + ├── algorithm: BaseAlgorithm (single algorithm) + │ └── replay: Dict[env_type, BaseReplay] (multiple replays, one per env type) + ├── collectors: Dict[env_type, BaseCollector] (multiple collectors, one per env type) + │ ├── collector['env_type_A']: manages VectorEnv A and stores to replay['env_type_A'] + │ ├── collector['env_type_B']: manages VectorEnv B and stores to replay['env_type_B'] + │ └── ... + └── reward_fn: Optional[BaseReward] """ import numpy as np import torch -from typing import Dict, Any, Optional, Union, List, Callable +from typing import Dict, Any, Optional, Union, List, Callable, TYPE_CHECKING from abc import ABC, abstractmethod # Type hints for Meta classes @@ -22,6 +33,12 @@ # Vector environment protocol from rl.envs import VectorEnvProtocol, EnvsType +if TYPE_CHECKING: + from rl.rewards.base_reward import BaseReward +else: + # Import at runtime to avoid potential circular imports + from rl.rewards.base_reward import BaseReward + class BaseTrainer(ABC): """ @@ -29,23 +46,26 @@ class BaseTrainer(ABC): This class defines the interface for all trainers in the RL framework. Trainers coordinate the training loop by managing: - - Vectorized environments (envs) - - Algorithms (with their policies and replay buffers) - - Collectors (for data collection) + - Algorithm (with policy and multiple replay buffers for different env types) + - Collectors (for data collection from different env types) - Reward functions (applied during training) + Key Design: + - Trainer does NOT directly manage environments + - Each Collector manages its own VectorEnv internally + - Algorithm has Dict[env_type, Replay] for storing data from different env types + - Number of collectors must match the number of env types in algorithm's replay dict + Attributes: - envs: Vectorized environment(s) for training - algorithm: Algorithm(s) for training - collector: Data collector(s) + algorithm: Algorithm for training + collectors: Dict of collectors, keyed by env_type reward_fn: Optional custom reward function """ def __init__( self, - envs: Union[VectorEnvProtocol, Dict[str, VectorEnvProtocol]], - algorithm: Union['BaseAlgorithm', List['BaseAlgorithm']], - collector: Optional[Union['BaseCollector', List['BaseCollector']]] = None, + algorithm: 'BaseAlgorithm', + collectors: Union['BaseCollector', Dict[str, 'BaseCollector']], reward_fn: Optional[Union['BaseReward', Callable]] = None, **kwargs ): @@ -53,94 +73,97 @@ def __init__( Initialize the trainer. Args: - envs: Vectorized environment(s), supports: - - VectorEnvProtocol: Single vectorized environment - (SequentialVectorEnv, SubprocVectorEnv, DummyVectorEnv, etc.) - - Dict[str, VectorEnvProtocol]: Multi-environment dict for different env types - e.g., {'sim': sim_vec_env, 'real': real_vec_env} - algorithm: BaseAlgorithm instance or list of BaseAlgorithm (required) - - Single algorithm: Single agent training - - Algorithm list: Multiple algorithms training independently in same env - (each algorithm has its own replay buffer) - collector: Optional data collector (if None, trainer creates default collector) - - Single algorithm: Single collector - - Multiple algorithms: Can be collector list, each for one algorithm - - If None, trainer creates default collector for each algorithm + algorithm: BaseAlgorithm instance (required) + - Should have replay as Dict[env_type, BaseReplay] for multi-env scenarios + - Each replay buffer stores data from its corresponding env type + collectors: Data collectors, supports: + - BaseCollector: Single collector (for single env type, uses 'default') + - Dict[str, BaseCollector]: Multi-collector dict for different env types + e.g., {'sim': sim_collector, 'real': real_collector} + - Each collector manages its own VectorEnv + - Keys should match algorithm's replay dict keys reward_fn: Optional reward function (if None, use raw environment reward) - Can be BaseReward instance or Callable function - Applied during training for algorithm updates - Note: Replay buffer stores raw rewards, reward function only applied during training **kwargs: Trainer-specific parameters """ - self.envs = envs self.algorithm = algorithm self.reward_fn = reward_fn self._kwargs = kwargs - # Normalize environment storage: always use dict internally - if isinstance(envs, dict): - self._envs_dict: Dict[str, VectorEnvProtocol] = envs - self._is_multi_env = True + # Normalize collector storage: always use dict internally + if isinstance(collectors, dict): + self._collectors_dict: Dict[str, 'BaseCollector'] = collectors else: - self._envs_dict = {'default': envs} - self._is_multi_env = False + self._collectors_dict = {'default': collectors} - # Handle collector initialization - self.collector = collector - # Note: Actual collector creation should be done in subclasses - # since it may require specific collector types + # Store reference for easier access + self.collectors = self._collectors_dict - def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + def get_collector(self, env_type: Optional[str] = None) -> 'BaseCollector': """ - Get the vectorized environment by type. + Get the collector by env_type. Args: - env_type: Environment type identifier. If None, returns 'default' env - or the first env if 'default' doesn't exist. + env_type: Environment type identifier. If None, returns 'default' collector + or the first collector if 'default' doesn't exist. Returns: - The vectorized environment + The collector for the specified env type Raises: KeyError: If specified env_type is not found """ if env_type is not None: - return self._envs_dict[env_type] + return self._collectors_dict[env_type] + + # Return 'default' if exists, otherwise return first collector + if 'default' in self._collectors_dict: + return self._collectors_dict['default'] + return list(self._collectors_dict.values())[0] + + def get_env(self, env_type: Optional[str] = None) -> VectorEnvProtocol: + """ + Get the vectorized environment by type (from collector). + + Args: + env_type: Environment type identifier. - # Return 'default' if exists, otherwise return first env - if 'default' in self._envs_dict: - return self._envs_dict['default'] - return list(self._envs_dict.values())[0] + Returns: + The vectorized environment managed by the collector + """ + collector = self.get_collector(env_type) + return collector.get_env() def get_total_env_num(self) -> int: """ Get total number of parallel environments across all types. Returns: - Total count of parallel environments + Total count of parallel environments across all collectors """ - return sum(len(env) for env in self._envs_dict.values()) + return sum(col.get_total_env_num() for col in self._collectors_dict.values()) def get_env_types(self) -> List[str]: """ Get all environment type identifiers. Returns: - List of environment type strings + List of environment type strings (collector keys) """ - return list(self._envs_dict.keys()) + return list(self._collectors_dict.keys()) @property def env_num(self) -> int: """ - Number of environments in the default (or first) environment. + Number of environments in the default (or first) collector. Returns: Number of parallel environments """ - if 'default' in self._envs_dict: - return len(self._envs_dict['default']) - return len(list(self._envs_dict.values())[0]) + collector = self.get_collector() + return collector.env_num @abstractmethod def train(self, **kwargs) -> None: @@ -183,9 +206,6 @@ def compute_reward( Computed reward value """ if self.reward_fn is not None: - # Import here to avoid circular imports - from rl.rewards.base_reward import BaseReward - if isinstance(self.reward_fn, BaseReward): return self.reward_fn.compute(state, action, next_state, env_reward, info) else: @@ -197,28 +217,42 @@ def collect_rollout( self, n_steps: int, env_type: Optional[str] = None - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + ) -> Union[Dict[str, Any], Dict[str, Dict[str, Any]]]: """ - Collect rollout data (using collector). + Collect rollout data (using collectors). Args: n_steps: Number of steps to collect - env_type: Optional environment type identifier (for multi-environment scenarios) - - Used to support a single algorithm storing data from multiple different environments + env_type: Optional environment type identifier + - If specified, only collect from that env type's collector + - If None, collect from ALL collectors Returns: - - Single algorithm: Rollout statistics dictionary - - Multiple algorithms: List of rollout statistics dictionaries + - If env_type specified: Single stats dict from that collector + - If env_type is None: Dict[env_type, stats] from all collectors """ - if self.collector is None: - raise ValueError("Collector is not initialized. Please provide or create a collector.") - - if isinstance(self.collector, list): - # Multiple algorithms: Each algorithm collects independently - return [col.collect(n_steps, env_type=env_type) for col in self.collector] + if env_type is not None: + # Collect from specific env type + collector = self.get_collector(env_type) + return collector.collect(n_steps, env_type=env_type) else: - # Single algorithm - return self.collector.collect(n_steps, env_type=env_type) + # Collect from all collectors + all_stats = {} + for env_type_key, collector in self._collectors_dict.items(): + all_stats[env_type_key] = collector.collect(n_steps, env_type=env_type_key) + return all_stats + + def collect_rollout_all(self, n_steps: int) -> Dict[str, Dict[str, Any]]: + """ + Collect rollout data from ALL collectors. + + Args: + n_steps: Number of steps to collect from each collector + + Returns: + Dict[env_type, stats] from all collectors + """ + return self.collect_rollout(n_steps, env_type=None) @abstractmethod def evaluate( @@ -252,13 +286,33 @@ def load(self, path: str) -> None: """Load model and training state.""" raise NotImplementedError - def get_algorithm(self) -> Union['BaseAlgorithm', List['BaseAlgorithm']]: - """Get the algorithm(s).""" + def get_algorithm(self) -> 'BaseAlgorithm': + """Get the algorithm.""" return self.algorithm - def get_collector(self) -> Optional[Union['BaseCollector', List['BaseCollector']]]: - """Get the collector(s).""" - return self.collector + def get_collectors(self) -> Dict[str, 'BaseCollector']: + """Get all collectors as a dict.""" + return self._collectors_dict + + def get_replay(self, env_type: Optional[str] = None) -> 'BaseReplay': + """ + Get the replay buffer for a specific env type. + + Args: + env_type: Environment type. If None, returns default or first replay. + + Returns: + The replay buffer + """ + replay = self.algorithm.replay + if isinstance(replay, dict): + if env_type is not None: + return replay[env_type] + elif 'default' in replay: + return replay['default'] + else: + return list(replay.values())[0] + return replay def set_reward_fn(self, reward_fn: Union['BaseReward', Callable]) -> None: """ @@ -269,14 +323,19 @@ def set_reward_fn(self, reward_fn: Union['BaseReward', Callable]) -> None: """ self.reward_fn = reward_fn + def reset_collectors(self) -> None: + """Reset all collectors.""" + for collector in self._collectors_dict.values(): + collector.reset() + def __repr__(self) -> str: - algo_info = self.algorithm.__class__.__name__ if not isinstance(self.algorithm, list) else f"List[{len(self.algorithm)}]" - collector_info = "None" if self.collector is None else ( - self.collector.__class__.__name__ if not isinstance(self.collector, list) else f"List[{len(self.collector)}]" - ) + algo_info = self.algorithm.__class__.__name__ + env_types = list(self._collectors_dict.keys()) + collector_info = f"Dict[{len(self._collectors_dict)}]: {env_types}" reward_info = "None" if self.reward_fn is None else self.reward_fn.__class__.__name__ - env_info = f"{self.get_total_env_num()} envs" if self._is_multi_env else f"{self.env_num} envs" - return f"{self.__class__.__name__}(envs={env_info}, algorithm={algo_info}, collector={collector_info}, reward_fn={reward_info})" + total_envs = self.get_total_env_num() + return (f"{self.__class__.__name__}(algorithm={algo_info}, " + f"collectors={collector_info}, total_envs={total_envs}, reward_fn={reward_info})") if __name__ == '__main__': @@ -284,53 +343,21 @@ def __repr__(self) -> str: Test code for BaseTrainer class. Since BaseTrainer is abstract, we create a simple concrete implementation for testing. - Tests now use vectorized environments (VectorEnvProtocol). + Tests the new architecture where Trainer uses Collectors (which manage envs internally). """ import sys sys.path.insert(0, '/home/zhang/robot/126/ILStudio') from benchmark.base import MetaObs, MetaAction, MetaEnv, MetaPolicy from rl.buffer.base_replay import BaseReplay - from rl.base import BaseAlgorithm + from rl.algorithms.base import BaseAlgorithm from rl.rewards.base_reward import BaseReward, IdentityReward, ScaledReward from rl.collectors.base_collector import BaseCollector from rl.envs import VectorEnvProtocol from dataclasses import asdict - # Simple replay buffer for testing - class SimpleReplay(BaseReplay): - def __init__(self, capacity=1000, device='cpu', **kwargs): - super().__init__(capacity=capacity, device=device, **kwargs) - self._storage = [] - - def add(self, transition): - if self._size < self.capacity: - self._storage.append(transition) - self._size += 1 - else: - self._storage[self._position] = transition - self._position = (self._position + 1) % self.capacity - - def sample(self, batch_size): - if self._size == 0: - return {} - indices = np.random.randint(0, self._size, size=min(batch_size, self._size)) - return { - 'states': [self._storage[i]['state'] for i in indices], - 'actions': [self._storage[i]['action'] for i in indices], - 'rewards': np.array([self._storage[i]['reward'] for i in indices]), - } - - def clear(self): - self._storage = [] - self._size = 0 - self._position = 0 - - def save(self, path, **kwargs): - pass - - def load(self, path, **kwargs): - pass + # Use MetaReplay for efficient env-first storage with vectorized sampling + from rl.buffer.meta_replay import MetaReplay # Simple dummy environment for testing class DummyEnv: @@ -380,58 +407,17 @@ def step(self, action): self.prev_obs = self.obs2meta(obs) return self.prev_obs, reward, done, info - # Simple VectorEnv for testing (similar to SequentialVectorEnv) - class DummyVectorEnv: - """Simple sequential vector environment for testing.""" - def __init__(self, env_fns): - self.envs = [fn() for fn in env_fns] - self.env_num = len(self.envs) - - def reset(self, id=None): - if id is None: - obs_list = [env.reset() for env in self.envs] - return np.array(obs_list, dtype=object) - else: - if np.isscalar(id): - return self.envs[id].reset() - else: - obs_list = [self.envs[i].reset() for i in id] - return np.array(obs_list, dtype=object) - - def step(self, action, id=None): - if id is None: - results = [env.step(act) for env, act in zip(self.envs, action)] - obs = np.array([r[0] for r in results], dtype=object) - rew = np.array([r[1] for r in results]) - done = np.array([r[2] for r in results]) - info = [r[3] for r in results] - return obs, rew, done, info - else: - if np.isscalar(id): - return self.envs[id].step(action) - else: - results = [self.envs[i].step(act) for i, act in zip(id, action)] - obs = np.array([r[0] for r in results], dtype=object) - rew = np.array([r[1] for r in results]) - done = np.array([r[2] for r in results]) - info = [r[3] for r in results] - return obs, rew, done, info - - def close(self): - for env in self.envs: - if hasattr(env, 'close'): - env.close() - - def __len__(self): - return self.env_num + # Use SequentialVectorEnv from benchmark/utils.py for testing + from benchmark.utils import SequentialVectorEnv # Simple policy for testing class DummyPolicy: def __init__(self, action_dim=7): self.action_dim = action_dim - def select_action(self, obs): - return MetaAction(action=np.random.randn(self.action_dim).astype(np.float32)) + def select_action(self, obs, n_envs=1): + # Return batched actions (n_envs, action_dim) + return MetaAction(action=np.random.randn(n_envs, self.action_dim).astype(np.float32)) def train(self): pass @@ -451,9 +437,26 @@ def __init__(self, action_dim=7): self.state_normalizer = None def select_action(self, mobs, t=0, **kwargs): - return self.policy.select_action(mobs) + # Infer batch size from obs + n_envs = 1 + if hasattr(mobs, 'state') and mobs.state is not None: + n_envs = mobs.state.shape[0] if mobs.state.ndim > 1 else 1 + + # Get MetaAction from policy + mact = self.policy.select_action(mobs, n_envs=n_envs) + + # Simulate MetaPolicy.inference output format: + # It returns a numpy array of dicts (one dict per env) + # We must convert MetaAction to this format + action_array = mact.action + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Create list of dicts, then convert to object array + action_dicts = [{'action': action_array[i]} for i in range(len(action_array))] + return np.array(action_dicts, dtype=object) - # Simple algorithm for testing + # Simple algorithm for testing (supports Dict[env_type, replay]) class DummyAlgorithm(BaseAlgorithm): def __init__(self, meta_policy, replay=None, **kwargs): super().__init__(meta_policy=meta_policy, replay=replay, **kwargs) @@ -463,93 +466,39 @@ def __init__(self, meta_policy, replay=None, **kwargs): def update(self, batch=None, **kwargs): self._update_count += 1 batch_size = kwargs.get('batch_size', 32) + env_type = kwargs.get('env_type', None) + if batch is None and self.replay is not None: - batch = self.replay.sample(batch_size) + if isinstance(self.replay, dict): + if env_type: + batch = self.replay[env_type].sample(batch_size) + else: + # Sample from all replays + batch = {} + for k, v in self.replay.items(): + batch[k] = v.sample(batch_size) + else: + batch = self.replay.sample(batch_size) return {'loss': np.random.randn(), 'update_count': self._update_count} def select_action(self, obs, **kwargs): return self.meta_policy.select_action(obs, t=self._timestep) - # Simple collector for testing (works with VectorEnv) - class DummyCollector(BaseCollector): - def __init__(self, envs, algorithm, **kwargs): - super().__init__(envs, algorithm, **kwargs) - self.vec_env = envs - self._last_obs = None - - def reset(self, **kwargs): - self._last_obs = self.vec_env.reset() - - def collect(self, n_steps, env_type=None): - if self._last_obs is None: - self.reset() - - stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0} - episode_rewards = np.zeros(len(self.vec_env)) - episode_lengths = np.zeros(len(self.vec_env), dtype=int) - - for step in range(n_steps): - # Get actions for all environments - actions = [] - for obs in self._last_obs: - action = self.algorithm.select_action(obs) - actions.append(action) - - # Step all environments - new_obs, rewards, dones, infos = self.vec_env.step(actions) - - # Record transitions and update stats - for i in range(len(self.vec_env)): - kwargs_trans = {'env_type': env_type} if env_type else {} - self.algorithm.record_transition( - state=self._last_obs[i], action=actions[i], reward=rewards[i], - next_state=new_obs[i], done=dones[i], info=infos[i], **kwargs_trans - ) - - episode_rewards[i] += rewards[i] - episode_lengths[i] += 1 - stats['total_steps'] += 1 - - if dones[i]: - stats['episode_rewards'].append(episode_rewards[i]) - stats['episode_lengths'].append(episode_lengths[i]) - episode_rewards[i] = 0.0 - episode_lengths[i] = 0 - # Reset this env - new_obs[i] = self.vec_env.reset(id=i) - - self._last_obs = new_obs - - return stats + # Import DummyCollector from base_collector + from rl.collectors.base_collector import DummyCollector - # Simple trainer implementation for testing + # Simple trainer implementation for testing (new architecture) class SimpleTrainer(BaseTrainer): - """Simple trainer for testing with vectorized environments.""" + """Simple trainer for testing - uses collectors (which manage envs internally).""" def __init__( self, - envs, algorithm, - collector=None, + collectors, reward_fn=None, **kwargs ): - super().__init__(envs, algorithm, collector, reward_fn, **kwargs) - - # Create default collector if not provided - if self.collector is None: - default_env = self.get_env() - if isinstance(algorithm, list): - self.collector = [ - DummyCollector(envs=default_env, algorithm=alg) - for alg in algorithm - ] - else: - self.collector = DummyCollector( - envs=default_env, - algorithm=algorithm - ) - + super().__init__(algorithm, collectors, reward_fn, **kwargs) self._total_steps = 0 self._training_logs = [] @@ -562,16 +511,15 @@ def train(self, **kwargs): print(f"Starting training for {total_steps} steps...") while self._total_steps < total_steps: - # Collect data - stats = self.collect_rollout(n_steps=update_interval) - self._total_steps += stats['total_steps'] if isinstance(stats, dict) else sum(s['total_steps'] for s in stats) + # Collect data from all env types + all_stats = self.collect_rollout_all(n_steps=update_interval) - # Update algorithm - if isinstance(self.algorithm, list): - for alg in self.algorithm: - result = alg.update(batch_size=batch_size) - else: - result = self.algorithm.update(batch_size=batch_size) + # Sum up total steps from all collectors + total_collected = sum(stats['total_steps'] for stats in all_stats.values()) + self._total_steps += total_collected + + # Update algorithm (can sample from any/all replays) + result = self.algorithm.update(batch_size=batch_size) # Log if self._total_steps % log_interval == 0: @@ -584,17 +532,17 @@ def train(self, **kwargs): print(f"Training completed. Total steps: {self._total_steps}") def evaluate(self, n_episodes=10, render=False, env_type=None, **kwargs): - print(f"Evaluating for {n_episodes} episodes...") + print(f"Evaluating for {n_episodes} episodes on env_type={env_type}...") - # Get eval env - vec_env = self.get_env(env_type) if env_type else self.get_env() - alg = self.algorithm[0] if isinstance(self.algorithm, list) else self.algorithm + # Get env from collector + vec_env = self.get_env(env_type) + alg = self.algorithm episode_rewards = [] episode_lengths = [] for ep in range(n_episodes): - obs = vec_env.reset(id=0) # Reset first env only for single episode eval + obs = vec_env.reset(id=0) episode_reward = 0.0 episode_length = 0 done = False @@ -624,101 +572,125 @@ def save(self, path): def load(self, path): print(f"Loading from {path} (mock)") + # ========================================================================== # Test the implementation + # ========================================================================== print("=" * 60) - print("Testing BaseTrainer with Vectorized Environments") + print("Testing BaseTrainer with New Architecture") + print("(Trainer uses Collectors which manage Envs internally)") print("=" * 60) - # Create components - print("\n1. Creating components...") - # Create vectorized environment with 4 parallel envs + # Test 1: Single environment type + print("\n" + "-" * 40) + print("Test 1: Single Environment Type") + print("-" * 40) + + # Create env, replay, algorithm, collector env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] - vec_env = DummyVectorEnv(env_fns) - print(f" Created DummyVectorEnv with {vec_env.env_num} parallel environments") + vec_env = SequentialVectorEnv(env_fns) + print(f"\n1. Created SequentialVectorEnv with {vec_env.env_num} parallel envs") meta_policy = DummyMetaPolicy(action_dim=7) - replay = SimpleReplay(capacity=10000) + replay = MetaReplay(capacity=10000, env_type='default', n_envs=4, state_dim=10, action_dim=7) algorithm = DummyAlgorithm(meta_policy=meta_policy, replay=replay) - # Test 1: Basic trainer with vectorized env - print("\n2. Testing basic trainer with vectorized env...") + # Create collector (it manages the env) + collector = DummyCollector(envs=vec_env, algorithm=algorithm) + + # Create trainer (NO envs parameter - collector manages it!) trainer = SimpleTrainer( - envs=vec_env, algorithm=algorithm, + collectors=collector, # Single collector -> becomes {'default': collector} reward_fn=None ) - print(f" Trainer: {trainer}") - print(f" env_num: {trainer.env_num}") - print(f" total_env_num: {trainer.get_total_env_num()}") + print(f"\n2. Created trainer: {trainer}") print(f" env_types: {trainer.get_env_types()}") + print(f" total_env_num: {trainer.get_total_env_num()}") - # Test compute_reward without reward_fn - print("\n3. Testing compute_reward without custom reward_fn...") + # Test compute_reward + print("\n3. Testing compute_reward...") state = MetaObs(state=np.random.randn(10).astype(np.float32)) action = MetaAction(action=np.random.randn(7).astype(np.float32)) next_state = MetaObs(state=np.random.randn(10).astype(np.float32)) env_reward = 1.5 computed_reward = trainer.compute_reward(state, action, next_state, env_reward, {}) print(f" Env reward: {env_reward}, Computed reward: {computed_reward}") - assert computed_reward == env_reward - - # Test compute_reward with custom reward_fn - print("\n4. Testing compute_reward with custom reward_fn...") - trainer.set_reward_fn(ScaledReward(scale=2.0, offset=0.5)) - computed_reward = trainer.compute_reward(state, action, next_state, env_reward, {}) - expected = env_reward * 2.0 + 0.5 - print(f" Env reward: {env_reward}, Computed reward: {computed_reward}, Expected: {expected}") - assert abs(computed_reward - expected) < 1e-6 # Test training - print("\n5. Testing training with vectorized env...") - trainer.train(total_steps=400, log_interval=100, update_interval=20, batch_size=16) + print("\n4. Testing training...") + trainer.train(total_steps=200, log_interval=100, update_interval=20, batch_size=16) print(f" Replay buffer size: {len(algorithm.replay)}") - # Test evaluation - print("\n6. Testing evaluation...") - eval_results = trainer.evaluate(n_episodes=3) - print(f" Evaluation results: mean_reward={eval_results['mean_reward']:.2f}") - - # Test collect_rollout - print("\n7. Testing collect_rollout...") - rollout_stats = trainer.collect_rollout(n_steps=50) - print(f" Rollout stats: total_steps={rollout_stats['total_steps']}") + # Test 2: Multiple environment types (the main use case!) + print("\n" + "-" * 40) + print("Test 2: Multiple Environment Types") + print("(One algorithm, multiple env types, each with own collector & replay)") + print("-" * 40) - # Test with multiple environment types (dict) - print("\n8. Testing with multiple environment types (dict)...") + # Create envs for different types sim_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=20) for _ in range(4)] real_env_fns = [lambda: DummyMetaEnv(state_dim=10, action_dim=7, max_steps=10) for _ in range(2)] - multi_envs = { - 'sim': DummyVectorEnv(sim_env_fns), - 'real': DummyVectorEnv(real_env_fns), + sim_vec_env = SequentialVectorEnv(sim_env_fns) + real_vec_env = SequentialVectorEnv(real_env_fns) + print(f"\n1. Created sim env with {len(sim_vec_env)} parallel envs") + print(f" Created real env with {len(real_vec_env)} parallel envs") + + # Create replays for each env type + replays = { + 'sim': MetaReplay(capacity=5000, env_type='sim', n_envs=4, state_dim=10, action_dim=7), + 'real': MetaReplay(capacity=2000, env_type='real', n_envs=2, state_dim=10, action_dim=7), } - multi_env_trainer = SimpleTrainer( - envs=multi_envs, - algorithm=DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)), - reward_fn=IdentityReward() - ) - print(f" Multi-env trainer: {multi_env_trainer}") - print(f" env_types: {multi_env_trainer.get_env_types()}") - print(f" total_env_num: {multi_env_trainer.get_total_env_num()}") - print(f" sim env_num: {len(multi_env_trainer.get_env('sim'))}") - print(f" real env_num: {len(multi_env_trainer.get_env('real'))}") - - # Test with multiple algorithms - print("\n9. Testing with multiple algorithms...") - algorithms = [ - DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)), - DummyAlgorithm(meta_policy=DummyMetaPolicy(action_dim=7), replay=SimpleReplay(capacity=1000)) - ] - multi_algo_trainer = SimpleTrainer( - envs=vec_env, - algorithm=algorithms, + # Create algorithm with multi-replay + meta_policy2 = DummyMetaPolicy(action_dim=7) + algorithm2 = DummyAlgorithm(meta_policy=meta_policy2, replay=replays) + + # Create collectors for each env type + collectors = { + 'sim': DummyCollector(envs=sim_vec_env, algorithm=algorithm2), + 'real': DummyCollector(envs=real_vec_env, algorithm=algorithm2), + } + print(f"\n2. Created collectors for env types: {list(collectors.keys())}") + + # Create trainer with multiple collectors + multi_trainer = SimpleTrainer( + algorithm=algorithm2, + collectors=collectors, reward_fn=IdentityReward() ) - print(f" Multi-algorithm trainer: {multi_algo_trainer}") - multi_algo_trainer.train(total_steps=100, log_interval=50, update_interval=20, batch_size=8) + print(f"\n3. Created multi-env trainer: {multi_trainer}") + print(f" env_types: {multi_trainer.get_env_types()}") + print(f" total_env_num: {multi_trainer.get_total_env_num()}") + print(f" sim env_num: {len(multi_trainer.get_env('sim'))}") + print(f" real env_num: {len(multi_trainer.get_env('real'))}") + + # Test collect from specific env type + print("\n4. Testing collect from specific env type...") + sim_stats = multi_trainer.collect_rollout(n_steps=50, env_type='sim') + print(f" Sim collected: {sim_stats['total_steps']} steps") + print(f" Sim replay size: {len(replays['sim'])}") + print(f" Real replay size: {len(replays['real'])}") + + # Test collect from all env types + print("\n5. Testing collect from ALL env types...") + all_stats = multi_trainer.collect_rollout_all(n_steps=30) + print(f" Collected from: {list(all_stats.keys())}") + for env_type, stats in all_stats.items(): + print(f" {env_type}: {stats['total_steps']} steps") + print(f" Sim replay size: {len(replays['sim'])}") + print(f" Real replay size: {len(replays['real'])}") + + # Test training with multi-env + print("\n6. Testing training with multi-env...") + multi_trainer.train(total_steps=200, log_interval=100, update_interval=20, batch_size=16) + print(f" Final sim replay size: {len(replays['sim'])}") + print(f" Final real replay size: {len(replays['real'])}") + + # Test evaluation on specific env type + print("\n7. Testing evaluation on 'sim' env...") + eval_results = multi_trainer.evaluate(n_episodes=3, env_type='sim') + print(f" Mean reward: {eval_results['mean_reward']:.2f}") print("\n" + "=" * 60) print("All tests passed!") diff --git a/rl/utils/__init__.py b/rl/utils/__init__.py index 6388a62e..702a42f7 100644 --- a/rl/utils/__init__.py +++ b/rl/utils/__init__.py @@ -8,12 +8,15 @@ - Running statistics (mean, variance) for normalization - Advantage computation (GAE) - Discount reward computation +- Action post-processing helpers (ensure/clip actions) """ import numpy as np import torch from typing import Dict, Any, Optional, List, Union +from .action_utils import ensure_action, clip_action_to_space + def compute_gae( rewards: np.ndarray, diff --git a/rl/utils/action_utils.py b/rl/utils/action_utils.py new file mode 100644 index 00000000..2a812b60 --- /dev/null +++ b/rl/utils/action_utils.py @@ -0,0 +1,80 @@ +""" +Action utility helpers for RL algorithms. + +These helpers keep action post-processing outside MetaEnv, so algorithms can +apply different refinement strategies when needed. +""" + +from __future__ import annotations + +from typing import Callable, Optional, Tuple + +import numpy as np +import torch + + +def _get_action_space(env) -> Optional[object]: + """Best-effort extraction of an action_space from env or wrappers.""" + if env is None: + return None + if hasattr(env, "action_space"): + return env.action_space + if hasattr(env, "env") and hasattr(env.env, "action_space"): + return env.env.action_space + if hasattr(env, "envs") and env.envs: + first_env = env.envs[0] + if hasattr(first_env, "action_space"): + return first_env.action_space + if hasattr(first_env, "env") and hasattr(first_env.env, "action_space"): + return first_env.env.action_space + return None + + +def _get_action_bounds(env) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + action_space = _get_action_space(env) + if action_space is None: + return None, None + low = getattr(action_space, "low", None) + high = getattr(action_space, "high", None) + if low is None or high is None: + return None, None + return np.asarray(low), np.asarray(high) + + +def clip_action_to_space(env, action): + """Clip action to env action space bounds if available.""" + low, high = _get_action_bounds(env) + if low is None or high is None: + return action + if torch.is_tensor(action): + low_t = torch.as_tensor(low, device=action.device, dtype=action.dtype) + high_t = torch.as_tensor(high, device=action.device, dtype=action.dtype) + return torch.clamp(action, low_t, high_t) + return np.clip(action, low, high) + + +def ensure_action( + env, + action, + refine_fn: Optional[Callable[[object, object], object]] = None, + apply_tanh: bool = True, +): + """ + Ensure actions are valid for the environment. + + - Applies tanh to bound outputs to [-1, 1] (as in TD3 policies). + - Applies an optional refine function for env-specific processing. + - Clips to env action space bounds if available. + + Args: + env: Environment (or wrapper) providing action_space. + action: Tensor or numpy array action. + refine_fn: Optional callable (env, action) -> action for custom refinement. + apply_tanh: Whether to apply tanh to the action. + """ + if apply_tanh: + action = torch.tanh(action) if torch.is_tensor(action) else np.tanh(action) + if refine_fn is not None: + action = refine_fn(env, action) + return clip_action_to_space(env, action) + diff --git a/utils/__init__.py b/utils/__init__.py index 4adaa8f7..c93083f3 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -6,5 +6,26 @@ # Configure logging on package import from .logger import logger -__all__ = ['logger'] +# Exploration strategies +from .exploration import ( + BaseExploration, + NoExploration, + GaussianNoise, + OUNoise, + CustomNoise, + RandomExploration, + ExplorationScheduler, +) + +__all__ = [ + 'logger', + # Exploration + 'BaseExploration', + 'NoExploration', + 'GaussianNoise', + 'OUNoise', + 'CustomNoise', + 'RandomExploration', + 'ExplorationScheduler', +] From fa63cc3a841de8be305e9086a7ed37561134b61e Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Wed, 4 Feb 2026 02:44:36 +0800 Subject: [PATCH 4/6] [feat] Add train_rl script --- rl/algorithms/__init__.py | 40 ++- rl/algorithms/td3/__init__.py | 3 +- rl/collectors/base_collector.py | 230 +++++++++----- rl/trainers/__init__.py | 7 +- rl/trainers/base_trainer.py | 178 ++++++++++- rl/trainers/offpolicy_trainer.py | 260 ++++++++++++++++ rl/utils/__init__.py | 1 - rl/utils/action_utils.py | 43 ++- train_rl.py | 500 +++++++++++++++++++++++++++++++ 9 files changed, 1160 insertions(+), 102 deletions(-) create mode 100644 rl/trainers/offpolicy_trainer.py create mode 100644 train_rl.py diff --git a/rl/algorithms/__init__.py b/rl/algorithms/__init__.py index 6c83330b..d93542bf 100644 --- a/rl/algorithms/__init__.py +++ b/rl/algorithms/__init__.py @@ -15,21 +15,25 @@ This __init__.py provides factory functions for creating algorithms. """ -from typing import Type, Dict, Any +from typing import Type, Dict, Any, Optional, Tuple -# Registry for algorithm classes +# Registry for algorithm classes and their config classes _ALGORITHM_REGISTRY: Dict[str, Type] = {} +_CONFIG_REGISTRY: Dict[str, Type] = {} -def register_algorithm(name: str, algorithm_class: Type) -> None: +def register_algorithm(name: str, algorithm_class: Type, config_class: Type = None) -> None: """ - Register an algorithm class. + Register an algorithm class and its config class. Args: name: Algorithm name (e.g., 'ppo', 'sac') algorithm_class: Algorithm class to register + config_class: Config class for the algorithm (optional) """ _ALGORITHM_REGISTRY[name.lower()] = algorithm_class + if config_class is not None: + _CONFIG_REGISTRY[name.lower()] = config_class def get_algorithm_class(name_or_type: str) -> Type: @@ -66,6 +70,32 @@ def get_algorithm_class(name_or_type: str) -> Type: raise ValueError(f"Unknown algorithm: '{name_or_type}'. Available: {list(_ALGORITHM_REGISTRY.keys())}") +def get_config_class(name: str) -> Optional[Type]: + """ + Get config class for an algorithm. + + Args: + name: Algorithm name (e.g., 'td3') + + Returns: + Config class or None if not registered + """ + return _CONFIG_REGISTRY.get(name.lower()) + + +def get_algorithm_and_config(name: str) -> Tuple[Type, Optional[Type]]: + """ + Get both algorithm class and config class. + + Args: + name: Algorithm name + + Returns: + Tuple of (algorithm_class, config_class) + """ + return get_algorithm_class(name), get_config_class(name) + + def list_algorithms() -> list: """List all registered algorithms.""" return list(_ALGORITHM_REGISTRY.keys()) @@ -74,6 +104,8 @@ def list_algorithms() -> list: __all__ = [ 'register_algorithm', 'get_algorithm_class', + 'get_config_class', + 'get_algorithm_and_config', 'list_algorithms', ] diff --git a/rl/algorithms/td3/__init__.py b/rl/algorithms/td3/__init__.py index c42799bb..2660f4aa 100644 --- a/rl/algorithms/td3/__init__.py +++ b/rl/algorithms/td3/__init__.py @@ -1,7 +1,8 @@ from .td3 import TD3Algorithm, TD3Config from .. import register_algorithm -register_algorithm("td3", TD3Algorithm) +# Register algorithm with its config class +register_algorithm("td3", TD3Algorithm, TD3Config) __all__ = ["TD3Algorithm", "TD3Config"] diff --git a/rl/collectors/base_collector.py b/rl/collectors/base_collector.py index d8e595b1..829ac6f9 100644 --- a/rl/collectors/base_collector.py +++ b/rl/collectors/base_collector.py @@ -19,7 +19,9 @@ # Type hints for Meta classes from benchmark.base import MetaObs, MetaAction - +from benchmark.base import MetaObs, MetaAction +from benchmark.utils import organize_obs +from rl.buffer.transition import RLTransition # Vector environment protocol from rl.envs import VectorEnvProtocol, EnvsType @@ -352,24 +354,165 @@ def __repr__(self) -> str: class DummyCollector(BaseCollector): """ - Simple collector implementation for testing and basic use cases. + Simple collector implementation for off-policy RL algorithms. This collector: - Works with vectorized environments (SequentialVectorEnv, SubprocVectorEnv, etc.) - Handles batched observations and actions - Creates RLTransition objects for storage in replay buffer - Tracks episode statistics + - Supports both step-by-step and batch collection """ - def __init__(self, envs, algorithm, **kwargs): + def __init__(self, envs, algorithm, ctrl_space='joint', action_dim=None, **kwargs): super().__init__(envs, algorithm, **kwargs) self.vec_env = self.get_env() self._last_obs = None + self.ctrl_space = ctrl_space + self.action_dim = action_dim + + # Episode tracking + self._episode_rewards = None + self._episode_lengths = None def reset(self, **kwargs): """Reset the collector and environments.""" self.vec_env = self.get_env() self._last_obs = self.vec_env.reset() + + # Initialize episode tracking + num_envs = len(self.vec_env) + self._episode_rewards = np.zeros(num_envs, dtype=np.float32) + self._episode_lengths = np.zeros(num_envs, dtype=np.int32) + + def collect_step( + self, + noise_scale: float = 0.0, + use_random: bool = False, + env_type: str = None + ) -> Dict[str, Any]: + """ + Collect one step of interaction data. + + This is the primary method for off-policy training where we collect + one step at a time and interleave with policy updates. + + Args: + noise_scale: Exploration noise scale (for policy actions) + use_random: If True, use random actions (for initial exploration) + env_type: Optional environment type identifier + + Returns: + Dictionary with statistics for this step + """ + if self._last_obs is None: + self.reset() + + from benchmark.base import MetaObs, MetaAction + from benchmark.utils import organize_obs + from rl.buffer.transition import RLTransition + + stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0} + + # Organize observations into batched MetaObs + obs = self._last_obs + if not isinstance(obs, MetaObs): + obs = organize_obs(obs, self.ctrl_space) + + num_envs = len(self.vec_env) + + # Select action + if use_random: + # Random exploration + action_dim = self.action_dim + if action_dim is None: + # Try to infer from environment + single_env = self.vec_env.envs[0] if hasattr(self.vec_env, 'envs') else self.vec_env + if hasattr(single_env, 'action_space'): + action_dim = single_env.action_space.shape[0] + elif hasattr(single_env, 'action_dim'): + action_dim = single_env.action_dim + else: + raise ValueError("Cannot determine action_dim for random exploration") + action_array = np.random.uniform(-1, 1, (num_envs, action_dim)).astype(np.float32) + action = MetaAction(action=action_array) + else: + # Policy action with exploration noise + action = self.algorithm.select_action(obs, noise_scale=noise_scale, env=self.vec_env) + + # Unpack action to numpy array + if hasattr(action, 'action'): + action_array = action.action + else: + action_array = action + + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Convert to list of dicts for vec_env.step + step_actions = [{'action': action_array[i]} for i in range(num_envs)] + + # Step environment + next_obs, rewards, dones, infos = self.vec_env.step(step_actions) + + # Handle infos which may be None, list, or numpy array + if infos is not None and len(infos) > 0: + truncated = np.array([ + info.get('TimeLimit.truncated', False) if isinstance(info, dict) else False + for info in infos + ]) + else: + truncated = np.zeros_like(dones) + + # Organize next observations + if not isinstance(next_obs, MetaObs): + next_obs = organize_obs(next_obs, self.ctrl_space) + + # Ensure action is MetaAction + if not isinstance(action, MetaAction): + action = MetaAction(action=action_array) + + # Create transition and record + transition = RLTransition( + obs=obs, + action=action, + next_obs=next_obs, + reward=rewards, + done=dones, + truncated=truncated, + ) + + kwargs_trans = {'env_type': env_type} if env_type else {} + self.algorithm.record_transition(transition, **kwargs_trans) + + # Update episode statistics + self._episode_rewards += rewards + self._episode_lengths += 1 + stats['total_steps'] = num_envs + self._total_steps += num_envs + + # Handle done episodes + done_indices = np.where(dones)[0] + if len(done_indices) > 0: + stats['episode_rewards'] = self._episode_rewards[done_indices].tolist() + stats['episode_lengths'] = self._episode_lengths[done_indices].tolist() + + # Reset tracking for done envs + self._episode_rewards[done_indices] = 0 + self._episode_lengths[done_indices] = 0 + + # Reset done environments and update next_obs + reset_obs = self.vec_env.reset(id=done_indices) + if reset_obs is not None: + reset_obs_organized = organize_obs(reset_obs, self.ctrl_space) if not isinstance(reset_obs, MetaObs) else reset_obs + if isinstance(next_obs, MetaObs) and next_obs.state is not None: + if hasattr(reset_obs_organized, 'state') and reset_obs_organized.state is not None: + next_obs.state[done_indices] = reset_obs_organized.state + + # Update last observation + self._last_obs = next_obs + + return stats def collect(self, n_steps, env_type=None): """ @@ -386,84 +529,13 @@ def collect(self, n_steps, env_type=None): - total_steps: Total number of steps collected - env_type: Environment type identifier """ - if self._last_obs is None: - self.reset() - - from benchmark.base import dict2meta, MetaObs, MetaAction - from benchmark.utils import organize_obs - from rl.buffer.transition import RLTransition - stats = {'episode_rewards': [], 'episode_lengths': [], 'total_steps': 0, 'env_type': env_type} - episode_rewards = np.zeros(len(self.vec_env)) - episode_lengths = np.zeros(len(self.vec_env), dtype=int) - for step in range(n_steps): - # Organize observations into batched MetaObs - batched_obs = organize_obs(self._last_obs) if not isinstance(self._last_obs, MetaObs) else self._last_obs - - # Get batched actions for all environments at once - # actions is expected to be an object array of dicts (from MetaPolicy) - actions = self.algorithm.select_action(batched_obs) - - # Apply exploration if configured - if self.exploration is not None: - if self.is_exploring_randomly: - stats['random_steps'] = stats.get('random_steps', 0) + len(self.vec_env) - actions = self.apply_exploration(actions, obs=batched_obs) - - # Step all environments (expects array of dicts, one per env) - new_obs, rewards, dones, infos = self.vec_env.step(actions) - - # Organize next observations into batched MetaObs - batched_next_obs = organize_obs(new_obs) if not isinstance(new_obs, MetaObs) else new_obs - - # Reconstruct MetaAction for storage (since actions is now an object array) - # We need to extract the raw action arrays from the dicts - if isinstance(actions, np.ndarray) and actions.dtype == object: - # Extract 'action' field from each dict - raw_actions = np.stack([a['action'] for a in actions]) - # Create MetaAction - stored_actions = MetaAction(action=raw_actions) - elif isinstance(actions, MetaAction): - stored_actions = actions - else: - stored_actions = dict2meta(actions, mtype='act') - - # Create single RLTransition for all environments - truncated = np.array([infos[i].get('TimeLimit.truncated', False) for i in range(len(infos))]) if infos else np.zeros_like(dones) - - transition = RLTransition( - obs=batched_obs, - action=stored_actions, - next_obs=batched_next_obs, - reward=rewards, - done=dones, - truncated=truncated, - info=infos if infos else None, - ) - - kwargs_trans = {'env_type': env_type} if env_type else {} - self.algorithm.record_transition(transition, **kwargs_trans) - - # Update episode statistics - episode_rewards += rewards - episode_lengths += 1 - stats['total_steps'] += len(self.vec_env) - self._total_steps += len(self.vec_env) - - # Handle done environments - done_indices = np.where(dones)[0] - if len(done_indices) > 0: - stats['episode_rewards'].extend(episode_rewards[done_indices].tolist()) - stats['episode_lengths'].extend(episode_lengths[done_indices].tolist()) - episode_rewards[done_indices] = 0.0 - episode_lengths[done_indices] = 0 - # Reset done environments - for idx in done_indices: - new_obs[idx] = self.vec_env.reset(id=idx) - - # Reorganize observations for next iteration - self._last_obs = organize_obs(new_obs) if not isinstance(new_obs, MetaObs) else new_obs + for _ in range(n_steps): + step_stats = self.collect_step(env_type=env_type) + stats['episode_rewards'].extend(step_stats.get('episode_rewards', [])) + stats['episode_lengths'].extend(step_stats.get('episode_lengths', [])) + stats['total_steps'] += step_stats.get('total_steps', 0) return stats diff --git a/rl/trainers/__init__.py b/rl/trainers/__init__.py index 4dd6dff3..00c94d85 100644 --- a/rl/trainers/__init__.py +++ b/rl/trainers/__init__.py @@ -21,9 +21,12 @@ from typing import Type, Dict, Any from .base_trainer import BaseTrainer +from .offpolicy_trainer import OffPolicyTrainer, OffPolicyTrainerConfig # Registry for trainer classes -_TRAINER_REGISTRY: Dict[str, Type] = {} +_TRAINER_REGISTRY: Dict[str, Type] = { + 'offpolicy': OffPolicyTrainer, +} def register_trainer(name: str, trainer_class: Type) -> None: @@ -78,6 +81,8 @@ def list_trainers() -> list: __all__ = [ 'BaseTrainer', + 'OffPolicyTrainer', + 'OffPolicyTrainerConfig', 'register_trainer', 'get_trainer_class', 'list_trainers', diff --git a/rl/trainers/base_trainer.py b/rl/trainers/base_trainer.py index 2be75ca0..17076a98 100644 --- a/rl/trainers/base_trainer.py +++ b/rl/trainers/base_trainer.py @@ -254,7 +254,6 @@ def collect_rollout_all(self, n_steps: int) -> Dict[str, Dict[str, Any]]: """ return self.collect_rollout(n_steps, env_type=None) - @abstractmethod def evaluate( self, n_episodes: int = 10, @@ -265,26 +264,187 @@ def evaluate( """ Evaluate policy performance. + Default implementation uses collector's environment and algorithm's select_action. + Subclasses can override for custom evaluation logic. + Args: n_episodes: Number of episodes to evaluate render: Whether to render environment (optional) env_type: Optional, specify evaluation environment type - **kwargs: Other evaluation parameters + **kwargs: Other evaluation parameters (vec_env, max_timesteps, ctrl_space, etc.) Returns: Dictionary containing evaluation metrics """ - raise NotImplementedError + from benchmark.utils import organize_obs + + # Get evaluation environment + vec_env = kwargs.get('vec_env', None) + if vec_env is None: + # Try to get from collector + collector = self.get_collector(env_type) + if collector is not None: + vec_env = collector.get_env() + else: + return {'error': 'No evaluation environment available'} + + max_timesteps = kwargs.get('max_timesteps', 1000) + ctrl_space = kwargs.get('ctrl_space', 'joint') + + num_envs = len(vec_env) + all_returns = [] + all_lengths = [] + episodes_completed = 0 + + # Set algorithm to eval mode + self.algorithm.eval_mode() + + while episodes_completed < n_episodes: + obs = vec_env.reset() + obs = organize_obs(obs, ctrl_space) + + episode_returns = np.zeros(num_envs, dtype=np.float32) + episode_lengths = np.zeros(num_envs, dtype=np.int32) + done_mask = np.zeros(num_envs, dtype=bool) + + for t in range(max_timesteps): + with torch.no_grad(): + action = self.algorithm.select_action(obs, noise_scale=0.0, env=vec_env) + + # Unpack action + if hasattr(action, 'action'): + action_array = action.action + else: + action_array = action + + if action_array.ndim == 1: + action_array = action_array[np.newaxis, :] + + # Convert to list of dicts for vec_env.step + step_actions = [{'action': action_array[i]} for i in range(len(action_array))] + + # Step environment + next_obs, rewards, dones, infos = vec_env.step(step_actions) + next_obs = organize_obs(next_obs, ctrl_space) + + # Accumulate rewards for active episodes + episode_returns += rewards * (~done_mask) + episode_lengths += (~done_mask).astype(np.int32) + + # Check for newly done episodes + newly_done = dones & (~done_mask) + if newly_done.any(): + for idx in np.where(newly_done)[0]: + all_returns.append(episode_returns[idx]) + all_lengths.append(episode_lengths[idx]) + episodes_completed += 1 + + done_mask = done_mask | dones + + # Reset done environments + if dones.any(): + done_indices = np.where(dones)[0] + reset_obs = vec_env.reset(id=done_indices) + if reset_obs is not None: + reset_obs_organized = organize_obs(reset_obs, ctrl_space) + if isinstance(next_obs, MetaObs) and next_obs.state is not None: + if hasattr(reset_obs_organized, 'state') and reset_obs_organized.state is not None: + next_obs.state[done_indices] = reset_obs_organized.state + episode_returns[done_indices] = 0 + episode_lengths[done_indices] = 0 + done_mask[done_indices] = False + + obs = next_obs + + if episodes_completed >= n_episodes: + break + + # Handle truncated episodes + for idx in range(num_envs): + if not done_mask[idx] and episode_lengths[idx] > 0: + all_returns.append(episode_returns[idx]) + all_lengths.append(episode_lengths[idx]) + episodes_completed += 1 + if episodes_completed >= n_episodes: + break + + # Set back to train mode + self.algorithm.train_mode() + + # Compute statistics + all_returns = np.array(all_returns[:n_episodes]) + all_lengths = np.array(all_lengths[:n_episodes]) + + return { + 'returns': all_returns.tolist(), + 'mean_return': float(np.mean(all_returns)), + 'std_return': float(np.std(all_returns)), + 'min_return': float(np.min(all_returns)), + 'max_return': float(np.max(all_returns)), + 'episode_lengths': all_lengths.tolist(), + 'mean_length': float(np.mean(all_lengths)), + 'num_episodes': len(all_returns), + } - @abstractmethod def save(self, path: str) -> None: - """Save model and training state.""" - raise NotImplementedError + """ + Save model and training state. + + Default implementation saves the algorithm's model. + Subclasses can override to save additional state. + + Args: + path: Path to save the model + """ + self.algorithm.save(path) - @abstractmethod def load(self, path: str) -> None: - """Load model and training state.""" - raise NotImplementedError + """ + Load model and training state. + + Default implementation loads the algorithm's model. + Subclasses can override to load additional state. + + Args: + path: Path to load the model from + """ + self.algorithm.load(path) + + def save_checkpoint(self, path: str, step: int) -> None: + """ + Save training checkpoint with step information. + + Args: + path: Directory or full path to save checkpoint + step: Current training step number + """ + import os + if os.path.isdir(path): + ckpt_path = os.path.join(path, f'checkpoint_{step}.pt') + else: + ckpt_path = path + self.save(ckpt_path) + + def load_checkpoint(self, path: str) -> int: + """ + Load checkpoint and return the step number. + + Extracts step number from checkpoint filename (e.g., checkpoint_10000.pt -> 10000). + + Args: + path: Path to checkpoint file + + Returns: + Step number extracted from checkpoint filename, or 0 if not found + """ + import os + self.load(path) + # Try to extract step from checkpoint name + try: + step = int(os.path.basename(path).split('_')[-1].split('.')[0]) + except (ValueError, IndexError): + step = 0 + return step def get_algorithm(self) -> 'BaseAlgorithm': """Get the algorithm.""" diff --git a/rl/trainers/offpolicy_trainer.py b/rl/trainers/offpolicy_trainer.py new file mode 100644 index 00000000..cadd7545 --- /dev/null +++ b/rl/trainers/offpolicy_trainer.py @@ -0,0 +1,260 @@ +""" +Off-Policy Trainer + +Trainer for off-policy RL algorithms (TD3, SAC, DDPG, etc.) + +Design: +- Separates data collection (Collector) from policy updates (Trainer) +- Supports evaluation during training +- Handles checkpointing and logging +""" + +import os +import json +import numpy as np +from typing import Dict, Any, Optional, Callable, Type +from dataclasses import dataclass +from loguru import logger +from tqdm import tqdm + +from .base_trainer import BaseTrainer +from rl.collectors import BaseCollector +from rl.algorithms.base import BaseAlgorithm +from benchmark.utils import SequentialVectorEnv + + +@dataclass +class OffPolicyTrainerConfig: + """Configuration for OffPolicyTrainer.""" + total_steps: int = 1000000 + start_steps: int = 25000 # Random exploration steps + update_after: int = 1000 # Start updating after this many steps + update_every: int = 50 # Update every N steps + batch_size: int = 256 + + # Logging and evaluation + log_freq: int = 1000 + eval_freq: int = 10000 + eval_episodes: int = 10 + save_freq: int = 50000 + + # Output + output_dir: str = 'ckpt/rl_training' + + # Environment settings + max_timesteps: int = 1000 + ctrl_space: str = 'joint' + + # Exploration + expl_noise: float = 0.1 + + +class OffPolicyTrainer(BaseTrainer): + """ + Trainer for off-policy RL algorithms. + + Coordinates: + - Collector: Gathers experience from environment + - Algorithm: Updates policy using collected data + - Evaluation: Periodically evaluates policy performance + """ + + def __init__( + self, + algorithm: BaseAlgorithm, + collector: BaseCollector, + config: Optional[OffPolicyTrainerConfig] = None, + eval_env_fn: Optional[Callable] = None, + eval_vec_env_cls: Optional[Type] = None, + **kwargs + ): + """ + Initialize OffPolicyTrainer. + + Args: + algorithm: RL algorithm (TD3, SAC, etc.) + collector: Data collector for environment interaction + config: Trainer configuration + eval_env_fn: Factory function for creating evaluation environments + eval_vec_env_cls: Vector environment class for evaluation (default: SequentialVectorEnv) + **kwargs: Additional arguments + """ + super().__init__(algorithm=algorithm, collectors=collector, **kwargs) + + self.config = config or OffPolicyTrainerConfig() + self.collector = collector + self.eval_env_fn = eval_env_fn + self.eval_vec_env_cls = eval_vec_env_cls or SequentialVectorEnv + + # Training state + self._current_step = 0 + self._episode_count = 0 + self._recent_rewards = [] + + def train(self, resume_step: int = 0) -> Dict[str, Any]: + """ + Run the training loop. + + Args: + resume_step: Step to resume from (for checkpoint resumption) + + Returns: + Training statistics + """ + os.makedirs(self.config.output_dir, exist_ok=True) + + # Save config + config_path = os.path.join(self.config.output_dir, 'config.json') + with open(config_path, 'w') as f: + config_dict = {k: v for k, v in vars(self.config).items()} + json.dump(config_dict, f, indent=2, default=str) + + self._current_step = resume_step + + logger.info("=" * 60) + logger.info(f"Starting training for {self.config.total_steps} steps") + logger.info("=" * 60) + + # Reset collector + self.collector.reset() + + pbar = tqdm( + range(resume_step, self.config.total_steps), + desc="Training", + initial=resume_step, + total=self.config.total_steps + ) + + for step in pbar: + self._current_step = step + + # Collect one step of data + # Use random exploration for initial steps + use_random = step < self.config.start_steps + collect_stats = self.collector.collect_step( + noise_scale=self.config.expl_noise if not use_random else None, + use_random=use_random, + ) + + # Update episode tracking + if collect_stats.get('episode_rewards'): + self._recent_rewards.extend(collect_stats['episode_rewards']) + self._episode_count += len(collect_stats['episode_rewards']) + # Keep only recent 100 rewards + self._recent_rewards = self._recent_rewards[-100:] + + # Policy update + if step >= self.config.update_after and step % self.config.update_every == 0: + for _ in range(self.config.update_every): + self.algorithm.update( + batch_size=self.config.batch_size, + env=self.collector.get_env() + ) + + # Logging + if step > 0 and step % self.config.log_freq == 0: + avg_reward = np.mean(self._recent_rewards) if self._recent_rewards else 0.0 + pbar.set_postfix({ + 'episodes': self._episode_count, + 'avg_reward': f'{avg_reward:.2f}', + 'buffer': len(self.algorithm.replay) if self.algorithm.replay else 0, + }) + + # Evaluation + if step > 0 and step % self.config.eval_freq == 0: + self._run_evaluation(step) + + # Save checkpoint + if step > 0 and step % self.config.save_freq == 0: + self._save_checkpoint(step) + + # Final save + final_path = os.path.join(self.config.output_dir, 'final_model.pt') + self.algorithm.save(final_path) + logger.info(f"Training complete. Final model saved to {final_path}") + + return { + 'total_steps': self.config.total_steps, + 'episode_count': self._episode_count, + 'final_avg_reward': np.mean(self._recent_rewards) if self._recent_rewards else 0.0, + } + + def _run_evaluation(self, step: int) -> Dict[str, Any]: + """Run evaluation and log results.""" + logger.info(f"Step {step}: Running evaluation...") + + if self.eval_env_fn is None: + logger.warning("No eval_env_fn provided, skipping evaluation") + return {} + + # Run evaluation (evaluate method handles env creation/closing) + eval_result = self.evaluate(n_episodes=self.config.eval_episodes) + + if not eval_result or 'error' in eval_result: + return eval_result + + # Log results + logger.info( + f"Step {step}: Eval return = {eval_result['mean_return']:.2f} ± {eval_result['std_return']:.2f}, " + f"length = {eval_result['mean_length']:.1f}, " + f"episodes = {eval_result['num_episodes']}" + ) + + # Save eval results + eval_path = os.path.join(self.config.output_dir, 'eval_results.json') + eval_data = {'step': step, **eval_result} + with open(eval_path, 'a') as f: + f.write(json.dumps(eval_data, default=lambda x: x.tolist() if hasattr(x, 'tolist') else x) + '\n') + + return eval_result + + def evaluate( + self, + n_episodes: int = 10, + render: bool = False, + env_type: Optional[str] = None, + **kwargs + ) -> Dict[str, Any]: + """ + Evaluate policy performance. + + Overrides base class to support creating eval env from eval_env_fn. + + Args: + n_episodes: Number of episodes to evaluate + render: Whether to render environment (not used currently) + env_type: Optional environment type (not used currently) + **kwargs: Additional arguments (e.g., vec_env, max_timesteps) + + Returns: + Dictionary containing evaluation metrics + """ + vec_env = kwargs.get('vec_env', None) + should_close = False + + # Create evaluation environment if not provided + if vec_env is None and self.eval_env_fn is not None: + eval_env_fns = [self.eval_env_fn for _ in range(n_episodes)] + vec_env = self.eval_vec_env_cls(eval_env_fns) + should_close = True + + # Set default max_timesteps and ctrl_space from config + if 'max_timesteps' not in kwargs: + kwargs['max_timesteps'] = self.config.max_timesteps + if 'ctrl_space' not in kwargs: + kwargs['ctrl_space'] = self.config.ctrl_space + + # Pass vec_env to base class evaluate + kwargs['vec_env'] = vec_env + result = super().evaluate(n_episodes=n_episodes, render=render, env_type=env_type, **kwargs) + + if should_close and vec_env is not None: + vec_env.close() + + return result + + def _save_checkpoint(self, step: int) -> None: + """Save training checkpoint.""" + self.save_checkpoint(self.config.output_dir, step) + logger.info(f"Saved checkpoint to {self.config.output_dir}/checkpoint_{step}.pt") + diff --git a/rl/utils/__init__.py b/rl/utils/__init__.py index 702a42f7..5c773d5a 100644 --- a/rl/utils/__init__.py +++ b/rl/utils/__init__.py @@ -15,7 +15,6 @@ import torch from typing import Dict, Any, Optional, List, Union -from .action_utils import ensure_action, clip_action_to_space def compute_gae( diff --git a/rl/utils/action_utils.py b/rl/utils/action_utils.py index 2a812b60..7e47d92e 100644 --- a/rl/utils/action_utils.py +++ b/rl/utils/action_utils.py @@ -11,7 +11,7 @@ import numpy as np import torch - +from loguru import logger def _get_action_space(env) -> Optional[object]: """Best-effort extraction of an action_space from env or wrappers.""" @@ -52,12 +52,36 @@ def clip_action_to_space(env, action): return torch.clamp(action, low_t, high_t) return np.clip(action, low, high) +def tanh_action_to_space(env, action): + """ + Apply tanh to action and scale to env action space bounds. + + Maps tanh output [-1, 1] to [action_space.low, action_space.high]. + If no action space bounds available, just applies tanh. + """ + # Apply tanh to bound to [-1, 1] + if torch.is_tensor(action): + tanh_action = torch.tanh(action) + else: + tanh_action = np.tanh(action) + + # Scale to action space bounds + low, high = _get_action_bounds(env) + if low is None or high is None: + return tanh_action + + # Map [-1, 1] -> [low, high]: scaled = low + (tanh + 1) * (high - low) / 2 + if torch.is_tensor(tanh_action): + low_t = torch.as_tensor(low, device=tanh_action.device, dtype=tanh_action.dtype) + high_t = torch.as_tensor(high, device=tanh_action.device, dtype=tanh_action.dtype) + return low_t + (tanh_action + 1.0) * (high_t - low_t) / 2.0 + else: + return low + (tanh_action + 1.0) * (high - low) / 2.0 def ensure_action( env, action, refine_fn: Optional[Callable[[object, object], object]] = None, - apply_tanh: bool = True, ): """ Ensure actions are valid for the environment. @@ -72,9 +96,14 @@ def ensure_action( refine_fn: Optional callable (env, action) -> action for custom refinement. apply_tanh: Whether to apply tanh to the action. """ - if apply_tanh: - action = torch.tanh(action) if torch.is_tensor(action) else np.tanh(action) - if refine_fn is not None: - action = refine_fn(env, action) - return clip_action_to_space(env, action) + try: + reasonable = env.envs[0].ensure_action_reasonable(action) + except ValueError: + logger.warning(f"Environment {env.__class__.__name__} does not support ensure_action_reasonable, using tanh_action_to_space") + if reasonable is False: + if refine_fn is not None: + action = refine_fn(env, action) + else: + raise ValueError(f"Action {action} is not reasonable") + return action \ No newline at end of file diff --git a/train_rl.py b/train_rl.py new file mode 100644 index 00000000..bce66dce --- /dev/null +++ b/train_rl.py @@ -0,0 +1,500 @@ +""" +Reinforcement Learning Training Script + +This script provides a unified interface for training RL algorithms. +It supports: +- Multiple registered algorithms (TD3, SAC, PPO, etc.) +- Vectorized environments for parallel data collection +- Flexible configuration via YAML files and CLI overrides +- Checkpoint saving and resuming + +Usage: + python train_rl.py -a td3 -e aloha --num_envs 4 --total_steps 1000000 + python train_rl.py -a configs/rl/td3.yaml -e configs/env/custom.yaml -o ckpt/td3_experiment +""" + +import configs # Must be first to suppress TensorFlow logs +import os +import argparse +import json +import importlib +from loguru import logger +import numpy as np +import torch +from tqdm import tqdm + +from data_utils.utils import set_seed +from benchmark.utils import SequentialVectorEnv, organize_obs + + +def parse_args(): + """Parse command line arguments for RL training.""" + parser = argparse.ArgumentParser(description='Train an RL algorithm') + + # Algorithm arguments + parser.add_argument('-a', '--algorithm', type=str, default='td3', + help='Algorithm name (registered) or path to YAML config file') + + # Environment arguments + parser.add_argument('-e', '--env', type=str, default='aloha', + help='Env config (name under configs/env or absolute path to yaml)') + parser.add_argument('--num_envs', type=int, default=1, + help='Number of parallel environments') + parser.add_argument('--use_subproc', action='store_true', + help='Use SubprocVectorEnv instead of SequentialVectorEnv') + + # Training arguments + parser.add_argument('--total_steps', type=int, default=1000000, + help='Total environment steps for training') + parser.add_argument('--start_steps', type=int, default=25000, + help='Number of random exploration steps before policy is used') + parser.add_argument('--update_after', type=int, default=1000, + help='Number of steps before starting policy updates') + parser.add_argument('--update_every', type=int, default=50, + help='Update policy every N environment steps') + parser.add_argument('--batch_size', type=int, default=256, + help='Batch size for policy updates') + parser.add_argument('--replay_size', type=int, default=1000000, + help='Replay buffer capacity') + parser.add_argument('--expl_noise', type=float, default=0.1, + help='Exploration noise scale (can be overridden by algo config)') + + # Output and logging + parser.add_argument('-o', '--output_dir', type=str, default='ckpt/rl_training', + help='Output directory for checkpoints and logs') + parser.add_argument('--save_freq', type=int, default=50000, + help='Save checkpoint every N steps') + parser.add_argument('--eval_freq', type=int, default=10000, + help='Evaluate policy every N steps') + parser.add_argument('--eval_episodes', type=int, default=10, + help='Number of episodes for evaluation') + parser.add_argument('--log_freq', type=int, default=1000, + help='Log training stats every N steps') + + # Misc + parser.add_argument('-s', '--seed', type=int, default=0, + help='Random seed') + parser.add_argument('--device', type=str, default='cuda', + help='Device to use (cuda or cpu)') + parser.add_argument('--resume', type=str, default=None, + help='Path to checkpoint to resume from') + + args, unknown = parser.parse_known_args() + args._unknown = unknown + return args + + +def load_env_config(args): + """Load environment configuration from YAML.""" + from configs.loader import ConfigLoader + cfg_loader = ConfigLoader(args=args, unknown_args=getattr(args, '_unknown', [])) + env_cfg, env_cfg_path = cfg_loader.load_env(args.env) + return env_cfg, env_cfg_path + + +def load_algo_config(args): + """ + Load algorithm configuration from YAML. + + Priority: + 1. If path ends with .yaml/.yml, load directly + 2. If algorithm name (e.g., 'td3'), try to load from configs/rl/{name}.yaml + 3. If no config file found, use defaults + """ + import yaml + + algo_name_or_path = args.algorithm + + # Check if it's a path to a YAML config file + if algo_name_or_path.endswith('.yaml') or algo_name_or_path.endswith('.yml'): + # Load from specified YAML config file + if os.path.exists(algo_name_or_path): + with open(algo_name_or_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + algo_cfg_path = algo_name_or_path + else: + # Try configs/rl directory + config_path = os.path.join('configs/rl', algo_name_or_path) + if os.path.exists(config_path): + with open(config_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + algo_cfg_path = config_path + else: + raise FileNotFoundError(f"Algorithm config file not found: {algo_name_or_path}") + + logger.info(f"Loaded algorithm config from: {algo_cfg_path}") + return algo_cfg, algo_cfg_path + else: + # Try to load from configs/rl/{algo_name}.yaml + config_path = os.path.join('configs/rl', f'{algo_name_or_path}.yaml') + if os.path.exists(config_path): + with open(config_path, 'r') as f: + algo_cfg = yaml.safe_load(f) or {} + logger.info(f"Loaded algorithm config from: {config_path}") + return algo_cfg, config_path + else: + # No config file, use defaults + logger.info(f"No config file found for '{algo_name_or_path}', using default parameters") + return {'type': algo_name_or_path}, None + + +def create_env_fn(env_cfg, env_module): + """Create a factory function for environment creation.""" + def _create(): + return env_module.create_env(env_cfg) + return _create + + +def create_vector_env(env_cfg, num_envs, use_subproc=False): + """Create vectorized environment.""" + # Parse env type + env_type = env_cfg.type + if '.' in env_type: + module_path, class_name = env_type.rsplit('.', 1) + env_module = importlib.import_module(module_path) + env_name = module_path.split('.')[-1] if '.' in module_path else module_path + else: + env_module = importlib.import_module(f"benchmark.{env_type}") + env_name = env_type + + if not hasattr(env_module, 'create_env'): + raise AttributeError(f"env module {env_type} has no 'create_env'") + + env_fns = [create_env_fn(env_cfg, env_module) for _ in range(num_envs)] + + if use_subproc and num_envs > 1: + from tianshou.env import SubprocVectorEnv + vec_env = SubprocVectorEnv(env_fns) + else: + vec_env = SequentialVectorEnv(env_fns) + + return vec_env, env_name, env_module + + +def get_env_dims(env_cfg, vec_env=None): + """ + Get state and action dimensions from environment config. + + Priority: + 1. Read from env_cfg (state_dim, action_dim) - preferred + 2. Infer from vec_env if not specified in config - fallback + + Args: + env_cfg: Environment configuration (namespace or dict) + vec_env: Optional vectorized environment for fallback inference + + Returns: + (state_dim, action_dim): Tuple of dimensions + """ + state_dim = None + action_dim = None + + # 1. Try to get from env_cfg first (preferred) + if hasattr(env_cfg, 'state_dim'): + state_dim = env_cfg.state_dim + logger.info(f"Read state_dim={state_dim} from env config") + if hasattr(env_cfg, 'action_dim'): + action_dim = env_cfg.action_dim + logger.info(f"Read action_dim={action_dim} from env config") + + # 2. Fallback: infer from environment if not specified + if (state_dim is None or action_dim is None) and vec_env is not None: + logger.warning("state_dim or action_dim not specified in env config, inferring from environment...") + + if state_dim is None: + # Reset to get observation + obs = vec_env.reset() + + # Handle different observation formats + if hasattr(obs, 'state'): + state_dim = obs.state.shape[-1] if hasattr(obs.state, 'shape') else len(obs.state) + elif isinstance(obs, np.ndarray): + if obs.ndim == 1: + state_dim = obs.shape[0] + else: + # For vectorized env, obs[0] is single env obs + first_obs = obs[0] if obs.dtype == np.object_ else obs[0] + if hasattr(first_obs, 'state'): + state_dim = first_obs.state.shape[-1] + else: + state_dim = first_obs.shape[-1] if hasattr(first_obs, 'shape') else len(first_obs) + elif isinstance(obs, dict): + state_dim = obs.get('state', obs.get('observation')).shape[-1] + else: + # Try to get first env's observation + first_obs = obs[0] if hasattr(obs, '__getitem__') else obs + if hasattr(first_obs, 'state'): + state_dim = first_obs.state.shape[-1] + else: + state_dim = len(first_obs) if hasattr(first_obs, '__len__') else 1 + logger.info(f"Inferred state_dim={state_dim} from environment") + + if action_dim is None: + # Get action dimension from action_space if available + single_env = vec_env.envs[0] if hasattr(vec_env, 'envs') else vec_env + if hasattr(single_env, 'action_space'): + action_dim = single_env.action_space.shape[0] + elif hasattr(single_env, 'env') and hasattr(single_env.env, 'action_space'): + action_dim = single_env.env.action_space.shape[0] + else: + # Default fallback + logger.warning("Could not determine action_dim from env, defaulting to state_dim") + action_dim = state_dim + logger.info(f"Inferred action_dim={action_dim} from environment") + + if state_dim is None or action_dim is None: + raise ValueError("Could not determine state_dim and action_dim. Please specify them in env config.") + + return state_dim, action_dim + + +def create_meta_policy(args, algo_cfg, state_dim, action_dim): + """ + Create MetaPolicy for the algorithm. + + Similar to how load_policy() works in policy/utils.py. + """ + from benchmark.base import MetaPolicy + from data_utils.normalize import Identity + + # Get control space and type from config or defaults + ctrl_space = 'ee' + ctrl_type = 'delta' + + if algo_cfg is not None: + ctrl_space = algo_cfg.get('ctrl_space', ctrl_space) + ctrl_type = algo_cfg.get('ctrl_type', ctrl_type) + + # For RL, we typically use identity normalizers (actions are already in [-1, 1]) + action_normalizer = Identity() + state_normalizer = Identity() + + # Store in args for later use + args.ctrl_space = ctrl_space + args.ctrl_type = ctrl_type + + return { + 'action_normalizer': action_normalizer, + 'state_normalizer': state_normalizer, + 'ctrl_space': ctrl_space, + 'ctrl_type': ctrl_type, + } + + +def create_replay_buffer(args, state_dim, action_dim, num_envs): + """Create replay buffer.""" + from rl.buffer import MetaReplay + + replay = MetaReplay( + capacity=args.replay_size, + state_dim=state_dim, + action_dim=action_dim, + n_envs=num_envs, + device='cpu', # Store on CPU, move to GPU during training + ) + return replay + + +def create_algorithm(args, algo_cfg, state_dim, action_dim, replay, meta_policy_params): + """ + Create RL algorithm from config dynamically. + + Args: + args: Command line arguments + algo_cfg: Algorithm config dict (from YAML) + state_dim: State dimension + action_dim: Action dimension + replay: Replay buffer + meta_policy_params: Dict with normalizers and ctrl settings for MetaPolicy + """ + import dataclasses + from rl.algorithms import get_algorithm_class, get_config_class, list_algorithms + + logger.info(f"Available algorithms: {list_algorithms()}") + + # Determine algorithm type from config + algo_args = algo_cfg if algo_cfg is not None else {} + algo_type = algo_args.get('type', algo_args.get('algorithm', args.algorithm)) + + # Get algorithm class and config class dynamically + AlgorithmClass = get_algorithm_class(algo_type) + ConfigClass = get_config_class(algo_type) + logger.info(f"Using algorithm: {AlgorithmClass.__name__}") + + # Build config if ConfigClass is available + config = None + if ConfigClass is not None: + # Required parameters + config_params = { + 'state_dim': state_dim, + 'action_dim': action_dim, + 'device': args.device, + } + + # Get config field names and add matching parameters from algo_cfg + if dataclasses.is_dataclass(ConfigClass): + config_fields = {f.name for f in dataclasses.fields(ConfigClass) if f.name not in config_params} + for key in config_fields: + if key in algo_args: + config_params[key] = algo_args[key] + + config = ConfigClass(**config_params) + + # Log config (dynamically get attributes) + log_attrs = ['discount', 'tau', 'actor_lr', 'critic_lr', 'lr'] + log_parts = [f"{attr}={getattr(config, attr)}" for attr in log_attrs if hasattr(config, attr)] + if log_parts: + logger.info(f"{AlgorithmClass.__name__} config: {', '.join(log_parts)}") + + # Get exploration noise from algo_cfg (overrides command line) + if 'expl_noise' in algo_args: + args.expl_noise = algo_args['expl_noise'] + logger.info(f"Using expl_noise={args.expl_noise} from algo config") + + # Create algorithm + if config is not None: + algorithm = AlgorithmClass( + replay=replay, + config=config, + meta_policy=None, + ctrl_space=meta_policy_params['ctrl_space'], + ctrl_type=meta_policy_params['ctrl_type'], + ) + else: + # Fallback for algorithms without config class + excluded_keys = {'type', 'algorithm', 'ctrl_space', 'ctrl_type', 'expl_noise'} + filtered_args = {k: v for k, v in algo_args.items() if k not in excluded_keys} + + algorithm = AlgorithmClass( + replay=replay, + state_dim=state_dim, + action_dim=action_dim, + device=args.device, + **filtered_args + ) + + return algorithm + + +def extract_state(obs, state_key='state'): + """Extract state array from observation.""" + if hasattr(obs, state_key): + return getattr(obs, state_key) + elif isinstance(obs, dict): + return obs.get(state_key, obs.get('observation')) + elif isinstance(obs, np.ndarray): + if obs.dtype == np.object_: + # Array of MetaObs + return np.stack([getattr(o, state_key) if hasattr(o, state_key) else o for o in obs]) + return obs + return obs + + +def train(args): + """Main training loop using OffPolicyTrainer and DummyCollector.""" + # Set seed + set_seed(args.seed) + logger.info(f"Set random seed to {args.seed}") + + # Load environment config + env_cfg, env_cfg_path = load_env_config(args) + logger.info(f"Loaded env config from: {env_cfg_path}") + + # Load algorithm config (similar to how eval_sim.py loads policy) + algo_cfg, algo_cfg_path = load_algo_config(args) + + # Create vectorized environment + vec_env, env_name, env_module = create_vector_env( + env_cfg, args.num_envs, args.use_subproc + ) + logger.info(f"Created {args.num_envs} parallel environments: {env_name}") + + # Sync derived values from env config (like eval_sim.py does) + if hasattr(env_cfg, 'max_timesteps'): + args.max_timesteps = env_cfg.max_timesteps + else: + args.max_timesteps = 1000 # Default + + if hasattr(env_cfg, 'task'): + args.task = env_cfg.task + + # Get environment dimensions from config (preferred) or infer from env (fallback) + state_dim, action_dim = get_env_dims(env_cfg, vec_env) + logger.info(f"State dim: {state_dim}, Action dim: {action_dim}") + + # Create MetaPolicy parameters (normalizers, ctrl settings) + meta_policy_params = create_meta_policy(args, algo_cfg, state_dim, action_dim) + logger.info(f"Control space: {args.ctrl_space}, Control type: {args.ctrl_type}") + + # Create replay buffer + replay = create_replay_buffer(args, state_dim, action_dim, args.num_envs) + logger.info(f"Created replay buffer with capacity {args.replay_size}") + + # Create algorithm (with MetaPolicy) + algorithm = create_algorithm(args, algo_cfg, state_dim, action_dim, replay, meta_policy_params) + algorithm.set_env(vec_env) # Bind environment for action processing + logger.info(f"Created algorithm: {algorithm}") + + # Create Collector for environment interaction + from rl.collectors import DummyCollector + collector = DummyCollector( + envs=vec_env, + algorithm=algorithm, + ctrl_space=args.ctrl_space, + action_dim=action_dim, + ) + + # Create Trainer configuration + from rl.trainers.offpolicy_trainer import OffPolicyTrainer, OffPolicyTrainerConfig + trainer_config = OffPolicyTrainerConfig( + total_steps=args.total_steps, + start_steps=args.start_steps, + update_after=args.update_after, + update_every=args.update_every, + batch_size=args.batch_size, + log_freq=args.log_freq, + eval_freq=args.eval_freq, + eval_episodes=args.eval_episodes, + save_freq=args.save_freq, + output_dir=args.output_dir, + max_timesteps=args.max_timesteps, + ctrl_space=args.ctrl_space, + expl_noise=args.expl_noise, + ) + + # Create evaluation environment factory + def eval_env_fn(): + return create_env_fn(env_cfg, env_module)() + + # Select vector environment class based on args + if args.use_subproc: + from benchmark.utils import SubprocVectorEnv + eval_vec_env_cls = SubprocVectorEnv + else: + eval_vec_env_cls = SequentialVectorEnv + + # Create Trainer + trainer = OffPolicyTrainer( + algorithm=algorithm, + collector=collector, + config=trainer_config, + eval_env_fn=eval_env_fn, + eval_vec_env_cls=eval_vec_env_cls, + ) + + # Resume from checkpoint if specified + start_step = 0 + if args.resume: + start_step = trainer.load_checkpoint(args.resume) + + # Run training + trainer.train(resume_step=start_step) + + # Cleanup + vec_env.close() + + +if __name__ == '__main__': + args = parse_args() + train(args) + From b3abfae12bf6b930ad4e77efaa1ecaeaf886c08f Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Wed, 4 Feb 2026 03:02:04 +0800 Subject: [PATCH 5/6] [feat] Add gymnasium benchmark and fix bug in action_utils.py --- benchmark/gymnasium/__init__.py | 15 + configs/env/gymnasium/ant.yaml | 26 ++ configs/env/gymnasium/halfcheetah.yaml | 26 ++ configs/env/gymnasium/hopper.yaml | 26 ++ configs/env/gymnasium/walker2d.yaml | 26 ++ rl/algorithms/td3/td3.py | 4 +- rl/utils/action_utils.py | 18 +- utils/exploration.py | 415 +++++++++++++++++++++++++ 8 files changed, 545 insertions(+), 11 deletions(-) create mode 100644 configs/env/gymnasium/ant.yaml create mode 100644 configs/env/gymnasium/halfcheetah.yaml create mode 100644 configs/env/gymnasium/hopper.yaml create mode 100644 configs/env/gymnasium/walker2d.yaml create mode 100644 utils/exploration.py diff --git a/benchmark/gymnasium/__init__.py b/benchmark/gymnasium/__init__.py index e4b9b26b..9b8df58d 100644 --- a/benchmark/gymnasium/__init__.py +++ b/benchmark/gymnasium/__init__.py @@ -16,6 +16,8 @@ import numpy as np import gymnasium as gym from ..base import MetaEnv, MetaObs, MetaAction +import torch + def create_env(config): @@ -144,6 +146,19 @@ def render(self): def close(self): """Close the environment.""" self.env.close() + + def ensure_action_reasonable(self, action): + """Ensure action is reasonable.""" + if self.action_low is not None or self.action_high is not None: + # Convert Tensor to numpy array if needed + if torch is not None and torch.is_tensor(action): + action = action.detach().cpu().numpy() + action = np.asarray(action) + + if np.any(action <= self.action_high) and np.any(action >= self.action_low): + return True + + return False # Common MuJoCo environment configurations diff --git a/configs/env/gymnasium/ant.yaml b/configs/env/gymnasium/ant.yaml new file mode 100644 index 00000000..3a48e810 --- /dev/null +++ b/configs/env/gymnasium/ant.yaml @@ -0,0 +1,26 @@ +# Ant-v5 environment configuration +# MuJoCo locomotion task: make a 3D ant robot run forward + +type: benchmark.gymnasium.GymnasiumEnv +name: ant + +# Environment settings +task: Ant-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Ant + +# Dimensions (explicit specification) +state_dim: 27 # qpos: 13, qvel: 14, excluding x,y position +action_dim: 8 # torques for 8 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the ant walk forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.5 * control_cost + healthy_reward diff --git a/configs/env/gymnasium/halfcheetah.yaml b/configs/env/gymnasium/halfcheetah.yaml new file mode 100644 index 00000000..8a600791 --- /dev/null +++ b/configs/env/gymnasium/halfcheetah.yaml @@ -0,0 +1,26 @@ +# HalfCheetah-v5 environment configuration +# MuJoCo locomotion task: make a 2D cheetah robot run as fast as possible + +type: benchmark.gymnasium.GymnasiumEnv +name: halfcheetah + +# Environment settings +task: HalfCheetah-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for HalfCheetah + +# Dimensions (explicit specification) +state_dim: 17 # qpos: 9, qvel: 8, excluding x-position +action_dim: 6 # torques for 6 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the cheetah run as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.1 * control_cost diff --git a/configs/env/gymnasium/hopper.yaml b/configs/env/gymnasium/hopper.yaml new file mode 100644 index 00000000..7cdd06ee --- /dev/null +++ b/configs/env/gymnasium/hopper.yaml @@ -0,0 +1,26 @@ +# Hopper-v5 environment configuration +# MuJoCo locomotion task: make a 2D one-legged robot hop forward + +type: benchmark.gymnasium.GymnasiumEnv +name: hopper + +# Environment settings +task: Hopper-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Hopper + +# Dimensions (explicit specification) +state_dim: 11 # qpos: 5, qvel: 6, excluding x-position +action_dim: 3 # torques for 3 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the hopper hop forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.001 * control_cost + healthy_reward diff --git a/configs/env/gymnasium/walker2d.yaml b/configs/env/gymnasium/walker2d.yaml new file mode 100644 index 00000000..e4708dd3 --- /dev/null +++ b/configs/env/gymnasium/walker2d.yaml @@ -0,0 +1,26 @@ +# Walker2d-v5 environment configuration +# MuJoCo locomotion task: make a 2D bipedal robot walk forward + +type: benchmark.gymnasium.GymnasiumEnv +name: walker2d + +# Environment settings +task: Walker2d-v5 +render_mode: null # Set to 'rgb_array' for video recording, 'human' for visualization +use_camera: false +max_episode_steps: 1000 # Default for Walker2d + +# Dimensions (explicit specification) +state_dim: 17 # qpos: 8, qvel: 9, excluding x-position +action_dim: 6 # torques for 6 joints + +# Control settings +ctrl_space: joint +ctrl_type: abs + +# Task description +raw_lang: "Make the walker walk forward as fast as possible" + +# Notes: +# - Action range: [-1, 1] (normalized) +# - Reward: forward velocity - 0.001 * control_cost + healthy_reward diff --git a/rl/algorithms/td3/td3.py b/rl/algorithms/td3/td3.py index 6f10df02..9570e9fe 100644 --- a/rl/algorithms/td3/td3.py +++ b/rl/algorithms/td3/td3.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - +import rl.utils.action_utils as action_utils from benchmark.base import MetaAction, MetaObs, MetaPolicy from policy.mlp.mlp import MLPPolicy, MLPPolicyConfig from rl.utils import polyak_update @@ -93,7 +93,7 @@ def __init__( config: TD3Config, actor_config: Optional[MLPPolicyConfig] = None, meta_policy: Optional[MetaPolicy] = None, - ensure_refine_fn=None, + ensure_refine_fn= action_utils.tanh_action_to_space, ctrl_space: str = "ee", ctrl_type: str = "delta", gripper_continuous: bool = False, diff --git a/rl/utils/action_utils.py b/rl/utils/action_utils.py index 7e47d92e..c9fe21af 100644 --- a/rl/utils/action_utils.py +++ b/rl/utils/action_utils.py @@ -86,24 +86,24 @@ def ensure_action( """ Ensure actions are valid for the environment. - - Applies tanh to bound outputs to [-1, 1] (as in TD3 policies). - - Applies an optional refine function for env-specific processing. + - Always applies the refine function (e.g., tanh_action_to_space) to map + network outputs to valid action space. - Clips to env action space bounds if available. Args: env: Environment (or wrapper) providing action_space. action: Tensor or numpy array action. refine_fn: Optional callable (env, action) -> action for custom refinement. - apply_tanh: Whether to apply tanh to the action. """ try: reasonable = env.envs[0].ensure_action_reasonable(action) except ValueError: logger.warning(f"Environment {env.__class__.__name__} does not support ensure_action_reasonable, using tanh_action_to_space") - if reasonable is False: - if refine_fn is not None: - action = refine_fn(env, action) - else: - raise ValueError(f"Action {action} is not reasonable") - + # Always apply refine_fn to map network outputs to action space + if refine_fn is not None: + action = refine_fn(env, action) + + # Final clip to ensure bounds (safety measure) + action = clip_action_to_space(env, action) + return action \ No newline at end of file diff --git a/utils/exploration.py b/utils/exploration.py new file mode 100644 index 00000000..11a7d612 --- /dev/null +++ b/utils/exploration.py @@ -0,0 +1,415 @@ +""" +Exploration Strategies for RL + +This module defines exploration strategies for action selection during training. +Supports: +- Random exploration (for initial exploration phase) +- Gaussian noise +- Ornstein-Uhlenbeck noise +- Custom noise functions (e.g., uncertainty-based) + +Usage: + # Gaussian noise with initial random exploration + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + + # Apply exploration to action + explored_action = exploration(action, step=current_step) +""" + +import numpy as np +from abc import ABC, abstractmethod +from typing import Optional, Callable, Any + + +class BaseExploration(ABC): + """ + Base class for exploration strategies. + + Exploration strategies modify actions to encourage exploration during training. + """ + + @abstractmethod + def __call__( + self, + action: np.ndarray, + step: int = 0, + **kwargs + ) -> np.ndarray: + """ + Apply exploration to action. + + Args: + action: Original action from policy, shape (action_dim,) or (n_envs, action_dim) + step: Current training step (for annealing) + **kwargs: Additional arguments (e.g., obs for uncertainty-based exploration) + + Returns: + Explored action with same shape as input + """ + raise NotImplementedError + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset exploration state (e.g., for OU noise).""" + pass + + +class NoExploration(BaseExploration): + """No exploration - return action as-is.""" + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + return action + + +class GaussianNoise(BaseExploration): + """ + Gaussian noise exploration. + + Adds zero-mean Gaussian noise to actions, with optional annealing. + Used in TD3, SAC, etc. + + Example: + noise = GaussianNoise(sigma=0.1, sigma_min=0.01, decay_steps=100000) + noisy_action = noise(action, step=current_step) + """ + + def __init__( + self, + sigma: float = 0.1, + sigma_min: float = 0.01, + sigma_decay: float = 1.0, # No decay by default + decay_steps: int = 0, # 0 means no decay + ): + """ + Args: + sigma: Initial noise standard deviation + sigma_min: Minimum noise standard deviation (for annealing) + sigma_decay: Decay factor per step (exponential decay, < 1.0 to enable) + decay_steps: Total steps for linear decay (> 0 to enable, takes priority over sigma_decay) + """ + self.sigma_init = sigma + self.sigma = sigma + self.sigma_min = sigma_min + self.sigma_decay = sigma_decay + self.decay_steps = decay_steps + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Anneal sigma + if self.decay_steps > 0: + # Linear decay + decay_ratio = max(0.0, 1.0 - step / self.decay_steps) + self.sigma = self.sigma_min + (self.sigma_init - self.sigma_min) * decay_ratio + elif self.sigma_decay < 1.0: + # Exponential decay + self.sigma = max(self.sigma_min, self.sigma_init * (self.sigma_decay ** step)) + + # Add Gaussian noise + noise = np.random.normal(0, self.sigma, size=action.shape) + noisy_action = action + noise + + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + pass # Gaussian noise is stateless + + def __repr__(self) -> str: + return (f"GaussianNoise(sigma={self.sigma:.4f}, sigma_init={self.sigma_init}, " + f"sigma_min={self.sigma_min}, decay_steps={self.decay_steps})") + + +class OUNoise(BaseExploration): + """ + Ornstein-Uhlenbeck noise exploration. + + Temporally correlated noise, often used in continuous control tasks. + Used in DDPG, etc. + + The OU process: dx = theta * (mu - x) * dt + sigma * dW + + Example: + noise = OUNoise(action_dim=7, n_envs=4, sigma=0.2, theta=0.15) + noisy_action = noise(action, step=current_step) + """ + + def __init__( + self, + action_dim: int, + n_envs: int = 1, + mu: float = 0.0, + theta: float = 0.15, + sigma: float = 0.2, + sigma_min: float = 0.01, + sigma_decay: float = 1.0, + decay_steps: int = 0, + ): + """ + Args: + action_dim: Dimension of action space + n_envs: Number of parallel environments + mu: Mean of the noise (typically 0) + theta: Rate of mean reversion (how fast noise returns to mu) + sigma: Volatility of the noise + sigma_min: Minimum sigma for annealing + sigma_decay: Exponential decay factor + decay_steps: Steps for linear decay + """ + self.action_dim = action_dim + self.n_envs = n_envs + self.mu = mu + self.theta = theta + self.sigma_init = sigma + self.sigma = sigma + self.sigma_min = sigma_min + self.sigma_decay = sigma_decay + self.decay_steps = decay_steps + + # Initialize state for each environment + self._state = np.ones((n_envs, action_dim)) * mu + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Anneal sigma + if self.decay_steps > 0: + decay_ratio = max(0.0, 1.0 - step / self.decay_steps) + self.sigma = self.sigma_min + (self.sigma_init - self.sigma_min) * decay_ratio + elif self.sigma_decay < 1.0: + self.sigma = max(self.sigma_min, self.sigma_init * (self.sigma_decay ** step)) + + # Handle both single and batched actions + is_batched = action.ndim == 2 + if not is_batched: + action = action[np.newaxis, :] # (1, action_dim) + + batch_size = action.shape[0] + + # Update OU state: dx = theta * (mu - x) + sigma * noise + dx = (self.theta * (self.mu - self._state[:batch_size]) + + self.sigma * np.random.randn(batch_size, self.action_dim)) + self._state[:batch_size] += dx + + # Add noise to action + noisy_action = action + self._state[:batch_size] + + if not is_batched: + noisy_action = noisy_action[0] + + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset OU state.""" + if env_idx is None: + self._state = np.ones((self.n_envs, self.action_dim)) * self.mu + else: + self._state[env_idx] = self.mu + + def __repr__(self) -> str: + return (f"OUNoise(action_dim={self.action_dim}, n_envs={self.n_envs}, " + f"theta={self.theta}, sigma={self.sigma:.4f})") + + +class CustomNoise(BaseExploration): + """ + Custom noise exploration using a user-provided function. + + Allows implementing advanced exploration strategies like: + - Uncertainty-based exploration (using model epistemic uncertainty) + - Curiosity-driven exploration + - Parameter noise + - Any other custom exploration logic + + Example: + def uncertainty_noise_fn(action, step, obs=None, model=None, **kwargs): + if model is not None and obs is not None: + uncertainty = model.get_uncertainty(obs) + noise_scale = uncertainty * 0.5 + else: + noise_scale = 0.1 + noise = np.random.normal(0, noise_scale, size=action.shape) + return action + noise + + noise = CustomNoise(noise_fn=uncertainty_noise_fn) + noisy_action = noise(action, step=step, obs=obs, model=model) + """ + + def __init__( + self, + noise_fn: Callable[[np.ndarray, int, Any], np.ndarray], + reset_fn: Optional[Callable[[Optional[int]], None]] = None, + ): + """ + Args: + noise_fn: Custom noise function with signature: + noise_fn(action, step, **kwargs) -> noisy_action + - action: Original action from policy + - step: Current training step + - **kwargs: Additional info (obs, model uncertainty, etc.) + reset_fn: Optional reset function for stateful noise + """ + self.noise_fn = noise_fn + self.reset_fn = reset_fn + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + noisy_action = self.noise_fn(action, step, **kwargs) + return noisy_action.astype(action.dtype) + + def reset(self, env_idx: Optional[int] = None) -> None: + if self.reset_fn is not None: + self.reset_fn(env_idx) + + def __repr__(self) -> str: + return f"CustomNoise(noise_fn={self.noise_fn.__name__})" + + +class RandomExploration(BaseExploration): + """ + Random action exploration. + + Samples random actions from action space, completely ignoring policy output. + Used for initial exploration phase. + + Example: + random_explore = RandomExploration( + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + random_action = random_explore(policy_action) # policy_action is ignored + """ + + def __init__( + self, + action_low: np.ndarray, + action_high: np.ndarray, + ): + """ + Args: + action_low: Lower bound of action space + action_high: Upper bound of action space + """ + self.action_low = np.asarray(action_low) + self.action_high = np.asarray(action_high) + + def __call__(self, action: np.ndarray, step: int = 0, **kwargs) -> np.ndarray: + # Completely ignore input action, sample random action + random_action = np.random.uniform( + self.action_low, + self.action_high, + size=action.shape + ) + return random_action.astype(action.dtype) + + def __repr__(self) -> str: + return f"RandomExploration(action_low={self.action_low}, action_high={self.action_high})" + + +class ExplorationScheduler: + """ + Scheduler for switching between exploration strategies. + + Supports: + - Initial random exploration phase (before a certain step) + - Transition to noise-based exploration + - Annealing exploration over time + + Example: + # Create scheduler with 10000 random steps, then Gaussian noise + exploration = ExplorationScheduler( + exploration_strategy=GaussianNoise(sigma=0.1, decay_steps=100000), + random_steps=10000, + action_low=np.array([-1.0] * 7), + action_high=np.array([1.0] * 7), + ) + + # In training loop: + for step in range(total_steps): + action = policy(obs) + explored_action = exploration(action, step=step) + # or use internal counter: + explored_action = exploration(action) + """ + + def __init__( + self, + exploration_strategy: BaseExploration, + random_steps: int = 0, + action_low: Optional[np.ndarray] = None, + action_high: Optional[np.ndarray] = None, + ): + """ + Args: + exploration_strategy: Main exploration strategy (e.g., GaussianNoise, OUNoise) + random_steps: Number of initial steps to use pure random exploration + - Set to 0 to disable random exploration phase + action_low: Lower bound for random exploration (required if random_steps > 0) + action_high: Upper bound for random exploration (required if random_steps > 0) + """ + self.exploration_strategy = exploration_strategy + self.random_steps = random_steps + self.action_low = action_low + self.action_high = action_high + + # Create random exploration for initial phase + if random_steps > 0: + if action_low is None or action_high is None: + raise ValueError("action_low and action_high required when random_steps > 0") + self.random_exploration = RandomExploration(action_low, action_high) + else: + self.random_exploration = None + + self._current_step = 0 + + def __call__( + self, + action: np.ndarray, + step: Optional[int] = None, + **kwargs + ) -> np.ndarray: + """ + Apply exploration based on current step. + + Args: + action: Original action from policy + step: Optional step override (if None, uses internal counter) + **kwargs: Additional arguments for exploration strategy + + Returns: + Explored action + """ + if step is None: + step = self._current_step + self._current_step += 1 + + # Initial random exploration phase + if step < self.random_steps and self.random_exploration is not None: + return self.random_exploration(action, step, **kwargs) + + # Main exploration strategy (pass adjusted step for proper annealing) + return self.exploration_strategy(action, step - self.random_steps, **kwargs) + + def reset(self, env_idx: Optional[int] = None) -> None: + """Reset exploration state.""" + self.exploration_strategy.reset(env_idx) + if self.random_exploration is not None: + self.random_exploration.reset(env_idx) + + def reset_step_counter(self) -> None: + """Reset internal step counter.""" + self._current_step = 0 + + @property + def is_random_phase(self) -> bool: + """Check if still in random exploration phase.""" + return self._current_step < self.random_steps + + @property + def current_step(self) -> int: + """Get current step count.""" + return self._current_step + + def __repr__(self) -> str: + return (f"ExplorationScheduler(strategy={self.exploration_strategy}, " + f"random_steps={self.random_steps}, current_step={self._current_step})") + + From 5e5464ced5615789a68db228d452ff1916d9243c Mon Sep 17 00:00:00 2001 From: yesen-chen <840419490@qq.com> Date: Wed, 4 Feb 2026 03:06:52 +0800 Subject: [PATCH 6/6] [feat] Add algo td3.yaml --- configs/rl/td3.yaml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 configs/rl/td3.yaml diff --git a/configs/rl/td3.yaml b/configs/rl/td3.yaml new file mode 100644 index 00000000..1d365c24 --- /dev/null +++ b/configs/rl/td3.yaml @@ -0,0 +1,32 @@ +# TD3 Algorithm Configuration +# Twin Delayed Deep Deterministic Policy Gradient + +type: td3 + +# Network architecture +hidden_dims: [256, 256] + +# Discount factor (gamma) +discount: 0.99 + +# Soft update coefficient for target networks +tau: 0.005 + +# Learning rates +actor_lr: 0.0003 +critic_lr: 0.0003 + +# Target policy smoothing +policy_noise: 0.2 # Noise added to target policy during critic update +noise_clip: 0.5 # Range to clip target policy noise + +# Delayed policy updates +policy_freq: 2 # Update policy every N critic updates + +# Exploration noise +expl_noise: 0.1 # Gaussian noise scale for exploration + +# Control settings +ctrl_space: joint # Control space: 'ee' (end-effector) or 'joint' +ctrl_type: absolute # Control type: 'delta' or 'absolute' +