diff --git a/scripts/reinforcement_learning/rsl_rl/play.py b/scripts/reinforcement_learning/rsl_rl/play.py index 19b5689..459e09b 100644 --- a/scripts/reinforcement_learning/rsl_rl/play.py +++ b/scripts/reinforcement_learning/rsl_rl/play.py @@ -9,6 +9,7 @@ import argparse import sys +from collections.abc import Mapping from isaaclab.app import AppLauncher @@ -34,6 +35,30 @@ help="Use the pre-trained checkpoint from Nucleus.", ) parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.") +parser.add_argument( + "--collect_dataset", + action="store_true", + default=False, + help="Collect state-action pairs during inference and save successful episodes for supervised learning.", +) +parser.add_argument( + "--dataset_output", + type=str, + default=None, + help="Output .pt path for collected dataset. Defaults to /supervised_dataset.pt", +) +parser.add_argument( + "--num_successful_episodes", + type=int, + default=100, + help="Stop data collection after this many successful episodes.", +) +parser.add_argument( + "--success_reward_threshold", + type=float, + default=0.1, + help="Episode is successful if its final reward is above this threshold.", +) # append RSL-RL cli arguments cli_args.add_rsl_rl_args(parser) # append AppLauncher cli args @@ -57,6 +82,7 @@ import os import time import torch +from tqdm import tqdm from rsl_rl.runners import DistillationRunner, OnPolicyRunner @@ -81,6 +107,35 @@ # PLACEHOLDER: Extension template (do not remove this comment) +def _flatten_observation_value(value) -> torch.Tensor: + """Flatten a tensor-like observation structure into `[num_envs, flat_dim]`.""" + if isinstance(value, torch.Tensor): + return value.reshape(value.shape[0], -1) + + if isinstance(value, Mapping) or (hasattr(value, "keys") and hasattr(value, "__getitem__")): + flat_chunks = [_flatten_observation_value(value[key]) for key in value.keys()] + if not flat_chunks: + raise ValueError("Observation mapping is empty and cannot be flattened.") + return torch.cat(flat_chunks, dim=-1) + + tensor_value = torch.as_tensor(value) + return tensor_value.reshape(tensor_value.shape[0], -1) + + +def _extract_policy_state(obs) -> torch.Tensor: + """Extract and flatten policy observations into `[num_envs, state_dim]`.""" + if isinstance(obs, Mapping) or (hasattr(obs, "keys") and hasattr(obs, "__getitem__")): + if "policy" in obs: + state = obs["policy"] + else: + # Fall back to first observation entry if policy group is unavailable. + first_key = next(iter(obs)) + state = obs[first_key] + else: + state = obs + return _flatten_observation_value(state) + + @hydra_task_config(args_cli.task, args_cli.agent) def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): """Play with RSL-RL agent.""" @@ -89,7 +144,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen train_task_name = task_name.replace("-Play", "") # override configurations with non-hydra CLI arguments - agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) + agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs # set the environment seed @@ -177,18 +232,97 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # reset environment obs = env.get_observations() + num_envs = env.num_envs + state_dim = _extract_policy_state(obs).shape[1] + action_dim = env.num_actions + + successful_episodes = 0 + total_finished_episodes = 0 + data_pbar = None + collected_state_chunks: list[torch.Tensor] = [] + collected_action_chunks: list[torch.Tensor] = [] + collected_reward_chunks: list[torch.Tensor] = [] + env_ids_device: torch.Tensor | None = None + episode_lengths: torch.Tensor | None = None + episode_states: torch.Tensor | None = None + episode_actions: torch.Tensor | None = None + episode_rewards: torch.Tensor | None = None + if args_cli.collect_dataset: + max_episode_steps = getattr(env.unwrapped, "max_episode_length", None) + if max_episode_steps is None: + raise RuntimeError("Dataset collection requires `env.unwrapped.max_episode_length` to be available.") + max_episode_steps = int(max_episode_steps) + collection_device = env.unwrapped.device + env_ids_device = torch.arange(num_envs, device=collection_device, dtype=torch.long) + episode_lengths = torch.zeros(num_envs, device=collection_device, dtype=torch.long) + episode_states = torch.empty((num_envs, max_episode_steps, state_dim), device=collection_device, dtype=torch.float32) + episode_actions = torch.empty((num_envs, max_episode_steps, action_dim), device=collection_device, dtype=torch.float32) + episode_rewards = torch.empty((num_envs, max_episode_steps), device=collection_device, dtype=torch.float32) + data_pbar = tqdm(total=args_cli.num_successful_episodes, desc="Collecting successful episodes", unit="episode") + timestep = 0 # simulate environment while simulation_app.is_running(): start_time = time.time() # run everything in inference mode with torch.inference_mode(): + state_t = _extract_policy_state(obs) # agent stepping actions = policy(obs) # env stepping - obs, _, dones, _ = env.step(actions) + obs, rewards, dones, _ = env.step(actions) # reset recurrent states for episodes that have terminated policy_nn.reset(dones) + + if args_cli.collect_dataset: + assert env_ids_device is not None + assert episode_lengths is not None + assert episode_states is not None + assert episode_actions is not None + assert episode_rewards is not None + flat_actions = actions.reshape(actions.shape[0], -1) + rewards_device = rewards.detach() + dones_device = dones.detach() + states_device = state_t.detach() + actions_device = flat_actions.detach() + step_ids = episode_lengths.clone() + episode_states[env_ids_device, step_ids] = states_device + episode_actions[env_ids_device, step_ids] = actions_device + episode_rewards[env_ids_device, step_ids] = rewards_device + episode_lengths += 1 + + done_ids = torch.nonzero(dones_device, as_tuple=False).squeeze(-1) + if done_ids.numel() > 0: + total_finished_episodes += int(done_ids.numel()) + + successful_done_mask = rewards_device[done_ids] > args_cli.success_reward_threshold + successful_done_ids = done_ids[successful_done_mask] + remaining_successes = args_cli.num_successful_episodes - successful_episodes + + if remaining_successes > 0 and successful_done_ids.numel() > 0: + successful_done_ids = successful_done_ids[:remaining_successes] + successful_lengths = episode_lengths[successful_done_ids] + successful_episodes += int(successful_done_ids.numel()) + if data_pbar is not None: + data_pbar.update(int(successful_done_ids.numel())) + + max_success_len = int(successful_lengths.max().item()) + valid_steps = ( + torch.arange(max_success_len, device=successful_lengths.device, dtype=torch.long).unsqueeze(0) + < successful_lengths.unsqueeze(1) + ) + collected_state_chunks.append(episode_states[successful_done_ids, :max_success_len][valid_steps].cpu()) + collected_action_chunks.append(episode_actions[successful_done_ids, :max_success_len][valid_steps].cpu()) + collected_reward_chunks.append(episode_rewards[successful_done_ids, :max_success_len][valid_steps].cpu()) + + episode_lengths[done_ids] = 0 + + if successful_episodes >= args_cli.num_successful_episodes: + print( + f"[INFO] Collected {successful_episodes} successful episodes " + f"(threshold={args_cli.success_reward_threshold})." + ) + break if args_cli.video: timestep += 1 # Exit the play loop after recording one video @@ -202,6 +336,47 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # close the simulator env.close() + if data_pbar is not None: + data_pbar.close() + + if args_cli.collect_dataset: + if collected_state_chunks: + collected_states = torch.cat(collected_state_chunks, dim=0) + collected_actions = torch.cat(collected_action_chunks, dim=0) + collected_rewards = torch.cat(collected_reward_chunks, dim=0) + else: + collected_states = torch.empty((0, state_dim), dtype=torch.float32) + collected_actions = torch.empty((0, action_dim), dtype=torch.float32) + collected_rewards = torch.empty((0,), dtype=torch.float32) + + dataset_path = args_cli.dataset_output + if dataset_path is None: + dataset_path = os.path.join(log_dir, "supervised_dataset.pt") + dataset_path = os.path.abspath(dataset_path) + os.makedirs(os.path.dirname(dataset_path), exist_ok=True) + + payload = { + "states": collected_states.float(), + "actions": collected_actions.float(), + "rewards": collected_rewards.float(), + "meta": { + "task": args_cli.task, + "checkpoint": resume_path, + "success_reward_threshold": args_cli.success_reward_threshold, + "target_successful_episodes": args_cli.num_successful_episodes, + "successful_episodes_collected": successful_episodes, + "finished_episodes_seen": total_finished_episodes, + "num_samples": int(collected_states.shape[0]), + }, + } + torch.save(payload, dataset_path) + print(f"[INFO] Saved supervised dataset to: {dataset_path}") + print( + "[INFO] Dataset stats: " + f"samples={payload['meta']['num_samples']}, " + f"success_eps={successful_episodes}, " + f"finished_eps={total_finished_episodes}" + ) if __name__ == "__main__": diff --git a/source/uwlab/uwlab/utils/wandb_utils.py b/source/uwlab/uwlab/utils/wandb_utils.py new file mode 100644 index 0000000..87656ad --- /dev/null +++ b/source/uwlab/uwlab/utils/wandb_utils.py @@ -0,0 +1,115 @@ +import gymnasium as gym +import numpy as np +import matplotlib.pyplot as plt +import torch +import wandb +import os + +from dataclasses import asdict +from typing import Optional + +from isaaclab.envs import ManagerBasedRLEnv +from isaaclab_rl.rsl_rl import RslRlVecEnvWrapper +from metalearning.isaac.plot_utils import get_figure + +import h5py + +def log_rl2_stats(returns: torch.Tensor, success: torch.Tensor, num_steps: torch.Tensor, log_step: int, num_envs: int, dataset_success_rate: Optional[float] = None, save_path: Optional[str] = None, log_once=True): + if log_once and hasattr(log_rl2_stats, 'logged'): + return + + wandb.define_metric("RL2/step") + wandb.define_metric("RL2/*", step_metric="RL2/step") + + if not hasattr(log_rl2_stats, 'num_finished_envs'): + log_rl2_stats.num_finished_envs = 0 + log_rl2_stats.returns = [] + log_rl2_stats.success = [] + log_rl2_stats.num_steps = [] + + log_rl2_stats.num_finished_envs += returns.shape[0] + log_rl2_stats.returns.append(returns.clone()) + log_rl2_stats.success.append(success.clone()) + log_rl2_stats.num_steps.append(num_steps.clone()) + if log_rl2_stats.num_finished_envs < num_envs: + return + + log_rl2_stats.logged = True + + stats = {} + + returns = torch.cat(log_rl2_stats.returns, dim=0) + success = torch.cat(log_rl2_stats.success, dim=0) + num_steps = torch.cat(log_rl2_stats.num_steps, dim=0) + cum_success = torch.clamp(torch.cumsum(success, dim=1), max=1.0).to(dtype=torch.bool) + + # compute consecutive success rate (after a successful episode, what % of environments stay successful?) + last_was_success = torch.roll(success, 1, dims=1) + consecutive_success_rate = torch.logical_and(success, last_was_success).sum(dim=0) / torch.clamp(last_was_success.sum(dim=0), min=1) + consecutive_success_rate[0] = 0.0 # first episode is not consecutive + stats["RL2/consecutive_success_rate"] = get_figure( + consecutive_success_rate.unsqueeze(0), "Episodes", "Consecutive Success Rate" + ) + + # compute fail to success rate (after an unsuccessful episode, what % of environments become successful?) + fail_to_success_rate = torch.logical_and(success, ~last_was_success).sum(dim=0) / torch.clamp((~last_was_success).sum(dim=0), min=1) + fail_to_success_rate[0] = 0.0 # first episode has no previous episode + stats["RL2/fail_to_success_rate"] = get_figure( + fail_to_success_rate.unsqueeze(0), "Episodes", "Fail to Success Rate" + ) + + stats["RL2/returns"] = get_figure( + returns, "Episodes", "Return" + ) + + success_rate_ref = None + if dataset_success_rate is not None: + success_rate_ref = dataset_success_rate / 100.0 + + stats["RL2/success"] = get_figure( + success, "Episodes", "Success", reference_line=success_rate_ref + ) + + stats["RL2/num_steps"] = get_figure( + num_steps, "Episodes", "Number of Steps" + ) + + cum_success_rate_ref = None + if dataset_success_rate is not None: + cum_success_rate_ref = dataset_success_rate / 100.0 + + stats["RL2/cum_success"] = get_figure( + cum_success, "Episodes", "Cumulative Success", reference_line=cum_success_rate_ref + ) + + stats["RL2/max_return"] = returns.max(dim=1)[0].mean() + stats["RL2/improvement"] = (returns[:, -1] - returns[:, 0]).mean() + + diffs = returns[:, 1:] - returns[:, :-1] + stats["RL2/return_diff_plot"] = get_figure( + diffs, "Episodes", "Return Difference" + ) + stats["RL2/num_envs"] = returns.shape[0] + stats["RL2/step"] = log_step + + wandb.log(stats) + + if save_path is not None: + # this is called before any other access to the file from the main training script, so we delete any old file and write to it + # TODO: this is a hack, should be more robust + if os.path.exists(save_path): + os.remove(save_path) + with h5py.File(save_path, 'a') as f: + f.create_dataset("returns", data=returns.cpu().numpy(), compression="gzip", compression_opts=4) + f.create_dataset("consecutive_success_rate", data=consecutive_success_rate.cpu().numpy(), compression="gzip", compression_opts=4) + f.create_dataset("fail_to_success_rate", data=fail_to_success_rate.cpu().numpy(), compression="gzip", compression_opts=4) + f.create_dataset("success", data=success.cpu().numpy(), compression="gzip", compression_opts=4) + f.create_dataset("num_steps", data=num_steps.cpu().numpy(), compression="gzip", compression_opts=4) + f.create_dataset("cum_success", data=cum_success.cpu().numpy(), compression="gzip", compression_opts=4) + f.attrs["dataset_success_rate"] = dataset_success_rate + + log_rl2_stats.num_finished_envs = 0 # reset for next batch + log_rl2_stats.returns = [] + log_rl2_stats.success = [] + log_rl2_stats.num_steps = [] + diff --git a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/config/ur5e_robotiq_2f85/rl_state_cfg.py b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/config/ur5e_robotiq_2f85/rl_state_cfg.py index 0818cab..15b5e53 100644 --- a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/config/ur5e_robotiq_2f85/rl_state_cfg.py +++ b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/config/ur5e_robotiq_2f85/rl_state_cfg.py @@ -273,10 +273,12 @@ class EvalEventCfg(BaseEventCfg): mode="reset", params={ "base_paths": [ - f"{UWLAB_CLOUD_ASSETS_DIR}/Datasets/Resets/ObjectPairs/ObjectAnywhereEEAnywhere", + f"reset_state_datasets/ObjectAnywhereEEAnywhere", + # f"{UWLAB_CLOUD_ASSETS_DIR}/Datasets/Resets/ObjectPairs/ObjectAnywhereEEAnywhere", ], "probs": [1.0], "success": "env.reward_manager.get_term_cfg('progress_context').func.success", + "reset_to_same_state": True, }, ) @@ -511,6 +513,8 @@ class TerminationsCfg: abnormal_robot = DoneTerm(func=task_mdp.abnormal_robot_state) + success = DoneTerm(func=task_mdp.success_term) + def make_insertive_object(usd_path: str): return RigidObjectCfg( @@ -588,7 +592,7 @@ class Ur5eRobotiq2f85RlStateCfg(ManagerBasedRLEnvCfg): def __post_init__(self): self.decimation = 12 - self.episode_length_s = 16.0 + self.episode_length_s = 10.0 # simulation settings self.sim.dt = 1 / 120.0 diff --git a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/events.py b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/events.py index ed754d1..25b75fe 100644 --- a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/events.py +++ b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/events.py @@ -998,6 +998,7 @@ def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv): base_paths: list[str] = cfg.params.get("base_paths", []) probabilities: list[float] = cfg.params.get("probs", []) + self.reset_to_same_state: bool = cfg.params.get("reset_to_same_state", False) if not base_paths: raise ValueError("No base paths provided") @@ -1028,6 +1029,7 @@ def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv): raise FileNotFoundError(f"Dataset file {dataset_file} could not be accessed or downloaded.") dataset = torch.load(local_file_path) + print(f"Loaded dataset {dataset_file} with {len(dataset['initial_state']['articulation']['robot']['joint_position'])} states") num_states.append(len(dataset["initial_state"]["articulation"]["robot"]["joint_position"])) init_indices = torch.arange(num_states[-1], device=env.device) self.datasets.append(sample_state_data_set(dataset, init_indices, env.device)) @@ -1045,6 +1047,8 @@ def __init__(self, cfg: EventTermCfg, env: ManagerBasedEnv): self.success_monitor = success_monitor_cfg.class_type(success_monitor_cfg) self.task_id = torch.randint(0, self.num_tasks, (self.num_envs,), device=self.device) + self.state_id = torch.zeros((self.num_envs,), device=self.device, dtype=torch.int32) + self.first_reset = True def __call__( self, @@ -1074,8 +1078,9 @@ def __call__( }) # Sample which dataset to use for each environment - dataset_indices = torch.multinomial(self.probs, len(env_ids), replacement=True) - self.task_id[env_ids] = dataset_indices + if not self.reset_to_same_state: + dataset_indices = torch.multinomial(self.probs, len(env_ids), replacement=True) + self.task_id[env_ids] = dataset_indices # Process each dataset's environments for dataset_idx in range(self.num_tasks): @@ -1084,11 +1089,17 @@ def __call__( continue current_env_ids = env_ids[mask] - state_indices = torch.randint( - 0, self.num_states[dataset_idx], (len(current_env_ids),), device=self._env.device - ) + if self.reset_to_same_state and not self.first_reset: + state_indices = self.state_id[current_env_ids] + else: + state_indices = torch.randint( + 0, self.num_states[dataset_idx], (len(current_env_ids),), device=self._env.device + ) + self.state_id[current_env_ids] = state_indices states_to_reset_from = sample_from_nested_dict(self.datasets[dataset_idx], state_indices) self._env.scene.reset_to(states_to_reset_from["initial_state"], env_ids=current_env_ids, is_relative=True) + + self.first_reset = False # Reset velocities robot: Articulation = self._env.scene["robot"] diff --git a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/rewards.py b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/rewards.py index 18cfbc2..fa3ad01 100644 --- a/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/rewards.py +++ b/source/uwlab_tasks/uwlab_tasks/manager_based/manipulation/reset_states/mdp/rewards.py @@ -178,6 +178,10 @@ def success_reward(env: ManagerBasedRLEnv, context: str = "progress_context") -> position_aligned: torch.Tensor = getattr(context_term, "position_aligned") return torch.where(orientation_aligned & position_aligned, 1.0, 0.0) +def success_term(env: ManagerBasedRLEnv, context: str = "progress_context") -> torch.Tensor: + rew = success_reward(env, context) + success_term.env_succeeded = rew.to(dtype=torch.bool) + return rew.to(dtype=torch.bool) def action_l2_clamped(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize the actions using L2 squared kernel."""