Skip to content

Make loading weights 10-100x faster #613

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 9 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ models/*
/result
/perplexity
/embedding
/Pipfile

arm_neon.h
compile_commands.json
Expand Down
5 changes: 5 additions & 0 deletions convert-ggml-to-pth.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def read_variables(fin):
shape = shape[::-1]
name = fin.read(name_length).decode("utf-8")

# ensure tensor data is aligned
tensor_data_offset = fin.tell()
tensor_data_offset = (tensor_data_offset + 31) & -32
fin.seek(tensor_data_offset)

if ftype_cur == 2:
# 4-bit quantized weights
dtype = np.uint8
Expand Down
5 changes: 5 additions & 0 deletions convert-gptq-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def write_header(shape, dst_name, ftype_cur):
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
fout.write(sname)

# ensure tensor data is aligned
tensor_data_offset = fout.tell()
tensor_data_offset = (tensor_data_offset + 31) & -32
fout.seek(tensor_data_offset)

def convert_non_q4(src_name, dst_name):
v = model[src_name]
shape = v.shape
Expand Down
201 changes: 148 additions & 53 deletions convert-pth-to-ggml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Convert a LLaMA model checkpoint to a ggml compatible file
# Convert a LLaMA model checkpoint to a ggjt compatible file
#
# Load the model using Torch
# Iterate over all variables and write them to a binary file.
Expand All @@ -24,16 +24,64 @@

from sentencepiece import SentencePieceProcessor

def parse_args():
QK = 32

GGML_TYPE_Q4_0 = 0
GGML_TYPE_Q4_1 = 1
GGML_TYPE_I8 = 2
GGML_TYPE_I16 = 3
GGML_TYPE_I32 = 4
GGML_TYPE_F16 = 5
GGML_TYPE_F32 = 6

WTYPES = {
0: GGML_TYPE_F32,
1: GGML_TYPE_F16,
2: GGML_TYPE_Q4_0,
3: GGML_TYPE_Q4_1,
}

GGML_BLCK_SIZE = {
GGML_TYPE_Q4_0: QK,
GGML_TYPE_Q4_1: QK,
GGML_TYPE_I8: 1,
GGML_TYPE_I16: 1,
GGML_TYPE_I32: 1,
GGML_TYPE_F16: 1,
GGML_TYPE_F32: 1,
}

GGML_TYPE_SIZE = {
GGML_TYPE_Q4_0: 4 + QK//2,
GGML_TYPE_Q4_1: 4*2 + QK//2,
GGML_TYPE_I8: 1,
GGML_TYPE_I16: 2,
GGML_TYPE_I32: 4,
GGML_TYPE_F16: 2,
GGML_TYPE_F32: 4,
}

def ggml_nelements(shape):
r = 1
for i in shape:
r *= i
return r

def ggml_nbytes(shape, ftype):
x = ggml_nelements(shape)
t = WTYPES[ftype]
x *= GGML_TYPE_SIZE[t]
x //= GGML_BLCK_SIZE[t]
return x

def parse_args():
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?')
return parser.parse_args()

def get_n_parts(dim):

mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
n_parts = mappings.get(dim)
if n_parts is None:
Expand All @@ -44,30 +92,24 @@ def get_n_parts(dim):
return n_parts

def load_hparams_and_tokenizer(dir_model):

# `dir_model` is something like `models/7B` or `models/7B/`.
# "tokenizer.model" is expected under model's parent dir.
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
# Let's use the model's parent dir directly.
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))

fname_hparams = f"{dir_model}/params.json"
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"

with open(fname_hparams, "r") as f:
hparams = json.load(f)
print(hparams)

tokenizer = SentencePieceProcessor(fname_tokenizer)
hparams.update({"vocab_size": tokenizer.vocab_size()})

return hparams, tokenizer

def write_header(fout, hparams, ftype):

keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
values = [
0x67676d66, # magic: ggmf in hex
0x67676a74, # magic: ggjt in hex
1, # file version
*[hparams[key] for key in keys],
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
Expand All @@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype):
fout.write(struct.pack("i" * len(values), *values))

def write_tokens(fout, tokenizer):

for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
text = " \u2047 ".encode("utf-8")
Expand All @@ -95,85 +136,139 @@ def write_tokens(fout, tokenizer):
fout.write(text)
fout.write(struct.pack("f", tokenizer.get_score(i)))

def process_and_write_variables(fout, model, ftype):

def process_and_write_variables(fout, model, ftype, part_id, n_parts):
for name, datao in model.items():

if name.endswith("freqs"):
continue

shape = datao.shape

print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")

# remove dimensions with a single element
data = datao.numpy().squeeze()
n_dims = len(shape)
partshape = data.shape
n_dims = len(data.shape)
assert n_dims in (1, 2)

# default type is fp16
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")

# coerce single-dimensional tensors from float16 to float32
ftype_cur = 1
if ftype == 0 or n_dims == 1:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0

# header
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]

# determine dimension along which multipart tensor is sharded
#
# split_dim 0 regex:
# - output.*
# - layers.*.attention.wq.weight
# - layers.*.attention.wk.weight
# - layers.*.attention.wv.weight
# - layers.*.feed_forward.w1.weight
# - layers.*.feed_forward.w3.weight
#
# split_dim 1 regex:
# - tok_embeddings.*
# - layers.*.attention.wo.weight
# - layers.*.feed_forward.w2.weight
#
if n_dims > 1:
split_dim = 1
if "tok_embeddings" in name:
split_dim = 1
elif "layers" in name:
if "attention.wo.weight" in name:
split_dim = 1
elif "feed_forward.w2.weight" in name:
split_dim = 1
else:
split_dim = 0
elif "output" in name:
split_dim = 0

# output tensor header
fullshape = list(partshape)
if n_dims > 1:
fullshape[split_dim] *= n_parts
sname = name.encode('utf-8')
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
for dim in reversed(data.shape):
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
for dim in reversed(fullshape):
fout.write(struct.pack("i", dim))
fout.write(sname)

# data output to file
data.tofile(fout)
# ensure tensor data is aligned
tensor_data_offset = fout.tell()
while tensor_data_offset % QK != 0:
fout.write(struct.pack("B", 0))
tensor_data_offset += 1

# output unified mappable tensor data
if n_dims == 1 or n_parts == 1:
# copy tensor which we thankfully received in one piece
if part_id == 0:
data.tofile(fout)
elif split_dim == 0:
# reassemble multifile tensor containing some of the rows
rows_per_chunk = partshape[0]
current_row = part_id * rows_per_chunk
bytes_per_row = fullshape[1] // blck_size * type_size
offset = current_row * bytes_per_row
fout.seek(tensor_data_offset + offset)
data.tofile(fout)
elif split_dim == 1:
# reassemble multifile tensor containing some of the cols
cols_per_chunk = partshape[1]
current_col = part_id * cols_per_chunk
bytes_per_row = fullshape[1] // blck_size * type_size
offset_current_col = current_col // blck_size * type_size
for row in range(partshape[0]):
offset_row = row * bytes_per_row
offset = offset_row + offset_current_col
fout.seek(tensor_data_offset + offset)
data[row].tofile(fout)

# advance file position to next tensor
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))

def main():

args = parse_args()
dir_model = args.dir_model
ftype = args.ftype
ftype_str = ["f32", "f16"]

hparams, tokenizer = load_hparams_and_tokenizer(dir_model)

print(args)

# if only writing vocab to file
if args.vocab_only:

fname_model = f"{dir_model}/consolidated.00.pth"
fname_out = f"{dir_model}/ggml-vocab.bin"

print(f"Extracting only the vocab from '{fname_model}'\n")


with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)


print(f"Done. Output file: {fname_out}\n")

return

n_parts = get_n_parts(hparams["dim"])

for p in range(n_parts):

print(f"Processing part {p+1} of {n_parts}\n")

fname_model = f"{dir_model}/consolidated.0{p}.pth"
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"

model = torch.load(fname_model, map_location="cpu")

with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
process_and_write_variables(fout, model, ftype)

del model

print(f"Done. Output file: {fname_out}, (part {p})\n")
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"

# we output a single file for ggml
with open(fname_out, "wb") as fout:
write_header(fout, hparams, ftype)
write_tokens(fout, tokenizer)
offset_of_tensors = fout.tell()
# the tensors we load could be split across multiple files
for part_id in range(n_parts):
fout.seek(offset_of_tensors)
print(f"Processing part {part_id+1} of {n_parts}\n")
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
model = torch.load(fname_model, map_location="cpu")
process_and_write_variables(fout, model, ftype, part_id, n_parts)
del model

print(f"Done. Output file: {fname_out}\n")

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ int main(int argc, char ** argv) {

// needed to initialize f16 tables
{
struct ggml_init_params params = { 0, NULL };
struct ggml_init_params params = { 0, NULL, false };
struct ggml_context * ctx = ggml_init(params);
ggml_free(ctx);
}
Expand Down
Loading