-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathadapters.py
853 lines (673 loc) · 34.9 KB
/
adapters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
import os
import os.path as osp
from re import template
import time
import os.path as osp
import datetime
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
from dassl.engine import TRAINER_REGISTRY, SimpleTrainer
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint, AverageMeter, MetricMeter
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from datasets.imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
_tokenizer = _Tokenizer()
CUSTOM_TEMPLATES = {
"OxfordPets": "a photo of a {}, a type of pet.",
"OxfordFlowers": "a photo of a {}, a type of flower.",
"FGVCAircraft": "a photo of a {}, a type of aircraft.",
"DescribableTextures": "{} texture.",
"EuroSAT": "a centered satellite photo of {}.",
"StanfordCars": "a photo of a {}.",
"Food101": "a photo of {}, a type of food.",
"SUN397": "a photo of a {}.",
"Caltech101": "a photo of a {}.",
"UCF101": "a photo of a person doing {}.",
"ImageNet": "a photo of a {}.",
"ImageNetSketch": "a photo of a {}.",
"ImageNetV2": "a photo of a {}.",
"ImageNetA": "a photo of a {}.",
"ImageNetR": "a photo of a {}.",
}
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict())
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
def _get_base_text_features(cfg, classnames, clip_model, text_encoder, pretrained_projection=None):
device = next(text_encoder.parameters()).device
if clip_model.dtype == torch.float16:
text_encoder = text_encoder.cuda()
if pretrained_projection is not None:
# Load pretrained projection from TaskResidual Work
pretrained_text_projection = torch.load(pretrained_projection)
# Move weight to current CLIP model
state_dict = text_encoder.state_dict()
state_dict['text_projection'] = pretrained_text_projection['state_dict']['weight'].t()
text_encoder.load_state_dict(state_dict)
print(">> Pretrained text encoder loaded!")
params = pretrained_text_projection['state_dict']['weight'].size(0) * \
pretrained_text_projection['state_dict']['weight'].size(1)
print(">> Text projection parameters: ", params)
print(pretrained_text_projection['state_dict'].keys())
dataset = cfg.DATASET.NAME
if dataset == "ImageNet":
TEMPLATES = IMAGENET_TEMPLATES_SELECT
else:
TEMPLATES = []
TEMPLATES += [CUSTOM_TEMPLATES[dataset]]
with torch.no_grad():
text_embeddings = []
for text in classnames:
tokens = clip.tokenize([template.format(text) for template in TEMPLATES]) # tokenized prompts are indices
embeddings = clip_model.token_embedding(tokens).type(clip_model.dtype)
prototype = text_encoder(embeddings.cuda(), tokens.cuda())
text_embeddings.append(prototype)
text_embeddings = torch.stack(text_embeddings)
text_embeddings_avg = text_embeddings.mean(1)
return text_embeddings_avg.to(device), text_embeddings
class AdapterMethod(nn.Module):
def __init__(self, cfg, clip_model, base_text_features):
super().__init__()
self.device = clip_model.dtype
self.logit_scale = clip_model.logit_scale
self.initialization = cfg.TRAINER.ADAPTER.INIT
self.apply_constraint = cfg.TRAINER.ADAPTER.CONSTRAINT
self.distance = "l2"
self.register_buffer("base_text_features", base_text_features)
self.alpha_constraint = torch.zeros((base_text_features.shape[0])).to(self.device)
self.base_text_features = base_text_features
self.augmentations = True # True
self.epochs_aumentation = 20 # 20
if self.initialization == "RANDOM": # Randomly initialized Linear Probing
print("Using RANDOM initialization in Linear Probing", end="\n")
self.prototypes = nn.Parameter(torch.nn.init.kaiming_normal_(torch.empty(base_text_features.shape)))
elif "ZS" in self.initialization: # Linear Probe initialized with zero-shot weights
print("Using Zero-Shot initialization in Linear Probing", end="\n")
self.prototypes = nn.Parameter(base_text_features.clone())
elif "TR" in self.initialization: # Task Residual Adapter form Yu et al. (2023)
print("Using Task_residual approach for Linear Probing", end="\n")
self.init_TR(alpha=0.5)
elif "ClipA" in self.initialization: # CLIP-Adapter form Gao et al. (2023)
self.init_clipA()
elif "TipA" in self.initialization: # TIP-Adapter form Zhang et al. (2022)
self.init_tipA()
elif "CrossModal" in self.initialization: # Cross-Modal Linear Probing form Lin et al. (2023)
print("Using CrossModal for Linear Probing", end="\n")
self.init_MultiModal()
else:
print("Initialization for Linear Probing not implemented")
assert False
if self.apply_constraint != "none":
print("Applying constraint to the logistic regression weights: " + str(self.distance))
def init_MultiModal(self):
print("Using Zero-Shot initialization in Linear Probing", end="\n")
self.prototypes = nn.Parameter(self.base_text_features.clone())
def init_TR(self, alpha=0.5):
print("Using TaskRes-Adapter", end="\n")
self.alpha = alpha
self.grid_search_param = {"lr": [1e-1, 1e-2, 1e-3],
"alpha": list(np.arange(0.2, 1.2, 0.2))}
print("Using Task_residual approach for Linear Probing", end="\n")
self.prototypes = nn.Parameter(torch.zeros_like(self.base_text_features.clone()))
def init_clipA(self, ratio=0.2):
print("Using CLIP-Adapter", end="\n")
self.grid_search_param = {"lr": [1e-1, 1e-2, 1e-3],
"ratio": list(np.arange(0.2, 1, 0.2))}
self.ratio = ratio
self.prototypes = nn.Parameter(self.base_text_features.clone())
self.prototypes.requires_grad = False
self.mlp = nn.Sequential(
nn.Linear(self.base_text_features.shape[-1], self.base_text_features.shape[-1] // 4, bias=False),
nn.ReLU(inplace=True),
nn.Linear(self.base_text_features.shape[-1] // 4, self.base_text_features.shape[-1], bias=False),
nn.ReLU(inplace=True)
).to(self.device)
def init_tipA(self, beta=1, alpha=1):
if "-f-" in self.initialization:
self.grid_search_param = {"lr": [1e-1, 1e-2],
"alpha": list(np.arange(1, 50, 50/10)),
"beta": list(np.arange(1, 28, 28/10))}
else:
self.grid_search_param = {"alpha": list(np.arange(1, 50, 50/20)),
"beta": list(np.arange(1, 28, 28/20))}
print("Using Tip-Adapter")
self.beta = beta
self.alpha = alpha
self.prototypes = nn.Parameter(self.base_text_features.clone())
self.prototypes.requires_grad = False
self.cache_keys = None # Features
self.cache_values = None # labels
def init_tipadapter(self, features_train, labels_train):
self.cache_keys = nn.Parameter(features_train.clone().to(self.device))
self.cache_keys.requires_grad = True
self.cache_values = nn.Parameter(torch.nn.functional.one_hot(labels_train).clone().to(torch.float32).to(self.device))
self.cache_values.requires_grad = False
def zero_shot_constraint(self):
# Compute constraint
if "l2" in self.apply_constraint:
disimilitude = (self.prototypes - self.base_text_features.clone()).pow(2).sum(-1)
elif "cosine" in self.apply_constraint:
disimilitude = (1 - torch.nn.functional.cosine_similarity(self.prototypes, self.base_text_features.clone()))
else:
print("Dissimilitude metric for constraint not implemented")
assert False
return torch.mean(self.alpha_constraint * disimilitude)
def init_lagrangian_multipliers(self, labels_ds, logits_ds):
if "balanced" in self.apply_constraint:
performance = torch.ones(logits_ds.shape[-1]).to(torch.float)
else:
with torch.no_grad():
# Get one-hot encoding ground-truth
labels_one_hot = torch.nn.functional.one_hot(labels_ds).cpu()
# Get zero_shot performance
performance = torch.diag(torch.softmax(logits_ds, -1).t() @ labels_one_hot.to(torch.float32)) /\
labels_one_hot.sum(0)
if "corrected" in self.apply_constraint:
performance *= (logits_ds.shape[-1] / torch.sum(performance).item())
if "constant" in self.apply_constraint:
performance = torch.ones(logits_ds.shape[-1]).to(torch.float) * torch.mean(performance).item()
# set new alphas
self.alpha_constraint = torch.clone(performance).to(self.device)
self.penalty_parameter = torch.zeros_like(self.alpha_constraint).to(self.device)
def outer_step(self):
def phr(h, lambd, rho):
x = lambd + rho * h
y_sup = 1 / (2 * rho) * (x ** 2 - lambd ** 2)
y_inf = - 1 / (2 * rho) * (lambd ** 2)
grad_y_sup = x
grad_y_inf = torch.zeros_like(h)
sup = x >= 0
return (
torch.where(sup, y_sup, y_inf),
torch.where(sup, grad_y_sup, grad_y_inf)
)
print("Outer step on Augmented Lagrangian Multiplier")
# Cmpute current constraints
disimilitude = (self.prototypes - self.base_text_features.clone()).pow(2).sum(-1)
# Compute phr
phr_value, phr_grad = phr(disimilitude, self.alpha_constraint, self.penalty_parameter)
# Update lagrangian multipliers
self.alpha_constraint = phr_grad.detach().clone()
# Update penalty parameters rho
self.penalty_parameter = disimilitude.detach().clone()
print("New lagrangian multipliers:")
print(self.alpha_constraint[0:5].detach().cpu().numpy())
def forward(self):
return self.prototypes
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.image_encoder = clip_model.visual
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype # float16
text_encoder = TextEncoder(clip_model)
# For TaskRes (Yu et al.) enhanced base - or regular CLIP base
if cfg.TRAINER.ADAPTER.ENHANCED_BASE == "none":
print(">> Use regular base!")
base_text_features, text_embeddings_all = _get_base_text_features(cfg, classnames, clip_model, text_encoder)
else:
print(">> Use enhanced base!")
base_text_features, text_embeddings_all = _get_base_text_features(
cfg, classnames, clip_model, text_encoder, cfg.TRAINER.TaskRes.ENHANCED_BASE)
self.text_embeddings_all = text_embeddings_all
self.adapter = AdapterMethod(cfg, clip_model, base_text_features)
def forward(self, image, return_features=False):
try:
image_features = self.image_encoder(image.type(self.dtype))
except:
image_features = self.image_encoder(image.float())
if "TR" in self.adapter.initialization:
logits = self.forward_task_residual(image_features)
elif "ClipA" in self.adapter.initialization:
logits = self.forward_clipadapter(image_features)
elif "TipA" in self.adapter.initialization:
logits = self.forward_tipadapter(image_features)
else:
logits = self.forward_lp(image_features)
if return_features:
return logits, image_features
else:
return logits
def forward_features(self, features):
if "TR" in self.adapter.initialization:
logits = self.forward_task_residual(features)
elif "ClipA" in self.adapter.initialization:
logits = self.forward_clipadapter(features)
elif "TipA" in self.adapter.initialization:
logits = self.forward_tipadapter(features)
else:
logits = self.forward_lp(features)
return logits
def forward_lp(self, features):
# Get trained prototype
prototypes = self.adapter()
image_features_norm = features / features.norm(dim=-1, keepdim=True)
prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits = image_features_norm @ prototypes_norm.t() * logit_scale
return logits
def forward_task_residual(self, features):
# Get trained prototype
prototypes = self.adapter()
# Sum residual features to base zero-shot prototypes
prototypes = self.adapter.base_text_features + self.adapter.alpha * prototypes
image_features_norm = features / features.norm(dim=-1, keepdim=True)
prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits = image_features_norm @ prototypes_norm.t() * logit_scale
return logits
def forward_clipadapter(self, features):
# Get zero-shot weights
prototypes = self.adapter()
# Produce residual features on vision features
x = self.adapter.mlp(features)
features = self.adapter.ratio * x + (1 - self.adapter.ratio) * features
# Normalize features
image_features_norm = features / features.norm(dim=-1, keepdim=True)
prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)
# Obtain logits
logit_scale = self.logit_scale.exp()
logits = image_features_norm @ prototypes_norm.t() * logit_scale
return logits
def forward_tipadapter(self, features):
# Get zero-shot weights
prototypes = self.adapter()
# Normalize features
image_features_norm = features / features.norm(dim=-1, keepdim=True)
prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)
# Obtain zero-shot logits
logit_scale = self.logit_scale.exp()
logits = image_features_norm @ prototypes_norm.t() * logit_scale
if self.adapter.cache_keys is not None:
# normalize cache keys
cache_keys = self.adapter.cache_keys / self.adapter.cache_keys.norm(dim=-1, keepdim=True)
# Get affinity betwen train features and test
affinity = features @ cache_keys.t().cuda().to(torch.float)
cache_logits = torch.exp(((-1) * (self.adapter.beta - self.adapter.beta * affinity))) @ self.adapter.cache_values.cuda().to(torch.float)
logits += self.adapter.alpha * cache_logits
return logits
class TrainerXCostume(SimpleTrainer):
"""A base trainer using labeled data only."""
def run_epoch(self):
# Eval mode - not updating batchnorm statistics
self.set_model_mode("eval")
# Init kpis tracker
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
# Set number of batches to sample
self.num_batches = len(self.train_loader_x)
self.batch_size = self.train_loader_x.batch_size
# Set features
features = self.features_train.clone().cpu().numpy()
labels = self.labels_train.clone()
# Sample half dataset - to tackle previous oversample with text prompts
if "CrossModal" in self.model.adapter.initialization:
idx = np.random.choice(list(np.arange(0, features.shape[0])), features.shape[0] // 2)
features = features[idx, :]
labels = labels[idx]
# Randomly shuffle
idx = np.random.rand(features.shape[0]).argsort(axis=0)
features = features[idx, :]
labels = labels[idx]
end = time.time()
for self.batch_idx in range(self.num_batches):
batch_init = self.batch_idx * self.batch_size
batch_end = (self.batch_idx + 1) * self.batch_size
data_time.update(time.time() - end)
loss_summary = self.forward_backward(features[batch_init:batch_end],
labels[batch_init:batch_end])
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"{losses}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
return loss_summary
@TRAINER_REGISTRY.register()
class ADAPTER(TrainerXCostume):
"""General Adapter
"""
def check_cfg(self, cfg):
assert cfg.TRAINER.ADAPTER.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.ADAPTER.PREC == "fp32" or cfg.TRAINER.ADAPTER.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
for name, param in self.model.named_parameters():
if "adapter" not in name:
param.requires_grad_(False)
else:
print(name)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model.adapter, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
self.model = self.model.float()
# NOTE: only give adapter to the optimizer
self.optim = build_optimizer(self.model.adapter, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("adapter", self.model.adapter, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.ADAPTER.PREC == "amp" else None
def train(self):
self.set_model_mode("eval")
# Feature extraction on test set
self.labels_test, output_test, self.features_test = self.extract_features(partition="test")
print("Zero-Shot accuracy on test: " +
str(round(compute_accuracy(output_test.cuda(), self.labels_test.cuda())[0].item(), 2)))
# Feature extraction on training set
self.labels_train, self.logits_zs, self.features_train = self.extract_features(
partition="train", reps=self.model.adapter.epochs_aumentation, transforms=self.model.adapter.augmentations)
if "CrossModal" in self.model.adapter.initialization:
print("Preparing cross-modal dataset... resampling text prompts")
# Cross-Modal: Add zero-shot prototypes as samples
zs_prototypes = self.model.text_embeddings_all.cpu().numpy()
zs_labels = np.repeat(np.expand_dims(np.arange(0, zs_prototypes.shape[0]), (0)), zs_prototypes.shape[1], 0)
zs_prototypes = np.reshape(np.transpose(zs_prototypes, (2, 1, 0)),
(zs_prototypes.shape[-1], zs_prototypes.shape[0]*zs_prototypes.shape[1])).transpose()
zs_labels = np.transpose(zs_labels, (1, 0)).flatten()
# Resample for a balanced dataset between modalities
idx = np.random.choice(list(np.arange(0, len(zs_labels))), self.features_train.shape[0])
zs_labels = zs_labels[idx]
zs_prototypes = zs_prototypes[idx, :]
self.features_train = torch.cat([self.features_train, torch.tensor(zs_prototypes)], dim=0)
self.labels_train = torch.cat([self.labels_train, torch.tensor(zs_labels).cuda()])
# Init alphas in constraint formulation
if self.model.adapter.apply_constraint != "none":
print("Getting initial lagrangian multipliers for constraint formulation", end="\n")
self.model.adapter.device = self.device
self.model.adapter.init_lagrangian_multipliers(self.labels_train, self.logits_zs)
print("Lagrangian multipliers: ")
print(list(torch.round(self.model.adapter.alpha_constraint.detach(), decimals=3).cpu().numpy()))
# In the case of tip-adapter, register cache features
if "TipA" in self.model.adapter.initialization:
# Given the new key features, register again the weights to optimizer
self.model.adapter.init_tipadapter(self.features_train, self.labels_train)
self.optim = build_optimizer(self.model.adapter, self.cfg.OPTIM) # Update optimizer with new params
self.sched = build_lr_scheduler(self.optim, self.cfg.OPTIM)
self.register_model("adapter_tipa-f-", self.model.adapter, self.optim, self.sched)
if "-f-" not in self.model.adapter.initialization:
self.max_epoch = 1 # Not training, just one forward round for predicting test performance
summary_grid = []
if "grid_search" in self.model.adapter.initialization:
from sklearn.model_selection import ParameterGrid
import pandas as pd
best_acc = 0.0
best_setting = []
grid = ParameterGrid(self.model.adapter.grid_search_param)
for params in grid:
print("Iteration grid hyperparameters search: ")
print(params)
self.reset_hyperparams(params)
# Training of adapter
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# Train and update weights per epoch
self.before_epoch()
loss_summary = self.run_epoch()
if loss_summary["acc_test"] > best_acc:
best_acc = loss_summary["acc_test"]
best_setting = params
self.epoch = -1 # To avoid saving weights
self.after_epoch()
params["acc_test"] = loss_summary["acc_test"]
summary_grid.append(params)
# Print current configuration performance
print("Current configuration: ")
print(params)
print("A on test:")
print(loss_summary["acc_test"])
# Print best configuration performance:
print("Best configuration: ")
print(best_setting)
print("Best accuracy on test:")
print(best_acc)
df = pd.DataFrame(summary_grid)
df.to_csv(self.cfg.OUTPUT_DIR + "/grid_search.csv")
else:
# Training of adapter
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# Train and update weights per epoch
self.before_epoch()
self.run_epoch()
# Update lagrangian parameter and multiplier
if "adaptative" in self.model.adapter.apply_constraint:
self.model.adapter.outer_step()
self.after_epoch()
self.after_train()
def reset_hyperparams(self, params):
import random
if "ClipA" in self.model.adapter.initialization:
self.model.adapter.init_clipA(ratio=params["ratio"])
if "TipA" in self.model.adapter.initialization:
self.model.adapter.init_tipA(alpha=params["alpha"], beta=params["beta"])
if "TR" in self.model.adapter.initialization:
self.model.adapter.init_TR(alpha=params["alpha"])
# In the case of tip-adapter, register cache features
if "TipA" in self.model.adapter.initialization:
# Given the new key features, register again the weights to optimizer
self.model.adapter.init_tipadapter(self.features_train, self.labels_train)
if "-f-" in self.model.adapter.initialization:
# Put epochs as in the original paper
self.max_epoch = 20
self.model.to(self.device)
self.model = self.model.float()
if "lr" in list(params.keys()):
self.cfg.OPTIM["LR"] = params["lr"]
self.optim = build_optimizer(self.model.adapter, self.cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, self.cfg.OPTIM)
self._models.popitem(), self._optims.popitem(),self._scheds.popitem()
self.register_model("adapter" + str(random.random()), self.model.adapter, self.optim, self.sched)
return 1
def after_train(self):
print("Finish training")
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print(f"Elapsed: {elapsed}")
# Close writer
self.close_writer()
def forward_backward(self, features, labels):
prec = self.cfg.TRAINER.ADAPTER.PREC
if prec == "amp":
with autocast():
# Cross-entropy loss
output = self.model.forward_features(torch.tensor(features).to(self.device))
# Softmax cross-entropy
loss_ce = F.cross_entropy(output, labels)
# Constraint to zero-shot (CLAP)
if self.model.adapter.apply_constraint != "none":
loss_constraint = self.model.adapter.zero_shot_constraint()
loss = loss_ce + loss_constraint
else:
loss = loss_ce
self.optim.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optim)
self.scaler.update()
else:
# Cross-entropy loss
output = self.model.forward_features(torch.tensor(features).to(self.device))
# Softmax cross-entropy
loss_ce = F.cross_entropy(output, labels)
# Constraint to zero-shot (CLAP)
if self.model.adapter.apply_constraint != "none":
loss_constraint = self.model.adapter.zero_shot_constraint()
loss = loss_ce + loss_constraint
else:
loss = loss_ce
self.model_backward_and_update(loss)
with torch.no_grad():
output_test = self.model.forward_features(self.features_test.clone().detach().to(self.device))
loss_summary = {
"loss": loss.item(),
"acc_train": compute_accuracy(output, labels)[0].item(),
"acc_test": compute_accuracy(output_test, self.labels_test)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
torch.cuda.empty_cache()
return loss_summary
def load_model(self, directory, cfg, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
else:
print("Pretrained model given")
if self.model.adapter.initialization == "TipA":
epoch = 1
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
else:
print('Model found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
if "TipA" in self.model.adapter.initialization:
self.model.adapter.cache_keys = nn.Parameter(state_dict['cache_keys'].clone())
self.model.adapter.cache_values = nn.Parameter(state_dict['cache_values'].clone())
if self.cfg.DATASET.NAME == 'ImageNetA' or self.cfg.DATASET.NAME == 'ImageNetR':
if self.cfg.DATASET.NAME == 'ImageNetA':
from datasets.imagenet_a_r_indexes_v2 import find_imagenet_a_indexes as find_indexes
else:
from datasets.imagenet_a_r_indexes_v2 import find_imagenet_r_indexes as find_indexes
imageneta_indexes = find_indexes()
print("Parameters found: ")
print(state_dict.keys())
state_dict['base_text_features'] = state_dict['base_text_features'][imageneta_indexes]
state_dict['prototypes'] = state_dict['prototypes'][imageneta_indexes]
if "TipA" in self.model.adapter.initialization:
state_dict['cache_values'] = state_dict['cache_values'][:, imageneta_indexes]
self.model.adapter.cache_keys = nn.Parameter(state_dict['cache_keys'].clone())
self.model.adapter.cache_values = nn.Parameter(state_dict['cache_values'].clone())
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "token_prefix" in state_dict:
del state_dict["token_prefix"]
if "token_suffix" in state_dict:
del state_dict["token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
self.model.float()
def extract_features(self, partition, reps=1, transforms=False):
print("Extracting features from: " + partition, end="\n")
self.set_model_mode("eval")
if partition == "train":
# Copy safe version of training dataloader
data_loader = copy.deepcopy(self.train_loader_x)
# Set data augmentation transforms
if not transforms:
data_loader.dataset.transform = self.val_loader.dataset.transform
# Set data loader with drop last to false for not losing samples
data_loader = torch.utils.data.DataLoader(
copy.deepcopy(self.train_loader_x.dataset), batch_size=self.train_loader_x.batch_size,
sampler=self.train_loader_x.sampler, num_workers=self.train_loader_x.num_workers,
drop_last=False, pin_memory=self.train_loader_x.pin_memory)
elif partition == "val":
data_loader = copy.deepcopy(self.val_loader)
elif partition == "test":
data_loader = copy.deepcopy(self.test_loader)
else:
assert False
if "TipA" not in self.model.adapter.initialization:
labels_ds, logits_ds, features_ds = [], [], []
for rep in range(reps):
for batch_idx, batch in enumerate(tqdm(data_loader)):
with torch.no_grad():
input, label = self.parse_batch_test(batch)
logits, features = self.model(input, return_features=True)
labels_ds.append(label), logits_ds.append(logits.cpu()), features_ds.append(features.cpu())
# Concatenate outputs
labels_ds = torch.cat(labels_ds, dim=0)
logits_ds = torch.cat(logits_ds, dim=0)
features_ds = torch.cat(features_ds, dim=0)
else:
labels_ds, logits_ds, features_ds = [], [], []
for rep in range(reps):
labels_ds_irep, logits_dsirep, features_ds_irep = [], [], []
for batch_idx, batch in enumerate(tqdm(data_loader)):
with torch.no_grad():
input, label = self.parse_batch_test(batch)
logits, features = self.model(input, return_features=True)
labels_ds_irep.append(label), logits_dsirep.append(logits.cpu()), features_ds_irep.append(features.cpu())
# Concatenate outputs for dataset
labels_ds_irep = torch.cat(labels_ds_irep, dim=0)
logits_dsirep = torch.cat(logits_dsirep, dim=0)
features_ds_irep = torch.cat(features_ds_irep, dim=0)
# Concatenate outputs for repetitions
labels_ds.append(labels_ds_irep.unsqueeze(0))
logits_ds.append(logits_dsirep.unsqueeze(0))
features_ds.append(features_ds_irep.unsqueeze(0))
# Concatenate outputs
labels_ds = torch.cat(labels_ds, dim=0)[0, :]
logits_ds = torch.cat(logits_ds, dim=0).mean(0)
features_ds = torch.cat(features_ds, dim=0).mean(0)
return labels_ds, logits_ds, features_ds