-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
90 lines (80 loc) · 3.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys
sys.path.append(".")
import csv
import time
import random
from argparse import ArgumentParser
import numpy as np
import agents
import envs
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--env', type=str, choices=envs.__all__, default='snake_v2')
parser.add_argument('-a', '--agent', choices=agents.AGENT_MAP.keys(), default='randomly')
parser.add_argument('-l', '--load_model', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('-r', '--render', action='store_true')
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('--dense', action='store_false')
parser.add_argument('--feature', action='store_true')
parser.add_argument('--difficulty', type=int, default=1)
parser.add_argument('--delay', type=float, default=0.)
parser.add_argument('--episode', type=int, default=int(1e8))
parser.add_argument('--resize', type=int, default=84)
parser.add_argument('--horizon', type=int, default=64)
parser.add_argument('--update_rate', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--seqlen', type=int, default=4)
parser.add_argument('--actor_units', type=int, nargs='*', default=[100] * 2)
parser.add_argument('--critic_units', type=int, nargs='*', default=[100] * 2)
parser.add_argument('--save_rate', type=int, default=100)
args = parser.parse_args()
print(args)
weight_path = 'weights/%s/%s' % (args.env, args.agent)
log_path = 'logs/%s' % (args.env)
log_file = os.path.join(log_path, '%s.csv' % args.agent)
if not os.path.exists(weight_path):
os.makedirs(weight_path)
if not os.path.exists(log_path):
os.makedirs(log_path)
args.env = envs.make(
args.env,
sparse_reward=args.dense,
use_feature=args.feature,
difficulty=args.difficulty
)
agent = agents.make(args.agent, **vars(args))
if args.load_model:
agent.load_model(weight_path)
best_score = 1.
stats = []
score, true_score = 0., 0.
step = 0.
for episode in range(1, args.episode+1):
stat = agent.play(args.render, args.verbose, args.delay, episode, args.test, False)
stats.append(stat)
score += stat['score']
step += stat['step']
true_score += stat['true_score']
print('[E%dT%d] Score: %.2f (%.2f)\t\t' % (episode, stat['step'], stat['score'], stat['true_score']), end='\r')
if episode % args.save_rate == 0 and not args.test:
# average stats
score /= args.save_rate
step /= args.save_rate
true_score /= args.save_rate
# print
print('Ep%d' % episode, 'Score:', score, '(%.2f)' % true_score, 'Step:', step, '\t\t', flush=True)
# save
if best_score < score:
best_score = score
print('New Best Score...! ', best_score)
agent.save_model(weight_path)
with open(log_file, 'a', newline='') as f:
writer = csv.writer(f)
for row in stats:
writer.writerow([r for _, r in row.items()])
score = 0.
step = 0.
stats.clear()
args.env.close()