-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmoco.py
95 lines (73 loc) · 3.6 KB
/
moco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import copy
import math
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import vit
from utils import CosineAnnealingWithWarmupLR
from utils import concat_all_gather
class MoCo(pl.LightningModule):
def __init__(self, encoder='vit_small', out_dim=256, mlp_dim=4096, tau=0.2, mu=0.99, lr=1.5e-4, weight_decay=0.1, warmup_steps=1, max_steps=10):
super(MoCo, self).__init__()
self.save_hyperparameters()
# build backbone
encoder = vit.__dict__[encoder]()
hidden_dim = encoder.head.weight.shape[1]
# build modules
self.encoder = copy.deepcopy(encoder)
self.encoder.head = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim, bias=False), nn.BatchNorm1d(mlp_dim), nn.ReLU(inplace=True),
nn.Linear(mlp_dim, mlp_dim, bias=False), nn.BatchNorm1d(mlp_dim), nn.ReLU(inplace=True),
nn.Linear(mlp_dim, out_dim, bias=False), nn.BatchNorm1d(out_dim, affine=False)
)
self.momentum_encoder = copy.deepcopy(self.encoder)
self.predictor = nn.Sequential(
nn.Linear(out_dim, mlp_dim, bias=False), nn.BatchNorm1d(mlp_dim), nn.ReLU(inplace=True),
nn.Linear(mlp_dim, out_dim, bias=False), nn.BatchNorm1d(out_dim, affine=False)
)
# stop gradient in momentum encoder
for param in self.momentum_encoder.parameters():
param.requires_grad = False
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
scheduler = CosineAnnealingWithWarmupLR(optimizer, self.hparams.warmup_steps, self.hparams.max_steps)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'step'
}
}
def forward(self, x):
return self.momentum_encoder(x)
def contrastive_loss(self, q, k):
q = nn.functional.normalize(q, dim=1)
k = nn.functional.normalize(k, dim=1)
k = concat_all_gather(k)
N = q.shape[0] # batch size per GPU
logits = q @ k.T
labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).to(device=self.device)
loss = nn.functional.cross_entropy(logits / self.hparams.tau, labels)
return 2 * self.hparams.tau * loss
@torch.no_grad()
def _update_momentum_encoder(self, batch_idx):
# Update mu with a cosine schedule
current_step = self.current_epoch * self.trainer.num_training_batches + batch_idx
mu = (1 - (1 + math.cos(math.pi * current_step / self.hparams.max_steps)) / 2) * (1-self.hparams.mu) + self.hparams.mu
# Update momentum encoder's parameters
for param, param_m in zip(self.encoder.parameters(), self.momentum_encoder.parameters()):
param_m.data = param_m.data * mu + param.data * (1. - mu)
def training_step(self, batch, batch_idx):
(x1, x2), _ = batch
self._update_momentum_encoder(batch_idx)
# encoder forward pass
q1 = self.predictor(self.encoder(x1))
q2 = self.predictor(self.encoder(x2))
# momentum encoder forward pass
k1 = self.momentum_encoder(x1)
k2 = self.momentum_encoder(x2)
# calculate MoCo contrastive loss
loss = self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)
self.log('MoCo-v3 loss', loss)
return loss