From be9929bd1d3aa79f8644d5a24edd2554a306140c Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Sat, 18 May 2024 09:34:53 +0200 Subject: [PATCH] feat: use avx2 to speedup matmulF32 (#56) --- src/funcs.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/funcs.cpp b/src/funcs.cpp index 3e15e4ac..e2785526 100644 --- a/src/funcs.cpp +++ b/src/funcs.cpp @@ -152,7 +152,7 @@ struct MatmulThreadInfo { void matmulF32(MatmulThreadInfo* a) { const float* input = (float*)a->input; float* w = (float*)a->weights; - int d; + unsigned int d, j; #if defined(__ARM_NEON) float32x4_t q; @@ -160,17 +160,29 @@ void matmulF32(MatmulThreadInfo* a) { float32x4_t z; for (d = a->ds; d < a->de; d++) { z = vmovq_n_f32(0); - for (int j = 0; j < a->n; j += 4) { + for (j = 0; j < a->n; j += 4) { q = vld1q_f32(&input[j]); p = vld1q_f32(&w[d * a->n + j]); z = vfmaq_f32(z, q, p); } a->output[d] = vaddvq_f32(z); } +#elif defined(__AVX2__) + assert(a->n % 8 == 0); + __m256 a0, b0, u; + for (d = a->ds; d < a->de; d++) { + u = _mm256_set1_ps(0.0f); + for (j = 0; j < a->n; j += 8) { + a0 = _mm256_loadu_ps(&input[j]); + b0 = _mm256_loadu_ps(&w[d * a->n + j]); + u = _mm256_fmadd_ps(a0, b0, u); + } + a->output[d] = hsum_float_8(u); + } #else for (d = a->ds; d < a->de; d++) { float val = 0.0f; - for (int j = 0; j < a->n; j++) { + for (j = 0; j < a->n; j++) { val += w[d * a->n + j] * input[j]; } a->output[d] = val;