-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
102 lines (85 loc) · 3.14 KB
/
agent.py
File metadata and controls
102 lines (85 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import json
import numpy as np
import tensorflow as tf
from replay_buffer import ReplayBuffer
from dqn import DQN
class Agent:
def __init__(
self,
alpha,
gamma,
num_actions,
epsilon,
batch_size,
input_shape,
model,
model_file,
metric_file,
epsilon_decay=0.999,
epsilon_min=0.01,
copy_period=300,
mem_size=250,
):
self.action_space = np.arange(num_actions)
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.batch_size = batch_size
self.model_file = model_file
self.metric_file = metric_file
self.copy_period = copy_period
self.memory = ReplayBuffer(mem_size, input_shape, num_actions, discrete=True)
self.dqn = DQN(model, learning_rate=alpha)
self.target_dqn = self.dqn.create_target_network()
self.learn_counter = 0
self.metrics = {}
self.temp = {}
def remember(self, state, action, reward, next_state, done):
self.memory.store_transition(state, action, reward, next_state, done)
def choose_action(self, state):
state = state[np.newaxis, :]
return (
np.random.choice(self.action_space)
if np.random.random() < self.epsilon
else np.argmax(self.dqn.get_model().predict(state))
)
def learn(self):
self.learn_counter += 1
if self.memory.mem_counter < self.batch_size:
return
states, actions, rewards, next_states, terminals = self.memory.sample_buffer(
self.batch_size
)
qs = self.dqn.get_model().predict(states)
qs_next = self.target_dqn.get_model().predict(next_states)
batch_index = np.arange(self.batch_size)
qs_target = qs.copy()
qs_target[batch_index, actions] = (
rewards + self.gamma * np.max(qs_next, axis=1) * terminals
)
self.dqn.get_model().fit(states, qs_target, verbose=0)
if self.learn_counter % self.copy_period == 0:
self.target_dqn.copy_from(self.dqn)
def finished_episode(self):
if self.epsilon > self.epsilon_min:
self.epsilon = self.epsilon * self.epsilon_decay
def save_model(self, model_file=None):
if model_file == None or model_file == "":
model_file = self.model_file
self.dqn.save_model(self.model_file)
def load_model(self, model_file=None):
if model_file == None or model_file == "":
model_file = self.model_file
self.dqn.load_model(model_file)
self.target_dqn.copy_from(self.dqn)
def save_data(self, metric_file=None):
if metric_file == None or metric_file == "":
metric_file = self.metric_file
with open(metric_file, "w") as out_file:
json.dump(self.metrics, out_file)
def load_data(self, metric_file=None):
if metric_file == None or metric_file == "":
metric_file = self.metric_file
with open(metric_file) as json_file:
self.metrics = json.load(json_file)