-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrainer.py
78 lines (57 loc) · 2.25 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow as tf
import gym
from A3C_Network import A3C_Network
from Worker import Worker
from config import *
from Summary import *
import threading
from time import sleep
try:
tf.reset_default_graph()
global_episodes = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
training_episodes = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
#summary_writer = tf.summary.FileWriter(args.summary_dir)
n_threads = args.num_thread
Env = gym.make(args.environment)
no_action = Env.action_space.n
Env.close()
learning_rate = tf.train.polynomial_decay(args.learning_rate, global_episodes, args.decay_steps,
args.learning_rate * 0.1)
#trainer = tf.train.RMSPropOptimizer(learning_rate=args.learning_rate, decay=args.decay)
summary_writer = []
for id in range(n_threads):
summary_writer.append(tf.summary.FileWriter(args.summary_dir+'/worker_'+str(id)))
summary_parameters = Summary_Parameters()
write_op = tf.summary.merge_all()
master_network = A3C_Network(args, no_action, 'master_network')
workers = []
env_list = []
for id in range(n_threads):
env = gym.make(args.environment)
if id == 0:
env = gym.wrappers.Monitor(env, "monitors", force=True)
workers.append(Worker(global_episodes, training_episodes, master_network, id, learning_rate, env, summary_writer[id],
summary_parameters, write_op, args))
env_list.append(env)
with tf.Session() as sess:
saver = tf.train.Saver(max_to_keep=5)
master_network.load_model(sess, saver)
coord = tf.train.Coordinator()
thread_list = []
for id in range(n_threads):
t = threading.Thread(target=workers[id].process, args=(sess, coord, saver))
t.start()
sleep(0.5)
thread_list.append(t)
coord.join(thread_list)
for t in thread_list:
t.start()
print("Ctrl + C to close")
coord.wait_for_stop()
except KeyboardInterrupt:
print("Closing threads")
coord.request_stop()
print("Closing environments")
for env in env_list:
env.close()
sess.close()