From 2e523f6b74acbf02edd992d4cb2919db31511473 Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Fri, 24 May 2024 21:36:34 +0200 Subject: [PATCH] feat: convert-hf.py (#62) --- .gitignore | 1 - README.md | 6 +- converter/.gitignore | 3 + converter/convert-grok-1.py | 8 +- converter/convert-hf.py | 210 +++++++++++++++++++ converter/convert-llama.py | 11 +- converter/convert-mixtral.py | 141 ------------- converter/convert-tokenizer-llama3.py | 2 +- converter/convert-tokenizer-sentencepiece.py | 20 +- converter/writer.py | 34 ++- src/app.cpp | 2 +- src/grok1-tasks.cpp | 2 + src/llama2-tasks-test.cpp | 6 +- src/llama2-tasks.cpp | 10 +- src/llama2-tasks.hpp | 2 +- src/transformer-test.cpp | 4 +- src/transformer.cpp | 14 +- src/transformer.hpp | 3 +- 18 files changed, 294 insertions(+), 185 deletions(-) create mode 100644 converter/convert-hf.py delete mode 100644 converter/convert-mixtral.py diff --git a/.gitignore b/.gitignore index b7d86ed9..418c59a7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,6 @@ *.o *.dSYM *.data -*.bin __pycache__ *-test diff --git a/README.md b/README.md index aaded0ea..4676315e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,9 @@ Run LLMs on weak devices or make powerful devices even more powerful by distribu **Supported models:** * Llama 2 (7B, 13B, 70B) chat and non-chat versions, * Llama 3, -* Grok-1 (314B). +* Grok-1 (314B) +* Mistral, Mixtral +* TinyLlama **Known limitations:** * You can run Distributed Llama only on 1, 2, 4... 2^n devices. @@ -28,7 +30,7 @@ Run LLMs on weak devices or make powerful devices even more powerful by distribu * ❌ F32 × F32 * ❌ F16 × F32 * ❌ Q40 × F32 - * ⚠️ Q40 × Q80 (partial optimization) + * ✅ Q40 × Q80 **Architecture**
The project is split up into two parts: diff --git a/converter/.gitignore b/converter/.gitignore index a8d6b6c1..c2ea6ab5 100644 --- a/converter/.gitignore +++ b/converter/.gitignore @@ -1 +1,4 @@ *.t +*.m +*.bin +*/ diff --git a/converter/convert-grok-1.py b/converter/convert-grok-1.py index 160dd112..73482f0b 100644 --- a/converter/convert-grok-1.py +++ b/converter/convert-grok-1.py @@ -2,7 +2,7 @@ import torch import sys import os -from writer import isFloatTypeSupported, writeTensor, writeHeader +from writer import parseFloatType, writeTensor, writeHeader # Model: https://huggingface.co/keyfan/grok-1-hf/tree/main @@ -116,11 +116,7 @@ def usage(): usage() folderPath = sys.argv[1] - targetFloatType = sys.argv[2] + targetFloatType = parseFloatType(sys.argv[2]) outputFileName = f'dllama-grok-1-{targetFloatType}.bin' - if not isFloatTypeSupported(targetFloatType): - print('Float type is not supported') - exit(1) - convert(targetFloatType, outputFileName) diff --git a/converter/convert-hf.py b/converter/convert-hf.py new file mode 100644 index 00000000..bbbda774 --- /dev/null +++ b/converter/convert-hf.py @@ -0,0 +1,210 @@ +import gc +import json +import sys +import os +from writer import parseFloatType, writeTensor, writeHeader, FloatType +from safetensors import safe_open + +class ArchType: + LLAMA = 0xABCD00 + MIXTRAL = 0xABCD02 + +def permute(tensor, nHeads: int, nKvHeads: int): + if nHeads != nKvHeads: + nHeads = nKvHeads + return (tensor.reshape(nHeads, 2, tensor.shape[0] // nHeads // 2, *tensor.shape[1:]).swapaxes(1, 2).reshape(tensor.shape)) + +class Processor: + def __init__(self, config): + self.config = config + self.currentModelIndex = None + self.currentModel = None + self.currentModelKeys = None + self.layerMap = {} + self.plan = [] + + def __unloadModel(self): + if self.currentModel: + del self.currentModel + self.currentModel = None + gc.collect() + + def __loadModel(self, index: int): + if (self.currentModelIndex == index): + return + self.__unloadModel() + filePath = self.config['files'][index] + fileName = os.path.basename(filePath) + print(f'💿 Loading file {fileName}...') + self.currentModel = safe_open(filePath, framework='pt', device='cpu') + self.currentModelKeys = list(self.currentModel.keys()) + for key in self.currentModelKeys: + self.layerMap[key] = index + print(f'Found {len(self.currentModelKeys)} layers') + self.currentModelIndex = index + + def __permuteQ(self, tensor): + return permute(tensor, self.config['n_heads'], self.config['n_heads']) + + def __permuteK(self, tensor): + return permute(tensor, self.config['n_heads'], self.config['n_kv_heads']) + + def __preparePlan(self): + wt = self.config['weights_float_type'] + p = self.plan + p.append([FloatType.F32, + 'model.embed_tokens.weight']) + for l in range(0, self.config['n_layers']): + p.append([wt, self.__permuteQ, + f'model.layers.{l}.self_attn.q_proj.weight']) + p.append([wt, self.__permuteK, + f'model.layers.{l}.self_attn.k_proj.weight']) + p.append([wt, + f'model.layers.{l}.self_attn.v_proj.weight']) + p.append([wt, + f'model.layers.{l}.self_attn.o_proj.weight']) + + if (self.config['n_experts'] > 0): + for e in range(self.config['n_experts']): + p.append([wt, + f'model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight']) # up + p.append([wt, + f'model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight']) # gate + p.append([wt, + f'model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight']) # down + else: + p.append([wt, + f'model.layers.{l}.mlp.gate_proj.weight']) # gate + p.append([wt, + f'model.layers.{l}.mlp.down_proj.weight']) # down + p.append([wt, + f'model.layers.{l}.mlp.up_proj.weight']) # up + + p.append([FloatType.F32, + f'model.layers.{l}.input_layernorm.weight']) + p.append([FloatType.F32, + f'model.layers.{l}.post_attention_layernorm.weight']) + p.append([FloatType.F32, + 'model.norm.weight']) + p.append([wt, + 'lm_head.weight']) + + def write(self, outputFile: str): + self.__preparePlan() + for planItem in self.plan: + lookup = planItem[1:] + transform = None + if (callable(lookup[0])): + transform = lookup[0] + lookup = lookup[1:] + + if (self.currentModelIndex == None): + modelIndex = 0 + else: + modelIndex = None + for layerName in lookup: + if (layerName in self.layerMap): + modelIndex = self.layerMap[layerName] + break + if (modelIndex is None): + modelIndex = self.currentModelIndex + 1 + self.__loadModel(modelIndex) + + tensor = None + for layerName in lookup: + if (layerName in self.currentModelKeys): + tensor = self.currentModel.get_tensor(layerName) + break + if tensor is None: + raise Exception(f'Layer {lookup[0]} not found') + print(f'🔶 Writing tensor {layerName} {tensor.shape}...') + + floatType = planItem[0] + if (transform): + tensor = transform(tensor) + writeTensor(outputFile, tensor, floatType) + +def parseArchType(type: str): + archType = { + 'llama': ArchType.LLAMA, + 'mistral': ArchType.LLAMA, + 'mixtral': ArchType.MIXTRAL, + }.get(type) + if (archType is None): + raise Exception(f'Unsupported arch type: {type}') + return archType + +def parseHiddenAct(act: str): + hiddenAct = { + 'gelu': 0, + 'silu': 1 + }.get(act) + if (hiddenAct is None): + raise Exception(f'Unsupported hidden act: {act}') + return hiddenAct + +def loadConfig(folderPath: str, weightsFloatType: int): + allFiles = os.listdir(folderPath) + allFiles.sort() + with open(os.path.join(folderPath, 'config.json')) as fc: + config = json.load(fc) + files = [] + for fileName in allFiles: + if fileName.endswith('.safetensors'): + files.append(os.path.join(folderPath, fileName)) + if (len(files) == 0): + raise Exception('Not found any model file') + + result = { + 'version': 0, + 'arch_type': parseArchType(config['model_type']), + 'hidden_act': parseHiddenAct(config['hidden_act']), + 'dim': config['hidden_size'], + 'hidden_dim': config['intermediate_size'], + 'n_layers': config['num_hidden_layers'], + 'n_heads': config['num_attention_heads'], + 'n_kv_heads': config['num_key_value_heads'], + 'weights_float_type': weightsFloatType, + 'max_seq_len': config['max_position_embeddings'], + 'vocab_size': config['vocab_size'], + 'files': files, + } + + nExperts = config.get('num_local_experts') + nActiveExperts = config.get('num_active_local_experts') or config.get('num_experts_per_tok') + result['n_experts'] = int(nExperts) if nExperts is not None else 0 + result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0 + + ropeTheta = config.get('rope_theta') + if (ropeTheta is not None): + result['rope_theta'] = int(ropeTheta) + return result + +def printUsage(): + print('Usage: python convert-hf.py ') + print() + print('Options:') + print(' The path to the folder containing the model files') + print(' The float type of the weights (e.g. "q40")') + print(' The name of the model (e.g. "llama3")') + +if __name__ == '__main__': + if (len(sys.argv) < 4): + printUsage() + exit(1) + + sourceFolderPath = sys.argv[1] + weightsFloatType = parseFloatType(sys.argv[2]) + name = sys.argv[3] + outputFileName = f'dllama_model_{name}_{sys.argv[2]}.m' + + print(f'Output file: {outputFileName}') + + config = loadConfig(sourceFolderPath, weightsFloatType) + + with open(outputFileName, 'wb') as outputFile: + writeHeader(outputFile, config) + processor = Processor(config) + processor.write(outputFile) + + print(f'✅ {outputFileName} created successfully') \ No newline at end of file diff --git a/converter/convert-llama.py b/converter/convert-llama.py index 30779054..75405837 100644 --- a/converter/convert-llama.py +++ b/converter/convert-llama.py @@ -4,7 +4,7 @@ import torch import math import numpy as np -from writer import writeTensor, writeHeader, isFloatTypeSupported +from writer import writeTensor, writeHeader, parseFloatType, FloatType from pathlib import Path LAYER_CHUNK_SIZE = 48 @@ -81,7 +81,7 @@ def convert(modelPath, outputPath, targetFloatType): layerName.endswith('.ffn_norm.weight') or layerName == 'norm.weight' ) - floatType = 'f32' if isAlwaysF32 else targetFloatType + floatType = FloatType.F32 if isAlwaysF32 else targetFloatType tensors = models[layerName] if len(tensors) == 1 or len(tensors[0].shape) == 1: @@ -105,13 +105,10 @@ def usage(): usage() modelPath = sys.argv[1] - targetFloatType = sys.argv[2] - - if (not modelPath or not isFloatTypeSupported(targetFloatType)): - usage() + targetFloatType = parseFloatType(sys.argv[2]) modelName = modelPath.split('/')[-1] - outputFileName = f'dllama_{modelName.lower()}_{targetFloatType}.bin' + outputFileName = f'dllama_model_{modelName.lower()}_{targetFloatType}.m' print(f'Model name: {modelName}') print(f'Target float type: {targetFloatType}') diff --git a/converter/convert-mixtral.py b/converter/convert-mixtral.py deleted file mode 100644 index fe529523..00000000 --- a/converter/convert-mixtral.py +++ /dev/null @@ -1,141 +0,0 @@ -import gc -import sys -import os -from writer import isFloatTypeSupported, writeTensor, writeHeader -from safetensors import safe_open - -# Model: https://huggingface.co/mistral-community/Mixtral-8x22B-v0.1 - -currentFileIndex = None -nFiles = None -model = None -layerMap = {} -folderPath = None - -def unloadModel(): - global model - if model: - del model - model = None - gc.collect() - -def loadModel(index): - global currentFileIndex - global nFiles - global model - global layerMap - global folderPath - if (currentFileIndex == index): - return - unloadModel() - fileName = f'model-000{str(index).zfill(2)}-of-000{nFiles}.safetensors'# + '?download=true' - filePath = os.path.join(folderPath, fileName) - print(f'💿 Loading file {fileName}...') - model = safe_open(filePath, framework='pt', device='cpu') - layerNames = list(model.keys()) - for layerName in layerNames: - layerMap[layerName] = index - print(f'Found layers: {layerNames}') - currentFileIndex = index - -def writeLayer(outFile, layerName, targetFloatType): - global currentFileIndex - global model - global layerMap - - if (not layerName in model.keys()): - if (layerName in layerMap): - loadModel(layerMap[layerName]) - else: - loadModel(currentFileIndex + 1) - if (not layerName in model.keys()): - raise Exception(f'Cannot load {layerName}') - - tensor = model.get_tensor(layerName) - print(f'🔶 Writing tensor {layerName} {tensor.shape}...') - writeTensor(outFile, tensor, targetFloatType) - -def getParams(modelName): - params = { - 'arch_type': 0xABCD02, - 'vocab_size': 32000, - 'n_experts': 8, - 'n_active_experts': 2, - 'hidden_act': 1, # silu - 'rope_theta': 1000000, - } - if (modelName == '8x7B'): - params['dim'] = 4096 - params['hidden_dim'] = 14336 - params['n_layers'] = 32 - params['n_heads'] = 32 - params['n_kv_heads'] = 8 - params['max_seq_len'] = 32768 - params['n_files'] = 19 - elif (modelName == '8x22B'): - params['dim'] = 6144 - params['hidden_dim'] = 16384 - params['n_layers'] = 56 - params['n_heads'] = 48 - params['n_kv_heads'] = 8 - params['max_seq_len'] = 65536 - params['n_files'] = 59 - else: - raise Exception(f'Unknown model {modelName}') - return params - -def convert(modelName, targetFloatType, outputFileName): - global nFiles - params = getParams(modelName) - nFiles = params['n_files'] - - outFile = open(outputFileName, 'wb') - writeHeader(outFile, params) - - loadModel(1) - - writeLayer(outFile, 'model.embed_tokens.weight', 'f32') - - for index in range(0, params['n_layers']): - writeLayer(outFile, f'model.layers.{index}.self_attn.q_proj.weight', targetFloatType) - writeLayer(outFile, f'model.layers.{index}.self_attn.k_proj.weight', targetFloatType) - writeLayer(outFile, f'model.layers.{index}.self_attn.v_proj.weight', targetFloatType) - writeLayer(outFile, f'model.layers.{index}.self_attn.o_proj.weight', targetFloatType) - - writeLayer(outFile, f'model.layers.{index}.block_sparse_moe.gate.weight', targetFloatType) - for e in range(params['n_experts']): - writeLayer(outFile, f'model.layers.{index}.block_sparse_moe.experts.{e}.w3.weight', targetFloatType) # up - writeLayer(outFile, f'model.layers.{index}.block_sparse_moe.experts.{e}.w1.weight', targetFloatType) # gate - writeLayer(outFile, f'model.layers.{index}.block_sparse_moe.experts.{e}.w2.weight', targetFloatType) # down - - writeLayer(outFile, f'model.layers.{index}.input_layernorm.weight', 'f32') - writeLayer(outFile, f'model.layers.{index}.post_attention_layernorm.weight', 'f32') - - loadModel(nFiles) - - writeLayer(outFile, 'model.norm.weight', 'f32') - writeLayer(outFile, 'lm_head.weight', targetFloatType) - - unloadModel() - - outFile.close() - print(f'Converted {outputFileName}') - -def usage(): - print('Usage: python convert-mixtral.py ') - exit(1) - -if __name__ == '__main__': - if (len(sys.argv) < 4): - usage() - - modelName = sys.argv[1] - folderPath = sys.argv[2] - targetFloatType = sys.argv[3] - outputFileName = f'dllama_mixtral_{modelName.lower()}-{targetFloatType}.bin' - - if not isFloatTypeSupported(targetFloatType): - print('Float type is not supported') - exit(1) - - convert(modelName, targetFloatType, outputFileName) diff --git a/converter/convert-tokenizer-llama3.py b/converter/convert-tokenizer-llama3.py index 04936a2b..c0284934 100644 --- a/converter/convert-tokenizer-llama3.py +++ b/converter/convert-tokenizer-llama3.py @@ -37,7 +37,7 @@ modelPath = sys.argv[1] with open(modelPath, 'r') as inputFile: - with open('dllama-llama3-tokenizer.t', 'wb') as outputFile: + with open('dllama_tokenizer_llama3.t', 'wb') as outputFile: inputLines = inputFile.readlines() nLines = len(inputLines) diff --git a/converter/convert-tokenizer-sentencepiece.py b/converter/convert-tokenizer-sentencepiece.py index a4167849..aaf5f6ec 100644 --- a/converter/convert-tokenizer-sentencepiece.py +++ b/converter/convert-tokenizer-sentencepiece.py @@ -5,10 +5,11 @@ from sentencepiece import SentencePieceProcessor class Tokenizer: - def __init__(self, model_path): + def __init__(self, model_path, model_name): assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) self.model_path = model_path + self.model_name = model_name # BOS / EOS token IDs self.n_words: int = self.sp_model.vocab_size() @@ -53,7 +54,7 @@ def export(self): # write to a binary file # the tokenizer.bin file is the same as .model file, but .bin - outputPath = 'dllama-' + os.path.basename(self.model_path).replace('.model', '.t') + outputPath = 'dllama_tokenizer_' + self.model_name + '.t' with open(outputPath, 'wb') as f: f.write(struct.pack('IIIiii', 0x567123, @@ -69,10 +70,17 @@ def export(self): f.write(bytes) print(f'Created {outputPath}') -if __name__ == "__main__": - if (len(sys.argv) < 2): - print('Invalid usage') +def printUsage(): + print('Usage: python convert-tokenizer-sentencepiece.py ') + print() + print('Options:') + print(' The path to the SentencePiece model file (.model)') + print(' The name of the tokenizer (e.g. "llama3")') + +if __name__ == '__main__': + if (len(sys.argv) < 3): + printUsage() exit(1) - t = Tokenizer(sys.argv[1]) + t = Tokenizer(sys.argv[1], sys.argv[2]) t.export() diff --git a/converter/writer.py b/converter/writer.py index 9dbc0b9d..256eef5f 100644 --- a/converter/writer.py +++ b/converter/writer.py @@ -3,8 +3,25 @@ import time import numpy as np -def isFloatTypeSupported(type): - return type in ['f16', 'f32', 'q40', 'q80'] +class FloatType: + F32 = 0 + F16 = 1 + Q40 = 2 + Q80 = 3 + +floatTypeMap = { + 'f32': FloatType.F32, + 'f16': FloatType.F16, + 'q40': FloatType.Q40, + 'q80': FloatType.Q80, +} +floatTypeNames = list(floatTypeMap.keys()) + +def parseFloatType(type): + floatType = floatTypeMap.get(type) + if floatType is not None: + return floatType + raise Exception(f'{type} is not supported') def writeQuantizedQ40Tensor(file, x): x = x.to(torch.float32).numpy().astype(np.float32) @@ -77,18 +94,18 @@ def writeTensor(file, tensor, floatType): d = tensor.detach().cpu().view(-1) t0 = time.time() nBytes = 0 - if (floatType == 'f16'): + if (floatType == FloatType.F16): nBytes = writeF16Tensor(file, d) - elif (floatType == 'f32'): + elif (floatType == FloatType.F32): nBytes = writeF32Tensor(file, d) - elif (floatType == 'q40'): + elif (floatType == FloatType.Q40): nBytes = writeQuantizedQ40Tensor(file, d) - elif (floatType == 'q80'): + elif (floatType == FloatType.Q80): nBytes = writeQuantizedQ80Tensor(file, d) else: - raise Exception('Unknown float type') + raise Exception(f'Unknown float type') t1 = time.time() - print(f'Saved {floatType} tensor in {t1 - t0:.2f}s, {nBytes} bytes') + print(f'Saved {floatTypeNames[floatType]} tensor in {t1 - t0:.2f}s, {nBytes} bytes') def writeHeader(file, params): headerKeys = { @@ -105,6 +122,7 @@ def writeHeader(file, params): 'max_seq_len': 10, 'hidden_act': 11, 'rope_theta': 12, + 'weights_float_type': 13 } header = struct.pack('i', 0xA00ABCD) diff --git a/src/app.cpp b/src/app.cpp index 7ea4cd58..ff1b59dc 100644 --- a/src/app.cpp +++ b/src/app.cpp @@ -92,7 +92,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) { } TransformerArch TransformerArchFactory::create(TransformerSpec* spec) { - if (spec->archType == LLAMA2) return buildLlama2Arch(spec); + if (spec->archType == LLAMA) return buildLlamaArch(spec); if (spec->archType == GROK1) return buildGrok1Arch(spec); if (spec->archType == MIXTRAL) return buildMixtralArch(spec); printf("Unsupported arch type: %d\n", spec->archType); diff --git a/src/grok1-tasks.cpp b/src/grok1-tasks.cpp index 0c68b74e..8c57041c 100644 --- a/src/grok1-tasks.cpp +++ b/src/grok1-tasks.cpp @@ -154,6 +154,8 @@ void grokMoeBlock1(TASK_ARGS) { silu(expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); } else if (spec->hiddenAct == GELU) { gelu(expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); + } else { + assert(false); } mul(expertUp, expertGate, block->moeUpAndGate0Slice->d0, nThreads, threadIndex); } diff --git a/src/llama2-tasks-test.cpp b/src/llama2-tasks-test.cpp index c750254b..ae96be4e 100644 --- a/src/llama2-tasks-test.cpp +++ b/src/llama2-tasks-test.cpp @@ -527,7 +527,7 @@ float expectedOutput[4096] = { int main() { TransformerSpec spec; spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int); - spec.archType = LLAMA2; + spec.archType = LLAMA; spec.dim = 4096; spec.nLayers = 1; spec.headSize = 128; @@ -542,7 +542,7 @@ int main() { spec.weightsFloatType = F32; spec.bufferFloatType = F32; spec.nSlices = 1; - spec.hiddenAct = GELU; + spec.hiddenAct = SILU; spec.ropeTheta = 10000.0f; size_t beforeBlockBytes = /* embedding */ 524288000; @@ -568,7 +568,7 @@ int main() { float* x = transformer.x; for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 120.0; - TransformerArch arch = buildLlama2Arch(&spec); + TransformerArch arch = buildLlamaArch(&spec); int nThreads = 4; TransformerContext context; diff --git a/src/llama2-tasks.cpp b/src/llama2-tasks.cpp index f5e155ce..5d0c13dc 100644 --- a/src/llama2-tasks.cpp +++ b/src/llama2-tasks.cpp @@ -164,7 +164,13 @@ void llamaFfn(TASK_ARGS) { matmul(spec->weightsFloatType, spec->bufferFloatType, hb0, xb, block->w10, block->w10Slice->n, block->w10Slice->d0, nThreads, threadIndex); matmul(spec->weightsFloatType, spec->bufferFloatType, block->hb20, xb, block->w30, block->w30Slice->n, block->w30Slice->d0, nThreads, threadIndex); - silu(hb0, block->w10Slice->d0, nThreads, threadIndex); + if (spec->hiddenAct == SILU) { + silu(hb0, block->w10Slice->d0, nThreads, threadIndex); + } else if (spec->hiddenDim == GELU) { + gelu(hb0, block->w10Slice->d0, nThreads, threadIndex); + } else { + assert(false); + } mul(hb0, block->hb20, block->w10Slice->d0, nThreads, threadIndex); } @@ -242,7 +248,7 @@ void llamaFinalize(TASK_ARGS) { matmul(spec->weightsFloatType, F32, transformer->logits, x, transformer->wcls, spec->dim, spec->vocabSize, nThreads, threadIndex); } -TransformerArch buildLlama2Arch(TransformerSpec* spec) { +TransformerArch buildLlamaArch(TransformerSpec* spec) { TransformerArch a; // inference diff --git a/src/llama2-tasks.hpp b/src/llama2-tasks.hpp index 8c06061a..2d451ea7 100644 --- a/src/llama2-tasks.hpp +++ b/src/llama2-tasks.hpp @@ -23,6 +23,6 @@ void llamaRmsFinal(TASK_ARGS); void llamaRmsFinalNorm(TASK_ARGS); void llamaFinalize(TASK_ARGS); -TransformerArch buildLlama2Arch(TransformerSpec* spec); +TransformerArch buildLlamaArch(TransformerSpec* spec); #endif \ No newline at end of file diff --git a/src/transformer-test.cpp b/src/transformer-test.cpp index 7cf4653f..ef5dc5c9 100644 --- a/src/transformer-test.cpp +++ b/src/transformer-test.cpp @@ -30,7 +30,7 @@ void testRopeSlice(const TransformerArchType archType, const int nSliceTests, co for (uint8_t sliceIndex = 0; sliceIndex < spec.nSlices; sliceIndex++) { RopeSlice* slice; - if (archType == LLAMA2) { + if (archType == LLAMA) { slice = new LlamaRopeSlice(&spec, sliceIndex); } else if (archType == MIXTRAL) { slice = new FalconRopeSlice(&spec, sliceIndex); @@ -80,6 +80,6 @@ void testRopeSlice(const TransformerArchType archType, const int nSliceTests, co int main() { testRopeSlice(MIXTRAL, 4, 6, 3); - testRopeSlice(LLAMA2, 6, 4, 3); + testRopeSlice(LLAMA, 6, 4, 3); return 0; } diff --git a/src/transformer.cpp b/src/transformer.cpp index a93b1875..0619adf1 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -186,7 +186,7 @@ MultiHeadAttSlice::~MultiHeadAttSlice() { TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) { TransformerSpec spec; memset(&spec, 0, sizeof(TransformerSpec)); - spec.hiddenAct = GELU; + spec.hiddenAct = SILU; spec.ropeTheta = 10000.0f; FILE* fd = fopen(path, "rb"); @@ -239,6 +239,7 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i else if (key == SEQ_LEN) spec.seqLen = value; else if (key == HIDDEN_ACT) spec.hiddenAct = (TransformerHiddenAct)value; else if (key == ROPE_THETA) spec.ropeTheta = (float)value; + else if (key == WEIGHTS_FLOAT_TYPE) { /* TODO */} else { throw std::runtime_error("Unsupported header key"); } @@ -253,8 +254,8 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i spec.bufferFloatType = bufferFloatType; spec.nSlices = nSlices; - if (spec.archType == LLAMA2) { - printf("💡 arch: llama2\n"); + if (spec.archType == LLAMA) { + printf("💡 arch: llama\n"); } else if (spec.archType == GROK1) { printf("💡 arch: grok1\n"); } else if (spec.archType == MIXTRAL) { @@ -262,6 +263,13 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i } else { throw std::runtime_error("Unsupported architecture"); } + if (spec.hiddenAct == GELU) { + printf("💡 hiddenAct: gelu\n"); + } else if (spec.hiddenAct == SILU) { + printf("💡 hiddenAct: silu\n"); + } else { + throw std::runtime_error("Unsupported hidden activation"); + } printf("💡 dim: %d\n", spec.dim); printf("💡 hiddenDim: %d\n", spec.hiddenDim); printf("💡 nLayers: %d\n", spec.nLayers); diff --git a/src/transformer.hpp b/src/transformer.hpp index 01e53823..2ab2cff2 100644 --- a/src/transformer.hpp +++ b/src/transformer.hpp @@ -53,6 +53,7 @@ enum TransformerHeaderKey { SEQ_LEN = 10, HIDDEN_ACT = 11, ROPE_THETA = 12, + WEIGHTS_FLOAT_TYPE = 13, }; struct TransformerFileOldHeader { @@ -68,7 +69,7 @@ struct TransformerFileOldHeader { }; enum TransformerArchType { - LLAMA2 = 0xABCD00, + LLAMA = 0xABCD00, GROK1 = 0xABCD01, MIXTRAL = 0xABCD02 };