-
Notifications
You must be signed in to change notification settings - Fork 11.5k
perf(ggml): tall and skinny GEMM for LoRA: F32 mul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expected
#956
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Comments
mul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expected
mul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expectedmul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expected
See: https://arxiv.org/pdf/2208.08088.pdf, a specialized GEMM for tall and skinny matrices - results in 4X speedup over BLAS, but it's not even close to 24x we need to explain here. |
To my understanding, the vectorization of BLAS and GGML happens along the row dimension. Tall and skinny matrices essentially get 0 vectorization. Correct? @ggerganov I will try to unroll the loops to get the vectorization. |
Where are you seeing dimensions like I'm seeing outputs more like this:
|
The timings you show are from multiplying square-ish matrices with short and wide matrices (vectors). We are multiplying tall and skinny matrices to obtain the delta update from LoRA adapters (low-rank adapters from fine-tuning). |
mul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expectedmul_mat([16 X 5120], [16 X 5120])
takes 120ms - 24x slower than expected
This comment was marked as off-topic.
This comment was marked as off-topic.
LOLOL |
Oh dang, I thought you were talking about general optimizations relating to the general mat_mul / q_f32. Didn't know Lora. I've been looking at it all day haha |
Is segfaults. Let me collapse it as it's off-topic. |
Actually, we are not even using BLAS, just the ggml impl since it does not fit BLAS criteria. Hopefully it's enough to account. |
This manages to run about 4X faster. // A = K X M, B = K X N
void multiply_tall_skinny_matrices(const float * A, const float * BT, float * C, int M, int N, int K, int ir0, int ir1) {
for (int i = ir0; i < ir1; ++i) {
for (int j = 0; j < N; j += 8) { // Process 8 elements of C's row at a time - 256 / size_of(float)
__m256 c_row = _mm256_setzero_ps(); // Initialize the result vector to all zeros
for (int k = 0; k < K; ++k) {
__m256 a = _mm256_broadcast_ss(&A[i + k * M]); // Broadcast the k-th element of the i-th row of A
__m256 b = _mm256_load_ps(&BT[j + k * N]); // Load a segment of the k-th row of B^T (corresponding to the k-th column of B)
c_row = _mm256_fmadd_ps(a, b, c_row); // FMA: c_row = a * b + c_row
}
// Store the result in the corresponding row of C
_mm256_store_ps(&C[i * N + j ], c_row);
}
}
} GPT4 almost delivered after all. Luckily I'm still smarter than it. |
AVX is about 3X faster. void multiply_tall_skinny_matrices_avx(const float * A, const float * BT, float * C, int M, int N, int K, int ir0, int ir1) {
for (int i = ir0; i < ir1; ++i) {
for (int j = 0; j < N; j += 4) { // Process 8 elements of C's row at a time - 128 / size_of(float)
__m128 c_row = _mm_setzero_ps(); // Initialize the result vector to all zeros
for (int k = 0; k < K; ++k) {
__m128 a = _mm_broadcast_ss(&A[i + k * M]); // Broadcast the k-th element of the i-th row of A
__m128 b = _mm_loadu_ps(&BT[j + k * N]); // Load a segment of the k-th row of B^T (corresponding to the k-th column of B)
c_row = _mm_fmadd_ps(a, b, c_row); // FMA: c_row = a * b + c_row
}
// Store the result in the corresponding row of C
_mm_store_ps(&C[i * N + j], c_row);
}
}
} |
I hope this can be optimized further to close the |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
* Give the CI builds a recognizable AVX1 name * Chat Adapters
Context: where and are tall and skinny.
LoRA requires computing
See #820 for the use-case.
Problem
The estimated FLOPs of these matmuls are
16 * 5120 * 5120 ~= 0.419 GFLOPs
.My setup can do 360 GFLOPs for f16 mul_mat (open blas). This would imply a computing time of ~1.2ms for float16, so float32 should be in the same ballpark.
The model's
mul_mat(5120 X 5120], [5120 X 8])
of FLOP complexity8 * 5120 * 5120 ~= 0.210 GFLOPs
takes 2.647 msSo at the least, we should be getting ~5ms, not 120ms. So, we are observing a 24x deviation from the expected compute time.
The above observed time is roughly the same with or without OpenBLAS.
EDIT: with the optimizations described below, we are now at about 30-35ms. So we have about 6-7x left to go.
The text was updated successfully, but these errors were encountered: