-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_net_with_model.py
More file actions
130 lines (114 loc) · 3.85 KB
/
Copy pathtest_net_with_model.py
File metadata and controls
130 lines (114 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from wrapper import RelativePosition, FlattenDict, SerializeAction
from gymnasium.envs.registration import register
import gymnasium as gym
import train_params_with_model as params
import modified_DDPG
import OurDDPG
import numpy as np
from loguru import logger
size = params.size
relay_config = params.relay_config
client_config = params.client_config
init_config = params.init_config
is_polar = False
# register the environment
register(
id='GridWorld-v0',
entry_point='grid_world:GridWorldEnv',
max_episode_steps=500,
kwargs={
"size": size,
"relay_config": relay_config,
"client_config": client_config,
"init_config": init_config,
"is_polar": is_polar,
"is_plot": False,
"is_log": False,
"use_model": True,
}
)
def get_env():
origin_env = gym.make("GridWorld-v0")
relative_env = RelativePosition(origin_env)
flatten_env = FlattenDict(relative_env)
env = SerializeAction(flatten_env, is_polar=is_polar)
return env
# create the environment
env = get_env()
def load_modified_DDPG():
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
kwargs = {
"state_dim": state_dim,
"action_dim": action_dim,
"max_action": max_action,
"discount": 0.5,
"tau": 0.005,
}
kwargs["position_range"] = {
"position": [-size / 2, size / 2],
"height": [relay_config.min_height, relay_config.max_height]
}
kwargs["relay_dim"] = relay_config.num * 3
kwargs["client_dim"] = client_config.num * 2
kwargs["speed"] = relay_config.speed
policy = modified_DDPG.DDPG(**kwargs)
policy.load("models/modified_DDPG_GridWorld-v0_with_model_0_2024-10-13_22-50-42")
return policy
def load_OurDDPG():
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
kwargs = {
"state_dim": state_dim,
"action_dim": action_dim,
"max_action": max_action,
"discount": 0.5,
"tau": 0.005,
}
policy = OurDDPG.DDPG(**kwargs)
policy.load("models/OurDDPG_GridWorld-v0_with_model_0_2024-11-06_20-35-13")
return policy
def set_seed(seed):
env.action_space.seed(seed)
np.random.seed(seed)
state, info = env.reset(seed=seed)
return state, info
seed = 0
eval_episodes = 100
def eval(policy_name, policy):
reward_list = []
step_reward_list = []
for eval_index in range(eval_episodes):
state, info = set_seed(seed + 100 + eval_index)
done = False
current_reward = 0
current_step_reward = []
while not done:
action = policy.select_action(np.array(state))
next_state, reward, _, done, next_info = env.step(action)
current_reward += reward
current_step_reward.append(reward)
state = next_state
step_reward_list.append(current_step_reward)
reward_list.append(current_reward)
logger.info(f"-- {policy_name} --")
logger.info(f"reward: {reward_list}")
logger.info(f"reward length: {len(reward_list)}")
logger.info(f"reward sum: {sum(reward_list)}")
logger.info(f"reward average: {sum(reward_list) / len(reward_list)}")
np.savetxt(f"eval/{policy_name}_reward.txt", reward_list, fmt="%.3f")
np.savetxt(f"eval/{policy_name}_step_reward.txt", step_reward_list, fmt="%.5f", delimiter=",")
# use modified_DDPG to see the performance
policy = load_modified_DDPG()
eval("modified_DDPG", policy)
# use OurDDPG to see the performance
policy = load_OurDDPG()
eval("OurDDPG", policy)
# use random to see the performance
class RandomPolicy:
def select_action(self, state):
return env.action_space.sample()
policy = RandomPolicy()
eval("Random", policy)