Skip to content

Trace model outputs to a binary file #477

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.vs/
.vscode/
.DS_Store
__pycache__

build/
build-em/
Expand Down
47 changes: 47 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.input_prefix = argv[i];
} else if (arg == "--trace") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.trace_fn = argv[i];
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
Expand Down Expand Up @@ -224,6 +230,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " --trace FNAME save the the model logits during evaluation to a binary file\n");
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -256,3 +263,43 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s

return res;
}

// Open the trace file and write the header in the binary format: magic:int version:int n_vocab:int
std::ofstream trace_open(const gpt_params & params, struct llama_context * ctx) {
std::ofstream trace_ofs;

const uint32_t n_vocab = llama_n_vocab(ctx);
if(n_vocab <= 0) {
return trace_ofs;
}
const auto& trace_fn = params.trace_fn;
trace_ofs.open(trace_fn, std::ios::binary);
if(trace_ofs.is_open() && trace_ofs.good()) {
fprintf(stderr, "Tracing evaluation to: '%s'\n", trace_fn.c_str());
trace_ofs.write(reinterpret_cast<const char*>(&LLAMA_TRACE_MAGIC), sizeof(uint32_t));
trace_ofs.write(reinterpret_cast<const char*>(&LLAMA_TRACE_VERSION), sizeof(uint32_t));
trace_ofs.write(reinterpret_cast<const char*>(&n_vocab), sizeof(uint32_t));
} else {
fprintf(stderr, "Could not open trace file: '%s'\n", trace_fn.c_str());
trace_ofs.close();
}
return trace_ofs;
}

// Write a record using the binary format: N:int {N}token_id:int {N*n_vocab}logits:float
void trace_write_record(
std::ofstream & out,
const std::vector<llama_token> & embd,
struct llama_context * ctx) {

const uint32_t N = embd.size();
const int n_vocab = llama_n_vocab(ctx);
const float * logits = llama_get_logits(ctx);
if(!out.is_open() || out.bad() || N == 0 || n_vocab <= 0) {
return;
}

out.write(reinterpret_cast<const char*>(&N), sizeof(uint32_t));
out.write(reinterpret_cast<const char*>(embd.data()), sizeof(llama_token)*N);
out.write(reinterpret_cast<const char*>(logits), sizeof(float)*N*n_vocab);
}
17 changes: 17 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string input_prefix = ""; // string to prefix user inputs with
std::string trace_fn = "";


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
Expand Down Expand Up @@ -63,3 +64,19 @@ std::string gpt_random_prompt(std::mt19937 & rng);
//

std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);

//
// Trace utils
//

static constexpr uint32_t LLAMA_TRACE_VERSION = 0;
static constexpr uint32_t LLAMA_TRACE_MAGIC = 0x67676d74; // 'ggmt' in hex

// Open format: magic:int version:int n_vocab:int
std::ofstream trace_open(const gpt_params & params, struct llama_context * ctx);

// Write a record using the binary format: N:int {N}token_id:int {N*n_vocab}logits:float
void trace_write_record(
std::ofstream & out,
const std::vector<llama_token> & embd,
struct llama_context * ctx);
5 changes: 5 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ int main(int argc, char ** argv) {
lparams.n_parts = params.n_parts;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = !params.trace_fn.empty();
lparams.use_mlock = params.use_mlock;

ctx = llama_init_from_file(params.model.c_str(), lparams);
Expand Down Expand Up @@ -205,6 +206,8 @@ int main(int argc, char ** argv) {
return 0;
}

std::ofstream trace_ofs = trace_open(params, ctx);

// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');

Expand Down Expand Up @@ -339,6 +342,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
trace_write_record(trace_ofs, embd, ctx);
}

n_past += embd.size();
Expand Down Expand Up @@ -502,6 +506,7 @@ int main(int argc, char ** argv) {

llama_print_timings(ctx);
llama_free(ctx);
trace_ofs.close();

set_console_state(CONSOLE_STATE_DEFAULT);

Expand Down
6 changes: 6 additions & 0 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "common.h"
#include "llama.h"

#include <fstream>

std::vector<double> softmax(const std::vector<float>& logits) {
std::vector<double> probs(logits.size());
float max_logit = logits[0];
Expand All @@ -27,6 +29,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
double nll = 0.0;
int seq_count = tokens.size() / params.n_ctx;

std::ofstream trace_ofs = trace_open(params, ctx);
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);

for (int i = 0; i < seq_count; ++i) {
Expand Down Expand Up @@ -57,6 +60,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// process the entire prompt.

auto logits = llama_get_logits(ctx);
trace_write_record(trace_ofs, embd, ctx);

for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
int n_vocab = llama_n_vocab(ctx);
Expand All @@ -72,6 +77,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
fflush(stdout);
}
printf("\n");
trace_ofs.close();
}

