Skip to content
Closed

Ttt #25

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 177 additions & 2 deletions scripts/reinforcement_learning/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import argparse
import sys
from collections.abc import Mapping

from isaaclab.app import AppLauncher

Expand All @@ -34,6 +35,30 @@
help="Use the pre-trained checkpoint from Nucleus.",
)
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
parser.add_argument(
"--collect_dataset",
action="store_true",
default=False,
help="Collect state-action pairs during inference and save successful episodes for supervised learning.",
)
parser.add_argument(
"--dataset_output",
type=str,
default=None,
help="Output .pt path for collected dataset. Defaults to <run_log_dir>/supervised_dataset.pt",
)
parser.add_argument(
"--num_successful_episodes",
type=int,
default=100,
help="Stop data collection after this many successful episodes.",
)
parser.add_argument(
"--success_reward_threshold",
type=float,
default=0.1,
help="Episode is successful if its final reward is above this threshold.",
)
# append RSL-RL cli arguments
cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
Expand All @@ -57,6 +82,7 @@
import os
import time
import torch
from tqdm import tqdm

from rsl_rl.runners import DistillationRunner, OnPolicyRunner

Expand All @@ -81,6 +107,35 @@
# PLACEHOLDER: Extension template (do not remove this comment)


def _flatten_observation_value(value) -> torch.Tensor:
"""Flatten a tensor-like observation structure into `[num_envs, flat_dim]`."""
if isinstance(value, torch.Tensor):
return value.reshape(value.shape[0], -1)

if isinstance(value, Mapping) or (hasattr(value, "keys") and hasattr(value, "__getitem__")):
flat_chunks = [_flatten_observation_value(value[key]) for key in value.keys()]
if not flat_chunks:
raise ValueError("Observation mapping is empty and cannot be flattened.")
return torch.cat(flat_chunks, dim=-1)

tensor_value = torch.as_tensor(value)
return tensor_value.reshape(tensor_value.shape[0], -1)


def _extract_policy_state(obs) -> torch.Tensor:
"""Extract and flatten policy observations into `[num_envs, state_dim]`."""
if isinstance(obs, Mapping) or (hasattr(obs, "keys") and hasattr(obs, "__getitem__")):
if "policy" in obs:
state = obs["policy"]
else:
# Fall back to first observation entry if policy group is unavailable.
first_key = next(iter(obs))
state = obs[first_key]
else:
state = obs
return _flatten_observation_value(state)


@hydra_task_config(args_cli.task, args_cli.agent)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
"""Play with RSL-RL agent."""
Expand All @@ -89,7 +144,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
train_task_name = task_name.replace("-Play", "")

# override configurations with non-hydra CLI arguments
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs

# set the environment seed
Expand Down Expand Up @@ -177,18 +232,97 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

# reset environment
obs = env.get_observations()
num_envs = env.num_envs
state_dim = _extract_policy_state(obs).shape[1]
action_dim = env.num_actions

successful_episodes = 0
total_finished_episodes = 0
data_pbar = None
collected_state_chunks: list[torch.Tensor] = []
collected_action_chunks: list[torch.Tensor] = []
collected_reward_chunks: list[torch.Tensor] = []
env_ids_device: torch.Tensor | None = None
episode_lengths: torch.Tensor | None = None
episode_states: torch.Tensor | None = None
episode_actions: torch.Tensor | None = None
episode_rewards: torch.Tensor | None = None
if args_cli.collect_dataset:
max_episode_steps = getattr(env.unwrapped, "max_episode_length", None)
if max_episode_steps is None:
raise RuntimeError("Dataset collection requires `env.unwrapped.max_episode_length` to be available.")
max_episode_steps = int(max_episode_steps)
collection_device = env.unwrapped.device
env_ids_device = torch.arange(num_envs, device=collection_device, dtype=torch.long)
episode_lengths = torch.zeros(num_envs, device=collection_device, dtype=torch.long)
episode_states = torch.empty((num_envs, max_episode_steps, state_dim), device=collection_device, dtype=torch.float32)
episode_actions = torch.empty((num_envs, max_episode_steps, action_dim), device=collection_device, dtype=torch.float32)
episode_rewards = torch.empty((num_envs, max_episode_steps), device=collection_device, dtype=torch.float32)
data_pbar = tqdm(total=args_cli.num_successful_episodes, desc="Collecting successful episodes", unit="episode")

timestep = 0
# simulate environment
while simulation_app.is_running():
start_time = time.time()
# run everything in inference mode
with torch.inference_mode():
state_t = _extract_policy_state(obs)
# agent stepping
actions = policy(obs)
# env stepping
obs, _, dones, _ = env.step(actions)
obs, rewards, dones, _ = env.step(actions)
# reset recurrent states for episodes that have terminated
policy_nn.reset(dones)

