-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathevaluation.py
44 lines (32 loc) · 1.49 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Tool to metrics calculation through data and label (string and string).
* Calculation from Optical Character Recognition (OCR) metrics with editdistance.
"""
import string
import unicodedata
import editdistance
import numpy as np
def ocr_metrics(predicts, ground_truth, norm_accentuation=False, norm_punctuation=False):
"""Calculate Character Error Rate (CER), Word Error Rate (WER) and Sequence Error Rate (SER)"""
if len(predicts) == 0 or len(ground_truth) == 0:
return (1, 1, 1)
cer, wer, ser = [], [], []
for (pd, gt) in zip(predicts, ground_truth):
if norm_accentuation:
pd = unicodedata.normalize("NFKD", pd).encode("ASCII", "ignore").decode("ASCII")
gt = unicodedata.normalize("NFKD", gt).encode("ASCII", "ignore").decode("ASCII")
if norm_punctuation:
pd = pd.translate(str.maketrans("", "", string.punctuation))
gt = gt.translate(str.maketrans("", "", string.punctuation))
pd_cer, gt_cer = list(pd), list(gt)
dist = editdistance.eval(pd_cer, gt_cer)
cer.append(dist / (max(len(pd_cer), len(gt_cer))))
pd_wer, gt_wer = pd.split(), gt.split()
dist = editdistance.eval(pd_wer, gt_wer)
wer.append(dist / (max(len(pd_wer), len(gt_wer))))
pd_ser, gt_ser = [pd], [gt]
dist = editdistance.eval(pd_ser, gt_ser)
ser.append(dist / (max(len(pd_ser), len(gt_ser))))
metrics = [cer, wer, ser]
metrics = np.mean(metrics, axis=1)
return metrics