-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
202 lines (173 loc) · 7.74 KB
/
train.py
File metadata and controls
202 lines (173 loc) · 7.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""使用收集到数据进行训练"""
import random
from collections import defaultdict, deque
import numpy as np
import pickle
import time
import os
from go import Game, Board
from mcts import MCTSPlayer
# from mcts_pure import MCTS_Pure
from pytorch_net import PVnet
from random import sample
import copy
from go import width,height
def main():
training_pipeline = TrainPipeline(init_model='model/current_policy.pkl')
training_pipeline.run()
# 定义整个训练流程
class TrainPipeline:
def __init__(self, init_model=None):
# 训练参数
self.n_playout = 100
self.c_puct = 5
self.learn_rate = 1e-3
self.temp = 1.0
self.batch_size = 512
self.epochs = 3
self.check_freq = 100 # 保存模型的频率
self.game_batch_num = 30000 # 训练更新的次数
self.best_win_ratio = 0.0
# self.pure_mcts_playout_num = 500
self.buffer_size=1000000
self.data_buffer = deque(maxlen=self.buffer_size)
if init_model:
try:
self.policy_value_net = PVnet(model_file=init_model)
print('已加载上次最终模型')
except:
# 从零开始训练
print('模型不存在,从零开始训练')
self.policy_value_net = PVnet(model_file="model/"+'iters'+'_'+'0'+".pkl")
def policy_evaluate(self, n_games=10):
"""
Evaluate the trained policy by playing against the pure MCTS player
Note: this is only for monitoring the progress of training
"""
self.iters=0
try:
while os.path.exists("model/"+'iters'+'_'+str(self.iters)+".pkl"):
self.iters+=1
self.iters-=1
self.best_net = PVnet(model_file="model/"+'iters'+'_'+str(self.iters)+".pkl")
print('已加载历史最优模型',str(self.iters))
except:
print('error')
assert False
current_mcts_player = MCTSPlayer(self.policy_value_net.pvnet_fn,
c_puct=self.c_puct,
n_playout=self.n_playout)
best_mcts_player = MCTSPlayer(self.best_net.pvnet_fn,
c_puct=self.c_puct,
n_playout=self.n_playout)
# pure_mcts_player = MCTS_Pure(c_puct=5,
# n_playout=self.pure_mcts_playout_num)
win_c=0
for i in range(n_games):
board = Board()
game = Game()
winner = game.play_against(current_mcts_player, best_mcts_player,board, self.iters+1, self.iters)
if winner == 1:
win_c += 1
win_ratio = win_c / n_games
print("num_playouts:{}, win: {}".format(
n_games,
win_c))
if win_ratio>=0.8:
self.policy_value_net.save_model('model/current_policy.pkl')
self.policy_value_net.save_model("model/"+'iters'+'_'+str(self.iters+1)+".pkl")
print('模型更新成功!')
else:
print('失败,继续训练')
return win_ratio
def policy_updata(self):
"""更新策略价值网络"""
''' 待修改'''
for i in range(self.epochs):
data_buffer=copy.deepcopy(self.data_buffer)
''' ???待修改'''
mini_batch = []
for i in range(self.batch_size):
mini_batch.append(data_buffer.popleft())
mini_batch=sample(mini_batch, self.batch_size)
# print(mini_batch[0][1],mini_batch[1][1])
state_batch = [data[0] for data in mini_batch]
state_batch = np.array(state_batch).astype('float32')
mcts_probs_batch = [data[1].tolist() for data in mini_batch]
mcts_probs_batch = np.array(mcts_probs_batch).astype('float32')
winner_batch = [data[2] for data in mini_batch]
winner_batch = np.array(winner_batch).astype('float32')
# 旧的策略,旧的价值函数
old_probs, old_v = self.policy_value_net.policy_value(state_batch)
loss = self.policy_value_net.train_step(
state_batch,
mcts_probs_batch,
winner_batch,
self.learn_rate
)
# 新的策略,新的价值函数
new_probs, new_v = self.policy_value_net.policy_value(state_batch)
explained_var_old = (1 -
np.var(np.array(winner_batch) - old_v.flatten()) /
np.var(np.array(winner_batch)))
explained_var_new = (1 -
np.var(np.array(winner_batch) - new_v.flatten()) /
np.var(np.array(winner_batch)))
print((
"loss:{},"
"explained_var_old:{:.9f},"
"explained_var_new:{:.9f}"
).format(loss,
explained_var_old,
explained_var_new))
self.data_buffer=data_buffer
return loss
def get_equi_data(self, play_data):
"""对称变换,数据增加8倍"""
extend_data = []
for state, mcts_prob, winner in play_data:
pas=mcts_prob[width*height]
extend_data.append((state, mcts_prob, winner))
mcts_prob=np.array(mcts_prob[0:width*height]).reshape((width,height)) ####检查方式
extend_data.append((np.rot90(state,1,(1, 2)), np.append(np.rot90(mcts_prob,1).flatten(),pas), winner))
extend_data.append((np.rot90(state,2,(1, 2)), np.append(np.rot90(mcts_prob,2).flatten(),pas), winner))
extend_data.append((np.rot90(state,3,axes=(1, 2)), np.append(np.rot90(mcts_prob,3).flatten(),pas), winner))
state=state.transpose([0, 2, 1])
mcts_prob=mcts_prob.transpose()
extend_data.append((state, np.append(mcts_prob.flatten(),pas), winner))
extend_data.append((np.rot90(state,1,(1, 2)), np.append(np.rot90(mcts_prob,1).flatten(),pas), winner))
extend_data.append((np.rot90(state,2,(1, 2)), np.append(np.rot90(mcts_prob,2).flatten(),pas), winner))
extend_data.append((np.rot90(state,3,axes=(1, 2)), np.append(np.rot90(mcts_prob,3).flatten(),pas), winner))
return extend_data
def run(self):
"""开始训练"""
while True:
try:
with open('data/train_data_buffer.pkl' ,'rb') as data_dict:
data_file = pickle.load(data_dict)
self.data_buffer.extend(data_file['data_buffer']) #合并
self.iters = data_file['iters']
del data_file
print('已载入数据')
os.remove('data/train_data_buffer.pkl')
# 增加数据
random.shuffle(self.data_buffer)
data_buffer = self.get_equi_data(self.data_buffer)
self.data_buffer = deque(maxlen=self.buffer_size)
self.data_buffer.extend(data_buffer)
print('数据扩充完毕')
break
except:
print('读取训练数据失败,等待')
time.sleep(30)
print('step i {}: '.format(self.iters))
while len(self.data_buffer) > self.batch_size:
print(len(self.data_buffer))
loss= self.policy_updata()
# 保存模型
self.policy_value_net.save_model('model/current_policy.pkl')
print('模型保存')
print('开始进行模型检验')
self.policy_evaluate()
if __name__=='__main__':
main()