Skip to content

Commit

Permalink
slight speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
magnetophon committed Dec 31, 2024
1 parent 1e1d188 commit 866073a
Showing 1 changed file with 61 additions and 51 deletions.
112 changes: 61 additions & 51 deletions src/svf_simper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,65 +401,75 @@ where
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn filter_avx2(&mut self, input: __m256) -> __m256 {
let v0 = input;
let v3 = _mm256_loadu_ps(self.ic2eq.to_array().as_ptr());

let ic1 = _mm256_loadu_ps(self.ic1eq.to_array().as_ptr());
let a1_ic1 = _mm256_mul_ps(_mm256_loadu_ps(self.a1.to_array().as_ptr()), ic1);
let a2_v3 = _mm256_mul_ps(_mm256_loadu_ps(self.a2.to_array().as_ptr()), v3);
let sum = _mm256_add_ps(a1_ic1, a2_v3);
let v1 = _mm256_fmadd_ps(sum, _mm256_loadu_ps(self.k.to_array().as_ptr()), v0);

let v1k = _mm256_mul_ps(v1, _mm256_loadu_ps(self.k.to_array().as_ptr()));
let v2 = _mm256_fmadd_ps(ic1, _mm256_loadu_ps(self.a2.to_array().as_ptr()), v1k);
#[inline(always)]
unsafe fn filter_sse2(&self, v0: __m128) -> __m128 {
use std::arch::x86_64::*;

// Load constants once (they're the same for all lanes)
let k = _mm_set1_ps(self.k.as_array()[0]);
let a1 = _mm_set1_ps(self.a1.as_array()[0]);
let a2 = _mm_set1_ps(self.a2.as_array()[0]);
let a3 = _mm_set1_ps(self.a3.as_array()[0]);

// Load state
let ic1eq = _mm_loadu_ps(self.ic1eq.as_array().as_ptr());
let ic2eq = _mm_loadu_ps(self.ic2eq.as_array().as_ptr());

// v1 = ic1eq + k * (v0 - ic2eq)
let v1 = _mm_add_ps(ic1eq, _mm_mul_ps(k, _mm_sub_ps(v0, ic2eq)));

// v2 = ic2eq + k * v1
let v2 = _mm_add_ps(ic2eq, _mm_mul_ps(k, v1));

// Update state variables
// ic1eq = a1 * v1 + a2 * v0
_mm_storeu_ps(
self.ic1eq.as_array().as_ptr() as *mut f32,
_mm_add_ps(_mm_mul_ps(a1, v1), _mm_mul_ps(a2, v0)),
);

let two = _mm256_set1_ps(2.0);
self.ic1eq = Simd::from_array({
let mut arr = [0.0; LANES];
_mm256_storeu_ps(arr.as_mut_ptr(), _mm256_sub_ps(_mm256_mul_ps(v1, two), ic1));
arr
});
self.ic2eq = Simd::from_array({
let mut arr = [0.0; LANES];
_mm256_storeu_ps(arr.as_mut_ptr(), _mm256_sub_ps(_mm256_mul_ps(v2, two), v3));
arr
});
// ic2eq = a2 * v1 + a3 * v0
_mm_storeu_ps(
self.ic2eq.as_array().as_ptr() as *mut f32,
_mm_add_ps(_mm_mul_ps(a2, v1), _mm_mul_ps(a3, v0)),
);

v2
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn filter_sse2(&mut self, input: __m128) -> __m128 {
let v0 = input;
let v3 = _mm_loadu_ps(self.ic2eq.to_array().as_ptr());

let ic1 = _mm_loadu_ps(self.ic1eq.to_array().as_ptr());
let a1_ic1 = _mm_mul_ps(_mm_loadu_ps(self.a1.to_array().as_ptr()), ic1);
let a2_v3 = _mm_mul_ps(_mm_loadu_ps(self.a2.to_array().as_ptr()), v3);
let sum = _mm_add_ps(a1_ic1, a2_v3);
let k = _mm_loadu_ps(self.k.to_array().as_ptr());
let v1 = _mm_add_ps(_mm_mul_ps(sum, k), v0);

let v1k = _mm_mul_ps(v1, k);
let v2 = _mm_add_ps(
_mm_mul_ps(ic1, _mm_loadu_ps(self.a2.to_array().as_ptr())),
v1k,
#[inline(always)]
unsafe fn filter_avx2(&self, v0: __m256) -> __m256 {
use std::arch::x86_64::*;

// Load constants once
let k = _mm256_set1_ps(self.k.as_array()[0]);
let a1 = _mm256_set1_ps(self.a1.as_array()[0]);
let a2 = _mm256_set1_ps(self.a2.as_array()[0]);
let a3 = _mm256_set1_ps(self.a3.as_array()[0]);

// Load state
let ic1eq = _mm256_loadu_ps(self.ic1eq.as_array().as_ptr());
let ic2eq = _mm256_loadu_ps(self.ic2eq.as_array().as_ptr());

// v1 = ic1eq + k * (v0 - ic2eq)
let v1 = _mm256_add_ps(ic1eq, _mm256_mul_ps(k, _mm256_sub_ps(v0, ic2eq)));

// v2 = ic2eq + k * v1
let v2 = _mm256_add_ps(ic2eq, _mm256_mul_ps(k, v1));

// Update state variables
// ic1eq = a1 * v1 + a2 * v0
_mm256_storeu_ps(
self.ic1eq.as_array().as_ptr() as *mut f32,
_mm256_add_ps(_mm256_mul_ps(a1, v1), _mm256_mul_ps(a2, v0)),
);

let two = _mm_set1_ps(2.0);
self.ic1eq = Simd::from_array({
let mut arr = [0.0; LANES];
_mm_storeu_ps(arr.as_mut_ptr(), _mm_sub_ps(_mm_mul_ps(v1, two), ic1));
arr
});
self.ic2eq = Simd::from_array({
let mut arr = [0.0; LANES];
_mm_storeu_ps(arr.as_mut_ptr(), _mm_sub_ps(_mm_mul_ps(v2, two), v3));
arr
});
// ic2eq = a2 * v1 + a3 * v0
_mm256_storeu_ps(
self.ic2eq.as_array().as_ptr() as *mut f32,
_mm256_add_ps(_mm256_mul_ps(a2, v1), _mm256_mul_ps(a3, v0)),
);

v2
}
Expand Down

0 comments on commit 866073a

Please # to comment.