Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions checkpoint_manual_sotred/tdmpc/eval.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
step,episode_reward
0.0,-1224.275634765625
50144.0,-164.05038452148438
100118.0,-177.66744995117188
150262.0,-222.49234008789062
200166.0,-281.21649169921875
250232.0,-106.09677124023438
300024.0,-269.89263916015625
350121.0,-144.60325622558594
400414.0,-156.74029541015625
450037.0,-503.3945617675781
500062.0,-133.65843200683594
550449.0,-1192.0301513671875
600384.0,-283.03472900390625
650419.0,-94.0393295288086
700072.0,-740.7599487304688
750406.0,-301.64703369140625
800401.0,-194.54046630859375
850428.0,-246.45106506347656
900039.0,-113.0138168334961
950014.0,-159.62356567382812
1000101.0,-259.19091796875
1050088.0,-103.6041259765625
1100439.0,991.2276611328125
1150371.0,-67.97618103027344
1200044.0,-246.8380126953125
1250220.0,-117.49695587158203
1300039.0,-1349.541259765625
1350315.0,-130.53326416015625
1400318.0,-96.53905487060547
1450303.0,999.9164428710938
1500252.0,981.6426391601562
1550399.0,-109.76904296875
1600487.0,990.9343872070312
1650241.0,989.6995239257812
1700273.0,-266.88201904296875
1750170.0,-219.47901916503906
1800154.0,-168.58602905273438
2 changes: 1 addition & 1 deletion dreamerv3/embodied/agents/dreamerv3/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ defaults:
script: train
steps: 1e10
duration: 0
num_envs: 4
num_envs: 1
expl_until: 0
log_every: 120
save_every: 900
Expand Down
19 changes: 13 additions & 6 deletions humanoid_bench/mjx/flax_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
class TorchModel(torch.nn.Module):
def __init__(self, inputs, num_classes=1):
super(TorchModel, self).__init__()
self.dense1 = torch.nn.Linear(inputs, 256)
self.dense2 = torch.nn.Linear(256, 256)
self.dense3 = torch.nn.Linear(256, num_classes)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dense1 = torch.nn.Linear(inputs, 256).to(self.device)
self.dense2 = torch.nn.Linear(256, 256).to(self.device)
self.dense3 = torch.nn.Linear(256, num_classes).to(self.device)

def forward(self, x):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device)
elif x.device != self.device:
x = x.to(self.device)

x = torch.nn.functional.tanh(self.dense1(x))
x = torch.nn.functional.tanh(self.dense2(x))
x = self.dense3(x)
Expand All @@ -18,15 +24,16 @@ def forward(self, x):
class TorchPolicy():

def __init__(self, model):
self.model = model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device)
self.mean = None
self.var = None

def step(self, obs):
if self.mean is not None and self.var is not None:
obs = (obs - self.mean) / np.sqrt(self.var + 1e-8)
obs = torch.from_numpy(obs).float()
action = self.model(obs).detach().numpy()
action = self.model(obs).detach().cpu().numpy()
return action

def get_weights(self):
Expand All @@ -39,7 +46,7 @@ def save(self, path):
torch.save(self.model.state_dict(), path)

def load(self, path, mean=None, var=None):
self.model.load_state_dict(torch.load(path))
self.model.load_state_dict(torch.load(path, map_location=self.device))
if mean is not None and var is not None:
self.mean = np.load(mean)[0]
self.var = np.load(var)[0]
Expand Down
117 changes: 117 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import argparse
import pathlib
import os

import cv2
import gymnasium as gym
import torch
import numpy as np
from termcolor import colored

import humanoid_bench
from humanoid_bench.env import ROBOTS, TASKS
from tdmpc2.model_loader import get_agent, load_checkpoint

