Skip to content

Commit

Permalink
feat: use avx2 to speedup matmulF32 (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 18, 2024
1 parent d1304c8 commit be9929b
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/funcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,37 @@ struct MatmulThreadInfo {
void matmulF32(MatmulThreadInfo* a) {
const float* input = (float*)a->input;
float* w = (float*)a->weights;
int d;
unsigned int d, j;

#if defined(__ARM_NEON)
float32x4_t q;
float32x4_t p;
float32x4_t z;
for (d = a->ds; d < a->de; d++) {
z = vmovq_n_f32(0);
for (int j = 0; j < a->n; j += 4) {
for (j = 0; j < a->n; j += 4) {
q = vld1q_f32(&input[j]);
p = vld1q_f32(&w[d * a->n + j]);
z = vfmaq_f32(z, q, p);
}
a->output[d] = vaddvq_f32(z);
}
#elif defined(__AVX2__)
assert(a->n % 8 == 0);
__m256 a0, b0, u;
for (d = a->ds; d < a->de; d++) {
u = _mm256_set1_ps(0.0f);
for (j = 0; j < a->n; j += 8) {
a0 = _mm256_loadu_ps(&input[j]);
b0 = _mm256_loadu_ps(&w[d * a->n + j]);
u = _mm256_fmadd_ps(a0, b0, u);
}
a->output[d] = hsum_float_8(u);
}
#else
for (d = a->ds; d < a->de; d++) {
float val = 0.0f;
for (int j = 0; j < a->n; j++) {
for (j = 0; j < a->n; j++) {
val += w[d * a->n + j] * input[j];
}
a->output[d] = val;
Expand Down

0 comments on commit be9929b

Please # to comment.