-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatchnorm.py
49 lines (41 loc) · 1.57 KB
/
batchnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
class BatchNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(1, dim), requires_grad=True)
self.batch_mean = None
self.batch_var = None
def forward(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps # torch.mean((x - m) ** 2, axis=0) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean.clone()
v = self.batch_var.clone()
x_hat = (x - m) / torch.sqrt(v)
x_hat = x_hat * torch.exp(self.gamma) + self.beta
log_det = torch.sum(self.gamma - 0.5 * torch.log(v))
return x_hat, log_det
def reverse(self, x):
if self.training:
m = x.mean(dim=0)
v = x.var(dim=0) + self.eps
self.batch_mean = None
else:
if self.batch_mean is None:
self.set_batch_stats_func(x)
m = self.batch_mean
v = self.batch_var
x_hat = (x - self.beta) * torch.exp(-self.gamma) * torch.sqrt(v) + m
log_det = torch.sum(-self.gamma + 0.5 * torch.log(v))
return x_hat, log_det
def set_batch_stats_func(self, x):
print("setting batch stats for validation")
self.batch_mean = x.mean(dim=0)
self.batch_var = x.var(dim=0) + self.eps