int main(int argc, char ** argv) {
Expand Down
1 change: 1 addition & 0 deletions examples/traceparser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parser import open_trace
70 changes: 70 additions & 0 deletions examples/traceparser/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse

import numpy as np
from sentencepiece import SentencePieceProcessor

from . import open_trace

def parse_args():
parser = argparse.ArgumentParser(description='Upgrade old ggml model files to the current format')
parser.add_argument('trace_file', help='tracefile to read')
parser.add_argument('--tokenizer', help='path to LLaMA tokenizer.model file',
dest='tokenizer_model_file', default='models/tokenizer.model')
parser.add_argument('--temp', help='Sampling temperature',
dest='temperature', default=0.8, type=float)
parser.add_argument('--top_k', help='top k tokens to sample', type=int)
parser.add_argument('--top_p', help='nucleus probability', type=float, default=1.0)
return parser.parse_args()


def top_k_indices(logits, k):
idxs = np.argpartition(logits, -k)[-k:]
idxs = idxs[np.argsort(logits[idxs])][::-1]
return idxs

def process_logits(logits, temp):
logits = logits / temp
logp = logits - logits.max()
p = np.exp(logp)
sum_p = p.sum()
entropy = -(p * logp).sum() / sum_p + np.log(sum_p)
p /= sum_p
#entropy = -(p * np.log(p)).sum()
return p, entropy

def top_p(p, top_p):
if top_p < 1:
cumsum = 0.
for i in range(len(p)):
cumsum += p[i]
if cumsum >= top_p:
return i + 1
return len(p)

def replicate_sampler(tokens, args, max_print=10):
log2 = np.log(2)
tokenizer = SentencePieceProcessor(args.tokenizer_model_file)
piece_repr = lambda tokid: repr(tokenizer.id_to_piece(int(tokid)))
for tokens, logits_arrs in f:
for tokid, logits in zip(tokens, logits_arrs):
idxs = None
if args.top_k is not None:
idxs = top_k_indices(logits, args.top_k)
else:
idxs = np.argsort(logits)[::-1]
logits = logits[idxs]
p, entropy = process_logits(logits, args.temperature)

n_top_p = top_p(p, args.top_p)
logits = logits[:n_top_p]
idxs = idxs[:n_top_p]

print(f'in:{piece_repr(tokid):10} logits: mean={logits.mean()=:5.2f} max={logits[0]:5.2f} entropy={entropy*log2:.2f} bits n={len(idxs)}')
print(' '*13, ' '.join(f'{piece_repr(candtok)}:{prob:.2f}' for candtok, prob in zip(idxs[:max_print], p)))

if __name__ == "__main__":
args = parse_args()

with open_trace(args.trace_file) as f:
print(f'n_vocab={f.n_vocab}')
replicate_sampler(f, args)
65 changes: 65 additions & 0 deletions examples/traceparser/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import struct
import mmap

import numpy as np


def open_trace(fn):
base_header_fmt = "i" * 2
file = open(fn, "rb")
magic, version = struct.unpack(base_header_fmt, file.read(struct.calcsize(base_header_fmt)))
if magic != 0x67676d74:
raise ValueError('Invalid file magic. Must be a llama.cpp trace file')
parser_cls = TraceParserBase._parsers.get(version)
if parser_cls is None:
raise ValueError(f'Unknown version {version}')
return parser_cls(file)

class TraceParserBase:
def __init__(self, file):
self.file = file
self.mmap = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
self.pos = file.tell() # Skip magic and version header
self.size = self.mmap.size()

def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.mmap.close()
self.file.close()

def __iter__(self):
return self

def __next__(self):
if self.pos >= self.size:
raise StopIteration
return self.parse_record()

class TraceParserV0(TraceParserBase):
def __init__(self, file):
super().__init__(file)
header_fmt = 'i' # n_vocab
self.n_vocab, = struct.unpack_from(header_fmt, self.mmap, self.pos)
self.pos += struct.calcsize(header_fmt)

def parse_record(self):
pos = self.pos
n_vocab = self.n_vocab

header_fmt = 'i' # n_tokens
n_tokens, = struct.unpack_from(header_fmt, self.mmap, pos)
pos += struct.calcsize(header_fmt)
tokens = np.frombuffer(self.mmap, dtype=np.int32, count=n_tokens, offset=pos)
pos += tokens.itemsize * tokens.size
logits = np.frombuffer(self.mmap, dtype=np.float32, count=n_tokens * n_vocab, offset=pos)
pos += logits.itemsize * logits.size

assert pos <= self.size
self.pos = pos
return tokens, logits.reshape((n_tokens, n_vocab))

TraceParserBase._parsers = {
0: TraceParserV0
}