Skip to content

Commit

Permalink
multithread rope.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Apr 28, 2024
1 parent 9dbfa9f commit 69b4d9e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 21 deletions.
24 changes: 3 additions & 21 deletions src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,9 @@ void llamaMultiheadAtt(TASK_ARGS) {

void llamaMultiheadAttRope(TASK_ARGS) {
TASK_VARIABLES;
if (threadIndex == 0) {
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
float* k = block->keyCache + transformer->pos * spec->kvDim;

// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < spec->dim; i+=2) {
int head_dim = i % spec->headSize;
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
float val = transformer->pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
int rotn = i < spec->kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int _v = 0; _v < rotn; _v++) {
float* vec = _v == 0 ? q : k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
}
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
float* k = block->keyCache + transformer->pos * spec->kvDim;
rope(transformer->ropeCache, q, k, spec, transformer->pos, nThreads, threadIndex);
}

void llamaMultiheadAttJoin(TASK_ARGS) {
Expand Down
44 changes: 44 additions & 0 deletions src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,40 @@ long MatmulSlice::mergeOutputs(uint8_t sliceIndex, float* output, float* output0
return offset; // offset in floats
}

void initRope(float* cache, TransformerSpec* spec) {
for (int pos = 0; pos < spec->seqLen; pos++) {
for (int i = 0; i < spec->dim; i += 2) {
int head_dim = i % spec->headSize;
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
cache[pos * spec->seqLen + i] = fcr;
cache[pos * spec->seqLen + i + 1] = fci;
}
}
}

void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex) {
int slice = spec->dim / (nThreads * 2);
int iStart = (threadIndex * slice) * 2;
int iEnd = ((nThreads - 1 == threadIndex) ? spec->dim : (iStart + slice)) * 2;

// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = iStart; i < iEnd; i += 2) {
float fcr = cache[pos * spec->seqLen + i];
float fci = cache[pos * spec->seqLen + i + 1];
int rotn = i < spec->kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int _v = 0; _v < rotn; _v++) {
float* vec = _v == 0 ? q : k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
}

TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) {
TransformerSpec spec;
memset(&spec, 0, sizeof(TransformerSpec));
Expand Down Expand Up @@ -252,6 +286,12 @@ Transformer::Transformer(TransformerSpec* spec, uint8_t sliceIndex) {
#endif
x = (float*)NEW_BUFFER(spec->dim * sizeof(float));
logits = (float*)NEW_BUFFER(spec->vocabSize * sizeof(float));

// TODO: cache should be for all architectures
if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
ropeCache = (float*)NEW_BUFFER(spec->vocabSize * spec->dim);
initRope(ropeCache, spec);
}
}
}

Expand All @@ -270,6 +310,10 @@ Transformer::~Transformer() {
#endif
FREE_BUFFER(x);
FREE_BUFFER(logits);

if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
FREE_BUFFER(ropeCache);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ struct TransformerSpec {
uint8_t nSlices;
};

void initRope(float* cache, TransformerSpec* spec);
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex);

class TransformerBlock {
public:
uint8_t sliceIndex;
Expand Down Expand Up @@ -186,6 +189,7 @@ class Transformer {
int pos;
float* x;
float* logits;
float* ropeCache;

~Transformer();

Expand Down

0 comments on commit 69b4d9e

Please # to comment.