if args_cli.collect_dataset:
assert env_ids_device is not None
assert episode_lengths is not None
assert episode_states is not None
assert episode_actions is not None
assert episode_rewards is not None
flat_actions = actions.reshape(actions.shape[0], -1)
rewards_device = rewards.detach()
dones_device = dones.detach()
states_device = state_t.detach()
actions_device = flat_actions.detach()
step_ids = episode_lengths.clone()
episode_states[env_ids_device, step_ids] = states_device
episode_actions[env_ids_device, step_ids] = actions_device
episode_rewards[env_ids_device, step_ids] = rewards_device
episode_lengths += 1

done_ids = torch.nonzero(dones_device, as_tuple=False).squeeze(-1)
if done_ids.numel() > 0:
total_finished_episodes += int(done_ids.numel())

successful_done_mask = rewards_device[done_ids] > args_cli.success_reward_threshold
successful_done_ids = done_ids[successful_done_mask]
remaining_successes = args_cli.num_successful_episodes - successful_episodes

if remaining_successes > 0 and successful_done_ids.numel() > 0:
successful_done_ids = successful_done_ids[:remaining_successes]
successful_lengths = episode_lengths[successful_done_ids]
successful_episodes += int(successful_done_ids.numel())
if data_pbar is not None:
data_pbar.update(int(successful_done_ids.numel()))

max_success_len = int(successful_lengths.max().item())
valid_steps = (
torch.arange(max_success_len, device=successful_lengths.device, dtype=torch.long).unsqueeze(0)
< successful_lengths.unsqueeze(1)
)
collected_state_chunks.append(episode_states[successful_done_ids, :max_success_len][valid_steps].cpu())
collected_action_chunks.append(episode_actions[successful_done_ids, :max_success_len][valid_steps].cpu())
collected_reward_chunks.append(episode_rewards[successful_done_ids, :max_success_len][valid_steps].cpu())

episode_lengths[done_ids] = 0

if successful_episodes >= args_cli.num_successful_episodes:
print(
f"[INFO] Collected {successful_episodes} successful episodes "
f"(threshold={args_cli.success_reward_threshold})."
)
break
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
Expand All @@ -202,6 +336,47 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

# close the simulator
env.close()
if data_pbar is not None:
data_pbar.close()

if args_cli.collect_dataset:
if collected_state_chunks:
collected_states = torch.cat(collected_state_chunks, dim=0)
collected_actions = torch.cat(collected_action_chunks, dim=0)
collected_rewards = torch.cat(collected_reward_chunks, dim=0)
else:
collected_states = torch.empty((0, state_dim), dtype=torch.float32)
collected_actions = torch.empty((0, action_dim), dtype=torch.float32)
collected_rewards = torch.empty((0,), dtype=torch.float32)

dataset_path = args_cli.dataset_output
if dataset_path is None:
dataset_path = os.path.join(log_dir, "supervised_dataset.pt")
dataset_path = os.path.abspath(dataset_path)
os.makedirs(os.path.dirname(dataset_path), exist_ok=True)

payload = {
"states": collected_states.float(),
"actions": collected_actions.float(),
"rewards": collected_rewards.float(),
"meta": {
"task": args_cli.task,
"checkpoint": resume_path,
"success_reward_threshold": args_cli.success_reward_threshold,
"target_successful_episodes": args_cli.num_successful_episodes,
"successful_episodes_collected": successful_episodes,
"finished_episodes_seen": total_finished_episodes,
"num_samples": int(collected_states.shape[0]),
},
}
torch.save(payload, dataset_path)
print(f"[INFO] Saved supervised dataset to: {dataset_path}")
print(
"[INFO] Dataset stats: "
f"samples={payload['meta']['num_samples']}, "
f"success_eps={successful_episodes}, "
f"finished_eps={total_finished_episodes}"
)


if __name__ == "__main__":
Expand Down
115 changes: 115 additions & 0 deletions source/uwlab/uwlab/utils/wandb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import wandb
import os

from dataclasses import asdict
from typing import Optional

from isaaclab.envs import ManagerBasedRLEnv
from isaaclab_rl.rsl_rl import RslRlVecEnvWrapper
from metalearning.isaac.plot_utils import get_figure

import h5py

def log_rl2_stats(returns: torch.Tensor, success: torch.Tensor, num_steps: torch.Tensor, log_step: int, num_envs: int, dataset_success_rate: Optional[float] = None, save_path: Optional[str] = None, log_once=True):
if log_once and hasattr(log_rl2_stats, 'logged'):
return

wandb.define_metric("RL2/step")
wandb.define_metric("RL2/*", step_metric="RL2/step")

if not hasattr(log_rl2_stats, 'num_finished_envs'):
log_rl2_stats.num_finished_envs = 0
log_rl2_stats.returns = []
log_rl2_stats.success = []
log_rl2_stats.num_steps = []

log_rl2_stats.num_finished_envs += returns.shape[0]
log_rl2_stats.returns.append(returns.clone())
log_rl2_stats.success.append(success.clone())
log_rl2_stats.num_steps.append(num_steps.clone())
if log_rl2_stats.num_finished_envs < num_envs:
return

log_rl2_stats.logged = True

