-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdsdh.py
179 lines (149 loc) · 5.12 KB
/
dsdh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
import torch.optim as optim
import os
import time
from torch.optim.lr_scheduler import CosineAnnealingLR
from models.model_loader import load_model
from loguru import logger
from models.dsdh_loss import DSDHLoss
from utils.evaluate import mean_average_precision
def train(
train_dataloader,
query_dataloader,
retrieval_dataloader,
arch,
code_length,
device,
lr,
max_iter,
mu,
nu,
eta,
topk,
evaluate_interval,
):
"""
Training model.
Args
train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader.
arch(str): CNN model name.
code_length(int): Hash code length.
device(torch.device): GPU or CPU.
lr(float): Learning rate.
max_iter: int
Maximum iteration
mu, nu, eta(float): Hyper-parameters.
topk(int): Compute mAP using top k retrieval result
evaluate_interval(int): Evaluation interval.
Returns
checkpoint(dict): Checkpoint.
"""
# Construct network, optimizer, loss
model = load_model(arch, code_length).to(device)
criterion = DSDHLoss(eta)
optimizer = optim.RMSprop(
model.parameters(),
lr=lr,
weight_decay=1e-5,
)
scheduler = CosineAnnealingLR(optimizer, max_iter, 1e-7)
# Initialize
N = len(train_dataloader.dataset)
B = torch.randn(code_length, N).sign().to(device)
U = torch.zeros(code_length, N).to(device)
train_targets = train_dataloader.dataset.get_onehot_targets().to(device)
S = (train_targets @ train_targets.t() > 0).float()
Y = train_targets.t()
best_map = 0.
iter_time = time.time()
for it in range(max_iter):
model.train()
# CNN-step
for data, targets, index in train_dataloader:
data, targets = data.to(device), targets.to(device)
optimizer.zero_grad()
U_batch = model(data).t()
U[:, index] = U_batch.data
loss = criterion(U_batch, U, S[:, index], B[:, index])
loss.backward()
optimizer.step()
scheduler.step()
# W-step
W = torch.inverse(B @ B.t() + nu / mu * torch.eye(code_length, device=device)) @ B @ Y.t()
# B-step
B = solve_dcc(W, Y, U, B, eta, mu)
# Evaluate
if it % evaluate_interval == evaluate_interval - 1:
iter_time = time.time() - iter_time
epoch_loss = calc_loss(U, S, Y, W, B, mu, nu, eta)
# Generate hash code
query_code = generate_code(model, query_dataloader, code_length, device)
retrieval_code = generate_code(model, retrieval_dataloader, code_length, device)
query_targets = query_dataloader.dataset.get_onehot_targets()
retrieval_targets = retrieval_dataloader.dataset.get_onehot_targets()
# Compute map
mAP = mean_average_precision(
query_code.to(device),
retrieval_code.to(device),
query_targets.to(device),
retrieval_targets.to(device),
device,
topk,
)
logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}]'.format(it+1, max_iter, epoch_loss, mAP, iter_time))
# Save checkpoint
if best_map < mAP:
best_map = mAP
checkpoint = {
'qB': query_code,
'qL': query_targets,
'rB': retrieval_code,
'rL': retrieval_targets,
'model': model.state_dict(),
'map': best_map,
}
iter_time = time.time()
return checkpoint
def solve_dcc(W, Y, U, B, eta, mu):
"""
DCC.
"""
for i in range(B.shape[0]):
P = W @ Y + eta / mu * U
p = P[i, :]
w = W[i, :]
W_prime = torch.cat((W[:i, :], W[i+1:, :]))
B_prime = torch.cat((B[:i, :], B[i+1:, :]))
B[i, :] = (p - B_prime.t() @ W_prime @ w).sign()
return B
def calc_loss(U, S, Y, W, B, mu, nu, eta):
"""
Compute loss.
"""
theta = torch.clamp(U.t() @ U / 2, min=-100, max=50)
metric_loss = (torch.log(1 + torch.exp(theta)) - S * theta).sum()
classify_loss = ((Y - W.t() @ B) ** 2).sum()
regular_loss = (W ** 2).sum()
quantization_loss = ((B - U) ** 2).sum()
loss = (metric_loss + mu * classify_loss + nu * regular_loss + eta * quantization_loss) / S.shape[0]
return loss.item()
def generate_code(model, dataloader, code_length, device):
"""
Generate hash code
Args
dataloader(torch.utils.data.DataLoader): Data loader.
code_length(int): Hash code length.
device(torch.device): Using gpu or cpu.
Returns
code(torch.Tensor, n*code_length): Hash code.
"""
model.eval()
with torch.no_grad():
N = len(dataloader.dataset)
code = torch.zeros([N, code_length])
for data, _, index in dataloader:
data = data.to(device)
hash_code = model(data)
code[index, :] = hash_code.sign().cpu()
model.train()
return code