-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
539 lines (463 loc) · 22.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
import csv
import os
import numpy as np
import torch
from PIL import Image
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import EMNIST, MNIST, FashionMNIST, CIFAR10, CIFAR100
from tqdm import tqdm
# definitions of train and test transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop((28, 28)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
test_transform = transforms.Compose([
transforms.ToTensor()])
train_transform_cifar = transforms.Compose([
transforms.RandomResizedCrop((32, 32)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
def get_args(parser):
"""
This function adds the arguments to the parser
:param parser: # the parser to add the arguments to
:return: # the parser with the added arguments
"""
# the dataset to train on
parser.add_argument('-d', '--dataset', default="EMNIST", type=str)
# the size of the latent space representation
parser.add_argument('-zd', '--z-dim', default=64, type=int,
help='the dimension of the latent space representation')
# input image channel size
parser.add_argument('-c', '--channel-size', default=3, type=int,
help='the channel size of the input')
parser.add_argument('-v', '--verbose', action='store_true',
help='print things')
parser.add_argument('-r', '--resume', action='store_true', help='resume training from path given')
# argument to use wandb
parser.add_argument('-w', '--wandb', action='store_true', help='use wandb')
# argument for model size
parser.add_argument('-ms', '--model-size', type=str, default='large', help='size of the model')
# argument for augmnentation
parser.add_argument('-aug', '--augment', action='store_true', help='use data augmentation')
# argument for the type of model
parser.add_argument('-at', '--ae-type', type=str, default='crvae', help='type of autoencoder')
parser.add_argument('-bs', '--batch-size', default=512, type=int)
parser.add_argument('-e', '--epochs', default=10, type=int)
parser.add_argument('-lr', '--learning-rate', default=1e-5, type=float)
parser.add_argument('-nw', '--num-workers', default=16, type=int)
parser.add_argument('-si', '--save_integral', default=10, type=int)
# parser.add_argument('-lp', '--load-path', type=str)
# parser.add_argument('-ckp', '--checkpoint', action='store_true',
# help='load model checkpoint. If false model dictionary is loaded')
# parser.add_argument('-n', '--model-name', type=str)
# parser.add_argument('-rp', '--run-path', type=str)
# loss function to use
parser.add_argument('-l', '--loss', default='bce',
help=' reconstruction loss function to use. currently supporting mse and bce')
parser.add_argument('-ni', '--norm-input', action='store_true',
help='normalize input')
parser.add_argument('-sh', '--shuffle', action='store_true',
help='shuffle train dataloader')
parser.add_argument('-opt', '--optimizer', default='sgd',
help='which optimizer to use')
parser.add_argument('-wd', '--weight-decay', default=1e-8, type=float, help='weight decay for optimizer')
parser.add_argument('-m', '--momentum', default=0.9, type=float, help='momentum value for sgd optimizer')
parser.add_argument('--pretrained', default='', type=str,
help='path to moco pretrained checkpoint')
parser.add_argument('-aut', '--active-units-threshold', default=0.01, type=float, help='active units threshold')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--tsne_dim', default=2, type=int)
parser.add_argument('--knn_k', default=200, type=int, help='number of voters in the KNN algorithm')
parser.add_argument('--knn_t', default=0.1, type=float, help='temperature parameter for the weighting in KNN')
parser.add_argument('--alpha', default=1, type=float, help='weight of reconstruction term')
parser.add_argument('--beta', default=1.0, type=float, help='weight of KL term')
parser.add_argument('--gamma', default=1, type=float, help='weight of infonce term')
parser.add_argument('--delta', default=0, type=float, help='InfoMax temperature')
parser.add_argument('--d_lr', default=1e-3, type=float, help='learning rate for discriminator')
parser.add_argument('--K', default=4096, type=int, help='number of negative samples')
parser.add_argument('--representation_metrics', default=1, type=int, help='produce representation metrics')
args = parser.parse_args() # running in command line
return args
class EMNISTPair(EMNIST):
"""
This is a modified version of the EMNIST dataset class from torchvision.
It returns a pair of stochastic augmentations of an image.
"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img.to("cpu").detach().numpy())
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
class MNISTPair(MNIST):
"""
This is a modified version of the MNIST dataset class from torchvision.
It returns a pair of stochastic augmentations of an image.
"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img.to("cpu").detach().numpy())
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
class FashionMNISTPair:
"""
This is a modified version of the FashionMNIST dataset class from torchvision.
It returns a pair of stochastic augmentations of an image.
"""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data = FashionMNIST(root, train=train, transform=transform, target_transform=target_transform,
download=download).data
self.targets = FashionMNIST(root, train=train, transform=transform, target_transform=target_transform,
download=download).targets
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img.to("cpu").detach().numpy())
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
def __len__(self):
return len(self.data)
class CIFAR10Pair(CIFAR10):
"""
This is a modified version of the CIFAR10 dataset class from torchvision.
It returns a pair of stochastic augmentations of an image.
"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
class CIFAR100Pair(CIFAR100):
"""
This is a modified version of the CIFAR100 dataset class from torchvision.
It returns a pair of stochastic augmentations of an image.
"""
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
def get_optimizer(model_parameteres, args):
"""
This function returns the optimizer based on the arguments.
"""
if args.optimizer in ["adam", "Adam", "ADAM"]:
return optim.Adam(model_parameteres, lr=args.learning_rate, weight_decay=args.weight_decay)
elif args.optimizer in ["sgd", "SGD", "Sdg"]:
return optim.SGD(model_parameteres, lr=args.learning_rate, weight_decay=args.weight_decay,
momentum=args.momentum)
else:
raise Exception("unknown optimizer asked in \"get_optimizer()\"")
def kl_divergence(mu, logvar):
return (-0.5 * (1 + logvar - mu ** 2 - logvar.exp()).sum(1)).mean()
def get_lr(optimizer):
"""Get learning rate from optimizer."""
for param_group in optimizer.param_groups:
return param_group['lr']
class Discriminator(nn.Module):
"""
Discriminator network for computing the mutual information between the latent space and the data for InfoMax VAE.
"""
def __init__(self, args=None):
super(Discriminator, self).__init__()
self.channels = args.color_channels
self.height, self.width = args.size
self.z_dim = args.z_dim
self.net = nn.Sequential(
nn.Linear(self.channels * self.height * self.width + self.z_dim, 2000),
nn.ReLU(True),
nn.Linear(2000, 1000),
nn.ReLU(True),
nn.Linear(1000, 100),
nn.ReLU(True),
nn.Linear(100, 1))
def forward(self, x, z):
x = x.view(-1, self.channels * self.height * self.width)
x = torch.cat((x, z), 1)
pred = self.net(x).squeeze() if self.channels > 1 else self.net(x)
return pred
def get_loss(model, reconstructed, img, mu, logvar):
"""
This function computes the loss for the VAE.
"""
reconstruction_loss = model.reconstruction_loss(reconstructed, img).sum([1, 2, 3]).mean()
kld = kl_divergence(mu, logvar)
vae_loss = reconstruction_loss + model.beta * kld
return vae_loss, reconstruction_loss, kld
def calc_mi(model, test_loader, device):
"""
adjusted from (https://github.com/jxhe/vae-lagging-encoder)
compute the mutual information between the latent space and the data
"""
mi = 0
num_examples = 0
for datum in test_loader:
batch_data, _ = datum
batch_data = batch_data.to(device)
batch_size = batch_data.size(0)
num_examples += batch_size
mutual_info = model.encoder.calc_mi(batch_data)
mi += mutual_info * batch_size
return - mi / num_examples
def calc_au(model, test_data_batch, delta=0.01):
"""
adjusted from (https://github.com/jxhe/vae-lagging-encoder)
compute the number of active units
"""
means_sum, var_sum = None, None
cnt = 0
for batch_data in test_data_batch:
if isinstance(batch_data, list):
batch_data = batch_data[0]
_, mu, _, _ = model.predict_batch(batch_data)
if means_sum is None:
means_sum = mu.sum(dim=0, keepdim=True)
else:
means_sum = means_sum + mu.sum(dim=0, keepdim=True)
cnt += mu.size(0)
# (1, nz)
mean_mean = means_sum / cnt
cnt = 0
for batch_data in test_data_batch:
if isinstance(batch_data, list):
batch_data = batch_data[0]
_, mu, _, _ = model.predict_batch(batch_data)
if var_sum is None:
var_sum = ((mu - mean_mean) ** 2).sum(dim=0)
else:
var_sum = var_sum + ((mu - mean_mean) ** 2).sum(dim=0)
cnt += mu.size(0)
# (nz)
au_var = var_sum / (cnt - 1)
return (au_var >= delta).sum().item(), au_var
def save_model(checkpoint, model_name, args):
paths = ["model/model_checkpoints/"]
for path in paths:
if not os.path.isdir(path):
os.makedirs(path)
save_ckp(checkpoint, f"model/model_checkpoints/{model_name}")
def save_ckp(state, checkpoint_dir):
f_path = f'{checkpoint_dir}.ckp'
torch.save(state, f_path)
def representation_metric_test(net, memory_data_loader, test_data_loader, knn_k, epoch, inference=True):
"""
This function is used to test the representation metric of the model. We use the trained model in two
semi-supervised learning tasks: 1) classification on the test set; 2) kNN classification on the memory set. The
performance of the model in these two tasks is used to evaluate the representation metric of the model.
"""
net.eval()
total_top1, total_num, feature_bank = 0.0, 0, []
with torch.no_grad():
# generate feature bank
feature_bar = tqdm(memory_data_loader, position=0, leave=True, desc='Feature extracting')
for data, target in feature_bar:
feature = encode_image(data, net, inference=inference)
feature_bank.append(feature)
# [D, N]
feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
# [N]
if isinstance(memory_data_loader.dataset.targets, list):
feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
elif isinstance(memory_data_loader.dataset.targets, torch.Tensor):
feature_labels = memory_data_loader.dataset.targets.clone().detach().to(feature_bank.device)
# convert to numpy
feature_bank_np = feature_bank.t().cpu().detach().numpy()
feature_labels_np = feature_labels.cpu().detach().numpy()[:len(feature_bank_np)]
# create linear classifier
classifier = SGDClassifier(loss='perceptron', n_jobs=-1)
print('\nFitting linear classifier')
# fit linear classifier
classifier.fit(feature_bank_np, feature_labels_np)
# create knn classifier
neigh = KNeighborsClassifier(n_neighbors=knn_k, n_jobs=-1)
print('\nFitting KNN classifier')
# fit knn classifier
neigh.fit(feature_bank_np, feature_labels_np)
# loop test data to predict the label by weighted knn search
feature_bank_np_test, feature_labels_np_test = [], []
test_bar = tqdm(test_data_loader, position=0, leave=True)
for data, target in test_bar:
# send to device
target = target.cuda(non_blocking=True)
# get feature
feature = encode_image(data, net, inference=inference)
# append to list
feature_bank_np_test.append(feature.cpu().detach().numpy())
feature_labels_np_test.append(target.cpu().detach().numpy())
# prediction using knn
pred_labels = neigh.predict(feature.cpu().detach().numpy())
total_num += data.size(0)
total_top1 += (pred_labels == target.cpu().numpy()).sum().item()
knn = total_top1 / total_num * 100
test_bar.set_description("KNN classification test Epoch {}: Acc@1:{:.2f}%".format(epoch, knn),
refresh=True)
feature_labels_np_test = np.concatenate(feature_labels_np_test, axis=0)
feature_bank_np_test = np.concatenate(feature_bank_np_test, axis=0)
y_pred = classifier.predict(feature_bank_np_test)
result_linear = np.mean(y_pred == feature_labels_np_test) * 100
return knn, result_linear
class CNNClassifier:
pass
def encode_image(batch, model, inference=False):
"""
This function is used to encode the image into the latent space.
"""
# check if the model's name is CNNClassifier
if model.__class__.__name__ == "CNNClassifier":
feature = model.encoder(batch)
return feature
data = batch.cuda(non_blocking=True)
feature, feature_mu, feature_logvar = model.encoder(data)
if inference:
feature = feature_mu
return feature
def recon_loss(loss):
"""
This function is used to get the reconstruction loss function.
"""
if loss in ["BCE", 'bce']:
loss = nn.BCELoss(reduction='none')
elif loss in ["MSE", 'mse']:
loss = nn.MSELoss(reduction='none')
return loss
def get_dataloaders(train_dataset, memory_data, test_dataset, num_workers=4):
"""
# get the data loaders for the training, validation and test datasets
:param num_workers: number of workers for the data loader
:return: train_loader, validation_loader, test_loader
"""
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=num_workers,
pin_memory=True,
drop_last=True)
validation_loader = torch.utils.data.DataLoader(memory_data, batch_size=512, shuffle=False,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=num_workers,
pin_memory=True,
drop_last=True)
return train_loader, validation_loader, test_loader
def get_datasets(dataset, augment=False):
"""
# get the train, validation and test datasets
:return: train_dataset, memory_data, test_dataset
"""
if dataset == 'EMNIST':
# root folder for the dataset
root = "./data/EMNIST"
# get the train dataset
if augment:
train_dataset = EMNISTPair(root=f"{root}/train", train=True, transform=train_transform, split='balanced',
download=True)
else:
train_dataset = EMNISTPair(root=f"{root}/train", train=True, transform=test_transform, split='balanced',
download=True)
# get the train dataset with the test transform for validation in the semisupevised setting
memory_data = EMNIST(root=f"{root}/train", train=True, transform=test_transform, split='balanced',
download=True)
# get the test dataset
test_dataset = datasets.EMNIST(root=f"{root}/test", train=False, transform=test_transform, split='balanced',
download=True)
elif dataset == 'MNIST':
# root folder for the dataset
root = "./data/MNIST"
# get the train dataset
if augment:
train_dataset = MNISTPair(root=f"{root}/train", train=True, transform=train_transform, download=True)
else:
train_dataset = MNISTPair(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the train dataset with the test transform for validation in the semisupevised setting
memory_data = datasets.MNIST(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the test dataset
test_dataset = datasets.MNIST(root=f"{root}/test", train=False, transform=test_transform, download=True)
elif dataset == 'FashionMNIST':
# root folder for the dataset
root = "./data/FashionMNIST"
# get the train dataset
if augment:
train_dataset = FashionMNISTPair(root=f"{root}/train", train=True, transform=train_transform,
download=True)
else:
train_dataset = FashionMNISTPair(root=f"{root}/train", train=True, transform=test_transform,
download=True)
# get the train dataset with the test transform for validation in the semisupevised setting
memory_data = datasets.FashionMNIST(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the test dataset
test_dataset = datasets.FashionMNIST(root=f"{root}/test", train=False, transform=test_transform,
download=True)
elif dataset == 'CIFAR10':
# root folder for the dataset
root = "./data/CIFAR10"
# get the train dataset
if augment:
train_dataset = CIFAR10Pair(root=f"{root}/train", train=True, transform=train_transform_cifar, download=True)
else:
train_dataset = CIFAR10Pair(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the train dataset with the test transform for validation in the semisupevised setting
memory_data = datasets.CIFAR10(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the test dataset
test_dataset = datasets.CIFAR10(root=f"{root}/test", train=False, transform=test_transform, download=True)
elif dataset == 'CIFAR100':
# root folder for the dataset
root = "./data/CIFAR100"
# get the train dataset
if augment:
train_dataset = CIFAR100Pair(root=f"{root}/train", train=True, transform=train_transform_cifar, download=True)
else:
train_dataset = CIFAR100Pair(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the train dataset with the test transform for validation in the semisupevised setting
memory_data = datasets.CIFAR100(root=f"{root}/train", train=True, transform=test_transform, download=True)
# get the test dataset
test_dataset = datasets.CIFAR100(root=f"{root}/test", train=False, transform=test_transform, download=True)
return train_dataset, memory_data, test_dataset
def init_log_file(log_file, model_name):
"""
This function is used to initialize the log file.
"""
# check if the log file exists
if not os.path.exists("logs"):
os.mkdir("logs")
# check if file with filename exists
if not os.path.exists(log_file):
# create the log file and write the header
with open(log_file, 'w') as f:
log_writer = csv.writer(f)
log_writer.writerow(
['epoch', 'lr', 'train_loss', 'train_reconstruction_loss', 'train_kl_loss',
'train_contrastive_loss', 'KNN_acc',
'linear_acc', 'alpha', 'beta', 'gamma', 'delta', 'mi', 'au', 'eval_loss', 'eval_recon', 'eval_kl'])
else:
# change the name of the file to avoid overwriting by adding a number
i = 1
while os.path.exists(log_file):
log_file = f"logs/{model_name}_logs_{i}.csv"
i += 1
# replace the file with a new one
with open(log_file, 'w') as f:
log_writer = csv.writer(f)
log_writer.writerow(
['epoch', 'lr', 'train_loss', 'train_reconstruction_loss', 'train_kl_loss',
'train_contrastive_loss', 'KNN_acc',
'linear_acc', 'alpha', 'beta', 'gamma', 'delta', 'mi', 'au', 'eval_loss', 'eval_recon', 'eval_kl'])
return log_file