-
Notifications
You must be signed in to change notification settings - Fork 540
/
Copy pathevaluate-accuracy.py
111 lines (83 loc) · 3.16 KB
/
evaluate-accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from transformers import AutoTokenizer
import nltk
import evaluate
import numpy as np
import json
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-path", required=True,
help="Path to Llama2-70b-hf-chat checkpoint")
parser.add_argument("--mlperf-accuracy-file", required=True,
help="path to mlperf_log_accuracy.json")
parser.add_argument("--dataset-file", required=True,
help="path to processed openorca validation set")
parser.add_argument("--verbose", action="store_true",
help="verbose messages")
parser.add_argument("--dtype", default="int64",
help="dtype of the accuracy log", choices=["int32", "int64", "float"])
args = parser.parse_args()
return args
def get_groundtruth(processed_dataset_file):
import pandas as pd
data = pd.read_pickle(processed_dataset_file)
ground_truths = data['output']
return ground_truths
def postprocess_text(preds, targets):
preds = [pred.strip() for pred in preds]
targets = [target.strip() for target in targets]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
return preds, targets
def main():
args = get_args()
dataset_path = args.dataset_file
checkpoint_path = args.checkpoint_path
metric = evaluate.load("rouge")
nltk.download('punkt')
tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path,
model_max_length=2048,
padding_side="left",
use_fast=False,)
targets = get_groundtruth(args.dataset_file)
target_required = []
preds_token_ids = []
eval_dtype = np.int64
if args.dtype == "int32":
eval_dtype = np.int32
elif args.dtype == "float":
eval_dtype = np.float32
with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)
seen = set()
gen_tok_len = 0
for pred in results:
qsl_idx = pred['qsl_idx']
if qsl_idx in seen:
continue
seen.add(qsl_idx)
target = targets[qsl_idx]
target_required.append(target)
pred = np.frombuffer( bytes.fromhex(pred['data']), eval_dtype)
gen_tok_len += len(pred)
preds_token_ids.append(pred)
preds_decoded_text = tokenizer.batch_decode(
preds_token_ids, skip_special_tokens=True)
preds, targets = postprocess_text(preds_decoded_text, target_required)
result = metric.compute(
predictions=preds, references=targets, use_stemmer=True, use_aggregator=False)
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
prediction_lens = [len(pred) for pred in preds]
gen_num = len(preds)
result = {**result,
'gen_len': np.sum(prediction_lens),
'gen_num': gen_num,
'gen_tok_len': gen_tok_len,
'tokens_per_sample': round(gen_tok_len / gen_num, 1)
}
print("\nResults\n")
print(result)
if __name__ == "__main__":
main()