-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtrain.py
executable file
·115 lines (87 loc) · 2.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
""" Front-end script for training a Snake agent. """
import json
import sys
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import *
from snakeai.agent import DeepQNetworkAgent
from snakeai.gameplay.environment import Environment
from snakeai.utils.cli import HelpOnFailArgumentParser
def parse_command_line_args(args):
""" Parse command-line arguments and organize them into a single structured object. """
parser = HelpOnFailArgumentParser(
description='Snake AI training client.',
epilog='Example: train.py --level 10x10.json --num-episodes 30000'
)
parser.add_argument(
'--level',
required=True,
type=str,
help='JSON file containing a level definition.',
)
parser.add_argument(
'--num-episodes',
required=True,
type=int,
default=30000,
help='The number of episodes to run consecutively.',
)
return parser.parse_args(args)
def create_snake_environment(level_filename):
""" Create a new Snake environment from the config file. """
with open(level_filename) as cfg:
env_config = json.load(cfg)
return Environment(config=env_config, verbose=1)
def create_dqn_model(env, num_last_frames):
"""
Build a new DQN model to be used for training.
Args:
env: an instance of Snake environment.
num_last_frames: the number of last frames the agent considers as state.
Returns:
A compiled DQN model.
"""
model = Sequential()
# Convolutions.
model.add(Conv2D(
16,
kernel_size=(3, 3),
strides=(1, 1),
data_format='channels_first',
input_shape=(num_last_frames, ) + env.observation_shape
))
model.add(Activation('relu'))
model.add(Conv2D(
32,
kernel_size=(3, 3),
strides=(1, 1),
data_format='channels_first'
))
model.add(Activation('relu'))
# Dense layers.
model.add(Flatten())
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dense(env.num_actions))
model.summary()
model.compile(RMSprop(), 'MSE')
return model
def main():
parsed_args = parse_command_line_args(sys.argv[1:])
env = create_snake_environment(parsed_args.level)
model = create_dqn_model(env, num_last_frames=4)
agent = DeepQNetworkAgent(
model=model,
memory_size=-1,
num_last_frames=model.input_shape[1]
)
agent.train(
env,
batch_size=64,
num_episodes=parsed_args.num_episodes,
checkpoint_freq=parsed_args.num_episodes // 10,
discount_factor=0.95
)
if __name__ == '__main__':
main()