-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluation.py
107 lines (88 loc) · 3.94 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import itertools
from bert_score import score
from nltk.translate.bleu_score import sentence_bleu
def process_inputs(groundtruth, cand):
# for evaluation we make sure both sets have 5 elements
while len(groundtruth) < 5:
groundtruth.append('')
while len(cand) < 5:
cand.append('')
return groundtruth[:5], cand[:5]
def best_bleu_cand(groundtruth, cand):
groundtruth, cand = process_inputs(groundtruth, cand)
all_permutations = list(itertools.permutations(cand))
max_bleu = 0.
best_cand = all_permutations[0]
for cand in all_permutations:
bleu = 0.
for i in range(min(len(groundtruth), len(cand))):
bleu += sentence_bleu([groundtruth[i]], cand[i]) / len(groundtruth)
if bleu > max_bleu:
max_bleu = bleu
best_cand = cand
return list(best_cand)
def eval_bleu(groundtruth, cand):
# Calculates the SET BLEU metrics, for 1-gram, 2-gram, 3-gram and 4-gram overlaps
groundtruth, cand = process_inputs(groundtruth, cand)
best_cand = best_bleu_cand(groundtruth, cand)
bleu = [0., 0., 0., 0.]
bleu_weights = [[1, 0, 0, 0], [0.5, 0.5, 0, 0], [0.33, 0.33, 0.33, 0], [0.25, 0.25, 0.25, 0.25]]
for j in range(4):
for i in range(min(len(groundtruth), len(best_cand))):
bleu[j] += sentence_bleu([groundtruth[i]], best_cand[i], weights=bleu_weights[j]) / len(groundtruth)
return bleu
def bertscore(groundtruth, cand):
# Calculates the Set BERT-Score metrics for Precision, Recall & F1
groundtruth, cand = process_inputs(groundtruth, cand)
best_cand = best_bleu_cand(groundtruth, cand)
(P, R, F), hashname = score(best_cand, groundtruth, lang="en", return_hash=True, device="cuda:0")
return P.mean().item(), R.mean().item(), F.mean().item()
def exact_match(groundtruth, cand):
# Calculates the exact match Precision, Recall & F1
groundtruth, cand = process_inputs(groundtruth, cand)
c = 0.
for x in cand:
if x != '' and x in groundtruth:
c += 1
p = c / (len([x for x in cand if x != ''])+1e-8)
r = c / (len([x for x in groundtruth if x != ''])+1e-8)
f1 = 2 * p * r / (p + r) if p + r > 0 else 0.
return [p, r, f1]
def term_match(groundtruth, cand):
# Calculates the term overlap Precision, Recall & F1
groundtruth, cand = process_inputs(groundtruth, cand)
gt_terms = set([])
for x in groundtruth:
if x == '':
continue
for t in x.strip().split():
gt_terms.add(t)
cand_terms = set([])
for x in cand:
if x == '':
continue
for t in x.strip().split():
cand_terms.add(t)
c = 0.
for x in cand_terms:
if x != '' and x in gt_terms:
c += 1
p = c / (len([x for x in cand_terms if x != ''])+1e-8)
r = c / (len([x for x in gt_terms if x != ''])+1e-8)
f1 = 2 * p * r / (p + r) if p + r > 0 else 0.
return [p, r, f1]
if __name__ == "__main__":
groundtruth = ["for sale", "used cars", "electric", "cheap"]
cand = ["afforable cars", "cars for sale", "used", "electric"]
term_overlap_metrics = term_match(groundtruth, cand)
print("Term overlap metrics: P={},R={},F1={}".format(term_overlap_metrics[0],
term_overlap_metrics[1],
term_overlap_metrics[2]))
exact_match_metrics = exact_match(groundtruth, cand)
print("Exact match metrics: P={},R={},F1={}".format(exact_match_metrics[0],
exact_match_metrics[1],
exact_match_metrics[2]))
bert_score_metrics = bertscore(groundtruth, cand)
print("BERT score metrics: P={},R={},F1={}".format(bert_score_metrics[0],
bert_score_metrics[1],
bert_score_metrics[2]))