-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathac.py
88 lines (77 loc) · 3.07 KB
/
ac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import numpy as np
import torch.distributions as dist
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, categorical, hidden_dim=128):
super(Actor, self).__init__()
self.actor_mlp = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.categorical = categorical
if categorical:
self.head = nn.Sequential(
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1))
else:
self.mean = nn.Linear(hidden_dim, action_dim)
# as parameter, not based on input
self.log_std = torch.nn.Parameter( torch.as_tensor( -.5*np.ones(action_dim, dtype=np.float32) ) )
#self.sqrt_std = torch.nn.Parameter( torch.as_tensor( -np.ones(action_dim, dtype=np.float32) ) )
def _get_distr(self, obs):
scores = self.actor_mlp(obs)
if self.categorical:
xx = dist.Categorical(self.head(scores))
else:
xx = dist.Normal(self.mean(scores), self.log_std.exp())
#xx = dist.Normal(self.mean(scores), .5 * self.sqrt_std.square())
return xx
def pi(self, obs):
xx = self._get_distr(obs)
act = xx.sample().detach()
logp = xx.log_prob(act).detach()
return act, logp
def logprob(self, obs, act):
xx = self._get_distr(obs)
logp = xx.log_prob(act)
return logp
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, categorical, hidden_dim=128):
super(Critic, self).__init__()
self.critic_mlp = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
#nn.Linear(hidden_dim, hidden_dim),
#nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def predict(self, obs, detach=False):
v_t = self.critic_mlp(obs).squeeze()
if detach:
v_t=v_t.detach().cpu().numpy()
return v_t
class MLPAC(nn.Module):
def __init__(self, state_dim, action_dim, categorical, hidden_dim=128, device="cuda"):
super(MLPAC, self).__init__()
self.actor = Actor(state_dim, action_dim, categorical, hidden_dim=hidden_dim)
self.critic = Critic(state_dim, action_dim, categorical, hidden_dim=hidden_dim)
self.device = torch.device(device)
self.to(self.device)
self.to(torch.float32)
def to_torch(self, obs):
return torch.from_numpy(obs).to(self.device).to(torch.float32)
def step(self, obs):
obs = self.to_torch(obs)
action, logp = self.actor.pi(obs)
v = self.critic.predict(obs, detach=True)
return action, v, logp
def logprob(self, obs, act):
obs = self.to_torch(obs)
act = torch.stack(act).to(self.device) # list
return self.actor.logprob(obs, act)
def predict(self, obs, detach=False):
obs = self.to_torch(obs)
return self.critic.predict(obs, detach=detach)