if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="HumanoidBench environment test")
parser.add_argument("--env", help="e.g. h1-walk-v0")
parser.add_argument("--keyframe", default=None)
parser.add_argument("--high_level_policy_path", default=None) # added this line to replace random sampling with high-level policy
parser.add_argument("--policy_path", default=None)
parser.add_argument("--mean_path", default=None)
parser.add_argument("--var_path", default=None)
parser.add_argument("--policy_type", default=None)
parser.add_argument("--blocked_hands", default="False")
parser.add_argument("--small_obs", default="False")
parser.add_argument("--obs_wrapper", default="False")
parser.add_argument("--sensors", default="")
parser.add_argument("--render_mode", default="rgb_array") # "human" or "rgb_array".
# NOTE: to get (nicer) 'human' rendering to work, you need to fix the compatibility issue between mujoco>3.0 and gymnasium: https://github.com/Farama-Foundation/Gymnasium/issues/749
args = parser.parse_args()

kwargs = vars(args).copy()
kwargs.pop("env")
kwargs.pop("render_mode")
kwargs.pop("high_level_policy_path") # added this line to replace random sampling with high-level policy
if kwargs["keyframe"] is None:
kwargs.pop("keyframe")
print(f"arguments: {kwargs}")

# Test offscreen rendering
print(f"Test offscreen mode...")
env = gym.make(args.env, render_mode="rgb_array", **kwargs)
ob, _ = env.reset()
if isinstance(ob, dict):
print(f"ob_space = {env.observation_space}")
print(f"ob = ")
for k, v in ob.items():
print(f" {k}: {v.shape}")
else:
print(f"ob_space = {env.observation_space}, ob = {ob.shape}")
print(f"ac_space = {env.action_space.shape}")

img = env.render()
rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
cv2.imwrite("test_env_img.png", rgb_img)

# Test online rendering with interactive viewer
print(f"Test onscreen mode...")
env = gym.make(args.env, render_mode=args.render_mode, **kwargs)
ob, _ = env.reset()


# Load model in two steps using the separated loader
agent = get_agent(args.policy_path, args.mean_path, args.var_path, args.policy_type, args.env)
agent = load_checkpoint(agent, args.high_level_policy_path)


# load high-level policy
if isinstance(ob, dict):
print(f"ob_space = {env.observation_space}")
print(f"ob = ")
for k, v in ob.items():
print(f" {k}: {v.shape}")
assert (
v.shape == env.observation_space.spaces[k].shape
), f"{v.shape} != {env.observation_space.spaces[k].shape}"
assert ob.keys() == env.observation_space.spaces.keys()
else:
print(f"ob_space = {env.observation_space}, ob = {ob.shape}")
assert env.observation_space.shape == ob.shape
print(f"ac_space = {env.action_space.shape}")
# print("observation:", ob)
env.render()
ret = 0
step=0
while True:

# action = env.action_space.sample()

# Get action from TD-MPC2 agent
if isinstance(ob, dict):
# Handle dictionary observations
ob_tensor = torch.cat([torch.FloatTensor(v.flatten()) for v in ob.values()])
else:
ob_tensor = torch.FloatTensor(ob)

ob_tensor = ob_tensor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

with torch.no_grad():
action = agent.act(ob_tensor, t0=step==0, eval_mode=True)
if isinstance(action, torch.Tensor):
action = action.squeeze().numpy()

ob, rew, terminated, truncated, info = env.step(action)
img = env.render()
ret += rew
step += 1

if args.render_mode == "rgb_array":
cv2.imshow("test_env", img[:, :, ::-1])
cv2.waitKey(1)

if terminated or truncated:
ret = 0
step = 0
env.reset()
env.close()
16 changes: 16 additions & 0 deletions inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
export WORK_DIR="$(pwd)"
export BASE_DIR="$WORK_DIR"
export TASK="humanoid_h1hand-push-v0"
export POLICY_PATH="${BASE_DIR}/data/reach_one_hand/torch_model.pt"
export MEAN_PATH="${BASE_DIR}/data/reach_one_hand/mean.npy"
export VAR_PATH="${BASE_DIR}/data/reach_one_hand/var.npy"
export CHECKPOINT="${BASE_DIR}/logs/humanoid_h1hand-push-v0/0/tdmpc/models/1800154.pt"

