-
Notifications
You must be signed in to change notification settings - Fork 0
/
01_a3c_data.py
executable file
·129 lines (101 loc) · 4.34 KB
/
01_a3c_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
import gym
import ptan
import numpy as np
import argparse
import collections
from tensorboardX import SummaryWriter
import torch
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from lib import common
GAMMA = 0.99
LEARNING_RATE = 0.001
ENTROPY_BETA = 0.01
BATCH_SIZE = 128
REWARD_STEPS = 4
CLIP_GRAD = 0.1
PROCESSES_COUNT = 4
NUM_ENVS = 15
if True:
ENV_NAME = "PongNoFrameskip-v4"
NAME = 'pong'
REWARD_BOUND = 18
else:
ENV_NAME = "BreakoutNoFrameskip-v4"
NAME = "breakout"
REWARD_BOUND = 400
def make_env():
return ptan.common.wrappers.wrap_dqn(gym.make(ENV_NAME))
TotalReward = collections.namedtuple('TotalReward', field_names='reward')
def data_func(net, device, train_queue):
envs = [make_env() for _ in range(NUM_ENVS)]
agent = ptan.agent.PolicyAgent(lambda x: net(x)[0], device=device, apply_softmax=True)
exp_source = ptan.experience.ExperienceSourceFirstLast(envs, agent, gamma=GAMMA, steps_count=REWARD_STEPS)
for exp in exp_source:
new_rewards = exp_source.pop_total_rewards()
if new_rewards:
train_queue.put(TotalReward(reward=np.mean(new_rewards)))
train_queue.put(exp)
if __name__ == "__main__":
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
parser.add_argument("-n", "--name", required=True, help="Name of the run")
args = parser.parse_args()
device = "cuda" if args.cuda else "cpu"
writer = SummaryWriter(comment="-a3c-data_" + NAME + "_" + args.name)
env = make_env()
net = common.AtariA2C(env.observation_space.shape, env.action_space.n).to(device)
net.share_memory()
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
train_queue = mp.Queue(maxsize=PROCESSES_COUNT)
data_proc_list = []
for _ in range(PROCESSES_COUNT):
data_proc = mp.Process(target=data_func, args=(net, device, train_queue))
data_proc.start()
data_proc_list.append(data_proc)
batch = []
step_idx = 0
try:
with common.RewardTracker(writer, stop_reward=REWARD_BOUND) as tracker:
with ptan.common.utils.TBMeanTracker(writer, batch_size=100) as tb_tracker:
while True:
train_entry = train_queue.get()
if isinstance(train_entry, TotalReward):
if tracker.reward(train_entry.reward, step_idx):
break
continue
step_idx += 1
batch.append(train_entry)
if len(batch) < BATCH_SIZE:
continue
states_v, actions_t, vals_ref_v = \
common.unpack_batch(batch, net, last_val_gamma=GAMMA**REWARD_STEPS, device=device)
batch.clear()
optimizer.zero_grad()
logits_v, value_v = net(states_v)
loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)
log_prob_v = F.log_softmax(logits_v, dim=1)
adv_v = vals_ref_v - value_v.detach()
log_prob_actions_v = adv_v * log_prob_v[range(BATCH_SIZE), actions_t]
loss_policy_v = -log_prob_actions_v.mean()
prob_v = F.softmax(logits_v, dim=1)
entropy_loss_v = ENTROPY_BETA * (prob_v * log_prob_v).sum(dim=1).mean()
loss_v = entropy_loss_v + loss_value_v + loss_policy_v
loss_v.backward()
nn_utils.clip_grad_norm_(net.parameters(), CLIP_GRAD)
optimizer.step()
tb_tracker.track("advantage", adv_v, step_idx)
tb_tracker.track("values", value_v, step_idx)
tb_tracker.track("batch_rewards", vals_ref_v, step_idx)
tb_tracker.track("loss_entropy", entropy_loss_v, step_idx)
tb_tracker.track("loss_policy", loss_policy_v, step_idx)
tb_tracker.track("loss_value", loss_value_v, step_idx)
tb_tracker.track("loss_total", loss_v, step_idx)
finally:
for p in data_proc_list:
p.terminate()
p.join()