-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
25 lines (19 loc) · 852 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from seqeval.metrics import recall_score, precision_score
import numpy as np
def compute_metrics(p, all_labels):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
true_predictions = [
[all_labels[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[all_labels[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
recall = recall_score(true_labels, true_predictions)
precision = precision_score(true_labels, true_predictions)
f1_score = (1 + 5 * 5) * recall * precision / (5 * 5 * precision + recall)
results = {"recall": recall, "precision": precision, "f1": f1_score}
return results