-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathcommon.py
executable file
·146 lines (118 loc) · 5.01 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import torch
import json
import glob
import collections
import random
import numpy as np
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader, Sampler
class SubsetSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (i for i in self.indices)
def __len__(self):
return len(self.indices)
class ImageFolderWithPaths(datasets.ImageFolder):
def __init__(self, path, transform, flip_label_prob=0.0):
super().__init__(path, transform)
self.flip_label_prob = flip_label_prob
if self.flip_label_prob > 0:
print(f'Flipping labels with probability {self.flip_label_prob}')
num_classes = len(self.classes)
for i in range(len(self.samples)):
if random.random() < self.flip_label_prob:
new_label = random.randint(0, num_classes-1)
self.samples[i] = (
self.samples[i][0],
new_label
)
def __getitem__(self, index):
image, label = super(ImageFolderWithPaths, self).__getitem__(index)
return {
'images': image,
'labels': label,
'image_paths': self.samples[index][0]
}
def maybe_dictionarize(batch):
if isinstance(batch, dict):
return batch
if len(batch) == 2:
batch = {'images': batch[0], 'labels': batch[1]}
elif len(batch) == 3:
batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
else:
raise ValueError(f'Unexpected number of elements: {len(batch)}')
return batch
def get_features_helper(image_encoder, dataloader, device, noscale):
all_data = collections.defaultdict(list)
image_encoder = image_encoder.to(device)
image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
image_encoder.eval()
with torch.no_grad():
for batch in tqdm(dataloader):
batch = maybe_dictionarize(batch)
inputs = batch['images'].cuda()
image_encoder = image_encoder.to(inputs.device)
features = image_encoder(inputs)
# if noscale:
# features = features / features.norm(dim=-1, keepdim=True)
# else:
# logit_scale = image_encoder.module.model.logit_scale
# features = logit_scale.exp() * features
all_data['features'].append(features.cpu())
for key, val in batch.items():
if key == 'images':
continue
if hasattr(val, 'cpu'):
val = val.cpu()
all_data[key].append(val)
else:
all_data[key].extend(val)
for key, val in all_data.items():
if torch.is_tensor(val[0]):
all_data[key] = torch.cat(val).numpy()
return all_data
def get_features(is_train, image_encoder, dataset, device, cache_dir, noscale):
split = 'train' if is_train else 'val'
dname = type(dataset).__name__
# import pdb;pdb.set_trace()
if cache_dir is not None:
cache_dir = f'{cache_dir}/{dname}/{split}'
cached_files = glob.glob(f'{cache_dir}/*')
if cache_dir is not None and len(cached_files) > 0:
print(f'Getting features from {cache_dir}')
data = {}
for cached_file in cached_files:
name = os.path.splitext(os.path.basename(cached_file))[0]
data[name] = torch.load(cached_file)
else:
print(f'Did not find cached features at {cache_dir}. Building from scratch.')
loader = dataset.train_loader if is_train else dataset.test_loader
data = get_features_helper(image_encoder, loader, device, noscale)
if cache_dir is None:
print('Not caching because no cache directory was passed.')
else:
os.makedirs(cache_dir, exist_ok=True)
print(f'Caching data at {cache_dir}')
for name, val in data.items():
torch.save(val, f'{cache_dir}/{name}.pt')
return data
class FeatureDataset(Dataset):
def __init__(self, is_train, image_encoder, dataset, device, cache_dir=None, noscale=True):
self.data = get_features(is_train, image_encoder, dataset, device, cache_dir, noscale)
def __len__(self):
return len(self.data['features'])
def __getitem__(self, idx):
data = {k: v[idx] for k, v in self.data.items()}
data['features'] = torch.from_numpy(data['features']).float()
return data
def get_dataloader(dataset, is_train, args, image_encoder=None):
if image_encoder is not None:
feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device, args.cache_dir, args.noscale)
dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
else:
dataloader = dataset.train_loader if is_train else dataset.test_loader
return dataloader