-
Notifications
You must be signed in to change notification settings - Fork 11.5k
/
Copy pathparser.py
65 lines (52 loc) · 1.95 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import struct
import mmap
import numpy as np
def open_trace(fn):
base_header_fmt = "i" * 2
file = open(fn, "rb")
magic, version = struct.unpack(base_header_fmt, file.read(struct.calcsize(base_header_fmt)))
if magic != 0x67676d74:
raise ValueError('Invalid file magic. Must be a llama.cpp trace file')
parser_cls = TraceParserBase._parsers.get(version)
if parser_cls is None:
raise ValueError(f'Unknown version {version}')
return parser_cls(file)
class TraceParserBase:
def __init__(self, file):
self.file = file
self.mmap = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
self.pos = file.tell() # Skip magic and version header
self.size = self.mmap.size()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.mmap.close()
self.file.close()
def __iter__(self):
return self
def __next__(self):
if self.pos >= self.size:
raise StopIteration
return self.parse_record()
class TraceParserV0(TraceParserBase):
def __init__(self, file):
super().__init__(file)
header_fmt = 'i' # n_vocab
self.n_vocab, = struct.unpack_from(header_fmt, self.mmap, self.pos)
self.pos += struct.calcsize(header_fmt)
def parse_record(self):
pos = self.pos
n_vocab = self.n_vocab
header_fmt = 'i' # n_tokens
n_tokens, = struct.unpack_from(header_fmt, self.mmap, pos)
pos += struct.calcsize(header_fmt)
tokens = np.frombuffer(self.mmap, dtype=np.int32, count=n_tokens, offset=pos)
pos += tokens.itemsize * tokens.size
logits = np.frombuffer(self.mmap, dtype=np.float32, count=n_tokens * n_vocab, offset=pos)
pos += logits.itemsize * logits.size
assert pos <= self.size
self.pos = pos
return tokens, logits.reshape((n_tokens, n_vocab))
TraceParserBase._parsers = {
0: TraceParserV0
}