-
Notifications
You must be signed in to change notification settings - Fork 0
/
ActorCriticCartPole.py
103 lines (82 loc) · 3.3 KB
/
ActorCriticCartPole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
# Hyperparameters
learning_rate = 0.0002
gamma = 0.98
n_rollout = 10
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.data = []
self.fc1 = nn.Linear(4, 256)
self.fc_pi = nn.Linear(256, 2)
self.fc_v = nn.Linear(256, 1)
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
def pi(self, x, softmax_dim=0):
x = F.relu(self.fc1(x))
x = self.fc_pi(x)
prob = F.softmax(x, dim=softmax_dim)
return prob
def v(self, x):
x = F.relu(self.fc1(x))
v = self.fc_v(x)
return v
def put_date(self, item):
self.data.append(item)
def make_batch(self):
s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
for transition in self.data:
s, a, r, s_prime, done = transition
s_lst.append(s)
a_lst.append([a])
r_lst.append([r/100.0])
s_prime_lst.append(s_prime)
done_mask = 0.0 if done else 1.0
done_lst.append([done_mask])
s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), \
torch.tensor(a_lst), \
torch.tensor(r_lst, dtype=torch.float), \
torch.tensor(s_prime_lst, dtype=torch.float), \
torch.tensor(done_lst, dtype=torch.float)
self.data = []
return s_batch, a_batch, r_batch, s_prime_batch, done_batch
def train_net(self):
s, a, r, s_prime, done = self.make_batch()
td_target = r + gamma * self.v(s_prime) * done # self.v(s_prime): batch_size * 1
delta = td_target - self.v(s)
pi = self.pi(s, softmax_dim=1) # input batch_size * 4(s) -> batch_size * 2(action), softmax_dim = 1 (action 축)
pi_a = pi.gather(1, a) # column index = a
loss = - torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
def main():
env = gym.make('CartPole-v1')
model = ActorCritic()
print_interval = 20
score = 0.0
for n_epi in range(10000):
s = env.reset()
done = False
while not done:
for t in range(n_rollout):
prob = model.pi(torch.from_numpy(s).float())
m = Categorical(prob)
a = m.sample().item()
s_prime, r, done, info = env.step(a)
model.put_date((s, a, r, s_prime, done)) # not prob? self.pi하고 나면 softmax에 의해 해당 action의 확률이 됨
s = s_prime
score += r
if done:
break
model.train_net()
if n_epi % print_interval == 0 and n_epi != 0:
print("# of episode: {}. avg. score: {}".format(n_epi, score/print_interval))
score = 0.0
env.close()
if __name__ == '__main__':
main()