-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_func.py
131 lines (93 loc) · 4.42 KB
/
eval_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
def map_label(label, classes):
mapped_label = torch.LongTensor(label.size()).fill_(-1)
for i in range(classes.size(0)):
mapped_label[label == classes[i]] = i
return mapped_label
def compute_per_class_acc(test_label, predicted_label, nclass):
acc_per_class = torch.FloatTensor(nclass).fill_(0)
for i in range(nclass):
idx = (test_label == i)
acc_per_class[i] = torch.sum(
test_label[idx] == predicted_label[idx]).float() / torch.sum(idx).float()
return acc_per_class.mean().item()
def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package):
device = in_package['device']
per_class_accuracies = torch.zeros(
target_classes.size()[0]).float().to(device).detach()
predicted_label = predicted_label.to(device)
for i in range(target_classes.size()[0]):
is_class = test_label == target_classes[i]
per_class_accuracies[i] = torch.div(
(predicted_label[is_class] == test_label[is_class]).sum().float(), is_class.sum().float())
return per_class_accuracies.mean().item()
def val_gzsl(test_seen_loader, target_classes, in_package, bias=0):
batch_size = in_package['batch_size']
model = in_package['model']
device = in_package['device']
test_label = []
predicted_label = []
with torch.no_grad():
for batch, (imgs, labels) in enumerate(test_seen_loader):
imgs, labels = imgs.to(device), labels.to(device)
out_package = model(imgs)
output = out_package['embed']
output[:, target_classes] = output[:, target_classes]+bias
predicted_label.append(torch.argmax(output.data, 1))
test_label.append(labels)
test_label = torch.cat(test_label, dim=0)
predicted_label = torch.cat(predicted_label, dim=0)
acc = compute_per_class_acc_gzsl(
test_label, predicted_label, target_classes, in_package)
return acc
def val_zs_gzsl(test_unseen_loader, unseen_classes, in_package, bias=0):
batch_size = in_package['batch_size']
model = in_package['model']
device = in_package['device']
test_label = []
predicted_label_gzsl = []
predicted_label_zsl = []
predicted_label_zsl_t = []
with torch.no_grad():
for batch, (imgs, labels) in enumerate(test_unseen_loader):
imgs, labels = imgs.to(device), labels.to(device)
out_package = model(imgs)
output = out_package['embed']
output_t = output.clone()
output_t[:, unseen_classes] = output_t[:,
unseen_classes]+torch.max(output)+1
predicted_label_zsl.append(torch.argmax(output_t.data, 1))
predicted_label_zsl_t.append(
torch.argmax(output.data[:, unseen_classes], 1))
output[:, unseen_classes] = output[:, unseen_classes]+bias
predicted_label_gzsl.append(torch.argmax(output.data, 1))
test_label.append(labels)
test_label = torch.cat(test_label, dim=0)
predicted_label_gzsl = torch.cat(predicted_label_gzsl, dim=0)
predicted_label_zsl = torch.cat(predicted_label_zsl, dim=0)
predicted_label_zsl_t = torch.cat(predicted_label_zsl_t, dim=0)
acc_gzsl = compute_per_class_acc_gzsl(
test_label, predicted_label_gzsl, unseen_classes, in_package)
acc_zs = compute_per_class_acc_gzsl(
test_label, predicted_label_zsl, unseen_classes, in_package)
acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes).to(
device), predicted_label_zsl_t, unseen_classes.size(0))
return acc_gzsl, acc_zs_t
def eval_zs_gzsl(config, dataloader, model, bias_seen=0, bias_unseen=0):
model.eval()
test_seen_loader = dataloader.test_seen_loader
test_unseen_loader = dataloader.test_unseen_loader
seenclasses = dataloader.seenclasses
unseenclasses = dataloader.unseenclasses
batch_size = config.batch_size
in_package = {'model': model, 'device': config.device, 'batch_size': batch_size}
with torch.no_grad():
acc_seen = val_gzsl(test_seen_loader, seenclasses,
in_package, bias=bias_seen)
acc_novel, acc_zs = val_zs_gzsl(
test_unseen_loader, unseenclasses, in_package, bias=bias_unseen)
if (acc_seen+acc_novel) > 0:
H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel)
else:
H = 0
return acc_seen, acc_novel, H, acc_zs