-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathmodel.py
68 lines (51 loc) · 1.85 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from utils import cuda
import time
from numbers import Number
class ToyNet(nn.Module):
def __init__(self, K=256):
super(ToyNet, self).__init__()
self.K = K
self.encode = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 2*self.K))
self.decode = nn.Sequential(
nn.Linear(self.K, 10))
def forward(self, x, num_sample=1):
if x.dim() > 2 : x = x.view(x.size(0),-1)
statistics = self.encode(x)
mu = statistics[:,:self.K]
std = F.softplus(statistics[:,self.K:]-5,beta=1)
encoding = self.reparametrize_n(mu,std,num_sample)
logit = self.decode(encoding)
if num_sample == 1 : pass
elif num_sample > 1 : logit = F.softmax(logit, dim=2).mean(0)
return (mu, std), logit
def reparametrize_n(self, mu, std, n=1):
# reference :
# http://pytorch.org/docs/0.3.1/_modules/torch/distributions.html#Distribution.sample_n
def expand(v):
if isinstance(v, Number):
return torch.Tensor([v]).expand(n, 1)
else:
return v.expand(n, *v.size())
if n != 1 :
mu = expand(mu)
std = expand(std)
eps = Variable(cuda(std.data.new(std.size()).normal_(), std.is_cuda))
return mu + eps * std
def weight_init(self):
for m in self._modules:
xavier_init(self._modules[m])
def xavier_init(ms):
for m in ms :
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
m.bias.data.zero_()