Skip to content

Commit

Permalink
feat: speed up synchronization of mlp. (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 25, 2024
1 parent 2e523f6 commit b4b3842
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 33 deletions.
50 changes: 19 additions & 31 deletions src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ void llamaQuantizeRmfFfn(TASK_ARGS) {
quantizeUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB, TB_UNIT_XB_QUANTIZED);
}

void llamaSyncRmfFfn(TASK_ARGS) {
void llamaSyncFfn(TASK_ARGS) {
TASK_VARIABLES;
syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED);
}

void llamaFfn(TASK_ARGS) {
void llamaFfn0(TASK_ARGS) {
TASK_VARIABLES;

float* xb = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
Expand All @@ -174,49 +174,41 @@ void llamaFfn(TASK_ARGS) {
mul(hb0, block->hb20, block->w10Slice->d0, nThreads, threadIndex);
}

void llamaQuantizeFfnA(TASK_ARGS) {
void llamaFfn1(TASK_ARGS) {
TASK_VARIABLES;
quantizeSlicedBuffer(nThreads, threadIndex, ctx, true, TB_SLICED_HB, TB_SLICED_HB_QUANTIZED);
}

void llamaSyncFfnA(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED);
}

void llamaSyncFfnB(TASK_ARGS) {
TASK_VARIABLES;
syncMissingSlicesOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED);
}

void llamaFfn2(TASK_ARGS) {
TASK_VARIABLES;

float *hb = (float*)transformer->buffer->getUnit(TB_SLICED_HB_QUANTIZED);
float *xb2 = (float*)transformer->buffer->getSliced(TB_SLICED_XB2, transformer->sliceIndex);
float *hb = (float*)transformer->buffer->getSliced(TB_SLICED_HB_QUANTIZED, transformer->sliceIndex);
float *xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, transformer->sliceIndex);

matmul(spec->weightsFloatType, spec->bufferFloatType, xb2, hb, block->w20, block->w20Slice->n, block->w20Slice->d0, nThreads, threadIndex);
matmul(spec->weightsFloatType, spec->bufferFloatType, xbv, hb, block->w20, block->w20Slice->n0, block->w20Slice->d, nThreads, threadIndex);
}

void llamaQuantizeFfn2(TASK_ARGS) {
TASK_VARIABLES;
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2, TB_SLICED_XB2_QUANTIZED);
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV, TB_SLICED_XBV_QUANTIZED);
}

void llamaSyncFfn2(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XB2_QUANTIZED);
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED);
}

void llamaDequantizeFfn2(TASK_ARGS) {
TASK_VARIABLES;
dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XB2_QUANTIZED, TB_SLICED_XB2);
dequantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_XBV_QUANTIZED, TB_SLICED_XBV);
}

void llamaMergeFfn2(TASK_ARGS) {
TASK_VARIABLES;
float* xb2 = (float*)transformer->buffer->getUnit(TB_SLICED_XB2);
add(transformer->x, xb2, spec->dim, nThreads, threadIndex);
for (uint8_t sliceIndex = 0; sliceIndex < spec->nSlices; sliceIndex++) {
float* xbv = (float*)transformer->buffer->getSliced(TB_SLICED_XBV, sliceIndex);
add(transformer->x, xbv, spec->dim, nThreads, threadIndex);
}
}

void llamaNextBlock(TASK_ARGS) {
Expand Down Expand Up @@ -271,11 +263,9 @@ TransformerArch buildLlamaArch(TransformerSpec* spec) {
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.I(llamaFfn, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeFfnA, TASK_TYPE_INFERENCE);
a.I(llamaSyncFfnA, TASK_TYPE_TRANSFER);
a.I(llamaSyncFfnB, TASK_TYPE_TRANSFER);
a.I(llamaSyncFfn, TASK_TYPE_TRANSFER);
a.I(llamaFfn0, TASK_TYPE_INFERENCE);
a.I(llamaFfn1, TASK_TYPE_INFERENCE);
a.I(llamaFfn2, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER);
Expand All @@ -298,11 +288,9 @@ TransformerArch buildLlamaArch(TransformerSpec* spec) {
a.W(llamaAtt, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.W(llamaFfn, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeFfnA, TASK_TYPE_INFERENCE);
a.W(llamaSyncFfnA, TASK_TYPE_TRANSFER);
a.W(llamaSyncFfnB, TASK_TYPE_TRANSFER);
a.W(llamaSyncFfn, TASK_TYPE_TRANSFER);
a.W(llamaFfn0, TASK_TYPE_INFERENCE);
a.W(llamaFfn1, TASK_TYPE_INFERENCE);
a.W(llamaFfn2, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
Expand Down
2 changes: 1 addition & 1 deletion src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, uint8_t sliceIndex) {
expertDown = (float*)NEW_BUFFER(moeDown0Slice->d0 * (spec->nExperts - 1) * sizeof(float));
} else {
w10Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->hiddenDim);
w20Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->hiddenDim, spec->dim);
w20Slice = new ColMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->hiddenDim, spec->dim);
w30Slice = new RowMatmulSlice(spec->weightsFloatType, spec->nSlices, spec->dim, spec->hiddenDim);

#if ALLOC_WEIGHTS
Expand Down
2 changes: 1 addition & 1 deletion src/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class TransformerBlock {
char* w10;
RowMatmulSlice* w10Slice;
char* w20;
RowMatmulSlice* w20Slice;
ColMatmulSlice* w20Slice;
char* w30;
RowMatmulSlice* w30Slice;

Expand Down

0 comments on commit b4b3842

Please # to comment.