From b4b384250084c3f177509c7b27fd6657de830692 Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Sat, 25 May 2024 11:24:40 +0200 Subject: [PATCH] feat: speed up synchronization of mlp. (#64) --- src/llama2-tasks.cpp | 50 +++++++++++++++++--------------------------- src/transformer.cpp | 2 +- src/transformer.hpp | 2 +- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/llama2-tasks.cpp b/src/llama2-tasks.cpp index 5d0c13dc..235552fd 100644 --- a/src/llama2-tasks.cpp +++ b/src/llama2-tasks.cpp @@ -150,12 +150,12 @@ void llamaQuantizeRmfFfn(TASK_ARGS) { quantizeUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED); } -void llamaSyncRmfFfn(TASK_ARGS) { +void llamaSyncFfn(TASK_ARGS) { TASK_VARIABLES; syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED); } -void llamaFfn(TASK_ARGS) { +void llamaFfn0(TASK_ARGS) { TASK_VARIABLES; float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); @@ -174,49 +174,41 @@ void llamaFfn(TASK_ARGS) { mul(hb0, block->hb20, block->w10Slice->d0, nThreads, threadIndex); } -void llamaQuantizeFfnA(TASK_ARGS) { +void llamaFfn1(TASK_ARGS) { TASK_VARIABLES; quantizeSlicedBuffer(nThreads, threadIndex, ctx, true, TB_SLICED_HB, TB_SLICED_HB_QUANTIZED); } -void llamaSyncFfnA(TASK_ARGS) { - TASK_VARIABLES; - syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED); -} - -void llamaSyncFfnB(TASK_ARGS) { - TASK_VARIABLES; - syncMissingSlicesOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED); -} - void llamaFfn2(TASK_ARGS) { TASK_VARIABLES; - float *hb = (float*)transformer->buffer->getUnit(TB_SLICED_HB_QUANTIZED); - float *xb2 = (float*)transformer->buffer->getSliced(TB_SLICED_XB2, transformer->sliceIndex); + float *hb = (float*)transformer->buffer->getSliced(TB_SLICED_HB_QUANTIZED, transformer->sliceIndex); + float *xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, transformer->sliceIndex); - matmul(spec->weightsFloatType, spec->bufferFloatType, xb2, hb, block->w20, block->w20Slice->n, block->w20Slice->d0, nThreads, threadIndex); + matmul(spec->weightsFloatType, spec->bufferFloatType, xbv, hb, block->w20, block->w20Slice->n0, block->w20Slice->d, nThreads, threadIndex); } void llamaQuantizeFfn2(TASK_ARGS) { TASK_VARIABLES; - quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2, TB_SLICED_XB2_QUANTIZED); + quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV, TB_SLICED_XBV_QUANTIZED); } void llamaSyncFfn2(TASK_ARGS) { TASK_VARIABLES; - syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XB2_QUANTIZED); + syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED); } void llamaDequantizeFfn2(TASK_ARGS) { TASK_VARIABLES; - dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2_QUANTIZED, TB_SLICED_XB2); + dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV_QUANTIZED, TB_SLICED_XBV); } void llamaMergeFfn2(TASK_ARGS) { TASK_VARIABLES; - float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2); - add(transformer->x, xb2, spec->dim, nThreads, threadIndex); + for (uint8_t sliceIndex = 0; sliceIndex < spec->nSlices; sliceIndex++) { + float* xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, sliceIndex); + add(transformer->x, xbv, spec->dim, nThreads, threadIndex); + } } void llamaNextBlock(TASK_ARGS) { @@ -271,11 +263,9 @@ TransformerArch buildLlamaArch(TransformerSpec* spec) { a.I(llamaRmfFfn, TASK_TYPE_INFERENCE); a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE); a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE); - a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER); - a.I(llamaFfn, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeFfnA, TASK_TYPE_INFERENCE); - a.I(llamaSyncFfnA, TASK_TYPE_TRANSFER); - a.I(llamaSyncFfnB, TASK_TYPE_TRANSFER); + a.I(llamaSyncFfn, TASK_TYPE_TRANSFER); + a.I(llamaFfn0, TASK_TYPE_INFERENCE); + a.I(llamaFfn1, TASK_TYPE_INFERENCE); a.I(llamaFfn2, TASK_TYPE_INFERENCE); a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE); a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); @@ -298,11 +288,9 @@ TransformerArch buildLlamaArch(TransformerSpec* spec) { a.W(llamaAtt, TASK_TYPE_INFERENCE); a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE); a.W(llamaSyncAtt, TASK_TYPE_TRANSFER); - a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER); - a.W(llamaFfn, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeFfnA, TASK_TYPE_INFERENCE); - a.W(llamaSyncFfnA, TASK_TYPE_TRANSFER); - a.W(llamaSyncFfnB, TASK_TYPE_TRANSFER); + a.W(llamaSyncFfn, TASK_TYPE_TRANSFER); + a.W(llamaFfn0, TASK_TYPE_INFERENCE); + a.W(llamaFfn1, TASK_TYPE_INFERENCE); a.W(llamaFfn2, TASK_TYPE_INFERENCE); a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE); a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER); diff --git a/src/transformer.cpp b/src/transformer.cpp index 0619adf1..7233deaf 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -484,7 +484,7 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, uint8_t sliceIndex) { expertDown = (float*)NEW_BUFFER(moeDown0Slice->d0 * (spec->nExperts - 1) * sizeof(float)); } else { w10Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->hiddenDim); - w20Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->hiddenDim, spec->dim); + w20Slice = new ColMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->hiddenDim, spec->dim); w30Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->hiddenDim); #if ALLOC_WEIGHTS diff --git a/src/transformer.hpp b/src/transformer.hpp index 2ab2cff2..f195754d 100644 --- a/src/transformer.hpp +++ b/src/transformer.hpp @@ -176,7 +176,7 @@ class TransformerBlock { char* w10; RowMatmulSlice* w10Slice; char* w20; - RowMatmulSlice* w20Slice; + ColMatmulSlice* w20Slice; char* w30; RowMatmulSlice* w30Slice;