stats = {}

returns = torch.cat(log_rl2_stats.returns, dim=0)
success = torch.cat(log_rl2_stats.success, dim=0)
num_steps = torch.cat(log_rl2_stats.num_steps, dim=0)
cum_success = torch.clamp(torch.cumsum(success, dim=1), max=1.0).to(dtype=torch.bool)

# compute consecutive success rate (after a successful episode, what % of environments stay successful?)
last_was_success = torch.roll(success, 1, dims=1)
consecutive_success_rate = torch.logical_and(success, last_was_success).sum(dim=0) / torch.clamp(last_was_success.sum(dim=0), min=1)
consecutive_success_rate[0] = 0.0 # first episode is not consecutive
stats["RL2/consecutive_success_rate"] = get_figure(
consecutive_success_rate.unsqueeze(0), "Episodes", "Consecutive Success Rate"
)

# compute fail to success rate (after an unsuccessful episode, what % of environments become successful?)
fail_to_success_rate = torch.logical_and(success, ~last_was_success).sum(dim=0) / torch.clamp((~last_was_success).sum(dim=0), min=1)
fail_to_success_rate[0] = 0.0 # first episode has no previous episode
stats["RL2/fail_to_success_rate"] = get_figure(
fail_to_success_rate.unsqueeze(0), "Episodes", "Fail to Success Rate"
)

stats["RL2/returns"] = get_figure(
returns, "Episodes", "Return"
)

success_rate_ref = None
if dataset_success_rate is not None:
success_rate_ref = dataset_success_rate / 100.0

stats["RL2/success"] = get_figure(
success, "Episodes", "Success", reference_line=success_rate_ref
)

stats["RL2/num_steps"] = get_figure(
num_steps, "Episodes", "Number of Steps"
)

cum_success_rate_ref = None
if dataset_success_rate is not None:
cum_success_rate_ref = dataset_success_rate / 100.0

stats["RL2/cum_success"] = get_figure(
cum_success, "Episodes", "Cumulative Success", reference_line=cum_success_rate_ref
)

stats["RL2/max_return"] = returns.max(dim=1)[0].mean()
stats["RL2/improvement"] = (returns[:, -1] - returns[:, 0]).mean()

diffs = returns[:, 1:] - returns[:, :-1]
stats["RL2/return_diff_plot"] = get_figure(
diffs, "Episodes", "Return Difference"
)
stats["RL2/num_envs"] = returns.shape[0]
stats["RL2/step"] = log_step

wandb.log(stats)

if save_path is not None:
# this is called before any other access to the file from the main training script, so we delete any old file and write to it
# TODO: this is a hack, should be more robust
if os.path.exists(save_path):
os.remove(save_path)
with h5py.File(save_path, 'a') as f:
f.create_dataset("returns", data=returns.cpu().numpy(), compression="gzip", compression_opts=4)
f.create_dataset("consecutive_success_rate", data=consecutive_success_rate.cpu().numpy(), compression="gzip", compression_opts=4)
f.create_dataset("fail_to_success_rate", data=fail_to_success_rate.cpu().numpy(), compression="gzip", compression_opts=4)
f.create_dataset("success", data=success.cpu().numpy(), compression="gzip", compression_opts=4)
f.create_dataset("num_steps", data=num_steps.cpu().numpy(), compression="gzip", compression_opts=4)
f.create_dataset("cum_success", data=cum_success.cpu().numpy(), compression="gzip", compression_opts=4)
f.attrs["dataset_success_rate"] = dataset_success_rate

log_rl2_stats.num_finished_envs = 0 # reset for next batch
log_rl2_stats.returns = []
log_rl2_stats.success = []
log_rl2_stats.num_steps = []

Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,12 @@ class EvalEventCfg(BaseEventCfg):
mode="reset",
params={
"base_paths": [
f"{UWLAB_CLOUD_ASSETS_DIR}/Datasets/Resets/ObjectPairs/ObjectAnywhereEEAnywhere",
f"reset_state_datasets/ObjectAnywhereEEAnywhere",
# f"{UWLAB_CLOUD_ASSETS_DIR}/Datasets/Resets/ObjectPairs/ObjectAnywhereEEAnywhere",
],
"probs": [1.0],
"success": "env.reward_manager.get_term_cfg('progress_context').func.success",
"reset_to_same_state": True,
},
)

Expand Down Expand Up @@ -511,6 +513,8 @@ class TerminationsCfg:

abnormal_robot = DoneTerm(func=task_mdp.abnormal_robot_state)

success = DoneTerm(func=task_mdp.success_term)


def make_insertive_object(usd_path: str):
return RigidObjectCfg(
Expand Down Expand Up @@ -588,7 +592,7 @@ class Ur5eRobotiq2f85RlStateCfg(ManagerBasedRLEnvCfg):

def __post_init__(self):
self.decimation = 12
self.episode_length_s = 16.0
self.episode_length_s = 10.0
# simulation settings
self.sim.dt = 1 / 120.0

Expand Down
Loading