-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils.py
145 lines (109 loc) · 4.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.transforms import Compose, ToTensor
# compose a transform configuration
transform_config = Compose([ToTensor()])
def accumulate_group_evidence(class_mu, class_logvar, labels_batch, is_cuda):
"""
:param class_mu: mu values for class latent embeddings of each sample in the mini-batch
:param class_logvar: logvar values for class latent embeddings for each sample in the mini-batch
:param labels_batch: class labels of each sample (the operation of accumulating class evidence can also
be performed using group labels instead of actual class labels)
:param is_cuda:
:return:
"""
var_dict = {}
mu_dict = {}
# convert logvar to variance for calculations
class_var = class_logvar.exp_()
# calculate var inverse for each group using group vars
for i in range(len(labels_batch)):
group_label = labels_batch[i].item()
# remove 0 values from variances
class_var[i][class_var[i] == float(0)] = 1e-6
if group_label in var_dict.keys():
var_dict[group_label] += 1 / class_var[i]
else:
var_dict[group_label] = 1 / class_var[i]
# invert var inverses to calculate mu and return value
for group_label in var_dict.keys():
var_dict[group_label] = 1 / var_dict[group_label]
# calculate mu for each group
for i in range(len(labels_batch)):
group_label = labels_batch[i].item()
if group_label in mu_dict.keys():
mu_dict[group_label] += class_mu[i] * (1 / class_var[i])
else:
mu_dict[group_label] = class_mu[i] * (1 / class_var[i])
# multiply group var with sums calculated above to get mu for the group
for group_label in mu_dict.keys():
mu_dict[group_label] *= var_dict[group_label]
# replace individual mu and logvar values for each sample with group mu and logvar
group_mu = torch.FloatTensor(class_mu.size(0), class_mu.size(1))
group_var = torch.FloatTensor(class_var.size(0), class_var.size(1))
if is_cuda:
group_mu = group_mu.cuda()
group_var = group_var.cuda()
for i in range(len(labels_batch)):
group_label = labels_batch[i].item()
group_mu[i] = mu_dict[group_label]
group_var[i] = var_dict[group_label]
# remove 0 from var before taking log
group_var[i][group_var[i] == float(0)] = 1e-6
# convert group vars into logvars before returning
return Variable(group_mu, requires_grad=True), Variable(torch.log(group_var), requires_grad=True)
def mse_loss(input, target):
return torch.sum((input - target).pow(2)) / input.data.nelement()
def l1_loss(input, target):
return torch.sum(torch.abs(input - target)) / input.data.nelement()
def reparameterize(training, mu, logvar):
if training:
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
return eps.mul(std).add_(mu)
else:
return mu
def group_wise_reparameterize(training, mu, logvar, labels_batch, cuda):
eps_dict = {}
# generate only 1 eps value per group label
for label in torch.unique(labels_batch):
if cuda:
eps_dict[label.item()] = torch.cuda.FloatTensor(1, logvar.size(1)).normal_(0., 0.1)
else:
eps_dict[label.item()] = torch.FloatTensor(1, logvar.size(1)).normal_(0., 0.1)
if training:
std = logvar.mul(0.5).exp_()
reparameterized_var = Variable(std.data.new(std.size()))
# multiply std by correct eps and add mu
for i in range(logvar.size(0)):
reparameterized_var[i] = std[i].mul(Variable(eps_dict[labels_batch[i].item()]))
reparameterized_var[i].add_(mu[i])
return reparameterized_var
else:
return mu
def weights_init(layer):
if isinstance(layer, nn.Conv2d):
layer.weight.data.normal_(0.0, 0.05)
layer.bias.data.zero_()
elif isinstance(layer, nn.BatchNorm2d):
layer.weight.data.normal_(1.0, 0.02)
layer.bias.data.zero_()
elif isinstance(layer, nn.Linear):
layer.weight.data.normal_(0.0, 0.05)
layer.bias.data.zero_()
def imshow_grid(images, shape=[2, 8], name='default', save=False):
"""Plot images in a grid of a given shape."""
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in range(size):
grid[i].axis('off')
grid[i].imshow(images[i]) # The AxesGrid object work as a list of axes.
if save:
plt.savefig('reconstructed_images/' + str(name) + '.png')
plt.clf()
else:
plt.show()