python -m tdmpc2.evaluate \
task=${TASK} \
policy_type=reach_single \
policy_path=${POLICY_PATH} \
mean_path=${MEAN_PATH} \
var_path=${VAR_PATH} \
checkpoint=${CHECKPOINT}
16 changes: 16 additions & 0 deletions old_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
export WORK_DIR="$(pwd)"
export BASE_DIR="$WORK_DIR"
export TASK="h1-push-v0"
export POLICY_PATH="${BASE_DIR}/data/reach_one_hand/torch_model.pt"
export MEAN_PATH="${BASE_DIR}/data/reach_one_hand/mean.npy"
export VAR_PATH="${BASE_DIR}/data/reach_one_hand/var.npy"
export HIGH_LEVEL_POLICY="${BASE_DIR}/logs/humanoid_h1hand-push-v0/0/tdmpc/models/1800154.pt"

python -m inference \
--env ${TASK} \
--policy_type reach_single \
--policy_path ${POLICY_PATH} \
--mean_path ${MEAN_PATH} \
--var_path ${VAR_PATH} \
--high_level_policy_path ${HIGH_LEVEL_POLICY}
Binary file added tdmpc2/tdmpc2/checkpoint.pt
Binary file not shown.
10 changes: 5 additions & 5 deletions tdmpc2/tdmpc2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ obs: state

# evaluation
checkpoint: ???
eval_episodes: 1
eval_episodes: 20
eval_freq: 50000

# training
Expand All @@ -25,7 +25,7 @@ discount_denom: 5
discount_min: 0.95
discount_max: 0.995
buffer_size: 3_000_000
exp_name: default
exp_name: tdmpc
data_dir: ???

# planning
Expand All @@ -50,7 +50,7 @@ vmin: -10
vmax: +10

# architecture
model_size: ???
model_size: 5
num_enc_layers: 2
enc_dim: 256
num_channels: 32
Expand All @@ -63,15 +63,15 @@ simnorm_dim: 8

# logging
wandb_project: humanoid-bench
wandb_entity: robot-learning
wandb_entity: albert-yw-lin
wandb_silent: false
disable_wandb: true
save_csv: true

# misc
save_video: true
save_agent: true
seed: 1
seed: 0

# convenience
work_dir: ???
Expand Down
7 changes: 4 additions & 3 deletions tdmpc2/tdmpc2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def evaluate(cfg: dict):
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
```
"""
assert torch.cuda.is_available()
# assert torch.cuda.is_available()
assert cfg.eval_episodes > 0, "Must evaluate at least 1 episode."
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
Expand Down Expand Up @@ -100,12 +100,13 @@ def evaluate(cfg: dict):
task_idx = None
ep_rewards, ep_successes = [], []
for i in range(cfg.eval_episodes):
obs, done, ep_reward, t = env.reset(task_idx=task_idx), False, 0, 0
obs, done, ep_reward, t = env.reset(task_idx=task_idx)[0], False, 0, 0
if cfg.save_video:
frames = [env.render()]
while not done:
action = agent.act(obs, t0=t == 0, task=task_idx)
obs, reward, done, info = env.step(action)
obs, reward, done, truncated, info = env.step(action)
done = done or truncated
ep_reward += reward
t += 1
if cfg.save_video:
Expand Down
2 changes: 1 addition & 1 deletion tdmpc2/tdmpc2/tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load(self, fp):
Args:
fp (str or dict): Filepath or state dict to load.
"""
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=self.device)
self.model.load_state_dict(state_dict["model"])

@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions tdmpc2/tdmpc2/trainer/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def train(self):
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, "eval")
# Add this line to save the model at each evaluation
self.logger.save_agent(self.agent, identifier=f"{self._step}")
eval_next = False

if self._step > 0:
Expand Down
Binary file modified test_env_img.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading