-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathmain.py
More file actions
119 lines (105 loc) · 4.15 KB
/
main.py
File metadata and controls
119 lines (105 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import importlib
import env_adapt as env
from model_adapt import ACModel, Discriminator, AdaptNet, MapCNN
import torch
import numpy as np
import random
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str,
help="Configure file used for training. Please refer to files in `config` folder.")
parser.add_argument("--meta", type=str, default=None,
help="Pretrained meta checkpoint file.")
parser.add_argument("--ckpt", type=str, default=None,
help="Checkpoint directory or file for training or evaluation.")
parser.add_argument("--test", action="store_true", default=False,
help="Run visual evaluation.")
parser.add_argument("--seed", type=int, default=42,
help="Random seed.")
parser.add_argument("--device", type=int, default=0,
help="ID of the target GPU device for model running.")
settings = parser.parse_args()
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(settings.seed)
np.random.seed(settings.seed)
random.seed(settings.seed)
torch.manual_seed(settings.seed)
torch.cuda.manual_seed(settings.seed)
torch.cuda.manual_seed_all(settings.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
def test(env, model):
model.eval()
env.reset()
while not env.request_quit:
obs, info = env.reset_done()
seq_len = info["ob_seq_lens"]
if "map" in info:
m = info["map"]
obs = torch.cat((obs, model.critic.map(m)), -1)
model.actor.g = model.actor.map(m)
actions = model.act(obs, seq_len-1)
env.step(actions)
if __name__ == "__main__":
spec = importlib.util.spec_from_file_location("config", settings.config)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
if hasattr(config, "discriminators"):
discriminators = {
name: env.DiscriminatorConfig(**prop)
for name, prop in config.discriminators.items()
}
else:
discriminators = {"_/full": env.DiscriminatorConfig()}
env_cls = getattr(env, config.env_cls)
if not hasattr(config, "env_params"):
setattr(config, "env_params", {})
env = env_cls(1,
discriminators=discriminators,
compute_device=settings.device,
**config.env_params
)
env.episode_length = 500000
map_dim = 256 if hasattr(env, "info") and "map" in env.info else 0
value_dim = len(env.discriminators)+env.rew_dim
model = ACModel(env.state_dim, env.act_dim, env.goal_dim, value_dim, meta_goal_dim=map_dim)
discriminators = torch.nn.ModuleDict({
name: Discriminator(dim) for name, dim in env.disc_dim.items()
})
device = torch.device(settings.device)
model.to(device)
discriminators.to(device)
if settings.meta is not None and os.path.exists(settings.meta):
if os.path.isdir(settings.meta):
ckpt = os.path.join(settings.meta, "ckpt")
else:
ckpt = settings.meta
settings.meta = os.path.dirname(ckpt)
if os.path.exists(ckpt):
print("Load meta-model from {}".format(ckpt))
state_dict = torch.load(ckpt, map_location=device)
pretrained = dict()
for k, p in state_dict["model"].items():
if "actor" in k or "actor_ob_normalizer" in k:
pretrained[k] = p
model.load_state_dict(pretrained, strict=False)
model.discriminators = discriminators
model.actor = AdaptNet(model, g_dim=map_dim)
if map_dim:
model.critic.map = MapCNN()
model.actor.map = MapCNN()
model.to(device)
if settings.ckpt is not None and os.path.exists(settings.ckpt):
if os.path.isdir(settings.ckpt):
ckpt = os.path.join(settings.ckpt, "ckpt")
else:
ckpt = settings.ckpt
settings.ckpt = os.path.dirname(ckpt)
if os.path.exists(ckpt):
print("Load model from {}".format(ckpt))
state_dict = torch.load(ckpt, map_location=torch.device(settings.device))
model.load_state_dict(state_dict["model"])
env.render()
test(env, model)