-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate_lfw.py
348 lines (250 loc) · 15.4 KB
/
evaluate_lfw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import os
from itertools import product
import matplotlib.pyplot as plt
from typing import Any
import math
import hydra
import torch
import torchvision
from pytorch_lightning.lite import LightningLite
from PIL import Image
from omegaconf import OmegaConf, DictConfig
import numpy as np
from utils.helpers import ensure_path_join, normalize_to_neg_one_to_one
from pyeer.eer_info import get_eer_stats
from pyeer.report import generate_eer_report
import sys
from utils.iresnet import iresnet100
sys.path.insert(0, 'IDiff-Face/')
class EvaluatorLite(LightningLite):
def run(self, cfg) -> Any:
self.seed_everything(cfg.evaluation.seed)
# load pre-defined lfw comparison pairs and prepare them
with open(cfg.evaluation.lfw_pairs_path) as f:
lfw_pairs = [line.rstrip('\n').split('\t') for line in f][1:]
def translate_lfw_fnames(id_name, i):
return f"{id_name}_{i.zfill(4)}"
lfw_pairs_folds = [lfw_pairs[i*600:(i+1)*600] for i in range(10)]
for fold_idx in range(10):
for j, genuine_pair in enumerate(lfw_pairs_folds[fold_idx][:300]):
id_name, i1, i2 = genuine_pair
lfw_pairs_folds[fold_idx][j] = (translate_lfw_fnames(id_name, i1), translate_lfw_fnames(id_name, i2))
for j, imposter_pair in enumerate(lfw_pairs_folds[fold_idx][300:]):
j = j + 300
id_name1, i1, id_name2, i2 = imposter_pair
lfw_pairs_folds[fold_idx][j] = (translate_lfw_fnames(id_name1, i1), translate_lfw_fnames(id_name2, i2))
all_lfw_genuine_pairs = []
all_lfw_imposter_pairs = []
for fold_idx in range(10):
all_lfw_genuine_pairs.extend(lfw_pairs_folds[fold_idx][:300])
all_lfw_imposter_pairs.extend(lfw_pairs_folds[fold_idx][300:])
# sanity checks
for genuine_pair in all_lfw_genuine_pairs:
assert genuine_pair[0].split("_")[0] == genuine_pair[1].split("_")[0]
for model_name in cfg.evaluation.model_names:
for frm_name in cfg.evaluation.frm_names:
if cfg.evaluation.aligned:
eval_dir = ensure_path_join("evaluation", "lfw_aligned", model_name, frm_name)
variation_preencoded_data_dir = os.path.join("samples", "aligned", "embeddings", model_name,
cfg.evaluation.variation_contexts_name, frm_name)
else:
eval_dir = ensure_path_join("evaluation", "lfw", model_name, frm_name)
variation_preencoded_data_dir = os.path.join("samples", "embeddings", model_name,
cfg.evaluation.variation_contexts_name, frm_name)
variation_embeddings = torch.load(os.path.join(variation_preencoded_data_dir, "embeddings.npy"))
variation_labels = torch.load(os.path.join(variation_preencoded_data_dir, "labels.npy"))
# load real data embeddings
real_contexts = cfg.evaluation.real_contexts.get(frm_name)
print("Real Contexts:", real_contexts)
real_embeddings_dict = torch.load(real_contexts.real_contexts_path if not cfg.evaluation.aligned else real_contexts.real_contexts_aligned_path)
real_labels = list(real_embeddings_dict.keys())
real_embeddings = [real_embeddings_dict[label] for label in real_labels]
variation_embeddings_dict = {label: emb for label, emb in zip(variation_labels, variation_embeddings)}
variation_labels = np.array(variation_labels)
real_labels = np.array(real_labels)
if cfg.evaluation.real_vs_real_comparison:
print("Starting REAL vs. REAL comparison ...")
genuine_scores, imposter_scores = [], []
for genuine_pair in all_lfw_genuine_pairs:
id_name1, id_name2 = genuine_pair
e1, e2 = real_embeddings_dict[id_name1], real_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
if cos_sim < 0.2:
print(genuine_pair, cos_sim)
genuine_scores.append(cos_sim)
for imposter_pair in all_lfw_imposter_pairs:
id_name1, id_name2 = imposter_pair
e1, e2 = real_embeddings_dict[id_name1], real_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
imposter_scores.append(cos_sim)
plt.clf()
plt.hist(genuine_scores, bins=np.arange(-1, 1, 0.1), label="genuine", color="green", alpha=0.5)
plt.hist(imposter_scores, bins=np.arange(-1, 1, 0.1), label="imposter", color="red", alpha=0.5)
plt.xlim(-1, 1)
plt.legend()
plt.savefig(ensure_path_join(eval_dir, "real_vs_real_distributions.png"), dpi=512)
genuine_file_path = os.path.join(eval_dir, f"real_vs_real_genuine_scores.txt")
imposter_file_path = os.path.join(eval_dir, f"real_vs_real_imposter_scores.txt")
with open(genuine_file_path, "w") as f:
for score in genuine_scores:
f.write(f"{score}\n")
with open(imposter_file_path, "w") as f:
for score in imposter_scores:
f.write(f"{score}\n")
del genuine_scores
del imposter_scores
if cfg.evaluation.variation_vs_variation_comparison:
print("Starting VARIATION vs. VARIATION comparison ...")
genuine_scores, imposter_scores = [], []
for genuine_pair in all_lfw_genuine_pairs:
id_name1, id_name2 = genuine_pair
if id_name1 not in variation_embeddings_dict or id_name2 not in variation_embeddings_dict:
print(f"Skipping {genuine_pair}")
continue
e1, e2 = variation_embeddings_dict[id_name1], variation_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
genuine_scores.append(cos_sim)
for imposter_pair in all_lfw_imposter_pairs:
id_name1, id_name2 = imposter_pair
if id_name1 not in variation_embeddings_dict or id_name2 not in variation_embeddings_dict:
print(f"Skipping {imposter_pair}")
continue
e1, e2 = variation_embeddings_dict[id_name1], variation_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
imposter_scores.append(cos_sim)
plt.clf()
plt.hist(genuine_scores, bins=np.arange(-1, 1, 0.1), label="genuine", color="green", alpha=0.5)
plt.hist(imposter_scores, bins=np.arange(-1, 1, 0.1), label="imposter", color="red", alpha=0.5)
plt.xlim(-1, 1)
plt.legend()
plt.savefig(ensure_path_join(eval_dir, "variation_vs_variation_distributions.png"), dpi=512)
genuine_file_path = os.path.join(eval_dir, f"variation_vs_variation_genuine_scores.txt")
imposter_file_path = os.path.join(eval_dir, f"variation_vs_variation_imposter_scores.txt")
with open(genuine_file_path, "w") as f:
for score in genuine_scores:
f.write(f"{score}\n")
with open(imposter_file_path, "w") as f:
for score in imposter_scores:
f.write(f"{score}\n")
del genuine_scores
del imposter_scores
if cfg.evaluation.real_vs_variation_comparison:
print("Starting REAL vs. VARIATION comparisons ...")
genuine_scores, imposter_scores = [], []
for genuine_pair in all_lfw_genuine_pairs:
id_name1, id_name2 = genuine_pair
if id_name2 not in variation_embeddings_dict:
print(f"Skipping {genuine_pair}")
continue
e1, e2 = real_embeddings_dict[id_name1], variation_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
genuine_scores.append(cos_sim)
for imposter_pair in all_lfw_imposter_pairs:
id_name1, id_name2 = imposter_pair
if id_name2 not in variation_embeddings_dict:
print(f"Skipping {imposter_pair}")
continue
e1, e2 = real_embeddings_dict[id_name1], variation_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
imposter_scores.append(cos_sim)
plt.clf()
plt.hist(genuine_scores, bins=np.arange(-1, 1, 0.1), label="genuine", color="green", alpha=0.5)
plt.hist(imposter_scores, bins=np.arange(-1, 1, 0.1), label="imposter", color="red", alpha=0.5)
plt.xlim(-1, 1)
plt.legend()
plt.savefig(ensure_path_join(eval_dir, "real_vs_variation_distributions.png"), dpi=512)
genuine_file_path = os.path.join(eval_dir, f"real_vs_variation_genuine_scores.txt")
imposter_file_path = os.path.join(eval_dir, f"real_vs_variation_imposter_scores.txt")
with open(genuine_file_path, "w") as f:
for score in genuine_scores:
f.write(f"{score}\n")
with open(imposter_file_path, "w") as f:
for score in imposter_scores:
f.write(f"{score}\n")
real_vs_variation_stats = get_eer_stats(genuine_scores, imposter_scores)
generate_eer_report([real_vs_variation_stats], ["real_vs_variation"], "pyeer_report.html")
del genuine_scores
del imposter_scores
if cfg.evaluation.variation_vs_real_comparison:
print("Starting VARIATION vs. REAL comparisons ...")
genuine_scores, imposter_scores = [], []
for genuine_pair in all_lfw_genuine_pairs:
id_name1, id_name2 = genuine_pair
if id_name1 not in variation_embeddings_dict:
print(f"Skipping {genuine_pair}")
continue
e1, e2 = variation_embeddings_dict[id_name1], real_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
genuine_scores.append(cos_sim)
for imposter_pair in all_lfw_imposter_pairs:
id_name1, id_name2 = imposter_pair
if id_name1 not in variation_embeddings_dict:
print(f"Skipping {imposter_pair}")
continue
e1, e2 = variation_embeddings_dict[id_name1], real_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
imposter_scores.append(cos_sim)
plt.clf()
plt.hist(genuine_scores, bins=np.arange(-1, 1, 0.1), label="genuine", color="green", alpha=0.5)
plt.hist(imposter_scores, bins=np.arange(-1, 1, 0.1), label="imposter", color="red", alpha=0.5)
plt.xlim(-1, 1)
plt.legend()
plt.savefig(ensure_path_join(eval_dir, "variation_vs_real_distributions.png"), dpi=512)
genuine_file_path = os.path.join(eval_dir, f"variation_vs_real_genuine_scores.txt")
imposter_file_path = os.path.join(eval_dir, f"variation_vs_real_imposter_scores.txt")
with open(genuine_file_path, "w") as f:
for score in genuine_scores:
f.write(f"{score}\n")
with open(imposter_file_path, "w") as f:
for score in imposter_scores:
f.write(f"{score}\n")
del genuine_scores
del imposter_scores
if cfg.evaluation.real_vs_reference_variation_comparison:
print("Starting REAL VS. REFERENCE VARIATION comparisons ...")
genuine_scores, imposter_scores = [], []
for genuine_pair in all_lfw_genuine_pairs:
id_name1, _ = genuine_pair
if id_name1 not in variation_embeddings_dict:
print(f"Skipping {genuine_pair}")
continue
e1, e2 = real_embeddings_dict[id_name1], variation_embeddings_dict[id_name1]
cos_sim = np.dot(e1, e2)
genuine_scores.append(cos_sim)
for imposter_pair in all_lfw_imposter_pairs:
id_name1, id_name2 = imposter_pair
if id_name2 not in variation_embeddings_dict:
print(f"Skipping {imposter_pair}")
continue
e1, e2 = real_embeddings_dict[id_name1], variation_embeddings_dict[id_name2]
cos_sim = np.dot(e1, e2)
imposter_scores.append(cos_sim)
plt.clf()
plt.hist(genuine_scores, bins=np.arange(-1, 1, 0.1), label="genuine", color="green", alpha=0.5)
plt.hist(imposter_scores, bins=np.arange(-1, 1, 0.1), label="imposter", color="red", alpha=0.5)
plt.xlim(-1, 1)
plt.legend()
plt.savefig(ensure_path_join(eval_dir, "real_vs_reference_variation_distributions.png"), dpi=512)
genuine_file_path = os.path.join(eval_dir, f"real_vs_reference_variation_genuine_scores.txt")
imposter_file_path = os.path.join(eval_dir, f"real_vs_reference_variation_imposter_scores.txt")
with open(genuine_file_path, "w") as f:
for score in genuine_scores:
f.write(f"{score}\n")
with open(imposter_file_path, "w") as f:
for score in imposter_scores:
f.write(f"{score}\n")
del genuine_scores
del imposter_scores
@staticmethod
def get_indices_for_lfw_pair(lfw_pair, labels_1, labels_2):
img_id1, img_id2 = lfw_pair
i = np.random.choice(np.where(labels_1 == img_id1)[0])
j = np.random.choice(np.where(labels_2 == img_id2)[0])
return i, j
@hydra.main(config_path='configs', config_name='evaluate_lfw_config', version_base=None)
def evaluate(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
evaluator = EvaluatorLite(devices="auto", accelerator="auto")
evaluator.run(cfg)
if __name__ == "__main__":
evaluate()