-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluation.py
218 lines (189 loc) · 7.76 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import numpy as np
import os
import torch
import torchvision
from PIL import Image
from kornia import augmentation
from torchvision import transforms
import metrics.fid
import utils
from models import inception
from models.classifiers import *
def calc_fid(recovery_img_path, private_img_path, batch_size=64):
"""
Calculate the FID of the reconstructed image.
:param recovery_img_path: the dir of reconstructed images
:param private_img_path: the dir of private data
:param batch_size: batch size
:return: FID of reconstructed images
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
inception_model = inception.InceptionV3().to(device)
recovery_list, private_list = [], []
# get the reconstructed images
list_of_idx = os.listdir(recovery_img_path) # [0,1,2,3,4,5....]
if len(list_of_idx) == 0:
return -1000
for idx in list_of_idx:
success_recovery_num = len(os.listdir(os.path.join(recovery_img_path, idx)))
for recovery_img in os.listdir(os.path.join(recovery_img_path, idx)):
image = Image.open(os.path.join(recovery_img_path, idx, recovery_img))
image = torchvision.transforms.ToTensor()(image).unsqueeze(0)
recovery_list.append(image.numpy())
# get real images from private date
eval_loader = iter(torch.utils.data.DataLoader(
torchvision.datasets.ImageFolder(
private_img_path,
torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
), batch_size, )
)
for imgs, _ in eval_loader:
private_list.append(imgs.numpy())
recovery_images = np.concatenate(recovery_list)
private_images = np.concatenate(private_list)
mu_fake, sigma_fake = metrics.fid.calculate_activation_statistics(
recovery_images, inception_model, batch_size, device=device
)
mu_real, sigma_real = metrics.fid.calculate_activation_statistics(
private_images, inception_model, batch_size, device=device
)
fid_score = metrics.fid.calculate_frechet_distance(
mu_fake, sigma_fake, mu_real, sigma_real
)
return fid_score
def get_private_feats(E, private_feats_path, private_loader):
"""
Get the features of private data on the evaluation model, and save as file.
:param E: Evaluation model
:param private_feats_path: save path
:param private_loader: dataloader of the private data
:return:
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if not os.path.exists(private_feats_path):
os.makedirs(private_feats_path)
private_feats = None
private_targets = None
for i, (images, targets) in enumerate(private_loader):
images, targets = images.to(device), targets.to(device)
targets = targets.view(-1)
images = augmentation.Resize((112, 112))(images)
feats = E(images)[0]
if i == 0:
private_feats = feats.detach().cpu()
private_targets = targets.detach().cpu()
else:
private_feats = torch.cat([private_feats, feats.detach().cpu()], dim=0)
private_targets = torch.cat([private_targets, targets.detach().cpu()], dim=0)
print("private_feats: ", private_feats.shape)
print("private_targets: ", private_targets.shape)
np.save(os.path.join(private_feats_path, 'private_feats.npy'), private_feats.numpy())
np.save(os.path.join(private_feats_path, 'private_targets.npy'), private_targets.numpy())
print("Done!")
def calc_knn(feat, iden, path):
"""
Get the KNN Dist from reconstructed images to private date
:param feat: features of reconstructed images output by evaluation model
:param iden: target class
:param path: the filepath of the private features
:return: KNN Distance
"""
iden = iden.cpu().long()
feat = feat.cpu()
true_feat = torch.from_numpy(np.load(os.path.join(path, "private_feats.npy"))).float()
info = torch.from_numpy(np.load(os.path.join(path, "private_targets.npy"))).view(-1).long()
bs = feat.size(0)
tot = true_feat.size(0)
knn_dist = 0
for i in range(bs):
knn = 1e8
for j in range(tot):
if info[j] == iden[i]: # 在private domain中找对应类别的图片
dist = torch.sum((feat[i, :] - true_feat[j, :]) ** 2) # 计算特征的l2距离
if dist < knn:
knn = dist
knn_dist += knn
return (knn_dist / bs).item()
def get_knn_dist(E, infered_image_path, private_feats_path):
"""
Get KNN Dist of reconstructed images.
:param E:
:param infered_image_path:
:param private_feats_path:
:return:
"""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
list_of_idx = os.listdir(infered_image_path)
images_list = []
targets_list = []
# load reconstructed images
for idx in list_of_idx:
for filename in os.listdir(os.path.join(infered_image_path, idx)):
target, seed = os.path.splitext(filename)[0].strip().split('_')[-2:]
image = Image.open(os.path.join(infered_image_path, idx, filename))
image = transforms.functional.to_tensor(image)
images_list.append(image)
targets_list.append(int(target))
images = torch.stack(images_list, dim=0)
targets = torch.LongTensor(targets_list)
# get features of reconstructed images
infered_feats = None
images_spilt_list = images.chunk(int(images.shape[0] / 10))
for i, images in enumerate(images_spilt_list):
images = augmentation.Resize((112, 112))(images).to(device)
feats = E(images)[0]
if i == 0:
infered_feats = feats.detach().cpu()
else:
infered_feats = torch.cat([infered_feats, feats.detach().cpu()], dim=0)
# calc the knn dist
knn_dist = calc_knn(infered_feats, targets, private_feats_path)
return knn_dist
def evaluate(args, current_iter, gen, device, inception_model=None, eval_iter=None):
"""Evaluate in the training process."""
calc_fid = (inception_model is not None) and (eval_iter is not None)
num_batches = args.n_eval_batches
gen.eval()
fake_list, real_list = [], []
conditional = True # args.cGAN
for i in range(0, num_batches):
if conditional:
class_id = i % args.num_classes
else:
class_id = None
fake = utils.generate_images(
gen, device, args.batch_size, args.gen_dim_z,
args.gen_distribution, class_id=class_id
)
if calc_fid and i <= args.n_fid_batches:
fake_list.append((fake.cpu().numpy() + 1.0) / 2.0)
real_list.append((next(eval_iter)[0].cpu().numpy() + 1.0) / 2.0)
# Save generated images.
root = args.eval_image_root
if conditional:
root = os.path.join(root, "class_id_{:04d}".format(i))
if not os.path.isdir(root):
os.makedirs(root)
fn = "image_iter_{:07d}_batch_{:04d}.png".format(current_iter, i)
torchvision.utils.save_image(
fake, os.path.join(root, fn), nrow=4, normalize=True, scale_each=True
)
# Calculate FID scores
if calc_fid:
fake_images = np.concatenate(fake_list)
real_images = np.concatenate(real_list)
mu_fake, sigma_fake = metrics.fid.calculate_activation_statistics(
fake_images, inception_model, args.batch_size, device=device
)
mu_real, sigma_real = metrics.fid.calculate_activation_statistics(
real_images, inception_model, args.batch_size, device=device
)
fid_score = metrics.fid.calculate_frechet_distance(
mu_fake, sigma_fake, mu_real, sigma_real
)
else:
fid_score = -1000
gen.train()
return fid_score