diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..1cd8f60 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,51 @@ +name: Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: argminmax test + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ['windows-latest', 'macOS-latest', 'ubuntu-latest'] + rust: ['stable', 'beta', 'nightly'] + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + components: clippy, rustfmt + + - name: Rust toolchain info + run: | + cargo --version --verbose + rustc --version + cargo clippy --version + cargo fmt --version + + - name: Linting + run: | + cargo fmt -- --check + cargo clippy --features half -- -D warnings + + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1 + + - name: Run cargo-tarpaulin + uses: actions-rs/tarpaulin@v0.1 + with: + args: '--features half -- --test-threads 1' + + - name: Upload to codecov.io + uses: codecov/codecov-action@v3 diff --git a/Cargo.toml b/Cargo.toml index 070d9c5..5fd0875 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "argminmax" -version = "0.1.1" +version = "0.2.0" authors = ["Jeroen Van Der Donckt"] edition = "2021" readme = "README.md" @@ -12,13 +12,18 @@ categories = ["algorithms", "mathematics", "science"] [dependencies] -ndarray = "0.15.6" +ndarray = { version = "0.15.6", default-features = false } +half = { version = "2.1.0", default-features = false, optional = true } [dev-dependencies] criterion = "0.3.0" dev_utils = { path = "dev_utils" } +[[bench]] +name = "bench_f16" +harness = false + [[bench]] name = "bench_f32" harness = false diff --git a/README.md b/README.md index 69979b5..bef6690 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # ArgMinMax -> Efficient argmin & argmax (in 1 function) with SIMD (avx2) for `f32`, `f64`, `i16`, `i32`, `i64` on `ndarray::ArrayView1` +> Efficient argmin & argmax (in 1 function) with SIMD (avx2) for `f16`, `f32`, `f64`, `i16`, `i32`, `i64` on `ndarray::ArrayView1` -🚀 The function is generic over the type of the array, so it can be used on an `ndarray::ArrayView1` where `T` can be `f32`, `f64`, `i16`, `i32`, `i64`. +🚀 The function is generic over the type of the array, so it can be used on an `ndarray::ArrayView1` where `T` can be `f16`*, `f32`, `f64`, `i16`, `i32`, `i64`. 👀 Note that this implementation contains no if checks, ensuring that the runtime of the function is independent of the input data its order (best-case = worst-case = average-case). +*for `f16` you should enable the 'half' feature. + ## Installing Add the following to your `Cargo.toml`: @@ -25,7 +27,7 @@ use ndarray::Array1; let arr: Vec = (0..200_000).collect(); let arr: Array1 = Array1::from(arr); -let (min, max) = arr.view().argminmax().unwrap(); // apply extension +let (min, max) = arr.view().argminmax(); // apply extension println!("min: {}, max: {}", min, max); println!("arr[min]: {}, arr[max]: {}", arr[min], arr[max]); @@ -41,7 +43,14 @@ See `/benches/results`. Run the benchmarks yourself with the following command: ```bash -cargo bench --quiet --message-format=short | grep "time:" +cargo bench --quiet --message-format=short --features half | grep "time:" +``` + +## Tests + +To run the tests use the following command: +```bash +cargo test --message-format=short --features half ``` --- diff --git a/benches/bench_f16.rs b/benches/bench_f16.rs new file mode 100644 index 0000000..16a7172 --- /dev/null +++ b/benches/bench_f16.rs @@ -0,0 +1,90 @@ +#[macro_use] +extern crate criterion; +extern crate dev_utils; + +#[cfg(feature = "half")] +use argminmax::ArgMinMax; +use criterion::{black_box, Criterion}; +use dev_utils::{config, utils}; + +#[cfg(feature = "half")] +use half::f16; +use ndarray::Array1; + +#[cfg(feature = "half")] +fn get_random_f16_array(n: usize) -> Array1 { + let data = utils::get_random_array::(n, u16::MIN, u16::MAX); + let data: Vec = data.iter().map(|&x| f16::from_bits(x)).collect(); + // Replace NaNs and Infs with 0 + let data: Vec = data + .iter() + .map(|&x| { + if x.is_nan() || x.is_infinite() { + f16::from_bits(0) + } else { + x + } + }) + .collect(); + let arr: Array1 = Array1::from(data); + arr +} + +#[cfg(feature = "half")] +fn minmax_f16_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data = get_random_f16_array(n); + c.bench_function("simp_random_long_f16", |b| { + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_random_long_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_random_array_short(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_SHORT; + let data = get_random_f16_array(n); + c.bench_function("simple_random_short_f16", |b| { + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_random_short_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_worst_case_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); + c.bench_function("simple_worst_long_f16", |b| { + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_worst_long_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_worst_case_array_short(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_SHORT; + let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); + c.bench_function("simple_worst_short_f16", |b| { + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_worst_short_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +criterion_group!( + benches, + minmax_f16_random_array_long, + minmax_f16_random_array_short, + minmax_f16_worst_case_array_long, + minmax_f16_worst_case_array_short +); +#[cfg(feature = "half")] +criterion_main!(benches); diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs index 489cff9..eb4c14c 100644 --- a/benches/bench_f32.rs +++ b/benches/bench_f32.rs @@ -10,7 +10,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, f32::MIN, f32::MAX); c.bench_function("simple_random_long_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, f32::MIN, f32::MAX); c.bench_function("simple_random_short_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_long_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_short_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_f32", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_f64.rs b/benches/bench_f64.rs index 3dbac99..a00a158 100644 --- a/benches/bench_f64.rs +++ b/benches/bench_f64.rs @@ -10,7 +10,7 @@ fn minmax_f64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, f64::MIN, f64::MAX); c.bench_function("simple_random_long_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_f64_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, f64::MIN, f64::MAX); c.bench_function("simple_random_short_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_f64_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_long_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_f64_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_short_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_f64", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i16.rs b/benches/bench_i16.rs index be9ee8f..6ffd236 100644 --- a/benches/bench_i16.rs +++ b/benches/bench_i16.rs @@ -10,7 +10,7 @@ fn minmax_i16_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i16::MIN, i16::MAX); c.bench_function("simple_random_long_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i16_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i16::MIN, i16::MAX); c.bench_function("simple_random_short_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i16_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i16_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i16", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i32.rs b/benches/bench_i32.rs index f284602..482e332 100644 --- a/benches/bench_i32.rs +++ b/benches/bench_i32.rs @@ -10,7 +10,7 @@ fn minmax_i32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i32::MIN, i32::MAX); c.bench_function("simple_random_long_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i32_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i32::MIN, i32::MAX); c.bench_function("simple_random_short_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i32_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i32_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i32", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i64.rs b/benches/bench_i64.rs index 2548cdb..922644d 100644 --- a/benches/bench_i64.rs +++ b/benches/bench_i64.rs @@ -10,7 +10,7 @@ fn minmax_i64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i64::MIN, i64::MAX); c.bench_function("simple_random_long_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i64_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i64::MIN, i64::MAX); c.bench_function("simple_random_short_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i64_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i64_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i64", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/src/generic.rs b/src/generic.rs deleted file mode 100644 index 5948e60..0000000 --- a/src/generic.rs +++ /dev/null @@ -1,86 +0,0 @@ -use ndarray::ArrayView1; - -// ------ On ArrayView1 - -#[inline] -pub fn simple_argmin(arr: ArrayView1) -> usize { - let mut low_index = 0usize; - let mut low = arr[low_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } - } - low_index -} - -#[inline] -pub fn simple_argmax(arr: ArrayView1) -> usize { - let mut high_index = 0usize; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item > high { - high = *item; - high_index = i; - } - } - high_index -} - -#[inline] -pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { - let mut low_index: usize = 0; - let mut high_index: usize = 0; - let mut low = arr[low_index]; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } else if *item > high { - high = *item; - high_index = i; - } - } - (low_index, high_index) -} - -// ------ On &[T] - -// Note: these two functions are necessary because in the final SIMD registers the -// indexes are not in sorted order - this means that the first index (in the SIMD -// registers) is not necessarily the lowest min / max index when the min / max value -// occurs multiple times. - -#[inline] -pub fn min_index_value(index: &[T], values: &[T]) -> (T, T) { - assert_eq!(index.len(), values.len()); - let mut min_index: usize = 0; - let mut min_value = values[min_index]; - for (i, value) in values.iter().skip(1).enumerate() { - if *value < min_value { - min_value = *value; - min_index = i + 1; - } else if *value == min_value && index[i + 1] < index[min_index] { - min_index = i + 1; - } - } - (index[min_index], min_value) -} - -#[inline] -pub fn max_index_value(index: &[T], values: &[T]) -> (T, T) { - assert_eq!(index.len(), values.len()); - let mut max_index: usize = 0; - let mut max_value = values[max_index]; - for (i, value) in values.iter().skip(1).enumerate() { - if *value > max_value { - max_value = *value; - max_index = i + 1; - } else if *value == max_value && index[i + 1] < index[max_index] { - max_index = i + 1; - } - } - (index[max_index], max_value) -} diff --git a/src/lib.rs b/src/lib.rs index 4722e2d..5ec16ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,68 +1,51 @@ -pub mod generic; -// #[cfg(target_feature = "sse")] +mod scalar_f16; +mod scalar_generic; mod simd; -// #[cfg(target_feature = "sse")] mod task; +mod utils; -pub use generic::{simple_argminmax}; +pub use scalar_generic::*; pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; use ndarray::ArrayView1; pub trait ArgMinMax { - fn argminmax(self) -> Option<(usize, usize)>; + // TODO: future work implement these other functions + // fn min(self) -> Self::Item; + // fn max(self) -> Self::Item; + // fn minmax(self) -> (T, T); + + // fn argmin(self) -> usize; + // fn argmax(self) -> usize; + fn argminmax(self) -> (usize, usize); } -impl ArgMinMax for ArrayView1<'_, f64> { - fn argminmax(self) -> Option<(usize, usize)> { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_f64::argminmax_f64(self); - } +macro_rules! impl_argminmax { + ($t:ty, $scalar_func:ident, $simd_mod:ident, $simd_func:ident) => { + impl ArgMinMax for ArrayView1<'_, $t> { + fn argminmax(self) -> (usize, usize) { + // TODO: what to do with cfg target_feature? + #[cfg(not(target_feature = "sse"))] + return $scalar_func(self); + #[cfg(target_feature = "sse")] + return $simd_mod::$simd_func(self); + } + } + }; } -impl ArgMinMax for ArrayView1<'_, i64> { - fn argminmax(self) -> Option<(usize, usize)> { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i64::argminmax_i64(self); - } -} - -impl ArgMinMax for ArrayView1<'_, f32> { - fn argminmax(self) -> Option<(usize, usize)> { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_f32::argminmax_f32(self); - } -} - -impl ArgMinMax for ArrayView1<'_, i32> { - fn argminmax(self) -> Option<(usize, usize)> { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i32::argminmax_i32(self); - } -} - -impl ArgMinMax for ArrayView1<'_, i16> { - fn argminmax(self) -> Option<(usize, usize)> { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i16::argminmax_i16(self); - } -} - -// impl ArgMinMax for ArrayView1<'_, i8> { -// fn argminmax(self) -> Option<(usize, usize)> { -// #[cfg(not(target_feature = "sse"))] -// return Some(simple_argminmax(self)); -// #[cfg(target_feature = "sse")] -// return simd_i8::argminmax_i8(self); -// } -// } +// Implement ArgMinMax for the rust primitive types +impl_argminmax!(f32, scalar_argminmax, simd_f32, argminmax_f32); +impl_argminmax!(f64, scalar_argminmax, simd_f64, argminmax_f64); +impl_argminmax!(i16, scalar_argminmax, simd_i16, argminmax_i16); +impl_argminmax!(i32, scalar_argminmax, simd_i32, argminmax_i32); +impl_argminmax!(i64, scalar_argminmax, simd_i64, argminmax_i64); +// Implement ArgMinMax for other data types +#[cfg(feature = "half")] +use half::f16; +#[cfg(feature = "half")] +pub use scalar_f16::scalar_argminmax_f16; +#[cfg(feature = "half")] +pub use simd::simd_f16; +#[cfg(feature = "half")] +impl_argminmax!(f16, scalar_argminmax_f16, simd_f16, argminmax_f16); diff --git a/src/scalar_f16.rs b/src/scalar_f16.rs new file mode 100644 index 0000000..ec506c6 --- /dev/null +++ b/src/scalar_f16.rs @@ -0,0 +1,66 @@ +#[cfg(feature = "half")] +use half::f16; +use ndarray::ArrayView1; + +// ------ On ArrayView1 + +#[cfg(feature = "half")] +#[inline] +fn f16_to_i16ord(x: f16) -> i16 { + let x = unsafe { std::mem::transmute::(x) }; + ((x >> 15) & 0x7FFF) ^ x +} + +#[cfg(feature = "half")] +#[inline] +pub fn scalar_argminmax_f16(arr: ArrayView1) -> (usize, usize) { + // f16 is transformed to i16ord + // benchmarks show: + // 1. this is 7-10x faster than using raw f16 + // 2. this is 3x faster than transforming to f32 or f64 + let mut low_index: usize = 0; + let mut high_index: usize = 0; + let mut low = f16_to_i16ord(arr[low_index]); + let mut high = f16_to_i16ord(arr[high_index]); + for (i, item) in arr.iter().enumerate() { + let item = f16_to_i16ord(*item); + if item < low { + low = item; + low_index = i; + } else if item > high { + high = item; + high_index = i; + } + } + (low_index, high_index) +} + +#[cfg(feature = "half")] +#[cfg(test)] +mod tests { + use super::scalar_argminmax_f16; + use crate::scalar_generic::scalar_argminmax; + + use half::f16; + use ndarray::Array1; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f16(n: usize) -> Array1 { + let arr = utils::get_random_array(n, i16::MIN, i16::MAX); + let arr = arr.mapv(|x| f16::from_f32(x as f32)); + Array1::from(arr) + } + + #[test] + fn test_generic_and_specific_impl_return_the_same_results() { + for _ in 0..100 { + let data = get_array_f16(1025); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = scalar_argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } +} diff --git a/src/scalar_generic.rs b/src/scalar_generic.rs new file mode 100644 index 0000000..1d82a23 --- /dev/null +++ b/src/scalar_generic.rs @@ -0,0 +1,65 @@ +use ndarray::ArrayView1; + +// ------ On ArrayView1 + +#[inline] +pub fn scalar_argmin(arr: ArrayView1) -> usize { + let mut low_index = 0usize; + let mut low = arr[low_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } + } + low_index +} + +#[inline] +pub fn scalar_argmax(arr: ArrayView1) -> usize { + let mut high_index = 0usize; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item > high { + high = *item; + high_index = i; + } + } + high_index +} + +#[inline] +pub fn scalar_argminmax(arr: ArrayView1) -> (usize, usize) { + let mut low_index: usize = 0; + let mut high_index: usize = 0; + let mut low = arr[low_index]; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } else if *item > high { + high = *item; + high_index = i; + } + } + (low_index, high_index) +} + +// Note: 5-7% faster than the above implementation (for floats) +// #[inline] +// pub fn scalar_argminmax_fold(arr: ArrayView1) -> (usize, usize) { +// let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( +// (0usize, arr[0], 0usize, arr[0]), +// |(min_idx, min, max_idx, max), (idx, item)| { +// if *item < min { +// (idx, *item, max_idx, max) +// } else if *item > max { +// (min_idx, min, idx, *item) +// } else { +// (min_idx, min, max_idx, max) +// } +// }, +// ); +// (minmax_tuple.0, minmax_tuple.2) +// } diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 9508e4f..ca3e2d1 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -3,6 +3,8 @@ pub mod simd_f64; pub use simd_f64::*; pub mod simd_f32; pub use simd_f32::*; +pub mod simd_f16; +pub use simd_f16::*; // SIGNED INT pub mod simd_i64; pub use simd_i64::*; diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs new file mode 100644 index 0000000..9bec83d --- /dev/null +++ b/src/simd/simd_f16.rs @@ -0,0 +1,167 @@ +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; +use ndarray::ArrayView1; +use std::arch::x86_64::*; + +#[cfg(feature = "half")] +use half::f16; + +const LANE_SIZE: usize = 16; + +// ------------------------------------ ARGMINMAX -------------------------------------- + +#[cfg(feature = "half")] +pub fn argminmax_f16(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) +} + +#[inline] +fn reg_to_i16_arr(reg: __m256i) -> [i16; 16] { + unsafe { std::mem::transmute::<__m256i, [i16; 16]>(reg) } +} + +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn f16_as_m256i_to_ord_i16(f16_as_m256i: __m256i) -> __m256i { + // on a scalar: ((v >> 15) & 0x7FFF) ^ v + let sign_bit_shifted = _mm256_srai_epi16(f16_as_m256i, 15); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, _mm256_set1_epi16(0x7FFF)); + _mm256_xor_si256(sign_bit_masked, f16_as_m256i) +} + +#[cfg(feature = "half")] +#[inline] +fn ord_i16_to_f16(ord_i16: i16) -> f16 { + let v = ((ord_i16 >> 15) & 0x7FFF) ^ ord_i16; + unsafe { std::mem::transmute::(v) } +} + +#[cfg(feature = "half")] +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f16, usize, f16, usize) { + // Efficient calculation of argmin and argmax together + let offset = _mm256_set1_epi16(offset as i16); + let mut new_index = _mm256_add_epi16( + _mm256_set_epi16(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + offset, + ); + let mut index_low = new_index; + let mut index_high = new_index; + + let increment = _mm256_set1_epi16(16); + + // println!("raw new values: {:?}", sim_arr.slice(s![0..16])); + let new_values = _mm256_loadu_si256(sim_arr.as_ptr() as *const __m256i); + // println!("new_values: {:?}", reg_to_i16_arr(new_values)); + let new_values = f16_as_m256i_to_ord_i16(new_values); + // println!("new_values: {:?}", reg_to_i16_arr(new_values)); + // println!(); + let mut values_low = new_values; + let mut values_high = new_values; + + sim_arr + .exact_chunks(16) + .into_iter() + .skip(1) + .for_each(|step| { + new_index = _mm256_add_epi16(new_index, increment); + + let new_values = _mm256_loadu_si256(step.as_ptr() as *const __m256i); + let new_values = f16_as_m256i_to_ord_i16(new_values); + let gt_mask = _mm256_cmpgt_epi16(new_values, values_high); + // Below does not work (bc instruction is not available) + // let lt_mask = _mm256_cmplt_epi16(new_values, values_low); + // Solution: swap parameters and use gt instead + let lt_mask = _mm256_cmpgt_epi16(values_low, new_values); + + index_low = _mm256_blendv_epi8(index_low, new_index, lt_mask); + index_high = _mm256_blendv_epi8(index_high, new_index, gt_mask); + + values_low = _mm256_blendv_epi8(values_low, new_values, lt_mask); + values_high = _mm256_blendv_epi8(values_high, new_values, gt_mask); + }); + + // Select max_index and max_value + let value_array = reg_to_i16_arr(values_high); + let index_array = reg_to_i16_arr(index_high); + let (index_max, value_max) = max_index_value(&index_array, &value_array); + + // Select min_index and min_value + let value_array = reg_to_i16_arr(values_low); + let index_array = reg_to_i16_arr(index_low); + let (index_min, value_min) = min_index_value(&index_array, &value_array); + + ( + ord_i16_to_f16(value_min), + index_min as usize, + ord_i16_to_f16(value_max), + index_max as usize, + ) +} + +//----- TESTS ----- + +#[cfg(feature = "half")] +#[cfg(test)] +mod tests { + use super::argminmax_f16; + use crate::scalar_generic::scalar_argminmax; + + use half::f16; + use ndarray::Array1; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f16(n: usize) -> Array1 { + let arr = utils::get_random_array(n, i16::MIN, i16::MAX); + let arr = arr.mapv(|x| f16::from_f32(x as f32)); + Array1::from(arr) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data = get_array_f16(1025); + assert_eq!(data.len() % 8, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + f16::from_f32(10.), + f16::MAX, + f16::from_f32(6.), + f16::NEG_INFINITY, + f16::NEG_INFINITY, + f16::MAX, + f16::from_f32(5_000.0), + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data = Array1::from(data); + + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data = get_array_f16(32 * 8 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } +} diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 46f5592..ceb1dac 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_f32(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_f32(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_f32_arr(reg: __m256) -> [f32; 8] { unsafe { std::mem::transmute::<__m256, [f32; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f32, usize, f32, usize) { // Efficient calculation of argmin and argmax together @@ -92,7 +74,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f32, u #[cfg(test)] mod tests { - use super::{argminmax_f32, simple_argminmax}; + use super::argminmax_f32; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -107,8 +91,8 @@ mod tests { let data = get_array_f32(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -127,11 +111,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -140,8 +124,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_f32(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index 6672785..427c17b 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 4; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_f64(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 4) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_f64(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_f64_arr(reg: __m256d) -> [f64; 4] { unsafe { std::mem::transmute::<__m256d, [f64; 4]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f64, usize, f64, usize) { // Efficient calculation of argmin and argmax together @@ -84,7 +66,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f64, u #[cfg(test)] mod tests { - use super::{argminmax_f64, simple_argminmax}; + use super::argminmax_f64; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -99,8 +83,8 @@ mod tests { let data = get_array_f64(1025); assert_eq!(data.len() % 4, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -119,11 +103,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -132,8 +116,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_f64(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index e4561fb..f4ea34e 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 16; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i16(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 16) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i16(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i16_arr(reg: __m256i) -> [i16; 16] { unsafe { std::mem::transmute::<__m256i, [i16; 16]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i16, usize, i16, usize) { // Efficient calculation of argmin and argmax together @@ -90,7 +72,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i16, u #[cfg(test)] mod tests { - use super::{argminmax_i16, simple_argminmax}; + use super::argminmax_i16; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -105,8 +89,8 @@ mod tests { let data = get_array_i16(513); assert_eq!(data.len() % 16, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -127,11 +111,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_simd_index, 1); assert_eq!(argmax_simd_index, 6); } @@ -140,8 +124,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i16(32 * 2 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 15563ba..31c7ccf 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i32(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i32(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i32_arr(reg: __m256i) -> [i32; 8] { unsafe { std::mem::transmute::<__m256i, [i32; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i32, usize, i32, usize) { // Efficient calculation of argmin and argmax together @@ -87,7 +69,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i32, u #[cfg(test)] mod tests { - use super::{argminmax_i32, simple_argminmax}; + use super::argminmax_i32; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -102,8 +86,8 @@ mod tests { let data = get_array_i32(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -123,11 +107,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_simd_index, 0); assert_eq!(argmax_simd_index, 5); } @@ -136,8 +120,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i32(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index a4d6115..edee939 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 4; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i64(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 4) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i64(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i64_arr(reg: __m256i) -> [i64; 4] { unsafe { std::mem::transmute::<__m256i, [i64; 4]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i64, usize, i64, usize) { // Efficient calculation of argmin and argmax together @@ -107,7 +89,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i64, u #[cfg(test)] mod tests { - use super::{argminmax_i64, simple_argminmax}; + use super::argminmax_i64; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -122,8 +106,8 @@ mod tests { let data = get_array_i64(1025); assert_eq!(data.len() % 4, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -143,11 +127,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_simd_index, 0); assert_eq!(argmax_simd_index, 5); } @@ -156,8 +140,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i64(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index c2d2f76..a6319b0 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::utils::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 32; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i8(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 32) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i8(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i8_arr(reg: __m256i) -> [i8; 32] { unsafe { std::mem::transmute::<__m256i, [i8; 32]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i8, usize, i8, usize) { // Efficient calculation of argmin and argmax together @@ -106,7 +88,9 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i8, usi #[cfg(test)] mod tests { - use super::{argminmax_i8, simple_argminmax}; + use super::argminmax_i8; + use crate::scalar_generic::scalar_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -121,8 +105,8 @@ mod tests { let data = get_array_i8(32 * 6 + 1); // TODO: lengte mag niet > 2^8 zijn... assert_eq!(data.len() % 32, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -133,11 +117,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_simd_index, 1); assert_eq!(argmax_simd_index, 6); } @@ -146,8 +130,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i8(32 * 2 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 29d672f..2fc5af0 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- pub fn argminmax_u16(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_u16_arr(reg: __m256i) -> [u16; 8] { unsafe { std::mem::transmute::<__m256i, [u16; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (u16, usize, u16, usize) { // Efficient calculation of argmin and argmax together @@ -87,10 +69,14 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (u16, u #[cfg(test)] mod tests { - use super::{argminmax_u16, simple_argminmax}; + use super::argminmax_u32; + use crate::generic; + use generic::scalar_argminmax; + use ndarray::Array1; - use rand::{thread_rng, Rng}; - use rand_distr::Uniform; + + extern crate dev_utils; + use dev_utils::utils; // TODO: duplicate code in bench config fn get_array_u16(n: usize) -> Array1 { @@ -105,8 +91,8 @@ mod tests { let data = get_array_u16(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -123,11 +109,11 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -136,8 +122,8 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_u16(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/task.rs b/src/task.rs index 825c214..573e280 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,8 +1,41 @@ +use crate::scalar_generic::scalar_argminmax; // TODO: dit in macro doorgeven + use ndarray::{ArrayView1, Axis}; use std::cmp::Ordering; #[inline] -pub(crate) fn split_array( +pub(crate) fn argminmax_generic( + arr: ArrayView1, + lane_size: usize, + core_argminmax: unsafe fn(ArrayView1, usize) -> (T, usize, T, usize), +) -> (usize, usize) { + assert!(!arr.is_empty()); // split_array should never return (None, None) + match split_array(arr, lane_size) { + (Some(rem), Some(sim)) => { + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); + let rem_result = ( + rem[rem_min_index], + rem_min_index, + rem[rem_max_index], + rem_max_index, + ); + let sim_result = unsafe { core_argminmax(sim, rem.len()) }; + find_final_index_minmax(rem_result, sim_result) + } + (Some(rem), None) => { + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); + (rem_min_index, rem_max_index) + } + (None, Some(sim)) => { + let sim_result = unsafe { core_argminmax(sim, 0) }; + (sim_result.1, sim_result.3) + } + (None, None) => panic!("Array is empty"), // Should never occur because of assert + } +} + +#[inline] +fn split_array( arr: ArrayView1, lane_size: usize, ) -> (Option>, Option>) { @@ -23,10 +56,10 @@ pub(crate) fn split_array( } #[inline] -pub fn find_final_index_minmax( +fn find_final_index_minmax( remainder_result: (T, usize, T, usize), simd_result: (T, usize, T, usize), -) -> Option<(usize, usize)> { +) -> (usize, usize) { let min_result = match remainder_result.0.partial_cmp(&simd_result.0).unwrap() { Ordering::Less => remainder_result.1, Ordering::Equal => std::cmp::min(remainder_result.1, simd_result.1), @@ -39,5 +72,5 @@ pub fn find_final_index_minmax( Ordering::Greater => simd_result.3, }; - Some((min_result, max_result)) + (min_result, max_result) } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..0038b11 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,38 @@ +// ------ On &[T] + +// Note: these two functions are necessary because in the final SIMD registers the +// indexes are not in sorted order - this means that the first index (in the SIMD +// registers) is not necessarily the lowest min / max index when the min / max value +// occurs multiple times. + +#[inline] +pub(crate) fn min_index_value(index: &[T], values: &[T]) -> (T, T) { + assert_eq!(index.len(), values.len()); + let mut min_index: usize = 0; + let mut min_value = values[min_index]; + for (i, value) in values.iter().skip(1).enumerate() { + if *value < min_value { + min_value = *value; + min_index = i + 1; + } else if *value == min_value && index[i + 1] < index[min_index] { + min_index = i + 1; + } + } + (index[min_index], min_value) +} + +#[inline] +pub(crate) fn max_index_value(index: &[T], values: &[T]) -> (T, T) { + assert_eq!(index.len(), values.len()); + let mut max_index: usize = 0; + let mut max_value = values[max_index]; + for (i, value) in values.iter().skip(1).enumerate() { + if *value > max_value { + max_value = *value; + max_index = i + 1; + } else if *value == max_value && index[i + 1] < index[max_index] { + max_index = i + 1; + } + } + (index[max_index], max_value) +}