diff --git a/Cargo.toml b/Cargo.toml index 5581ec4..cbd0854 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,21 +22,38 @@ arrow = { version = "*", default-features = false, optional = true} default = ["half"] # TODO: remove this as default feature as soon as https://github.com/CodSpeedHQ/codspeed-rust/issues/1 is fixed [dev-dependencies] +# rstest = { version = "0.16", default-features = false} +# rstest_reuse = "0.5" codspeed-criterion-compat = "1.0.1" criterion = "0.3.1" dev_utils = { path = "dev_utils" } [[bench]] -name = "bench_f16" +name = "bench_f16_return_nan" +harness = false +required-features = ["half"] + +# TODO: support this +# [[bench]] +# name = "bench_f16_ignore_nan" +# harness = false +# required-features = ["half"] + +[[bench]] +name = "bench_f32_return_nan" +harness = false + +[[bench]] +name = "bench_f32_ignore_nan" harness = false [[bench]] -name = "bench_f32" +name = "bench_f64_return_nan" harness = false [[bench]] -name = "bench_f64" +name = "bench_f64_ignore_nan" harness = false [[bench]] diff --git a/benches/bench_f16.rs b/benches/bench_f16.rs deleted file mode 100644 index 064b5ae..0000000 --- a/benches/bench_f16.rs +++ /dev/null @@ -1,214 +0,0 @@ -#![feature(stdsimd)] - -extern crate dev_utils; - -#[cfg(feature = "half")] -use argminmax::ArgMinMax; -use codspeed_criterion_compat::*; -use dev_utils::{config, utils}; - -use argminmax::{ScalarArgMinMax, SCALAR}; -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; - -#[cfg(feature = "half")] -use half::f16; - -#[cfg(feature = "half")] -fn get_random_f16_array(n: usize) -> Vec { - let data = utils::get_random_array::(n, u16::MIN, u16::MAX); - let data: Vec = data.iter().map(|&x| f16::from_bits(x)).collect(); - // Replace NaNs and Infs with 0 - let data: Vec = data - .iter() - .map(|&x| { - if x.is_nan() || x.is_infinite() { - f16::from_bits(0) - } else { - x - } - }) - .collect(); - data -} - -#[cfg(feature = "half")] -fn minmax_f16_random_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f16] = &get_random_f16_array(n); - c.bench_function("scalar_random_long_f16", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_long_f16", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx2") { - c.bench_function("avx2_random_long_f16", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512bw") { - c.bench_function("avx512_random_long_f16", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_random_long_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_random_long_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_long_f16", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -#[cfg(feature = "half")] -fn minmax_f16_random_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f16] = &get_random_f16_array(n); - c.bench_function("scalar_random_short_f16", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_short_f16", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx2") { - c.bench_function("avx2_random_short_f16", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512bw") { - c.bench_function("avx512_random_short_f16", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_random_short_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_random_short_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_short_f16", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -#[cfg(feature = "half")] -fn minmax_f16_worst_case_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f16] = &utils::get_worst_case_array::(n, f16::from_f32(1.)); - c.bench_function("scalar_worst_long_f16", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_long_f16", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx2") { - c.bench_function("avx2_worst_long_f16", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512bw") { - c.bench_function("avx512_worst_long_f16", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_worst_long_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_worst_long_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_long_f16", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -#[cfg(feature = "half")] -fn minmax_f16_worst_case_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f16] = &utils::get_worst_case_array::(n, f16::from_f32(1.)); - c.bench_function("scalar_worst_short_f16", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_short_f16", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx2") { - c.bench_function("avx2_worst_short_f16", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512bw") { - c.bench_function("avx512_worst_short_f16", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_worst_short_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_worst_short_f16", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_short_f16", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -#[cfg(feature = "half")] -criterion_group!( - benches, - minmax_f16_random_array_long, - // minmax_f16_random_array_short, - // minmax_f16_worst_case_array_long, - // minmax_f16_worst_case_array_short -); -#[cfg(feature = "half")] -criterion_main!(benches); diff --git a/benches/bench_f16_return_nan.rs b/benches/bench_f16_return_nan.rs new file mode 100644 index 0000000..2fa492f --- /dev/null +++ b/benches/bench_f16_return_nan.rs @@ -0,0 +1,83 @@ +#![feature(stdsimd)] + +extern crate dev_utils; + +#[cfg(feature = "half")] +use argminmax::ArgMinMax; +use codspeed_criterion_compat::*; +use dev_utils::{config, utils}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; + +#[cfg(feature = "half")] +use half::f16; + +#[cfg(feature = "half")] +fn get_random_f16_array(n: usize) -> Vec { + let data = utils::get_random_array::(n, u16::MIN, u16::MAX); + let data: Vec = data.iter().map(|&x| f16::from_bits(x)).collect(); + // Replace NaNs and Infs with 0 + let data: Vec = data + .iter() + .map(|&x| { + if x.is_nan() || x.is_infinite() { + f16::from_bits(0) + } else { + x + } + }) + .collect(); + data +} + +// TODO: rename _random_long_ to _nanargminmax_ +#[cfg(feature = "half")] +fn nanargminmax_f16_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data: &[f16] = &get_random_f16_array(n); + c.bench_function("scalar_random_long_f16", |b| { + b.iter(|| SCALAR::argminmax(black_box(data))) + }); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.1") { + c.bench_function("sse_random_long_f16", |b| { + b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx2") { + c.bench_function("avx2_random_long_f16", |b| { + b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx512bw") { + c.bench_function("avx512_random_long_f16", |b| { + b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "arm")] + if std::arch::is_arm_feature_detected!("neon") { + c.bench_function("neon_random_long_f16", |b| { + b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "aarch64")] + if std::arch::is_aarch64_feature_detected!("neon") { + c.bench_function("neon_random_long_f16", |b| { + b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) + }); + } + c.bench_function("impl_random_long_f16", |b| { + b.iter(|| black_box(data.nanargminmax())) + }); +} + +#[cfg(feature = "half")] +criterion_group!(benches, nanargminmax_f16_random_array_long,); +#[cfg(feature = "half")] +criterion_main!(benches); diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs deleted file mode 100644 index f71b815..0000000 --- a/benches/bench_f32.rs +++ /dev/null @@ -1,187 +0,0 @@ -#![feature(stdsimd)] - -extern crate dev_utils; - -#[cfg(feature = "half")] -use argminmax::ArgMinMax; -use codspeed_criterion_compat::*; -use dev_utils::{config, utils}; - -use argminmax::{ScalarArgMinMax, SCALAR}; -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; - -fn minmax_f32_random_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f32] = &utils::get_random_array::(n, f32::MIN, f32::MAX); - c.bench_function("scalar_random_long_f32", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_long_f32", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_random_long_f32", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_random_long_f32", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_random_long_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_random_long_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_long_f32", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f32_random_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f32] = &utils::get_random_array::(n, f32::MIN, f32::MAX); - c.bench_function("scalar_random_short_f32", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_short_f32", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_random_short_f32", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_random_short_f32", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_random_short_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_random_short_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_short_f32", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f32_worst_case_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f32] = &utils::get_worst_case_array::(n, 1.0); - c.bench_function("scalar_worst_long_f32", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_long_f32", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_worst_long_f32", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_worst_long_f32", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_worst_long_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_worst_long_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_long_f32", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f32_worst_case_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f32] = &utils::get_worst_case_array::(n, 1.0); - c.bench_function("scalar_worst_short_f32", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_short_f32", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_worst_short_f32", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_worst_short_f32", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "arm")] - if std::arch::is_arm_feature_detected!("neon") { - c.bench_function("neon_worst_short_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - #[cfg(target_arch = "aarch64")] - if std::arch::is_aarch64_feature_detected!("neon") { - c.bench_function("neon_worst_short_f32", |b| { - b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_short_f32", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -criterion_group!( - benches, - minmax_f32_random_array_long, - // minmax_f32_random_array_short, - // minmax_f32_worst_case_array_long, - // minmax_f32_worst_case_array_short -); -criterion_main!(benches); diff --git a/benches/bench_f32_ignore_nan.rs b/benches/bench_f32_ignore_nan.rs new file mode 100644 index 0000000..411ab68 --- /dev/null +++ b/benches/bench_f32_ignore_nan.rs @@ -0,0 +1,57 @@ +#![feature(stdsimd)] + +extern crate dev_utils; + +use argminmax::ArgMinMax; +use codspeed_criterion_compat::*; +use dev_utils::{config, utils}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use argminmax::{AVX2IgnoreNaN, AVX512IgnoreNaN, SIMDArgMinMaxIgnoreNaN, SSEIgnoreNaN}; +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +use argminmax::{NEONIgnoreNaN, SIMDArgMinMaxIgnoreNaN}; +use argminmax::{SCALARIgnoreNaN, ScalarArgMinMax}; + +fn argminmax_f32_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data: &[f32] = &utils::get_random_array::(n, f32::MIN, f32::MAX); + c.bench_function("scalar_random_long_f32", |b| { + b.iter(|| SCALARIgnoreNaN::argminmax(black_box(data))) + }); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.1") { + c.bench_function("sse_random_long_f32", |b| { + b.iter(|| unsafe { SSEIgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx") { + c.bench_function("avx_random_long_f32", |b| { + b.iter(|| unsafe { AVX2IgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx512f") { + c.bench_function("avx512_random_long_f32", |b| { + b.iter(|| unsafe { AVX512IgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "arm")] + if std::arch::is_arm_feature_detected!("neon") { + c.bench_function("neon_random_long_f32", |b| { + b.iter(|| unsafe { NEONIgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "aarch64")] + if std::arch::is_aarch64_feature_detected!("neon") { + c.bench_function("neon_random_long_f32", |b| { + b.iter(|| unsafe { NEONIgnoreNaN::argminmax(black_box(data)) }) + }); + } + c.bench_function("impl_random_long_f32", |b| { + b.iter(|| black_box(data.argminmax())) + }); +} + +criterion_group!(benches, argminmax_f32_random_array_long,); +criterion_main!(benches); diff --git a/benches/bench_f32_return_nan.rs b/benches/bench_f32_return_nan.rs new file mode 100644 index 0000000..30c864d --- /dev/null +++ b/benches/bench_f32_return_nan.rs @@ -0,0 +1,57 @@ +#![feature(stdsimd)] + +extern crate dev_utils; + +use argminmax::ArgMinMax; +use codspeed_criterion_compat::*; +use dev_utils::{config, utils}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; + +fn nanargminmax_f32_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data: &[f32] = &utils::get_random_array::(n, f32::MIN, f32::MAX); + c.bench_function("scalar_nanargminmax_f32", |b| { + b.iter(|| SCALAR::argminmax(black_box(data))) + }); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.1") { + c.bench_function("sse_nanargminmax_f32", |b| { + b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx2") { + c.bench_function("avx2_nanargminmax_f32", |b| { + b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx512f") { + c.bench_function("avx512_nanargminmax_f32", |b| { + b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "arm")] + if std::arch::is_arm_feature_detected!("neon") { + c.bench_function("neon_nanargminmax_f32", |b| { + b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) + }); + } + #[cfg(target_arch = "aarch64")] + if std::arch::is_aarch64_feature_detected!("neon") { + c.bench_function("neon_nanargminmax_f32", |b| { + b.iter(|| unsafe { NEON::argminmax(black_box(data)) }) + }); + } + c.bench_function("impl_nanargminmax_f32", |b| { + b.iter(|| black_box(data.nanargminmax())) + }); +} + +criterion_group!(benches, nanargminmax_f32_random_array_long,); +criterion_main!(benches); diff --git a/benches/bench_f64.rs b/benches/bench_f64.rs deleted file mode 100644 index cb895c2..0000000 --- a/benches/bench_f64.rs +++ /dev/null @@ -1,137 +0,0 @@ -#![feature(stdsimd)] - -extern crate dev_utils; - -#[cfg(feature = "half")] -use argminmax::ArgMinMax; -use codspeed_criterion_compat::*; -use dev_utils::{config, utils}; - -use argminmax::{ScalarArgMinMax, SCALAR}; -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; - -fn minmax_f64_random_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f64] = &utils::get_random_array::(n, f64::MIN, f64::MAX); - c.bench_function("scalar_random_long_f64", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_long_f64", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_random_long_f64", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_random_long_f64", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_long_f64", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f64_random_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f64] = &utils::get_random_array::(n, f64::MIN, f64::MAX); - c.bench_function("scalar_random_short_f64", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_random_short_f64", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_random_short_f64", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_random_short_f64", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_random_short_f64", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f64_worst_case_array_long(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_LONG; - let data: &[f64] = &utils::get_worst_case_array::(n, 1.0); - c.bench_function("scalar_worst_long_f64", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_long_f64", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_worst_long_f64", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_worst_long_f64", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_long_f64", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -fn minmax_f64_worst_case_array_short(c: &mut Criterion) { - let n = config::ARRAY_LENGTH_SHORT; - let data: &[f64] = &utils::get_worst_case_array::(n, 1.0); - c.bench_function("scalar_worst_short_f64", |b| { - b.iter(|| SCALAR::argminmax(black_box(data))) - }); - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("sse4.1") { - c.bench_function("sse_worst_short_f64", |b| { - b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx") { - c.bench_function("avx_worst_short_f64", |b| { - b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) - }); - } - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if is_x86_feature_detected!("avx512f") { - c.bench_function("avx512_worst_short_f64", |b| { - b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) - }); - } - c.bench_function("impl_worst_short_f64", |b| { - b.iter(|| black_box(data.argminmax())) - }); -} - -criterion_group!( - benches, - minmax_f64_random_array_long, - // minmax_f64_random_array_short, - // minmax_f64_worst_case_array_long, - // minmax_f64_worst_case_array_short -); -criterion_main!(benches); diff --git a/benches/bench_f64_ignore_nan.rs b/benches/bench_f64_ignore_nan.rs new file mode 100644 index 0000000..64b0202 --- /dev/null +++ b/benches/bench_f64_ignore_nan.rs @@ -0,0 +1,43 @@ +#![feature(stdsimd)] + +extern crate dev_utils; + +use argminmax::ArgMinMax; +use codspeed_criterion_compat::*; +use dev_utils::{config, utils}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use argminmax::{AVX2IgnoreNaN, AVX512IgnoreNaN, SIMDArgMinMaxIgnoreNaN, SSEIgnoreNaN}; +use argminmax::{SCALARIgnoreNaN, ScalarArgMinMax}; + +fn argminmax_f64_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data: &[f64] = &utils::get_random_array::(n, f64::MIN, f64::MAX); + c.bench_function("scalar_random_long_f64", |b| { + b.iter(|| SCALARIgnoreNaN::argminmax(black_box(data))) + }); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.1") { + c.bench_function("sse_random_long_f64", |b| { + b.iter(|| unsafe { SSEIgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx") { + c.bench_function("avx_random_long_f64", |b| { + b.iter(|| unsafe { AVX2IgnoreNaN::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx512f") { + c.bench_function("avx512_random_long_f64", |b| { + b.iter(|| unsafe { AVX512IgnoreNaN::argminmax(black_box(data)) }) + }); + } + c.bench_function("impl_random_long_f64", |b| { + b.iter(|| black_box(data.argminmax())) + }); +} + +criterion_group!(benches, argminmax_f64_random_array_long,); +criterion_main!(benches); diff --git a/benches/bench_f64_return_nan.rs b/benches/bench_f64_return_nan.rs new file mode 100644 index 0000000..5611869 --- /dev/null +++ b/benches/bench_f64_return_nan.rs @@ -0,0 +1,43 @@ +#![feature(stdsimd)] + +extern crate dev_utils; + +use argminmax::ArgMinMax; +use codspeed_criterion_compat::*; +use dev_utils::{config, utils}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; +use argminmax::{ScalarArgMinMax, SCALAR}; + +fn nanargminmax_f64_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data: &[f64] = &utils::get_random_array::(n, f64::MIN, f64::MAX); + c.bench_function("scalar_nanargminmax_f64", |b| { + b.iter(|| SCALAR::argminmax(black_box(data))) + }); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("sse4.2") { + c.bench_function("sse_nanargminmax_f64", |b| { + b.iter(|| unsafe { SSE::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx2") { + c.bench_function("avx2_nanargminmax_f64", |b| { + b.iter(|| unsafe { AVX2::argminmax(black_box(data)) }) + }); + } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if is_x86_feature_detected!("avx512f") { + c.bench_function("avx512_nanargminmax_f64", |b| { + b.iter(|| unsafe { AVX512::argminmax(black_box(data)) }) + }); + } + c.bench_function("impl_nanargminmax_f64", |b| { + b.iter(|| black_box(data.nanargminmax())) + }); +} + +criterion_group!(benches, nanargminmax_f64_random_array_long,); +criterion_main!(benches); diff --git a/benches/bench_i16.rs b/benches/bench_i16.rs index 9f27874..9abee64 100644 --- a/benches/bench_i16.rs +++ b/benches/bench_i16.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_i16_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_i32.rs b/benches/bench_i32.rs index b347dbc..67e35d7 100644 --- a/benches/bench_i32.rs +++ b/benches/bench_i32.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_i32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_i64.rs b/benches/bench_i64.rs index 7ff24aa..3e71a2f 100644 --- a/benches/bench_i64.rs +++ b/benches/bench_i64.rs @@ -2,14 +2,13 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_i64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_i8.rs b/benches/bench_i8.rs index c252fe4..e091a51 100644 --- a/benches/bench_i8.rs +++ b/benches/bench_i8.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_i8_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_u16.rs b/benches/bench_u16.rs index c8c82ea..2262be8 100644 --- a/benches/bench_u16.rs +++ b/benches/bench_u16.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_u16_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_u32.rs b/benches/bench_u32.rs index 76ab8d7..8602f87 100644 --- a/benches/bench_u32.rs +++ b/benches/bench_u32.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_u32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_u64.rs b/benches/bench_u64.rs index 0817359..91b14a3 100644 --- a/benches/bench_u64.rs +++ b/benches/bench_u64.rs @@ -2,14 +2,13 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_u64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/benches/bench_u8.rs b/benches/bench_u8.rs index 2c74896..07bb610 100644 --- a/benches/bench_u8.rs +++ b/benches/bench_u8.rs @@ -2,16 +2,15 @@ extern crate dev_utils; -#[cfg(feature = "half")] use argminmax::ArgMinMax; use codspeed_criterion_compat::*; use dev_utils::{config, utils}; -use argminmax::{ScalarArgMinMax, SCALAR}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use argminmax::{AVX2, AVX512, SIMD, SSE}; +use argminmax::{SIMDArgMinMax, AVX2, AVX512, SSE}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -use argminmax::{NEON, SIMD}; +use argminmax::{SIMDArgMinMax, NEON}; +use argminmax::{ScalarArgMinMax, SCALAR}; fn minmax_u8_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; diff --git a/src/lib.rs b/src/lib.rs index ee78409..d621419 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,11 @@ mod scalar; mod simd; -pub use scalar::{ScalarArgMinMax, SCALAR}; -pub use simd::{AVX2, AVX512, NEON, SIMD, SSE}; +pub use scalar::{SCALARIgnoreNaN, ScalarArgMinMax, SCALAR}; +pub use simd::{ + AVX2IgnoreNaN, AVX512IgnoreNaN, NEONIgnoreNaN, SIMDArgMinMaxIgnoreNaN, SSEIgnoreNaN, +}; +pub use simd::{SIMDArgMinMax, AVX2, AVX512, NEON, SSE}; #[cfg(feature = "half")] use half::f16; @@ -22,131 +25,90 @@ pub trait ArgMinMax { // fn argmin(self) -> usize; // fn argmax(self) -> usize; + + /// Get the index of the minimum and maximum values in the array, ignoring NaNs. + /// This will only result in unexpected behavior if the array contains *only* NaNs + /// and infinities (in which case index 0 is returned for both). + /// Note that this differs from numpy, where the `argmin` and `argmax` functions + /// return the index of the first NaN (which is the behavior of our nanargminmax + /// function). fn argminmax(&self) -> (usize, usize); + + /// Get the index of the minimum and maximum values in the array. + /// If the array contains NaNs, the index of the first NaN is returned. + /// Note that this differs from numpy, where the `nanargmin` and `nanargmax` + /// functions ignore NaNs (which is the behavior of our argminmax function). + fn nanargminmax(&self) -> (usize, usize); } +// TODO: split this up +// pub trait NaNArgMinMax { +// fn nanargminmax(&self) -> (usize, usize); +// } + // ---- Helper macros ---- trait DTypeInfo { const NB_BITS: usize; - const IS_FLOAT: bool; } +/// Macro for implementing DTypeInfo for the passed data types (uints, ints, floats) macro_rules! impl_nb_bits { - ($is_float:expr, $($t:ty)*) => ($( - impl DTypeInfo for $t { - const NB_BITS: usize = std::mem::size_of::<$t>() * 8; - const IS_FLOAT: bool = $is_float; + // $data_type is the data type (e.g. i32) + // you can pass multiple types (separated by commas) to this macro + ($($data_type:ty)*) => ($( + impl DTypeInfo for $data_type { + const NB_BITS: usize = std::mem::size_of::<$data_type>() * 8; } )*) } -impl_nb_bits!(false, i8 i16 i32 i64 u8 u16 u32 u64); -impl_nb_bits!(true, f32 f64); +impl_nb_bits!(i8 i16 i32 i64 u8 u16 u32 u64); +impl_nb_bits!(f32 f64); #[cfg(feature = "half")] -impl_nb_bits!(true, f16); - -// use once_cell::sync::Lazy; - -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// static AVX512BW_DETECTED: Lazy = Lazy::new(|| is_x86_feature_detected!("avx512bw")); -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// static AVX512F_DETECTED: Lazy = Lazy::new(|| is_x86_feature_detected!("avx512f")); -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// static AVX2_DETECTED: Lazy = Lazy::new(|| is_x86_feature_detected!("avx2")); -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// static AVX_DETECTED: Lazy = Lazy::new(|| is_x86_feature_detected!("avx")); -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// static SSE_DETECTED: Lazy = Lazy::new(|| is_x86_feature_detected!("sse4.1")); -// #[cfg(target_arch = "arm")] -// static NEON_DETECTED: Lazy = Lazy::new(|| std::arch::is_arm_feature_detected!("neon")); - -// macro_rules! impl_argminmax { -// ($($t:ty),*) => { -// $( -// impl ArgMinMax for ArrayView1<'_, $t> { -// fn argminmax(self) -> (usize, usize) { -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// { -// if *AVX512BW_DETECTED & (<$t>::NB_BITS <= 16) { -// // BW (ByteWord) instructions are needed for 16-bit avx512 -// return unsafe { AVX512::argminmax(self) } -// } else if *AVX512F_DETECTED { // TODO: check if avx512bw is included in avx512f -// return unsafe { AVX512::argminmax(self) } -// } else if *AVX2_DETECTED { -// return unsafe { AVX2::argminmax(self) } -// } else if *AVX_DETECTED & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) { -// // f32 and f64 do not require avx2 -// return unsafe { AVX2::argminmax(self) } -// // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers -// // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) { -// // // SSE4.2 is needed for comparing 64-bit integers -// // return unsafe { SSE::argminmax(self) } -// } else if *SSE_DETECTED & (<$t>::NB_BITS < 64) { -// // Scalar is faster for 64-bit numbers -// return unsafe { SSE::argminmax(self) } -// } -// } -// #[cfg(target_arch = "aarch64")] -// { -// // TODO: support aarch64 -// } -// #[cfg(target_arch = "arm")] -// { -// if *NEON_DETECTED & (<$t>::NB_BITS < 32) { -// // TODO: requires v7? -// // We miss some NEON instructions for 64-bit numbers -// return unsafe { NEON::argminmax(self) } -// } -// } -// SCALAR::argminmax(self) -// } -// } -// )* -// }; -// } +impl_nb_bits!(f16); // ------------------------------ &[T] ------------------------------ -macro_rules! impl_argminmax { - ($($t:ty),*) => { +/// Macro for implementing ArgMinMax for signed and unsigned integers +macro_rules! impl_argminmax_non_float { + // $int_type is the integer data type of the array (e.g. i32) + // you can pass multiple types (separated by commas) to this macro + ($($int_type:ty),*) => { $( - impl ArgMinMax for &[$t] { + impl ArgMinMax for &[$int_type] { fn argminmax(&self) -> (usize, usize) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS == 8) { + if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS == 8) { // 8-bit numbers are best handled by SSE4.1 return unsafe { SSE::argminmax(self) } - } else if is_x86_feature_detected!("avx512bw") & (<$t>::NB_BITS <= 16) { + } else if is_x86_feature_detected!("avx512bw") & (<$int_type>::NB_BITS <= 16) { // BW (ByteWord) instructions are needed for 8 or 16-bit avx512 return unsafe { AVX512::argminmax(self) } } else if is_x86_feature_detected!("avx512f") { // TODO: check if avx512bw is included in avx512f return unsafe { AVX512::argminmax(self) } } else if is_x86_feature_detected!("avx2") { return unsafe { AVX2::argminmax(self) } - } else if is_x86_feature_detected!("avx") & (<$t>::NB_BITS >= 32) & (<$t>::IS_FLOAT == true) { - // f32 and f64 do not require avx2 - return unsafe { AVX2::argminmax(self) } // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers - // // } else if is_x86_feature_detected!("sse4.2") & (<$t>::NB_BITS == 64) & (<$t>::IS_FLOAT == false) { + // // } else if is_x86_feature_detected!("sse4.2") & (<$int_type>::NB_BITS == 64) & (<$int_type>::IS_FLOAT == false) { // // SSE4.2 is needed for comparing 64-bit integers // return unsafe { SSE::argminmax(self) } - } else if is_x86_feature_detected!("sse4.1") & (<$t>::NB_BITS < 64) { + } else if is_x86_feature_detected!("sse4.1") & (<$int_type>::NB_BITS < 64) { // Scalar is faster for 64-bit numbers return unsafe { SSE::argminmax(self) } } } #[cfg(target_arch = "aarch64")] { - if std::arch::is_aarch64_feature_detected!("neon") & (<$t>::NB_BITS < 64) { + if std::arch::is_aarch64_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) { // We miss some NEON instructions for 64-bit numbers return unsafe { NEON::argminmax(self) } } } #[cfg(target_arch = "arm")] { - if std::arch::is_arm_feature_detected!("neon") & (<$t>::NB_BITS < 64) { + if std::arch::is_arm_feature_detected!("neon") & (<$int_type>::NB_BITS < 64) { // TODO: requires v7? // We miss some NEON instructions for 64-bit numbers return unsafe { NEON::argminmax(self) } @@ -154,16 +116,104 @@ macro_rules! impl_argminmax { } SCALAR::argminmax(self) } + + // As there are no NaNs when NOT using floats -> just use argminmax + fn nanargminmax(&self) -> (usize, usize) { + self.argminmax() + } + } + )* + }; +} + +/// Macro for implementing ArgMinMax for floats +macro_rules! impl_argminmax_float { + // $float_type is the float data type of the array (e.g. f32) + // you can pass multiple types (separated by commas) to this macro + ($($float_type:ty),*) => { + $( + impl ArgMinMax for &[$float_type] { + fn nanargminmax(&self) -> (usize, usize) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS == 8) { + // 8-bit numbers are best handled by SSE4.1 + return unsafe { SSE::argminmax(self) } + } else if is_x86_feature_detected!("avx512bw") & (<$float_type>::NB_BITS <= 16) { + // BW (ByteWord) instructions are needed for 8 or 16-bit avx512 + return unsafe { AVX512::argminmax(self) } + } else if is_x86_feature_detected!("avx512f") { // TODO: check if avx512bw is included in avx512f + return unsafe { AVX512::argminmax(self) } + } else if is_x86_feature_detected!("avx2") { + return unsafe { AVX2::argminmax(self) } + // SKIP SSE4.2 bc scalar is faster or equivalent for 64 bit numbers + } else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) { + // Scalar is faster for 64-bit numbers + // TODO: double check this (observed different things for new float implementation) + return unsafe { SSE::argminmax(self) } + } + } + #[cfg(target_arch = "aarch64")] + { + if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) { + // We miss some NEON instructions for 64-bit numbers + return unsafe { NEON::argminmax(self) } + } + } + #[cfg(target_arch = "arm")] + { + if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) { + // TODO: requires v7? + // We miss some NEON instructions for 64-bit numbers + return unsafe { NEON::argminmax(self) } + } + } + SCALAR::argminmax(self) + } + fn argminmax(&self) -> (usize, usize) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if <$float_type>::NB_BITS <= 16 { + // TODO: f16 IgnoreNaN is not yet SIMD-optimized + // do nothing (defaults to scalar) + } else if is_x86_feature_detected!("avx512f") { + return unsafe { AVX512IgnoreNaN::argminmax(self) } + } else if is_x86_feature_detected!("avx") { + // f32 and f64 do not require avx2 + return unsafe { AVX2IgnoreNaN::argminmax(self) } + } else if is_x86_feature_detected!("sse4.1") & (<$float_type>::NB_BITS < 64) { + // Scalar is faster for 64-bit numbers + return unsafe { SSEIgnoreNaN::argminmax(self) } + } + } + #[cfg(target_arch = "aarch64")] + { + if std::arch::is_aarch64_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) { + // We miss some NEON instructions for 64-bit numbers + return unsafe { NEONIgnoreNaN::argminmax(self) } + } + } + #[cfg(target_arch = "arm")] + { + if std::arch::is_arm_feature_detected!("neon") & (<$float_type>::NB_BITS < 64) { + // TODO: requires v7? + // We miss some NEON instructions for 64-bit numbers + return unsafe { NEONIgnoreNaN::argminmax(self) } + } + } + SCALARIgnoreNaN::argminmax(self) + } } )* }; } // Implement ArgMinMax for the rust primitive types -impl_argminmax!(i8, i16, i32, i64, f32, f64, u8, u16, u32, u64); +impl_argminmax_non_float!(i8, i16, i32, i64, u8, u16, u32, u64); +impl_argminmax_float!(f32, f64); // Implement ArgMinMax for other data types #[cfg(feature = "half")] -impl_argminmax!(f16); +impl_argminmax_float!(f16); // ------------------------------ [T] ------------------------------ @@ -185,6 +235,9 @@ where fn argminmax(&self) -> (usize, usize) { self.as_slice().argminmax() } + fn nanargminmax(&self) -> (usize, usize) { + self.as_slice().nanargminmax() + } } // ----------------------- (optional) ndarray ---------------------- @@ -205,6 +258,9 @@ mod ndarray_impl { fn argminmax(&self) -> (usize, usize) { self.as_slice().unwrap().argminmax() } + fn nanargminmax(&self) -> (usize, usize) { + self.as_slice().unwrap().nanargminmax() + } } } @@ -225,5 +281,8 @@ mod arrow_impl { fn argminmax(&self) -> (usize, usize) { self.values().argminmax() } + fn nanargminmax(&self) -> (usize, usize) { + self.values().nanargminmax() + } } } diff --git a/src/scalar/generic.rs b/src/scalar/generic.rs index f16963c..8a090ff 100644 --- a/src/scalar/generic.rs +++ b/src/scalar/generic.rs @@ -1,3 +1,5 @@ +use num_traits::float::FloatCore; + #[cfg(feature = "half")] use super::scalar_f16::scalar_argminmax_f16; #[cfg(feature = "half")] @@ -9,8 +11,13 @@ pub trait ScalarArgMinMax { pub struct SCALAR; +pub struct SCALARIgnoreNaN; + // #[inline(always)] leads to poor performance on aarch64 +/// Default scalar implementation of the argminmax function. +/// This implementation returns the index of the first NaN value if any are present, +/// otherwise it returns the index of the minimum and maximum values. // #[inline(never)] pub fn scalar_argminmax(arr: &[T]) -> (usize, usize) { assert!(!arr.is_empty()); @@ -20,6 +27,40 @@ pub fn scalar_argminmax(arr: &[T]) -> (usize, usize) { // than using .iter().enumerate() (with a fold). let mut low: T = unsafe { *arr.get_unchecked(low_index) }; let mut high: T = unsafe { *arr.get_unchecked(high_index) }; + for i in 0..arr.len() { + let v: T = unsafe { *arr.get_unchecked(i) }; + if v != v { + // Because NaN != NaN - compiled identically to v.is_nan(): https://godbolt.org/z/Y6xh51ePb + // Return the index of the first NaN value + return (i, i); + } + if v < low { + low = v; + low_index = i; + } else if v > high { + high = v; + high_index = i; + } + } + (low_index, high_index) +} + +/// Scalar implementation of the argminmax function that ignores NaN values. +/// This implementation returns the index of the minimum and maximum values. +/// Note that this function only works for floating point types. +pub fn scalar_argminmax_ignore_nans(arr: &[T]) -> (usize, usize) { + assert!(!arr.is_empty()); + let mut low_index: usize = 0; + let mut high_index: usize = 0; + // It is remarkably faster to iterate over the index and use get_unchecked + // than using .iter().enumerate() (with a fold). + let start_value: T = unsafe { *arr.get_unchecked(0) }; + let mut low: T = start_value; + let mut high: T = start_value; + if start_value.is_nan() { + low = T::infinity(); + high = T::neg_infinity(); + } for i in 0..arr.len() { let v: T = unsafe { *arr.get_unchecked(i) }; if v < low { @@ -63,6 +104,18 @@ macro_rules! impl_scalar { )* }; } +macro_rules! impl_scalar_ignore_nans { + ($($t:ty),*) => // ty can only be float types + { + $( + impl ScalarArgMinMax<$t> for SCALARIgnoreNaN { + fn argminmax(data: &[$t]) -> (usize, usize) { + scalar_argminmax_ignore_nans(data) + } + } + )* + }; +} #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod scalar_x86 { @@ -115,6 +168,9 @@ mod scalar_generic { f64 ); } +impl_scalar_ignore_nans!(f32, f64); #[cfg(feature = "half")] impl_scalar!(scalar_argminmax_f16, f16); +#[cfg(feature = "half")] +impl_scalar_ignore_nans!(f16); // TODO: use correct implementation (not sure if this is correct atm) diff --git a/src/scalar/scalar_f16.rs b/src/scalar/scalar_f16.rs index 54082b2..a854506 100644 --- a/src/scalar/scalar_f16.rs +++ b/src/scalar/scalar_f16.rs @@ -8,7 +8,8 @@ fn f16_to_i16ord(x: f16) -> i16 { ((x >> 15) & 0x7FFF) ^ x } -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +// TODO: commented this (see the TODO below) +// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(feature = "half")] // #[inline(never)] pub(crate) fn scalar_argminmax_f16(arr: &[f16]) -> (usize, usize) { @@ -24,7 +25,12 @@ pub(crate) fn scalar_argminmax_f16(arr: &[f16]) -> (usize, usize) { let mut low: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(low_index) }); let mut high: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(high_index) }); for i in 0..arr.len() { - let v: i16 = f16_to_i16ord(unsafe { *arr.get_unchecked(i) }); + let v: f16 = unsafe { *arr.get_unchecked(i) }; + if v.is_nan() { + // Return the index of the first NaN value + return (i, i); + } + let v: i16 = f16_to_i16ord(v); if v < low { low = v; low_index = i; @@ -36,31 +42,37 @@ pub(crate) fn scalar_argminmax_f16(arr: &[f16]) -> (usize, usize) { (low_index, high_index) } -#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] -#[cfg(feature = "half")] -// #[inline(never)] -pub(crate) fn scalar_argminmax_f16(arr: &[f16]) -> (usize, usize) { - // f16 is transformed to i16ord - // benchmarks show: - // 1. this is 7-10x faster than using raw f16 - // 2. this is 3x faster than transforming to f32 or f64 - assert!(!arr.is_empty()); - // This is 3% slower on x86_64, but 12% faster on aarch64. - let minmax_tuple: (usize, i16, usize, i16) = arr.iter().enumerate().fold( - (0, f16_to_i16ord(arr[0]), 0, f16_to_i16ord(arr[0])), - |(low_index, low, high_index, high), (i, item)| { - let item = f16_to_i16ord(*item); - if item < low { - (i, item, high_index, high) - } else if item > high { - (low_index, low, i, item) - } else { - (low_index, low, high_index, high) - } - }, - ); - (minmax_tuple.0, minmax_tuple.2) -} +// TODO: previously we had dedicated non x86_64 code for f16 (see below) + +// #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] +// #[cfg(feature = "half")] +// // #[inline(never)] +// pub(crate) fn scalar_argminmax_f16(arr: &[f16]) -> (usize, usize) { +// // f16 is transformed to i16ord +// // benchmarks show: +// // 1. this is 7-10x faster than using raw f16 +// // 2. this is 3x faster than transforming to f32 or f64 +// assert!(!arr.is_empty()); +// // This is 3% slower on x86_64, but 12% faster on aarch64. +// let minmax_tuple: (usize, i16, usize, i16) = arr.iter().enumerate().fold( +// (0, f16_to_i16ord(arr[0]), 0, f16_to_i16ord(arr[0])), +// |(low_index, low, high_index, high), (i, item)| { +// if item.is_nan() { +// // Return the index of the first NaN value +// return (i, i); +// } +// let item = f16_to_i16ord(*item); +// if item < low { +// (i, item, high_index, high) +// } else if item > high { +// (low_index, low, i, item) +// } else { +// (low_index, low, high_index, high) +// } +// }, +// ); +// (minmax_tuple.0, minmax_tuple.2) +// } #[cfg(feature = "half")] #[cfg(test)] @@ -83,9 +95,39 @@ mod tests { for _ in 0..100 { let data: &[f16] = &get_array_f16(1025); let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = scalar_argminmax_f16(data); - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); + let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16(data); + assert_eq!(argmin_index, argmin_index_f16); + assert_eq!(argmax_index, argmax_index_f16); + } + } + + #[test] + fn test_generic_and_specific_impl_return_nans() { + let arr_len: usize = 1025; + + // first, middle, last element + let nan_pos: [usize; 3] = [0, arr_len / 2, arr_len - 1]; + for pos in nan_pos.iter() { + let mut data: Vec = get_array_f16(arr_len); + data[*pos] = f16::NAN; + let (argmin_index, argmax_index) = scalar_argminmax(&data); + let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16(&data); + assert_eq!(argmin_index, argmin_index_f16); + assert_eq!(argmax_index, argmax_index_f16); + assert_eq!(argmin_index, *pos); + assert_eq!(argmax_index, *pos); + } + + // All elements are NaN + let mut data: Vec = get_array_f16(arr_len); + for i in 0..arr_len { + data[i] = f16::NAN; } + let (argmin_index, argmax_index) = scalar_argminmax(&data); + let (argmin_index_f16, argmax_index_f16) = scalar_argminmax_f16(&data); + assert_eq!(argmin_index, argmin_index_f16); + assert_eq!(argmax_index, argmax_index_f16); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); } } diff --git a/src/simd/config.rs b/src/simd/config.rs index 4c0ec06..76bef38 100644 --- a/src/simd/config.rs +++ b/src/simd/config.rs @@ -1,6 +1,9 @@ // https://github.com/rust-lang/portable-simd/blob/master/beginners-guide.md#target-features +/// SIMD instruction set trait - used to store the register size and get the lane size +/// for a given datatype pub trait SIMDInstructionSet { + /// The size of the register in bits const REGISTER_SIZE: usize; // Set the const lanesize for each datatype @@ -14,34 +17,85 @@ pub trait SIMDInstructionSet { } } -// ----------------------------- x86_64 / x86 ----------------------------- +// ----------------------------------- x86_64 / x86 ------------------------------------ +/// SSE instruction set - this will be implemented for all: +/// - ints (see, the simd_i*.rs files) +/// - uints (see, the simd_u*.rs files) +/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files) pub struct SSE; +/// SSE instruction set - this will be implemented for all: +/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files) +pub struct SSEIgnoreNaN; impl SIMDInstructionSet for SSE { + /// SSE register size is 128 bits + /// https://en.wikipedia.org/wiki/Streaming_SIMD_Extensions#Registers const REGISTER_SIZE: usize = 128; } -pub struct AVX2; // for f32 and f64 AVX is enough +/// AVX2 instruction set - this will be implemented for all: +/// - ints (see, the simd_i*.rs files) +/// - uints (see, the simd_u*.rs files) +/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files) +pub struct AVX2; + +/// AVX(2) instruction set - this will be implemented for all: +/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files) +/// +/// Important remark: AVX is enough for f32 and f64! +/// -> for f16 we need AVX2 - but this is currently not yet implemented (TODO) +/// +/// Note: this struct does not implement the `SIMDInstructionSet` trait +pub struct AVX2IgnoreNaN; impl SIMDInstructionSet for AVX2 { + /// AVX(2) register size is 256 bits + /// AVX: https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#Advanced_Vector_Extensions + /// AVX2: https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#AVX2 const REGISTER_SIZE: usize = 256; } +/// AVX512 instruction set - this will be implemented for all: +/// - ints (see, the simd_i*.rs files) +/// - uints (see, the simd_u*.rs files) +/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files) pub struct AVX512; +/// AVX512 instruction set - this will be implemented for all: +/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files) +/// +/// Note: this struct does not implement the `SIMDInstructionSet` trait +pub struct AVX512IgnoreNaN; + impl SIMDInstructionSet for AVX512 { + /// AVX512 register size is 512 bits + /// https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#AVX-512 const REGISTER_SIZE: usize = 512; } -// ----------------------------- aarch64 / arm ----------------------------- +// ----------------------------------- aarch64 / arm ----------------------------------- +/// NEON instruction set - this will be implemented for all: +/// - ints (see, the simd_i*.rs files) +/// - uints (see, the simd_u*.rs files) +/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files) pub struct NEON; +/// NEON instruction set - this will be implemented for all: +/// - floats: ignoring NaNs (see, the `simd_f*_ignore_nan.rs` files) +/// +/// Note: this struct does not implement the `SIMDInstructionSet` trait +pub struct NEONIgnoreNaN; + impl SIMDInstructionSet for NEON { + /// NEON register size is 128 bits + /// https://en.wikipedia.org/wiki/ARM_architecture#Advanced_SIMD_(Neon) const REGISTER_SIZE: usize = 128; } +// --------------------------------------- Tests --------------------------------------- + #[cfg(test)] mod tests { use super::*; diff --git a/src/simd/generic.rs b/src/simd/generic.rs index c5d708e..fda06b7 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -1,41 +1,57 @@ +use num_traits::float::FloatCore; use num_traits::AsPrimitive; +use super::config::SIMDInstructionSet; use super::task::*; -use crate::scalar::{ScalarArgMinMax, SCALAR}; - -// TODO: other potential generic SIMDIndexDtype: Copy -#[allow(clippy::missing_safety_doc)] // TODO: add safety docs? -pub trait SIMD< +use crate::scalar::{SCALARIgnoreNaN, ScalarArgMinMax, SCALAR}; + +// ---------------------------------- SIMD operations ---------------------------------- + +/// Core SIMD operations +/// These operations are used by the SIMD algorithm and have to be implemented for each +/// data type - SIMD instruction set combination. +/// The operations are implemented in the `simd_*.rs` files. +/// +/// Note that for floating point dataypes two implementations are required: +/// - one for the ignore NaN case (uses a floating point SIMDVecDtype) +/// (see the `simd_f*_ignore_nan.rs` files) +/// - one for the return NaN case (uses an integer SIMDVecDtype - as we use the +/// ord_transform to view the floating point data as ordinal integer data). +/// (see the `simd_f*_return_nan.rs` files) +pub trait SIMDOps +where ScalarDType: Copy + PartialOrd + AsPrimitive, SIMDVecDtype: Copy, SIMDMaskDtype: Copy, - const LANE_SIZE: usize, -> { + /// Integers > this value **cannot** be accurately represented in SIMDVecDtype + const MAX_INDEX: usize; + /// Initial index value for the SIMD vector const INITIAL_INDEX: SIMDVecDtype; - const MAX_INDEX: usize; // Integers > this value **cannot** be accurately represented in SIMDVecDtype - - #[inline(always)] - fn _find_largest_lower_multiple_of_lane_size(n: usize) -> usize { - n - n % LANE_SIZE - } - - // ------------------------------------ SIMD HELPERS -------------------------------------- + /// Increment value for the SIMD vector + const INDEX_INCREMENT: SIMDVecDtype; + /// Convert a SIMD register to array unsafe fn _reg_to_arr(reg: SIMDVecDtype) -> [ScalarDType; LANE_SIZE]; + /// Load a SIMD register from memory unsafe fn _mm_loadu(data: *const ScalarDType) -> SIMDVecDtype; - unsafe fn _mm_set1(a: usize) -> SIMDVecDtype; - + /// Add two SIMD registers unsafe fn _mm_add(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDVecDtype; + /// Compare two SIMD registers for greater-than (gt): a > b + /// Returns a SIMD mask unsafe fn _mm_cmpgt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + /// Compare two SIMD registers for less-than (lt): a < b unsafe fn _mm_cmplt(a: SIMDVecDtype, b: SIMDVecDtype) -> SIMDMaskDtype; + /// Blend two SIMD registers using a SIMD mask (selects elements from a or b) unsafe fn _mm_blendv(a: SIMDVecDtype, b: SIMDVecDtype, mask: SIMDMaskDtype) -> SIMDVecDtype; + /// Horizontal min: get the minimum value from the value SIMD register and its + /// corresponding index from the index SIMD register #[inline(always)] unsafe fn _horiz_min(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 @@ -52,6 +68,8 @@ pub trait SIMD< (min_index.as_(), min_value) } + /// Horizontal max: get the maximum value from the value SIMD register and its + /// corresponding index from the index SIMD register #[inline(always)] unsafe fn _horiz_max(index: SIMDVecDtype, value: SIMDVecDtype) -> (usize, ScalarDType) { // This becomes the bottleneck when using 8-bit data types, as for every 2**7 @@ -68,44 +86,115 @@ pub trait SIMD< (max_index.as_(), max_value) } + /// Get the largest multiple of LANE_SIZE that is <= MAX_INDEX #[inline(always)] - unsafe fn _mm_prefetch(data: *const ScalarDType) { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - #[cfg(target_arch = "x86")] - use std::arch::x86::_mm_prefetch; - #[cfg(target_arch = "x86_64")] - use std::arch::x86_64::_mm_prefetch; - - _mm_prefetch(data as *const i8, 0); // 0=NTA - } - #[cfg(target_arch = "aarch64")] - { - use std::arch::aarch64::_prefetch; - - _prefetch(data as *const i8, 0, 0); // 0=READ, 0=NTA - } + fn _get_overflow_lane_size_limit() -> usize { + Self::MAX_INDEX - Self::MAX_INDEX % LANE_SIZE } +} - // ------------------------------------ ARGMINMAX -------------------------------------- +// ---------------------------------- SIMD algorithm ----------------------------------- + +// --------------- Default + +/// The default SIMDCore trait (for all data types) +/// +/// This trait is auto-implemented below for: +/// - ints +/// - uints +/// - floats: returning NaN +/// => this corresponds to structs that implement SIMDInstructionSet (see `config.rs`) +/// (thus also for example `SSE` for float returning NaN) +pub trait SIMDCore: + SIMDOps +where + ScalarDType: Copy + PartialOrd + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, +{ + /// Core argminmax algorithm - returns (argmin, min, argmax, max) + /// + /// This method asserts: + /// - the array length is a multiple of LANE_SIZE + /// This method assumes: + /// - the array length is <= MAX_INDEX + /// + /// Note that this method is not overflow safe, as it assumes that the array length + /// is <= MAX_INDEX. The `_overflow_safe_core_argminmax method` is overflow safe. + /// + /// Note that this method is leveraged by the return NaN implementation (as the + /// float values - including NaNs - are mapped to ordinal integers). + #[inline(always)] + unsafe fn _core_argminmax(arr: &[ScalarDType]) -> (usize, ScalarDType, usize, ScalarDType) { + assert_eq!(arr.len() % LANE_SIZE, 0); + // Efficient calculation of argmin and argmax together + let mut new_index = Self::INITIAL_INDEX; + let mut index_low = Self::INITIAL_INDEX; + let mut index_high = Self::INITIAL_INDEX; - unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize); + let mut arr_ptr = arr.as_ptr(); // Array pointer we will increment in the loop + let mut values_low = Self::_mm_loadu(arr_ptr); + let mut values_high = Self::_mm_loadu(arr_ptr); - #[inline(always)] - unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize) - where - SCALAR: ScalarArgMinMax, - { - argminmax_generic(data, LANE_SIZE, Self::_overflow_safe_core_argminmax) + // This is (40%-5%) slower than the loop below (depending on the data type) + // arr.chunks_exact(LANE_SIZE) + // .into_iter() + // .skip(1) + // .for_each(|step| { + // new_index = Self::_mm_add(new_index, increment); + + // let new_values = Self::_mm_loadu(step.as_ptr()); + + // let lt_mask = Self::_mm_cmplt(new_values, values_low); + // let gt_mask = Self::_mm_cmpgt(new_values, values_high); + + // index_low = Self::_mm_blendv(index_low, new_index, lt_mask); + // index_high = Self::_mm_blendv(index_high, new_index, gt_mask); + + // values_low = Self::_mm_blendv(values_low, new_values, lt_mask); + // values_high = Self::_mm_blendv(values_high, new_values, gt_mask); + // }); + + for _ in 0..arr.len() / LANE_SIZE - 1 { + // Increment the index + new_index = Self::_mm_add(new_index, Self::INDEX_INCREMENT); + // Load the next chunk of data + arr_ptr = arr_ptr.add(LANE_SIZE); + let new_values = Self::_mm_loadu(arr_ptr); + + // Update the lowest values and index + let mask_low = Self::_mm_cmplt(new_values, values_low); + values_low = Self::_mm_blendv(values_low, new_values, mask_low); + index_low = Self::_mm_blendv(index_low, new_index, mask_low); + + // Update the highest values and index + let mask_high = Self::_mm_cmpgt(new_values, values_high); + values_high = Self::_mm_blendv(values_high, new_values, mask_high); + index_high = Self::_mm_blendv(index_high, new_index, mask_high); + } + + // Get the min/max index and corresponding value from the SIMD vectors and return + let (min_index, min_value) = Self::_horiz_min(index_low, values_low); + let (max_index, max_value) = Self::_horiz_max(index_high, values_high); + (min_index, min_value, max_index, max_value) } + /// Overflow-safe core argminmax algorithm - returns (argmin, min, argmax, max) + /// + /// This method asserts: + /// - the array is not empty + /// - the array length is a multiple of LANE_SIZE + /// + /// Note that this method checks for nans by comparing v != v (is true for nans) + /// -> returns once `_core_argminmax` returns a NaN value #[inline(always)] unsafe fn _overflow_safe_core_argminmax( arr: &[ScalarDType], ) -> (usize, ScalarDType, usize, ScalarDType) { assert!(!arr.is_empty()); + assert_eq!(arr.len() % LANE_SIZE, 0); // 0. Get the max value of the data type - which needs to be divided by LANE_SIZE - let dtype_max = Self::_find_largest_lower_multiple_of_lane_size(Self::MAX_INDEX); + let dtype_max = Self::_get_overflow_lane_size_limit(); // 1. Determine the number of loops needed // let n_loops = (arr.len() + dtype_max - 1) / dtype_max; // ceil division @@ -119,14 +208,17 @@ pub trait SIMD< let mut start: usize = 0; // 2.0 Perform the full loops for _ in 0..n_loops { - // Self::_mm_prefetch(arr.as_ptr().add(start)); + if min_value != min_value || max_value != max_value { + // If min_value or max_value is NaN, we can return immediately + return (min_index, min_value, max_index, max_value); + } let (min_index_, min_value_, max_index_, max_value_) = Self::_core_argminmax(&arr[start..start + dtype_max]); - if min_value_ < min_value { + if min_value_ < min_value || min_value_ != min_value_ { min_index = start + min_index_; min_value = min_value_; } - if max_value_ > max_value { + if max_value_ > max_value || max_value_ != max_value_ { max_index = start + max_index_; max_value = max_value_; } @@ -134,14 +226,17 @@ pub trait SIMD< } // 2.1 Handle the remainder if start < arr.len() { - // Self::_mm_prefetch(arr.as_ptr().add(start)); + if min_value != min_value || max_value != max_value { + // If min_value or max_value is NaN, we can return immediately + return (min_index, min_value, max_index, max_value); + } let (min_index_, min_value_, max_index_, max_value_) = Self::_core_argminmax(&arr[start..]); - if min_value_ < min_value { + if min_value_ < min_value || min_value_ != min_value_ { min_index = start + min_index_; min_value = min_value_; } - if max_value_ > max_value { + if max_value_ > max_value || max_value_ != max_value_ { max_index = start + max_index_; max_value = max_value_; } @@ -150,85 +245,270 @@ pub trait SIMD< // 3. Return the min/max index and corresponding value (min_index, min_value, max_index, max_value) } +} - // TODO: can be cleaner (perhaps?) - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: SIMDVecDtype, - values_low: SIMDVecDtype, - index_high: SIMDVecDtype, - values_high: SIMDVecDtype, - ) -> (usize, ScalarDType, usize, ScalarDType) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - (min_index, min_value, max_index, max_value) - } +// Implement SIMDCore where SIMDOps is implemented (for the SIMDIstructionSet structs) +impl + SIMDCore for T +where + ScalarDType: Copy + PartialOrd + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, + T: SIMDOps + SIMDInstructionSet, +{ + // Use the implementation +} + +// --------------- Float Ignore NaNs +/// SIMD operations for setting a SIMD vector to a scalar value (only required for floats) +pub trait SIMDSetOps +where + ScalarDType: FloatCore, +{ + /// Set a SIMD vector to a scalar value (each lane is set to the scalar value) + unsafe fn _mm_set1(a: ScalarDType) -> SIMDVecDtype; +} + +/// SIMDCore trait that ignore NaNs (for float types) +/// +/// This trait is auto-implemented below for: +/// - floats: ignoring NaN +/// => this corresponds to the IgnoreNan structs (see `config.rs`) +/// (for example `SSEIgnoreNaN`) +pub trait SIMDCoreIgnoreNaN: + SIMDOps + SIMDSetOps +where + ScalarDType: FloatCore + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, +{ + /// Core argminmax algorithm - returns (argmin, min, argmax, max) + /// + /// This method asserts: + /// - the array length is a multiple of LANE_SIZE + /// This method assumes: + /// - the array length is <= MAX_INDEX + /// + /// Note that this method is not overflow safe, as it assumes that the array length + /// is <= MAX_INDEX. The `_overflow_safe_core_argminmax method` is overflow safe. #[inline(always)] unsafe fn _core_argminmax(arr: &[ScalarDType]) -> (usize, ScalarDType, usize, ScalarDType) { assert_eq!(arr.len() % LANE_SIZE, 0); // Efficient calculation of argmin and argmax together let mut new_index = Self::INITIAL_INDEX; - let mut index_low = Self::INITIAL_INDEX; - let mut index_high = Self::INITIAL_INDEX; - - let increment = Self::_mm_set1(LANE_SIZE); let mut arr_ptr = arr.as_ptr(); // Array pointer we will increment in the loop - let mut values_low = Self::_mm_loadu(arr_ptr); - let mut values_high = Self::_mm_loadu(arr_ptr); - - // This is (40%-5%) slower than the loop below (depending on the data type) - // arr.chunks_exact(LANE_SIZE) - // .into_iter() - // .skip(1) - // .for_each(|step| { - // new_index = Self::_mm_add(new_index, increment); - - // let new_values = Self::_mm_loadu(step.as_ptr()); - - // let lt_mask = Self::_mm_cmplt(new_values, values_low); - // let gt_mask = Self::_mm_cmpgt(new_values, values_high); - - // index_low = Self::_mm_blendv(index_low, new_index, lt_mask); - // index_high = Self::_mm_blendv(index_high, new_index, gt_mask); - - // values_low = Self::_mm_blendv(values_low, new_values, lt_mask); - // values_high = Self::_mm_blendv(values_high, new_values, gt_mask); - // }); + let new_values = Self::_mm_loadu(arr_ptr); + + // Update the lowest values and index + let mask_low = Self::_mm_cmplt(new_values, Self::_mm_set1(ScalarDType::infinity())); + let mut values_low = Self::_mm_blendv( + Self::_mm_set1(ScalarDType::infinity()), + new_values, + mask_low, + ); + let mut index_low = + Self::_mm_blendv(Self::_mm_set1(ScalarDType::zero()), new_index, mask_low); + + // Update the highest values and index + let mask_high = Self::_mm_cmpgt(new_values, Self::_mm_set1(ScalarDType::neg_infinity())); + let mut values_high = Self::_mm_blendv( + Self::_mm_set1(ScalarDType::neg_infinity()), + new_values, + mask_high, + ); + let mut index_high = + Self::_mm_blendv(Self::_mm_set1(ScalarDType::zero()), new_index, mask_high); for _ in 0..arr.len() / LANE_SIZE - 1 { // Increment the index - new_index = Self::_mm_add(new_index, increment); + new_index = Self::_mm_add(new_index, Self::INDEX_INCREMENT); // Load the next chunk of data arr_ptr = arr_ptr.add(LANE_SIZE); - // Self::_mm_prefetch(arr_ptr); // Hint to the CPU to prefetch the next chunk of data let new_values = Self::_mm_loadu(arr_ptr); // Update the lowest values and index - let mask = Self::_mm_cmplt(new_values, values_low); - values_low = Self::_mm_blendv(values_low, new_values, mask); - index_low = Self::_mm_blendv(index_low, new_index, mask); + let mask_low = Self::_mm_cmplt(new_values, values_low); + values_low = Self::_mm_blendv(values_low, new_values, mask_low); + index_low = Self::_mm_blendv(index_low, new_index, mask_low); // Update the highest values and index - let mask = Self::_mm_cmpgt(new_values, values_high); - values_high = Self::_mm_blendv(values_high, new_values, mask); - index_high = Self::_mm_blendv(index_high, new_index, mask); + let mask_high = Self::_mm_cmpgt(new_values, values_high); + values_high = Self::_mm_blendv(values_high, new_values, mask_high); + index_high = Self::_mm_blendv(index_high, new_index, mask_high); + } + + // Get the min/max index and corresponding value from the SIMD vectors and return + let (min_index, min_value) = Self::_horiz_min(index_low, values_low); + let (max_index, max_value) = Self::_horiz_max(index_high, values_high); + (min_index, min_value, max_index, max_value) + } + + /// Overflow-safe core argminmax algorithm - returns (argmin, min, argmax, max) + /// + /// This method asserts: + /// - the array is not empty + /// - the array length is a multiple of LANE_SIZE + /// + /// Note that this method ignores nans by assuring that no NaN values are inserted + /// in the initial min / max SIMD vectors. Since comparing a value to NaN always + /// returns false, the NaN values will never be selected as the min / max values. + #[inline(always)] + unsafe fn _overflow_safe_core_argminmax( + arr: &[ScalarDType], + ) -> (usize, ScalarDType, usize, ScalarDType) { + assert!(!arr.is_empty()); + // 0. Get the max value of the data type - which needs to be divided by LANE_SIZE + let dtype_max = Self::_get_overflow_lane_size_limit(); - // 25 is a non-scientific number, but seems to work overall - // => TODO: probably this should be in function of the data type - // Self::_mm_prefetch(arr_ptr.add(LANE_SIZE * 25)); // Hint to the CPU to prefetch upcoming data + // 1. Determine the number of loops needed + // let n_loops = (arr.len() + dtype_max - 1) / dtype_max; // ceil division + let n_loops = arr.len() / dtype_max; // floor division + + // 2. Perform overflow-safe _core_argminmax + let mut min_index: usize = 0; + let mut min_value: ScalarDType = ScalarDType::infinity(); + let mut max_index: usize = 0; + let mut max_value: ScalarDType = ScalarDType::neg_infinity(); + let mut start: usize = 0; + // 2.0 Perform the full loops + for _ in 0..n_loops { + let (min_index_, min_value_, max_index_, max_value_) = + Self::_core_argminmax(&arr[start..start + dtype_max]); + if min_value_ < min_value { + min_index = start + min_index_; + min_value = min_value_; + } + if max_value_ > max_value { + max_index = start + max_index_; + max_value = max_value_; + } + start += dtype_max; + } + // 2.1 Handle the remainder + if start < arr.len() { + let (min_index_, min_value_, max_index_, max_value_) = + Self::_core_argminmax(&arr[start..]); + if min_value_ < min_value { + min_index = start + min_index_; + min_value = min_value_; + } + if max_value_ > max_value { + max_index = start + max_index_; + max_value = max_value_; + } } - Self::_get_min_max_index_value(index_low, values_low, index_high, values_high) + // 3. Return the min/max index and corresponding value + (min_index, min_value, max_index, max_value) } } -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -macro_rules! unimplement_simd { - ($scalar_type:ty, $reg:ty, $simd_type:ident) => { - impl SIMD<$scalar_type, $reg, $reg, 0> for $simd_type { +// Implement SIMDCoreIgnoreNaNs where SIMDOps + SIMDSetOps is implemented for floats +impl + SIMDCoreIgnoreNaN for T +where + ScalarDType: FloatCore + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, + T: SIMDOps + + SIMDSetOps, +{ + // Use the implementation +} + +// -------------------------------- ArgMinMax SIMD TRAIT ------------------------------- + +// --------------- Default + +/// Trait for SIMD argminmax operations +/// +/// This trait its `argminmax` method should be implemented for all structs that +/// implement `SIMDOps` for the same generics. +/// This trait is implemented for: +/// - ints (see, the simd_i*.rs files) +/// - uints (see, the simd_u*.rs files) +/// - floats: returning NaNs (see, the simd_f*_return_nan.rs files) +#[allow(clippy::missing_safety_doc)] // TODO: add safety docs? +pub trait SIMDArgMinMax: + SIMDCore +where + ScalarDType: Copy + PartialOrd + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, +{ + /// Returns the index of the minimum and maximum value in the array + unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize); + + // Is necessary to have a separate function for this so we can call it in the + // argminmax function when we add the target feature to the function. + #[inline(always)] + unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize) + where + SCALAR: ScalarArgMinMax, + { + argminmax_generic( + data, + LANE_SIZE, + Self::_overflow_safe_core_argminmax, + false, + SCALAR::argminmax, + ) + } +} + +// --------------- Float Return NaN + +// This is the same code as the default trait - thus we can just use the default trait. + +// --------------- Float Ignore NaN + +/// Trait for SIMD argminmax operations that ignore NaNs +/// +/// This trait its `argminmax` method should be implemented for all structs that +/// implement `SIMDOps` and `SIMDSetOps` for the same generics. +/// This trait is implemented for: +/// - floats: ignoring NaNs (see, the simd_f*_ignore_nan.rs files) +#[allow(clippy::missing_safety_doc)] // TODO: add safety docs? +pub trait SIMDArgMinMaxIgnoreNaN: + SIMDCoreIgnoreNaN +where + ScalarDType: FloatCore + AsPrimitive, + SIMDVecDtype: Copy, + SIMDMaskDtype: Copy, +{ + /// Returns the index of the minimum and maximum value in the array + unsafe fn argminmax(data: &[ScalarDType]) -> (usize, usize); + + // Is necessary to have a separate function for this so we can call it in the + // argminmax function when we add the target feature to the function. + #[inline(always)] + unsafe fn _argminmax(data: &[ScalarDType]) -> (usize, usize) + where + SCALARIgnoreNaN: ScalarArgMinMax, + { + argminmax_generic( + data, + LANE_SIZE, + Self::_overflow_safe_core_argminmax, + true, + SCALARIgnoreNaN::argminmax, + ) + } +} + +// --------------------------------- Unimplement Macros -------------------------------- + +// TODO: temporarily removed the target_arch specification bc we currently do not +// ArgMinMaxIgnoreNan for f16 ignore nan + +// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +macro_rules! unimpl_SIMDOps { + ($scalar_type:ty, $reg:ty, $simd_instructionset:ident) => { + impl SIMDOps<$scalar_type, $reg, $reg, 0> for $simd_instructionset { const INITIAL_INDEX: $reg = 0; + const INDEX_INCREMENT: $reg = 0; const MAX_INDEX: usize = 0; unsafe fn _reg_to_arr(_reg: $reg) -> [$scalar_type; 0] { @@ -239,10 +519,6 @@ macro_rules! unimplement_simd { unimplemented!() } - unsafe fn _mm_set1(_a: usize) -> $reg { - unimplemented!() - } - unsafe fn _mm_add(_a: $reg, _b: $reg) -> $reg { unimplemented!() } @@ -258,12 +534,41 @@ macro_rules! unimplement_simd { unsafe fn _mm_blendv(_a: $reg, _b: $reg, _mask: $reg) -> $reg { unimplemented!() } + } + }; +} +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +macro_rules! unimpl_SIMDArgMinMax { + ($scalar_type:ty, $reg:ty, $simd_instructionset:ident) => { + impl SIMDArgMinMax<$scalar_type, $reg, $reg, 0> for $simd_instructionset { unsafe fn argminmax(_data: &[$scalar_type]) -> (usize, usize) { unimplemented!() } } }; } + +// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +macro_rules! unimpl_SIMDArgMinMaxIgnoreNaN { + ($scalar_type:ty, $reg:ty, $simd_instructionset:ident) => { + impl SIMDSetOps<$scalar_type, $reg> for $simd_instructionset { + unsafe fn _mm_set1(_a: $scalar_type) -> $reg { + unimplemented!() + } + } + impl SIMDArgMinMaxIgnoreNaN<$scalar_type, $reg, $reg, 0> for $simd_instructionset { + unsafe fn argminmax(_data: &[$scalar_type]) -> (usize, usize) { + unimplemented!() + } + } + }; +} + +// TODO: temporarily removed the target_arch until we implement f16_ignore_nans #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -pub(crate) use unimplement_simd; // Now classic paths Just Workâ„¢ +pub(crate) use unimpl_SIMDArgMinMax; // Now classic paths Just Workâ„¢ + // #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +pub(crate) use unimpl_SIMDArgMinMaxIgnoreNaN; // Now classic paths Just Workâ„¢ + // #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +pub(crate) use unimpl_SIMDOps; // Now classic paths Just Workâ„¢ diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 3d73650..2699259 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -6,9 +6,12 @@ pub use config::*; mod generic; pub use generic::*; // FLOAT -mod simd_f16; -mod simd_f32; -mod simd_f64; +mod simd_f16_ignore_nan; // TODO: not supported yet +mod simd_f16_return_nan; +mod simd_f32_ignore_nan; +mod simd_f32_return_nan; +mod simd_f64_ignore_nan; +mod simd_f64_return_nan; // SIGNED INT mod simd_i16; mod simd_i32; diff --git a/src/simd/simd_f16_ignore_nan.rs b/src/simd/simd_f16_ignore_nan.rs new file mode 100644 index 0000000..6580df2 --- /dev/null +++ b/src/simd/simd_f16_ignore_nan.rs @@ -0,0 +1,60 @@ +/// Currently not supported. Should give this some more thought. +/// + +// #[cfg(feature = "half")] +// use super::config::SIMDInstructionSet; +#[cfg(feature = "half")] +use super::generic::{unimpl_SIMDArgMinMaxIgnoreNaN, unimpl_SIMDOps}; +#[cfg(feature = "half")] +use super::generic::{SIMDArgMinMaxIgnoreNaN, SIMDOps, SIMDSetOps}; + +#[cfg(feature = "half")] +use half::f16; + +// ------------------------------------------ AVX2 ------------------------------------------ + +#[cfg(feature = "half")] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx2_ignore_nan { + use super::super::config::AVX2IgnoreNaN; + use super::*; + + unimpl_SIMDOps!(f16, usize, AVX2IgnoreNaN); + unimpl_SIMDArgMinMaxIgnoreNaN!(f16, usize, AVX2IgnoreNaN); +} + +// ----------------------------------------- SSE ----------------------------------------- + +#[cfg(feature = "half")] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod sse_ignore_nan { + use super::super::config::SSEIgnoreNaN; + use super::*; + + unimpl_SIMDOps!(f16, usize, SSEIgnoreNaN); + unimpl_SIMDArgMinMaxIgnoreNaN!(f16, usize, SSEIgnoreNaN); +} + +// --------------------------------------- AVX512 ---------------------------------------- + +#[cfg(feature = "half")] +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx512_ignore_nan { + use super::super::config::AVX512IgnoreNaN; + use super::*; + + unimpl_SIMDOps!(f16, usize, AVX512IgnoreNaN); + unimpl_SIMDArgMinMaxIgnoreNaN!(f16, usize, AVX512IgnoreNaN); +} + +// ---------------------------------------- NEON ----------------------------------------- + +#[cfg(feature = "half")] +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +mod neon_ignore_nan { + use super::super::config::NEONIgnoreNaN; + use super::*; + + unimpl_SIMDOps!(f16, usize, NEONIgnoreNaN); + unimpl_SIMDArgMinMaxIgnoreNaN!(f16, usize, NEONIgnoreNaN); +} diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16_return_nan.rs similarity index 51% rename from src/simd/simd_f16.rs rename to src/simd/simd_f16_return_nan.rs index 2bec89e..a27320e 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16_return_nan.rs @@ -1,7 +1,37 @@ +/// Implementation of the argminmax operations for f16 where NaN values take precedence. +/// This implementation returns the index of the first* NaN value if any are present, +/// otherwise it returns the index of the minimum and maximum values. +/// +/// To serve this functionality we transform the f16 values to ordinal i32 values: +/// ord_i16 = ((v >> 15) & 0x7FFFFFFF) ^ v +/// +/// This transformation is a bijection, i.e. it is reversible: +/// v = ((ord_i16 >> 15) & 0x7FFFFFFF) ^ ord_i16 +/// +/// Through this transformation we can perform the argminmax operations on the ordinal +/// integer values and then transform the result back to the original f16 values. +/// This transformation is necessary because comparisons with NaN values are always false. +/// So unless we perform ! <= as gt and ! >= as lt the argminmax operations will not +/// add NaN values to the accumulating SIMD register. And as le and ge are significantly +/// more expensive than lt and gt we use this efficient bitwise transformation. +/// +/// Note that most x86 CPUs do not support f16 instructions - making this implementation +/// multitudes (up to 300x) faster than trying to use a vanilla scalar implementation. +/// +/// +/// --- +/// +/// *Note: the first NaN value is only returned iff all NaN values have the same bit +/// representation. When NaN values have different bit representations then the index of +/// the highest / lowest ord_i16 is returned for the +/// SIMDOps::_get_overflow_lane_size_limit() chunk of the data - which is not +/// necessarily the index of the first NaN value. +/// + #[cfg(feature = "half")] use super::config::SIMDInstructionSet; #[cfg(feature = "half")] -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(feature = "half")] #[cfg(target_arch = "aarch64")] @@ -20,16 +50,20 @@ use std::arch::x86_64::*; use half::f16; #[cfg(feature = "half")] -const XOR_VALUE: i16 = 0x7FFF; +const BIT_SHIFT: i32 = 15; +#[cfg(feature = "half")] +const MASK_VALUE: i16 = 0x7FFF; // i16::MAX - masks everything but the sign bit #[cfg(feature = "half")] #[inline(always)] -fn _ord_i16_to_f16(ord_i16: i16) -> f16 { - // TODO: more efficient transformation -> can be decreasing order as well - let v = ((ord_i16 >> 15) & XOR_VALUE) ^ ord_i16; - unsafe { std::mem::transmute::(v) } +fn _i16ord_to_f16(ord_i16: i16) -> f16 { + let v = ((ord_i16 >> BIT_SHIFT) & MASK_VALUE) ^ ord_i16; + f16::from_bits(v as u16) } +#[cfg(feature = "half")] +const MAX_INDEX: usize = i16::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(feature = "half")] @@ -39,16 +73,19 @@ mod avx2 { use super::*; const LANE_SIZE: usize = AVX2::LANE_SIZE_16; - const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; - - // ------------------------------------ ARGMINMAX -------------------------------------- + const LOWER_15_MASK: __m256i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; #[inline(always)] unsafe fn _f16_as_m256i_to_i16ord(f16_as_m256i: __m256i) -> __m256i { // on a scalar: ((v >> 15) & 0x7FFF) ^ v - let sign_bit_shifted = _mm256_srai_epi16(f16_as_m256i, 15); - let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, XOR_MASK); + let sign_bit_shifted = _mm256_srai_epi16(f16_as_m256i, BIT_SHIFT); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, LOWER_15_MASK); _mm256_xor_si256(sign_bit_masked, f16_as_m256i) + // TODO: investigate if this is faster + // _mm256_xor_si256( + // _mm256_srai_epi16(f16_as_m256i, 15), + // _mm256_and_si256(f16_as_m256i, LOWER_15_MASK), + // ) } #[inline(always)] @@ -56,18 +93,22 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, 13i16, 14i16, 15i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m256i) -> [f16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. unimplemented!() } @@ -76,11 +117,6 @@ mod avx2 { _f16_as_m256i_to_i16ord(_mm256_loadu_si256(data as *const __m256i)) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { _mm256_add_epi16(a, b) @@ -101,13 +137,6 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx2")] - unsafe fn argminmax(data: &[f16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, f16) { // 0. Find the minimum value @@ -135,7 +164,7 @@ mod avx2 { imin = _mm256_min_epi16(imin, _mm256_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm256_extract_epi16(imin, 0) as usize; - (min_index, _ord_i16_to_f16(min_value)) + (min_index, _i16ord_to_f16(min_value)) } #[inline(always)] @@ -165,7 +194,14 @@ mod avx2 { imin = _mm256_min_epi16(imin, _mm256_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm256_extract_epi16(imin, 0) as usize; - (max_index, _ord_i16_to_f16(max_value)) + (max_index, _i16ord_to_f16(max_value)) + } + } + + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[f16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -173,7 +209,8 @@ mod avx2 { #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::SIMDArgMinMax; + use super::AVX2; use crate::scalar::generic::scalar_argminmax; use half::f16; @@ -228,6 +265,165 @@ mod avx2 { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f16(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f16::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f16::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f16(arr_len); + data[100] = f16::INFINITY; + data[200] = f16::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f16(arr_len); + data[0] = f16::NAN; + println!("{:?}", data); + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f16(arr_len); + data[arr_len - 1] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f16(arr_len); + data[123] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f16(128); + data[17] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + #[test] fn test_no_overflow() { if !is_x86_feature_detected!("avx2") { @@ -269,13 +465,13 @@ mod sse { use super::*; const LANE_SIZE: usize = SSE::LANE_SIZE_16; - const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; + const LOWER_15_MASK: __m128i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; #[inline(always)] unsafe fn _f16_as_m128i_to_i16ord(f16_as_m128i: __m128i) -> __m128i { // on a scalar: ((v >> 15) & 0x7FFF) ^ v - let sign_bit_shifted = _mm_srai_epi16(f16_as_m128i, 15); - let sign_bit_masked = _mm_and_si128(sign_bit_shifted, XOR_MASK); + let sign_bit_shifted = _mm_srai_epi16(f16_as_m128i, BIT_SHIFT); + let sign_bit_masked = _mm_and_si128(sign_bit_shifted, LOWER_15_MASK); _mm_xor_si128(sign_bit_masked, f16_as_m128i) } @@ -284,14 +480,18 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m128i) -> [f16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. unimplemented!() } @@ -300,11 +500,6 @@ mod sse { _f16_as_m128i_to_i16ord(_mm_loadu_si128(data as *const __m128i)) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { _mm_add_epi16(a, b) @@ -325,13 +520,6 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[f16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, f16) { // 0. Find the minimum value @@ -357,7 +545,7 @@ mod sse { imin = _mm_min_epi16(imin, _mm_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm_extract_epi16(imin, 0) as usize; - (min_index, _ord_i16_to_f16(min_value)) + (min_index, _i16ord_to_f16(min_value)) } #[inline(always)] @@ -385,7 +573,14 @@ mod sse { imin = _mm_min_epi16(imin, _mm_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm_extract_epi16(imin, 0) as usize; - (max_index, _ord_i16_to_f16(max_value)) + (max_index, _i16ord_to_f16(max_value)) + } + } + + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[f16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -393,7 +588,8 @@ mod sse { #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::SIMDArgMinMax; + use super::SSE; use crate::scalar::generic::scalar_argminmax; use half::f16; @@ -440,6 +636,157 @@ mod sse { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f16(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f16::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f16::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f16(arr_len); + data[100] = f16::INFINITY; + data[200] = f16::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f16(arr_len); + data[0] = f16::NAN; + println!("{:?}", data); + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f16(arr_len); + data[arr_len - 1] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f16(arr_len); + data[123] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f16(128); + data[17] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + #[test] fn test_no_overflow() { let n: usize = 1 << 18; @@ -473,13 +820,13 @@ mod avx512 { use super::*; const LANE_SIZE: usize = AVX512::LANE_SIZE_16; - const XOR_MASK: __m512i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; + const LOWER_15_MASK: __m512i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; #[inline(always)] unsafe fn _f16_as_m521i_to_i16ord(f16_as_m512i: __m512i) -> __m512i { // on a scalar: ((v >> 15) & 0x7FFF) ^ v - let sign_bit_shifted = _mm512_srai_epi16(f16_as_m512i, 15); - let sign_bit_masked = _mm512_and_si512(sign_bit_shifted, XOR_MASK); + let sign_bit_shifted = _mm512_srai_epi16(f16_as_m512i, BIT_SHIFT as u32); + let sign_bit_masked = _mm512_and_si512(sign_bit_shifted, LOWER_15_MASK); _mm512_xor_si512(f16_as_m512i, sign_bit_masked) } @@ -488,7 +835,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -496,11 +843,15 @@ mod avx512 { 25i16, 26i16, 27i16, 28i16, 29i16, 30i16, 31i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m512i) -> [f16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. unimplemented!() } @@ -509,11 +860,6 @@ mod avx512 { _f16_as_m521i_to_i16ord(_mm512_loadu_epi16(data as *const i16)) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { _mm512_add_epi16(a, b) @@ -534,13 +880,6 @@ mod avx512 { _mm512_mask_blend_epi16(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512bw")] - unsafe fn argminmax(data: &[f16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, f16) { // 0. Find the minimum value @@ -570,7 +909,7 @@ mod avx512 { imin = _mm512_min_epi16(imin, _mm512_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm_extract_epi16(_mm512_castsi512_si128(imin), 0) as usize; - (min_index, _ord_i16_to_f16(min_value)) + (min_index, _i16ord_to_f16(min_value)) } #[inline(always)] @@ -602,7 +941,14 @@ mod avx512 { imin = _mm512_min_epi16(imin, _mm512_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm_extract_epi16(_mm512_castsi512_si128(imin), 0) as usize; - (max_index, _ord_i16_to_f16(max_value)) + (max_index, _i16ord_to_f16(max_value)) + } + } + + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512bw")] + unsafe fn argminmax(data: &[f16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -610,7 +956,8 @@ mod avx512 { #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::SIMDArgMinMax; + use super::AVX512; use crate::scalar::generic::scalar_argminmax; use half::f16; @@ -665,6 +1012,165 @@ mod avx512 { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f16(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f16::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f16::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f16(arr_len); + data[100] = f16::INFINITY; + data[200] = f16::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f16(arr_len); + data[0] = f16::NAN; + println!("{:?}", data); + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f16(arr_len); + data[arr_len - 1] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f16(arr_len); + data[123] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f16(128); + data[17] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + #[test] fn test_no_overflow() { if !is_x86_feature_detected!("avx512bw") { @@ -706,13 +1212,13 @@ mod neon { use super::*; const LANE_SIZE: usize = NEON::LANE_SIZE_16; - const XOR_MASK: int16x8_t = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; + const LOWER_15_MASK: int16x8_t = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; #[inline(always)] unsafe fn _f16_as_int16x8_to_i16ord(f16_as_int16x8: int16x8_t) -> int16x8_t { // on a scalar: ((v >> 15) & 0x7FFF) ^ v - let sign_bit_shifted = vshrq_n_s16(f16_as_int16x8, 15); - let sign_bit_masked = vandq_s16(sign_bit_shifted, XOR_MASK); + let sign_bit_shifted = vshrq_n_s16(f16_as_int16x8, BIT_SHIFT); + let sign_bit_masked = vandq_s16(sign_bit_shifted, LOWER_15_MASK); veorq_s16(f16_as_int16x8, sign_bit_masked) } @@ -721,14 +1227,18 @@ mod neon { std::mem::transmute::(reg) } - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: int16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: int16x8_t = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: int16x8_t) -> [f16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. unimplemented!() } @@ -739,11 +1249,6 @@ mod neon { })) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { vaddq_s16(a, b) @@ -764,13 +1269,6 @@ mod neon { vbslq_s16(mask, b, a) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[f16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, f16) { // 0. Find the minimum value @@ -796,7 +1294,7 @@ mod neon { imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); let min_index: usize = vgetq_lane_s16(imin, 0) as usize; - (min_index, _ord_i16_to_f16(min_value)) + (min_index, _i16ord_to_f16(min_value)) } #[inline(always)] @@ -824,7 +1322,14 @@ mod neon { imin = vminq_s16(imin, vextq_s16(imin, imin, 1)); let max_index: usize = vgetq_lane_s16(imin, 0) as usize; - (max_index, _ord_i16_to_f16(max_value)) + (max_index, _i16ord_to_f16(max_value)) + } + } + + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[f16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -832,7 +1337,8 @@ mod neon { #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::SIMDArgMinMax; + use super::NEON; use crate::scalar::generic::scalar_argminmax; use half::f16; @@ -879,6 +1385,157 @@ mod neon { assert_eq!(argmax_simd_index, 1); } + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f16(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f16::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f16::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f16(arr_len); + data[100] = f16::INFINITY; + data[200] = f16::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f16(arr_len); + data[0] = f16::NAN; + println!("{:?}", data); + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f16(arr_len); + data[arr_len - 1] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f16(arr_len); + data[123] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f16::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f16(128); + data[17] = f16::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + #[test] fn test_no_overflow() { let n: usize = 1 << 18; diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs deleted file mode 100644 index c56b3ad..0000000 --- a/src/simd/simd_f32.rs +++ /dev/null @@ -1,586 +0,0 @@ -use super::config::SIMDInstructionSet; -use super::generic::SIMD; -#[cfg(target_arch = "aarch64")] -use std::arch::aarch64::*; -#[cfg(target_arch = "arm")] -use std::arch::arm::*; -#[cfg(target_arch = "x86")] -use std::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - -// ------------------------------------------ AVX2 ------------------------------------------ - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod avx2 { - use super::super::config::AVX2; - use super::*; - - const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256 = unsafe { - std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, - ]) - }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256) -> [f32; LANE_SIZE] { - std::mem::transmute::<__m256, [f32; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> __m256 { - _mm256_loadu_ps(data as *const f32) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256 { - _mm256_set1_ps(a as f32) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m256, b: __m256) -> __m256 { - _mm256_add_ps(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m256, b: __m256) -> __m256 { - _mm256_cmp_ps(a, b, _CMP_GT_OQ) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m256, b: __m256) -> __m256 { - _mm256_cmp_ps(b, a, _CMP_GT_OQ) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m256, b: __m256, mask: __m256) -> __m256 { - _mm256_blendv_ps(a, b, mask) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx")] - unsafe fn argminmax(data: &[f32]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{AVX2, SIMD}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f32(n: usize) -> Vec { - utils::get_random_array(n, f32::MIN, f32::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { - return; - } - - let data: &[f32] = &get_array_f32(1025); - assert_eq!(data.len() % 8, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { - return; - } - - let data = [ - 10., - std::f32::MAX, - 6., - std::f32::NEG_INFINITY, - std::f32::NEG_INFINITY, - std::f32::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f32] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_no_overflow() { - if !is_x86_feature_detected!("avx") { - return; - } - - let n: usize = 1 << 25; - let data: &[f32] = &get_array_f32(n); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { - return; - } - - for _ in 0..10_000 { - let data: &[f32] = &get_array_f32(32 * 8 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// ----------------------------------------- SSE ----------------------------------------- - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod sse { - use super::super::config::SSE; - use super::*; - - const LANE_SIZE: usize = SSE::LANE_SIZE_32; - - impl SIMD for SSE { - const INITIAL_INDEX: __m128 = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128) -> [f32; LANE_SIZE] { - std::mem::transmute::<__m128, [f32; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> __m128 { - _mm_loadu_ps(data as *const f32) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128 { - _mm_set1_ps(a as f32) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m128, b: __m128) -> __m128 { - _mm_add_ps(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m128, b: __m128) -> __m128 { - _mm_cmpgt_ps(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m128, b: __m128) -> __m128 { - _mm_cmplt_ps(a, b) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m128, b: __m128, mask: __m128) -> __m128 { - _mm_blendv_ps(a, b, mask) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[f32]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{SIMD, SSE}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f32(n: usize) -> Vec { - utils::get_random_array(n, f32::MIN, f32::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - let data: &[f32] = &get_array_f32(1025); - assert_eq!(data.len() % 4, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10., - std::f32::MAX, - 6., - std::f32::NEG_INFINITY, - std::f32::NEG_INFINITY, - std::f32::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f32] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_no_overflow() { - let n: usize = 1 << 25; - let data: &[f32] = &get_array_f32(n); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_many_random_runs() { - for _ in 0..10_000 { - let data: &[f32] = &get_array_f32(32 * 4 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// --------------------------------------- AVX512 ---------------------------------------- - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod avx512 { - use super::super::config::AVX512; - use super::*; - - const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512 = unsafe { - std::mem::transmute([ - 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, - 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, - ]) - }; - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512) -> [f32; LANE_SIZE] { - std::mem::transmute::<__m512, [f32; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> __m512 { - _mm512_loadu_ps(data as *const f32) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512 { - _mm512_set1_ps(a as f32) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m512, b: __m512) -> __m512 { - _mm512_add_ps(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m512, b: __m512) -> u16 { - _mm512_cmp_ps_mask(a, b, _CMP_GT_OQ) - } - // unimplemented!("AVX512 comparison instructions for ps output a u16 mask.") - // let u16_mask = _mm512_cmp_ps_mask(a, b, _CMP_GT_OQ); - // _mm512_mask_mov_ps(_mm512_setzero_ps(), u16_mask, _mm512_set1_ps(1.0)) - // } - // { _mm512_cmp_ps_mask(a, b, _CMP_GT_OQ) } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m512, b: __m512) -> u16 { - _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) - } - // unimplemented!("AVX512 comparison instructions for ps output a u16 mask.") - // let u16_mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); - // _mm512_mask_mov_ps(_mm512_setzero_ps(), u16_mask, _mm512_set1_ps(1.0)) - // } - // { _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m512, b: __m512, mask: u16) -> __m512 { - _mm512_mask_blend_ps(mask, a, b) - } - // unimplemented!("AVX512 blendv instructions for ps require a u16 mask.") - // convert the mask to u16 by extracting the sign bit of each lane - // let u16_mask = _mm512_castps_si512(mask); - // _mm512_mask_mov_ps(a, u16_mask, b) - // _mm512_mask_blend_ps(u16_mask, a, b) - // _mm512_mask_mov_ps(a, _mm512_castps_si512(mask), b) - // } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512f")] - unsafe fn argminmax(data: &[f32]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{AVX512, SIMD}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f32(n: usize) -> Vec { - utils::get_random_array(n, f32::MIN, f32::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - let data: &[f32] = &get_array_f32(1025); - assert_eq!(data.len() % 16, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - let data = [ - 10., - std::f32::MAX, - 6., - std::f32::NEG_INFINITY, - std::f32::NEG_INFINITY, - std::f32::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f32] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_no_overflow() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - let n: usize = 1 << 25; - let data: &[f32] = &get_array_f32(n); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_many_random_runs() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - for _ in 0..10_000 { - let data: &[f32] = &get_array_f32(32 * 16 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// ---------------------------------------- NEON ----------------------------------------- - -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -mod neon { - use super::super::config::NEON; - use super::*; - - const LANE_SIZE: usize = NEON::LANE_SIZE_32; - - impl SIMD for NEON { - const INITIAL_INDEX: float32x4_t = - unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: float32x4_t) -> [f32; LANE_SIZE] { - std::mem::transmute::(reg) - } - // https://stackoverflow.com/a/3793950 - const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f32) -> float32x4_t { - vld1q_f32(data as *const f32) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> float32x4_t { - vdupq_n_f32(a as f32) - } - - #[inline(always)] - unsafe fn _mm_add(a: float32x4_t, b: float32x4_t) -> float32x4_t { - vaddq_f32(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: float32x4_t, b: float32x4_t) -> uint32x4_t { - vcgtq_f32(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: float32x4_t, b: float32x4_t) -> uint32x4_t { - vcltq_f32(a, b) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: float32x4_t, b: float32x4_t, mask: uint32x4_t) -> float32x4_t { - vbslq_f32(mask, b, a) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[f32]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{NEON, SIMD}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f32(n: usize) -> Vec { - utils::get_random_array(n, f32::MIN, f32::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - let data: &[f32] = &get_array_f32(1025); - assert_eq!(data.len() % 4, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10., - std::f32::MAX, - 6., - std::f32::NEG_INFINITY, - std::f32::NEG_INFINITY, - std::f32::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f32] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_no_overflow() { - let n: usize = 1 << 25; - let data: &[f32] = &get_array_f32(n); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_many_random_runs() { - for _ in 0..10_000 { - let data: &[f32] = &get_array_f32(32 * 4 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} diff --git a/src/simd/simd_f32_ignore_nan.rs b/src/simd/simd_f32_ignore_nan.rs new file mode 100644 index 0000000..098c4ea --- /dev/null +++ b/src/simd/simd_f32_ignore_nan.rs @@ -0,0 +1,1117 @@ +/// Implementation of the argminmax operations for f32 that ignores NaN values. +/// This implementation returns the index of the minimum and maximum values. +/// However, unexpected behavior may occur when there are +/// - *only* NaN values in the array +/// - *only* +/- infinity values in the array +/// - *only* NaN and +/- infinity values in the array +/// In these cases, index 0 is returned. +/// +/// NaN values are ignored and treated as if they are not present in the array. +/// To realize this we create an initial SIMD register with values +/- infinity. +/// As comparisons with NaN always return false, it is guaranteed that no NaN values +/// are added to the accumulating SIMD register. +/// +use super::config::SIMDInstructionSet; +use super::generic::{SIMDArgMinMaxIgnoreNaN, SIMDOps, SIMDSetOps}; +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; +#[cfg(target_arch = "arm")] +use std::arch::arm::*; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// https://stackoverflow.com/a/3793950 +const MAX_INDEX: usize = 1 << f32::MANTISSA_DIGITS; + +// ------------------------------------------ AVX2 ------------------------------------------ + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx2_ignore_nan { + use super::super::config::{AVX2IgnoreNaN, AVX2}; + use super::*; + + const LANE_SIZE: usize = AVX2::LANE_SIZE_32; + + impl SIMDOps for AVX2IgnoreNaN { + const INITIAL_INDEX: __m256 = unsafe { + std::mem::transmute([ + 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, + ]) + }; + const INDEX_INCREMENT: __m256 = + unsafe { std::mem::transmute([LANE_SIZE as f32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m256) -> [f32; LANE_SIZE] { + std::mem::transmute::<__m256, [f32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m256 { + _mm256_loadu_ps(data as *const f32) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m256, b: __m256) -> __m256 { + _mm256_add_ps(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m256, b: __m256) -> __m256 { + _mm256_cmp_ps(a, b, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m256, b: __m256) -> __m256 { + _mm256_cmp_ps(b, a, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m256, b: __m256, mask: __m256) -> __m256 { + _mm256_blendv_ps(a, b, mask) + } + } + + impl SIMDSetOps for AVX2IgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f32) -> __m256 { + _mm256_set1_ps(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for AVX2IgnoreNaN { + #[target_feature(enable = "avx")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::AVX2IgnoreNaN as AVX2; + use super::SIMDArgMinMaxIgnoreNaN; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx") { + return; + } + + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 8, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx") { + return; + } + + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + if !is_x86_feature_detected!("avx") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_no_overflow() { + if !is_x86_feature_detected!("avx") { + return; + } + + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx") { + return; + } + + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 8 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ----------------------------------------- SSE ----------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod sse_ignore_nan { + use super::super::config::{SSEIgnoreNaN, SSE}; + use super::*; + + const LANE_SIZE: usize = SSE::LANE_SIZE_32; + + impl SIMDOps for SSEIgnoreNaN { + const INITIAL_INDEX: __m128 = + unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; + const INDEX_INCREMENT: __m128 = + unsafe { std::mem::transmute([LANE_SIZE as f32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m128) -> [f32; LANE_SIZE] { + std::mem::transmute::<__m128, [f32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m128 { + _mm_loadu_ps(data as *const f32) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m128, b: __m128) -> __m128 { + _mm_add_ps(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m128, b: __m128) -> __m128 { + _mm_cmpgt_ps(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m128, b: __m128) -> __m128 { + _mm_cmplt_ps(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m128, b: __m128, mask: __m128) -> __m128 { + _mm_blendv_ps(a, b, mask) + } + } + + impl SIMDSetOps for SSEIgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f32) -> __m128 { + _mm_set1_ps(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for SSEIgnoreNaN { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::SIMDArgMinMaxIgnoreNaN; + use super::SSEIgnoreNaN as SSE; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_no_overflow() { + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 4 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// --------------------------------------- AVX512 ---------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx512_ignore_nan { + use super::super::config::{AVX512IgnoreNaN, AVX512}; + use super::*; + + const LANE_SIZE: usize = AVX512::LANE_SIZE_32; + + impl SIMDOps for AVX512IgnoreNaN { + const INITIAL_INDEX: __m512 = unsafe { + std::mem::transmute([ + 0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, + 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, + ]) + }; + const INDEX_INCREMENT: __m512 = + unsafe { std::mem::transmute([LANE_SIZE as f32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m512) -> [f32; LANE_SIZE] { + std::mem::transmute::<__m512, [f32; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m512 { + _mm512_loadu_ps(data as *const f32) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m512, b: __m512) -> __m512 { + _mm512_add_ps(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m512, b: __m512) -> u16 { + _mm512_cmp_ps_mask(a, b, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m512, b: __m512) -> u16 { + _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m512, b: __m512, mask: u16) -> __m512 { + _mm512_mask_blend_ps(mask, a, b) + } + } + + impl SIMDSetOps for AVX512IgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f32) -> __m512 { + _mm512_set1_ps(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for AVX512IgnoreNaN { + #[target_feature(enable = "avx512f")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::AVX512IgnoreNaN as AVX512; + use super::SIMDArgMinMaxIgnoreNaN; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 16, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_no_overflow() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 16 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ---------------------------------------- NEON ----------------------------------------- + +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +mod neon_ignore_nan { + use super::super::config::{NEONIgnoreNaN, NEON}; + use super::*; + + const LANE_SIZE: usize = NEON::LANE_SIZE_32; + + impl SIMDOps for NEONIgnoreNaN { + const INITIAL_INDEX: float32x4_t = + unsafe { std::mem::transmute([0.0f32, 1.0f32, 2.0f32, 3.0f32]) }; + const INDEX_INCREMENT: float32x4_t = + unsafe { std::mem::transmute([LANE_SIZE as f32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: float32x4_t) -> [f32; LANE_SIZE] { + std::mem::transmute::(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> float32x4_t { + vld1q_f32(data as *const f32) + } + + #[inline(always)] + unsafe fn _mm_add(a: float32x4_t, b: float32x4_t) -> float32x4_t { + vaddq_f32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: float32x4_t, b: float32x4_t) -> uint32x4_t { + vcgtq_f32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: float32x4_t, b: float32x4_t) -> uint32x4_t { + vcltq_f32(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: float32x4_t, b: float32x4_t, mask: uint32x4_t) -> float32x4_t { + vbslq_f32(mask, b, a) + } + } + + impl SIMDSetOps for NEONIgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f32) -> float32x4_t { + vdupq_n_f32(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for NEONIgnoreNaN { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::NEONIgnoreNaN as NEON; + use super::SIMDArgMinMaxIgnoreNaN; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_no_overflow() { + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 4 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} diff --git a/src/simd/simd_f32_return_nan.rs b/src/simd/simd_f32_return_nan.rs new file mode 100644 index 0000000..c83b318 --- /dev/null +++ b/src/simd/simd_f32_return_nan.rs @@ -0,0 +1,1339 @@ +/// Implementation of the argminmax operations for f32 where NaN values take precedence. +/// This implementation returns the index of the first* NaN value if any are present, +/// otherwise it returns the index of the minimum and maximum values. +/// +/// To serve this functionality we transform the f32 values to ordinal i32 values: +/// ord_i32 = ((v >> 31) & 0x7FFFFFFF) ^ v +/// +/// This transformation is a bijection, i.e. it is reversible: +/// v = ((ord_i32 >> 31) & 0x7FFFFFFF) ^ ord_i32 +/// +/// Through this transformation we can perform the argminmax operations on the ordinal +/// integer values and then transform the result back to the original f32 values. +/// This transformation is necessary because comparisons with NaN values are always false. +/// So unless we perform ! <= as gt and ! >= as lt the argminmax operations will not +/// add NaN values to the accumulating SIMD register. And as le and ge are significantly +/// more expensive than lt and gt we use this efficient bitwise transformation. +/// +/// Also comparing integers is faster than comparing floats: +/// - https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_cmp_ps&ig_expand=902 +/// - https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_cmpgt_epi32&ig_expand=1084 +/// +/// +/// --- +/// +/// *Note: the first NaN value is only returned iff all NaN values have the same bit +/// representation. When NaN values have different bit representations then the index of +/// the highest / lowest ord_i32 is returned for the +/// SIMDOps::_get_overflow_lane_size_limit() chunk of the data - which is not +/// necessarily the index of the first NaN value. +/// +use super::config::SIMDInstructionSet; +use super::generic::{SIMDArgMinMax, SIMDOps}; +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; +#[cfg(target_arch = "arm")] +use std::arch::arm::*; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::task::{max_index_value, min_index_value}; + +const BIT_SHIFT: i32 = 31; +const MASK_VALUE: i32 = 0x7FFFFFFF; // i32::MAX - masks everything but the sign bit + +#[inline(always)] +fn _i32ord_to_f32(ord_i32: i32) -> f32 { + let v = ((ord_i32 >> BIT_SHIFT) & MASK_VALUE) ^ ord_i32; + f32::from_bits(v as u32) +} + +const MAX_INDEX: usize = i32::MAX as usize; + +// ------------------------------------------ AVX2 ------------------------------------------ + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx2_return_nan { + use super::super::config::AVX2; + use super::*; + + const LANE_SIZE: usize = AVX2::LANE_SIZE_32; + const LOWER_31_MASK: __m256i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_m256i_to_i32ord(f32_as_m256i: __m256i) -> __m256i { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = _mm256_srai_epi32(f32_as_m256i, BIT_SHIFT); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, LOWER_31_MASK); + _mm256_xor_si256(sign_bit_masked, f32_as_m256i) + } + + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: __m256i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) + } + + impl SIMDOps for AVX2 { + const INITIAL_INDEX: __m256i = + unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m256i) -> [f32; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m256i { + _f32_as_m256i_to_i32ord(_mm256_loadu_si256(data as *const __m256i)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi32(b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_f32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m256i, value: __m256i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_f32(max_value)) + } + } + + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, AVX2}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 8, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f32(128); + data[17] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_no_overflow() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx2") { + return; + } + + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 8 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ----------------------------------------- SSE ----------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod sse { + use super::super::config::SSE; + use super::*; + + const LANE_SIZE: usize = SSE::LANE_SIZE_32; + const LOWER_31_MASK: __m128i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_m128i_to_i32ord(f32_as_m128i: __m128i) -> __m128i { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = _mm_srai_epi32(f32_as_m128i, BIT_SHIFT); + let sign_bit_masked = _mm_and_si128(sign_bit_shifted, LOWER_31_MASK); + _mm_xor_si128(sign_bit_masked, f32_as_m128i) + } + + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: __m128i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) + } + + impl SIMDOps for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m128i) -> [f32; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m128i { + _f32_as_m128i_to_i32ord(_mm_loadu_si128(data as *const __m128i)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmpgt_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmplt_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_f32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m128i, value: __m128i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_f32(max_value)) + } + } + + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, SSE}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f32(128); + data[17] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_no_overflow() { + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 4 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// --------------------------------------- AVX512 ---------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx512 { + use super::super::config::AVX512; + use super::*; + + const LANE_SIZE: usize = AVX512::LANE_SIZE_32; + const LOWER_31_MASK: __m512i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_m512i_to_i32ord(f32_as_m512i: __m512i) -> __m512i { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = _mm512_srai_epi32(f32_as_m512i, BIT_SHIFT as u32); + let sign_bit_masked = _mm512_and_si512(sign_bit_shifted, LOWER_31_MASK); + _mm512_xor_si512(sign_bit_masked, f32_as_m512i) + } + + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: __m512i) -> [i32; LANE_SIZE] { + std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) + } + + impl SIMDOps for AVX512 { + const INITIAL_INDEX: __m512i = unsafe { + std::mem::transmute([ + 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, + 13i32, 14i32, 15i32, + ]) + }; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m512i) -> [f32; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> __m512i { + _f32_as_m512i_to_i32ord(_mm512_loadu_si512(data as *const i32)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m512i, b: __m512i) -> u16 { + _mm512_cmpgt_epi32_mask(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m512i, b: __m512i) -> u16 { + _mm512_cmplt_epi32_mask(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { + _mm512_mask_blend_epi32(mask, a, b) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_f32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m512i, value: __m512i) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_f32(max_value)) + } + } + + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512f")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, AVX512}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 16, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f32(128); + data[17] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_no_overflow() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 16 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ---------------------------------------- NEON ----------------------------------------- + +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +mod neon_float_return_nan { + use super::super::config::NEON; + use super::*; + + const LANE_SIZE: usize = NEON::LANE_SIZE_32; + const LOWER_31_MASK: int32x4_t = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f32_as_int32x4_to_i32ord(f32_as_int32x4: int32x4_t) -> int32x4_t { + // on a scalar: ((v >> 31) & 0x7FFFFFFF) ^ v + let sign_bit_shifted = vshrq_n_s32(f32_as_int32x4, BIT_SHIFT); + let sign_bit_masked = vandq_s32(sign_bit_shifted, LOWER_31_MASK); + veorq_s32(sign_bit_masked, f32_as_int32x4) + } + + #[inline(always)] + unsafe fn _reg_to_i32_arr(reg: int32x4_t) -> [i32; LANE_SIZE] { + std::mem::transmute::(reg) + } + + impl SIMDOps for NEON { + const INITIAL_INDEX: int32x4_t = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; + const INDEX_INCREMENT: int32x4_t = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: int32x4_t) -> [f32; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f32) -> int32x4_t { + _f32_as_int32x4_to_i32ord(vld1q_s32(data as *const i32)) + } + + #[inline(always)] + unsafe fn _mm_add(a: int32x4_t, b: int32x4_t) -> int32x4_t { + vaddq_s32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: int32x4_t, b: int32x4_t) -> uint32x4_t { + vcgtq_s32(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: int32x4_t, b: int32x4_t) -> uint32x4_t { + vcltq_s32(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { + vbslq_s32(mask, b, a) + } + + #[inline(always)] + unsafe fn _horiz_min(index: int32x4_t, value: int32x4_t) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_f32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: int32x4_t, value: int32x4_t) -> (usize, f32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_f32(max_value)) + } + } + + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[f32]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, NEON}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f32(n: usize) -> Vec { + utils::get_random_array(n, f32::MIN, f32::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f32] = &get_array_f32(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f32::MAX, + 6., + f32::NEG_INFINITY, + f32::NEG_INFINITY, + f32::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f32] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f32(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f32::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f32::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f32(arr_len); + data[100] = f32::INFINITY; + data[200] = f32::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f32(arr_len); + data[0] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f32(arr_len); + data[arr_len - 1] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f32(arr_len); + data[123] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f32::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f32(128); + data[17] = f32::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_no_overflow() { + let n: usize = 1 << 25; + let data: &[f32] = &get_array_f32(n); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f32] = &get_array_f32(32 * 4 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { NEON::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs deleted file mode 100644 index 0ea4a13..0000000 --- a/src/simd/simd_f64.rs +++ /dev/null @@ -1,419 +0,0 @@ -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -use super::config::SIMDInstructionSet; -use super::generic::SIMD; -#[cfg(target_arch = "x86")] -use std::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - -// ------------------------------------------ AVX2 ------------------------------------------ - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod avx2 { - use super::super::config::AVX2; - use super::*; - - const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - - impl SIMD for AVX2 { - const INITIAL_INDEX: __m256d = - unsafe { std::mem::transmute([0.0f64, 1.0f64, 2.0f64, 3.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m256d) -> [f64; LANE_SIZE] { - std::mem::transmute::<__m256d, [f64; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f64) -> __m256d { - _mm256_loadu_pd(data as *const f64) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256d { - _mm256_set1_pd(a as f64) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m256d, b: __m256d) -> __m256d { - _mm256_add_pd(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m256d, b: __m256d) -> __m256d { - _mm256_cmp_pd(a, b, _CMP_GT_OQ) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m256d, b: __m256d) -> __m256d { - _mm256_cmp_pd(b, a, _CMP_GT_OQ) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { - _mm256_blendv_pd(a, b, mask) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx")] - unsafe fn argminmax(data: &[f64]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{AVX2, SIMD}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f64(n: usize) -> Vec { - utils::get_random_array(n, f64::MIN, f64::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx") { - return; - } - - let data: &[f64] = &get_array_f64(1025); - assert_eq!(data.len() % 4, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx") { - return; - } - - let data = [ - 10., - std::f64::MAX, - 6., - std::f64::NEG_INFINITY, - std::f64::NEG_INFINITY, - std::f64::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f64] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_many_random_runs() { - if !is_x86_feature_detected!("avx") { - return; - } - - for _ in 0..10_000 { - let data: &[f64] = &get_array_f64(32 * 8 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// ----------------------------------------- SSE ----------------------------------------- - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod sse { - use super::super::config::SSE; - use super::*; - - const LANE_SIZE: usize = SSE::LANE_SIZE_64; - - impl SIMD for SSE { - const INITIAL_INDEX: __m128d = unsafe { std::mem::transmute([0.0f64, 1.0f64]) }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m128d) -> [f64; LANE_SIZE] { - std::mem::transmute::<__m128d, [f64; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f64) -> __m128d { - _mm_loadu_pd(data as *const f64) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128d { - _mm_set1_pd(a as f64) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m128d, b: __m128d) -> __m128d { - _mm_add_pd(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m128d, b: __m128d) -> __m128d { - _mm_cmpgt_pd(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m128d, b: __m128d) -> __m128d { - _mm_cmplt_pd(a, b) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { - _mm_blendv_pd(a, b, mask) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[f64]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{SIMD, SSE}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f64(n: usize) -> Vec { - utils::get_random_array(n, f64::MIN, f64::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - let data: &[f64] = &get_array_f64(1025); - assert_eq!(data.len() % 2, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10., - std::f64::MAX, - 6., - std::f64::NEG_INFINITY, - std::f64::NEG_INFINITY, - std::f64::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f64] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_many_random_runs() { - for _ in 0..10_000 { - let data: &[f64] = &get_array_f64(32 * 2 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// --------------------------------------- AVX512 ---------------------------------------- - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod avx512 { - use super::super::config::AVX512; - use super::*; - - const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - - impl SIMD for AVX512 { - const INITIAL_INDEX: __m512d = unsafe { - std::mem::transmute([ - 0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, - ]) - }; - // https://stackoverflow.com/a/3793950 - #[cfg(target_arch = "x86_64")] - const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; - #[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 - const MAX_INDEX: usize = u32::MAX as usize; - - #[inline(always)] - unsafe fn _reg_to_arr(reg: __m512d) -> [f64; LANE_SIZE] { - std::mem::transmute::<__m512d, [f64; LANE_SIZE]>(reg) - } - - #[inline(always)] - unsafe fn _mm_loadu(data: *const f64) -> __m512d { - _mm512_loadu_pd(data as *const f64) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512d { - _mm512_set1_pd(a as f64) - } - - #[inline(always)] - unsafe fn _mm_add(a: __m512d, b: __m512d) -> __m512d { - _mm512_add_pd(a, b) - } - - #[inline(always)] - unsafe fn _mm_cmpgt(a: __m512d, b: __m512d) -> u8 { - _mm512_cmp_pd_mask(a, b, _CMP_GT_OQ) - } - - #[inline(always)] - unsafe fn _mm_cmplt(a: __m512d, b: __m512d) -> u8 { - _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ) - } - - #[inline(always)] - unsafe fn _mm_blendv(a: __m512d, b: __m512d, mask: u8) -> __m512d { - _mm512_mask_blend_pd(mask, a, b) - } - - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512f")] - unsafe fn argminmax(data: &[f64]) -> (usize, usize) { - Self::_argminmax(data) - } - } - - // ------------------------------------ TESTS -------------------------------------- - - #[cfg(test)] - mod tests { - use super::{AVX512, SIMD}; - use crate::scalar::generic::scalar_argminmax; - - extern crate dev_utils; - use dev_utils::utils; - - fn get_array_f64(n: usize) -> Vec { - utils::get_random_array(n, f64::MIN, f64::MAX) - } - - #[test] - fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - let data: &[f64] = &get_array_f64(1025); - assert_eq!(data.len() % 2, 1); - - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - - #[test] - fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - let data = [ - 10., - std::f64::MAX, - 6., - std::f64::NEG_INFINITY, - std::f64::NEG_INFINITY, - std::f64::MAX, - 10_000.0, - ]; - let data: Vec = data.iter().map(|x| *x).collect(); - let data: &[f64] = &data; - - let (argmin_index, argmax_index) = scalar_argminmax(data); - assert_eq!(argmin_index, 3); - assert_eq!(argmax_index, 1); - - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_simd_index, 3); - assert_eq!(argmax_simd_index, 1); - } - - #[test] - fn test_many_random_runs() { - if !is_x86_feature_detected!("avx512f") { - return; - } - - for _ in 0..10_000 { - let data: &[f64] = &get_array_f64(32 * 2 + 1); - let (argmin_index, argmax_index) = scalar_argminmax(data); - let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; - assert_eq!(argmin_index, argmin_simd_index); - assert_eq!(argmax_index, argmax_simd_index); - } - } - } -} - -// ---------------------------------------- NEON ----------------------------------------- - -// There are no NEON intrinsics for f64, so we need to use the scalar version. -// although NEON intrinsics exist for i64 and u64, we cannot use them as -// they there is no 64-bit variant (of any data type) for the following three -// intrinsics: vadd_, vcgt_, vclt_ - -#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] -mod neon { - use super::super::config::NEON; - use super::super::generic::unimplement_simd; - use super::*; - - // We need to (un)implement the SIMD trait for the NEON struct as otherwise the - // compiler will complain that the trait is not implemented for the struct - - // even though we are not using the trait for the NEON struct when dealing with - // > 64 bit data types. - unimplement_simd!(f64, usize, NEON); -} diff --git a/src/simd/simd_f64_ignore_nan.rs b/src/simd/simd_f64_ignore_nan.rs new file mode 100644 index 0000000..de38073 --- /dev/null +++ b/src/simd/simd_f64_ignore_nan.rs @@ -0,0 +1,834 @@ +/// Implementation of the argminmax operations for f64 that ignores NaN values. +/// This implementation returns the index of the minimum and maximum values. +/// However, unexpected behavior may occur when there are +/// - *only* NaN values in the array +/// - *only* +/- infinity values in the array +/// - *only* NaN and +/- infinity values in the array +/// In these cases, index 0 is returned. +/// +/// NaN values are ignored and treated as if they are not present in the array. +/// To realize this we create an initial SIMD register with values +/- infinity. +/// As comparisons with NaN always return false, it is guaranteed that no NaN values +/// are added to the accumulating SIMD register. +/// + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use super::config::SIMDInstructionSet; +use super::generic::{SIMDArgMinMaxIgnoreNaN, SIMDOps, SIMDSetOps}; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// https://stackoverflow.com/a/3793950 +#[cfg(target_arch = "x86_64")] +const MAX_INDEX: usize = 1 << f64::MANTISSA_DIGITS; +#[cfg(target_arch = "x86")] // https://stackoverflow.com/a/29592369 +const MAX_INDEX: usize = u32::MAX as usize; + +// ------------------------------------------ AVX2 ------------------------------------------ + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx2_ignore_nan { + use super::super::config::{AVX2IgnoreNaN, AVX2}; + use super::*; + + const LANE_SIZE: usize = AVX2::LANE_SIZE_64; + + impl SIMDOps for AVX2IgnoreNaN { + const INITIAL_INDEX: __m256d = + unsafe { std::mem::transmute([0.0f64, 1.0f64, 2.0f64, 3.0f64]) }; + const INDEX_INCREMENT: __m256d = + unsafe { std::mem::transmute([LANE_SIZE as f64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m256d) -> [f64; LANE_SIZE] { + std::mem::transmute::<__m256d, [f64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m256d { + _mm256_loadu_pd(data as *const f64) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m256d, b: __m256d) -> __m256d { + _mm256_add_pd(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m256d, b: __m256d) -> __m256d { + _mm256_cmp_pd(a, b, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m256d, b: __m256d) -> __m256d { + _mm256_cmp_pd(b, a, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m256d, b: __m256d, mask: __m256d) -> __m256d { + _mm256_blendv_pd(a, b, mask) + } + } + + impl SIMDSetOps for AVX2IgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f64) -> __m256d { + _mm256_set1_pd(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for AVX2IgnoreNaN { + #[target_feature(enable = "avx")] + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::AVX2IgnoreNaN as AVX2; + use super::SIMDArgMinMaxIgnoreNaN; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx") { + return; + } + + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx") { + return; + } + + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + if !is_x86_feature_detected!("avx") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx") { + return; + } + + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 8 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ----------------------------------------- SSE ----------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod sse_ignore_nan { + use super::super::config::{SSEIgnoreNaN, SSE}; + use super::*; + + const LANE_SIZE: usize = SSE::LANE_SIZE_64; + + impl SIMDOps for SSEIgnoreNaN { + const INITIAL_INDEX: __m128d = unsafe { std::mem::transmute([0.0f64, 1.0f64]) }; + const INDEX_INCREMENT: __m128d = + unsafe { std::mem::transmute([LANE_SIZE as f64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m128d) -> [f64; LANE_SIZE] { + std::mem::transmute::<__m128d, [f64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m128d { + _mm_loadu_pd(data as *const f64) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m128d, b: __m128d) -> __m128d { + _mm_add_pd(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m128d, b: __m128d) -> __m128d { + _mm_cmpgt_pd(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m128d, b: __m128d) -> __m128d { + _mm_cmplt_pd(a, b) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m128d, b: __m128d, mask: __m128d) -> __m128d { + _mm_blendv_pd(a, b, mask) + } + } + + impl SIMDSetOps for SSEIgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f64) -> __m128d { + _mm_set1_pd(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for SSEIgnoreNaN { + #[target_feature(enable = "sse4.1")] // TODO: check if this is correct + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::SIMDArgMinMaxIgnoreNaN; + use super::SSEIgnoreNaN as SSE; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 2, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 2 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// --------------------------------------- AVX512 ---------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx512_ignore_nan { + use super::super::config::{AVX512IgnoreNaN, AVX512}; + use super::*; + + const LANE_SIZE: usize = AVX512::LANE_SIZE_64; + + impl SIMDOps for AVX512IgnoreNaN { + const INITIAL_INDEX: __m512d = unsafe { + std::mem::transmute([ + 0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, + ]) + }; + const INDEX_INCREMENT: __m512d = + unsafe { std::mem::transmute([LANE_SIZE as f64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(reg: __m512d) -> [f64; LANE_SIZE] { + std::mem::transmute::<__m512d, [f64; LANE_SIZE]>(reg) + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m512d { + _mm512_loadu_pd(data as *const f64) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m512d, b: __m512d) -> __m512d { + _mm512_add_pd(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m512d, b: __m512d) -> u8 { + _mm512_cmp_pd_mask(a, b, _CMP_GT_OQ) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m512d, b: __m512d) -> u8 { + _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m512d, b: __m512d, mask: u8) -> __m512d { + _mm512_mask_blend_pd(mask, a, b) + } + } + + impl SIMDSetOps for AVX512IgnoreNaN { + #[inline(always)] + unsafe fn _mm_set1(a: f64) -> __m512d { + _mm512_set1_pd(a) + } + } + + impl SIMDArgMinMaxIgnoreNaN for AVX512IgnoreNaN { + #[target_feature(enable = "avx512f")] + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::AVX512IgnoreNaN as AVX512; + use super::SIMDArgMinMaxIgnoreNaN; + use crate::scalar::generic::scalar_argminmax_ignore_nans as scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 2, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_ignore_nans() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 0); + assert!(argmax_index != 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 0); + assert!(argmax_simd_index != 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index > 99); + assert!(argmax_index > 99); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index > 99); + assert!(argmax_simd_index > 99); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 1026); + assert!(argmax_index != 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 1026); + assert!(argmax_simd_index != 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index < arr_len - 100); + assert!(argmax_index < arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index < arr_len - 100); + assert!(argmax_simd_index < arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert!(argmin_index != 123); + assert!(argmax_index != 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert!(argmin_simd_index != 123); + assert!(argmax_simd_index != 123); + + // Case 6: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 2 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ---------------------------------------- NEON ----------------------------------------- + +// There are no NEON intrinsics for f64, so we need to use the scalar version. +// although NEON intrinsics exist for i64 and u64, we cannot use them as +// they there is no 64-bit variant (of any data type) for the following three +// intrinsics: vadd_, vcgt_, vclt_ + +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +mod neon_ignore_nan { + use super::super::config::NEONIgnoreNaN; + use super::super::generic::{unimpl_SIMDArgMinMaxIgnoreNaN, unimpl_SIMDOps}; + use super::*; + + // We need to (un)implement the SIMD trait for the NEON struct as otherwise the + // compiler will complain that the trait is not implemented for the struct - + // even though we are not using the trait for the NEON struct when dealing with + // > 64 bit data types. + unimpl_SIMDOps!(f64, usize, NEONIgnoreNaN); + unimpl_SIMDArgMinMaxIgnoreNaN!(f64, usize, NEONIgnoreNaN); +} diff --git a/src/simd/simd_f64_return_nan.rs b/src/simd/simd_f64_return_nan.rs new file mode 100644 index 0000000..cd7f9e5 --- /dev/null +++ b/src/simd/simd_f64_return_nan.rs @@ -0,0 +1,1019 @@ +/// Implementation of the argminmax operations for f64 where NaN values take precedence. +/// This implementation returns the index of the first* NaN value if any are present, +/// otherwise it returns the index of the minimum and maximum values. +/// +/// To serve this functionality we transform the f64 values to ordinal i64 values: +/// ord_i64 = ((v >> 63) & 0x7FFFFFFFFFFFFFFF) ^ v +/// +/// This transformation is a bijection, i.e. it is reversible: +/// v = ((ord_i64 >> 63) & 0x7FFFFFFFFFFFFFFF) ^ ord_i64 +/// +/// Through this transformation we can perform the argminmax operations on the ordinal +/// integer values and then transform the result back to the original f64 values. +/// This transformation is necessary because comparisons with NaN values are always false. +/// So unless we perform ! <= as gt and ! >= as lt the argminmax operations will not +/// add NaN values to the accumulating SIMD register. And as le and ge are significantly +/// more expensive than lt and gt we use this efficient bitwise transformation. +/// +/// Also comparing integers is faster than comparing floats: +/// - https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_cmp_pd&ig_expand=886 +/// - https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_cmpgt_epi64&ig_expand=1094 +/// +/// +/// --- +/// +/// *Note: the first NaN value is only returned iff all NaN values have the same bit +/// representation. When NaN values have different bit representations then the index of +/// the highest / lowest ord_i64 is returned for the +/// SIMDOps::_get_overflow_lane_size_limit() chunk of the data - which is not +/// necessarily the index of the first NaN value. +/// + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use super::config::SIMDInstructionSet; +use super::generic::{SIMDArgMinMax, SIMDOps}; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +use super::task::{max_index_value, min_index_value}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const BIT_SHIFT: i32 = 63; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const MASK_VALUE: i64 = 0x7FFFFFFFFFFFFFFF; // i64::MAX - masks everything but the sign bit + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[inline(always)] +fn _i64ord_to_f64(ord_i64: i64) -> f64 { + let v = ((ord_i64 >> BIT_SHIFT) & MASK_VALUE) ^ ord_i64; + f64::from_bits(v as u64) +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const MAX_INDEX: usize = i64::MAX as usize; + +// ------------------------------------------ AVX2 ------------------------------------------ + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx2 { + use super::super::config::AVX2; + use super::*; + + const LANE_SIZE: usize = AVX2::LANE_SIZE_64; + const LOWER_63_MASK: __m256i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f64_as_m256i_to_i64ord(f64_as_m256i: __m256i) -> __m256i { + // on a scalar: ((v >> 63) & 0x7FFFFFFFFFFFFFFF) ^ v + // Note: _mm256_srai_epi64 is not available on AVX2.. (only AVX512F) + // -> As we only want to shift the sign bit to the first position, we can use + // _mm256_srai_epi32 instead, which is available on AVX2, and then copy the + // sign bit to the next 32 bits (per 64 bit lane). + let sign_bit_shifted = + _mm256_shuffle_epi32(_mm256_srai_epi32(f64_as_m256i, BIT_SHIFT), 0b11110101); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, LOWER_63_MASK); + _mm256_xor_si256(sign_bit_masked, f64_as_m256i) + } + + #[inline(always)] + unsafe fn _reg_to_i64_arr(reg: __m256i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) + } + + impl SIMDOps for AVX2 { + const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m256i) -> [f64; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m256i { + _f64_as_m256i_to_i64ord(_mm256_loadu_si256(data as *const __m256i)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { + _mm256_add_epi64(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi64(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m256i, b: __m256i) -> __m256i { + _mm256_cmpgt_epi64(b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { + _mm256_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_f64(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m256i, value: __m256i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_f64(max_value)) + } + } + + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, AVX2}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 4, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx2") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f64(128); + data[17] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx2") { + return; + } + + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 8 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX2::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ----------------------------------------- SSE ----------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod sse { + use super::super::config::SSE; + use super::*; + + const LANE_SIZE: usize = SSE::LANE_SIZE_64; + const LOWER_63_MASK: __m128i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f64_as_m128i_to_i64ord(f64_as_m128i: __m128i) -> __m128i { + // on a scalar: ((v >> 63) & 0x7FFFFFFFFFFFFFFF) ^ v + // Note: _mm_srai_epi64 is not available on AVX2.. (only on AVX512F) + // -> As we only want to shift the sign bit to the first position, we can use + // _mm_srai_epi32 instead, which is available on AVX2, and then copy the + // sign bit to the next 32 bits (per 64 bit lane). + let sign_bit_shifted = + _mm_shuffle_epi32(_mm_srai_epi32(f64_as_m128i, BIT_SHIFT), 0b11110101); + let sign_bit_masked = _mm_and_si128(sign_bit_shifted, LOWER_63_MASK); + _mm_xor_si128(sign_bit_masked, f64_as_m128i) + } + + #[inline(always)] + unsafe fn _reg_to_i64_arr(reg: __m128i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) + } + + impl SIMDOps for SSE { + const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m128i) -> [f64; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m128i { + _f64_as_m128i_to_i64ord(_mm_loadu_si128(data as *const __m128i)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { + _mm_add_epi64(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmpgt_epi64(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m128i, b: __m128i) -> __m128i { + _mm_cmpgt_epi64(b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { + _mm_blendv_epi8(a, b, mask) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_f64(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m128i, value: __m128i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_f64(max_value)) + } + } + + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.2")] + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, SSE}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 2, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f64(128); + data[17] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 2 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { SSE::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// --------------------------------------- AVX512 ---------------------------------------- + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +mod avx512 { + use super::super::config::AVX512; + use super::*; + + const LANE_SIZE: usize = AVX512::LANE_SIZE_64; + const LOWER_63_MASK: __m512i = unsafe { std::mem::transmute([MASK_VALUE; LANE_SIZE]) }; + + #[inline(always)] + unsafe fn _f64_as_m512i_to_i64ord(f64_as_m512i: __m512i) -> __m512i { + // on a scalar: ((v >> 63) & 0x7FFFFFFFFFFFFFFF) ^ v + let sign_bit_shifted = _mm512_srai_epi64(f64_as_m512i, BIT_SHIFT as u32); + let sign_bit_masked = _mm512_and_si512(sign_bit_shifted, LOWER_63_MASK); + _mm512_xor_si512(sign_bit_masked, f64_as_m512i) + } + + #[inline(always)] + unsafe fn _reg_to_i64_arr(reg: __m512i) -> [i64; LANE_SIZE] { + std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) + } + + impl SIMDOps for AVX512 { + const INITIAL_INDEX: __m512i = + unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; + + #[inline(always)] + unsafe fn _reg_to_arr(_: __m512i) -> [f64; LANE_SIZE] { + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to floats. + unimplemented!() + } + + #[inline(always)] + unsafe fn _mm_loadu(data: *const f64) -> __m512i { + _f64_as_m512i_to_i64ord(_mm512_loadu_epi64(data as *const i64)) + } + + #[inline(always)] + unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { + _mm512_add_epi64(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmpgt(a: __m512i, b: __m512i) -> u8 { + _mm512_cmpgt_epi64_mask(a, b) + } + + #[inline(always)] + unsafe fn _mm_cmplt(a: __m512i, b: __m512i) -> u8 { + _mm512_cmpgt_epi64_mask(b, a) + } + + #[inline(always)] + unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { + _mm512_mask_blend_epi64(mask, a, b) + } + + #[inline(always)] + unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_f64(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m512i, value: __m512i) -> (usize, f64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_f64(max_value)) + } + } + + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512f")] + unsafe fn argminmax(data: &[f64]) -> (usize, usize) { + Self::_argminmax(data) + } + } + + // ------------------------------------ TESTS -------------------------------------- + + #[cfg(test)] + mod tests { + use super::{SIMDArgMinMax, AVX512}; + use crate::scalar::generic::scalar_argminmax; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f64(n: usize) -> Vec { + utils::get_random_array(n, f64::MIN, f64::MAX) + } + + #[test] + fn test_both_versions_return_the_same_results() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data: &[f64] = &get_array_f64(1025); + assert_eq!(data.len() % 2, 1); + + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let data = [ + 10., + f64::MAX, + 6., + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::MAX, + 10_000.0, + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data: &[f64] = &data; + + let (argmin_index, argmax_index) = scalar_argminmax(data); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_return_infs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + let mut data: Vec = get_array_f64(arr_len); + + // Case 1: all elements are +inf + for i in 0..data.len() { + data[i] = f64::INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: all elements are -inf + for i in 0..data.len() { + data[i] = f64::NEG_INFINITY; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: add some +inf and -inf in the middle + let mut data: Vec = get_array_f64(arr_len); + data[100] = f64::INFINITY; + data[200] = f64::NEG_INFINITY; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 200); + assert_eq!(argmax_index, 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 200); + assert_eq!(argmax_simd_index, 100); + } + + #[test] + fn test_return_nans() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + let arr_len: usize = 1027; + + // Case 1: NaN is the first element + let mut data: Vec = get_array_f64(arr_len); + data[0] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 2: first 100 elements are NaN + for i in 0..100 { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 3: NaN is the last element + let mut data: Vec = get_array_f64(arr_len); + data[arr_len - 1] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 1026); + assert_eq!(argmax_index, 1026); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 1026); + assert_eq!(argmax_simd_index, 1026); + + // Case 4: last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, arr_len - 100); + assert_eq!(argmax_index, arr_len - 100); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, arr_len - 100); + assert_eq!(argmax_simd_index, arr_len - 100); + + // Case 5: NaN is somewhere in the middle element + let mut data: Vec = get_array_f64(arr_len); + data[123] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 6: NaN in the middle of the array and last 100 elements are NaN + for i in 0..100 { + data[arr_len - 1 - i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 123); + assert_eq!(argmax_index, 123); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 123); + assert_eq!(argmax_simd_index, 123); + + // Case 7: all elements are NaN + for i in 0..data.len() { + data[i] = f64::NAN; + } + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 0); + assert_eq!(argmax_index, 0); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 0); + assert_eq!(argmax_simd_index, 0); + + // Case 8: array exact multiple of LANE_SIZE and only 1 element is NaN + let mut data: Vec = get_array_f64(128); + data[17] = f64::NAN; + + let (argmin_index, argmax_index) = scalar_argminmax(&data); + assert_eq!(argmin_index, 17); + assert_eq!(argmax_index, 17); + + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(&data) }; + assert_eq!(argmin_simd_index, 17); + assert_eq!(argmax_simd_index, 17); + } + + #[test] + fn test_many_random_runs() { + if !is_x86_feature_detected!("avx512f") { + return; + } + + for _ in 0..10_000 { + let data: &[f64] = &get_array_f64(32 * 2 + 1); + let (argmin_index, argmax_index) = scalar_argminmax(data); + let (argmin_simd_index, argmax_simd_index) = unsafe { AVX512::argminmax(data) }; + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } + } +} + +// ---------------------------------------- NEON ----------------------------------------- + +// There are no NEON intrinsics for f64, so we need to use the scalar version. +// although NEON intrinsics exist for i64 and u64, we cannot use them as +// they there is no 64-bit variant (of any data type) for the following three +// intrinsics: vadd_, vcgt_, vclt_ + +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +mod neon { + use super::super::config::NEON; + use super::super::generic::{unimpl_SIMDArgMinMax, unimpl_SIMDOps}; + use super::*; + + // We need to (un)implement the SIMD trait for the NEON struct as otherwise the + // compiler will complain that the trait is not implemented for the struct - + // even though we are not using the trait for the NEON struct when dealing with + // > 64 bit data types. + unimpl_SIMDOps!(f64, usize, NEON); + unimpl_SIMDArgMinMax!(f64, usize, NEON); +} diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index f6ae197..e5c45ed 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -9,6 +9,8 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; +const MAX_INDEX: usize = i16::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -18,14 +20,16 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_16; - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, 13i16, 14i16, 15i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m256i) -> [i16; LANE_SIZE] { @@ -37,11 +41,6 @@ mod avx2 { _mm256_loadu_si256(data as *const __m256i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { _mm256_add_epi16(a, b) @@ -62,13 +61,6 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx2")] - unsafe fn argminmax(data: &[i16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, i16) { // 0. Find the minimum value @@ -130,11 +122,18 @@ mod avx2 { } } + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[i16]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -165,17 +164,7 @@ mod avx2 { return; } - let data = [ - 10, - std::i16::MIN, - 6, - 9, - 9, - 22, - std::i16::MAX, - 4, - std::i16::MAX, - ]; + let data = [10, i16::MIN, 6, 9, 9, 22, i16::MAX, 4, i16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i16] = &data; @@ -229,10 +218,12 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_16; - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m128i) -> [i16; LANE_SIZE] { @@ -244,11 +235,6 @@ mod sse { _mm_loadu_si128(data as *const __m128i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { _mm_add_epi16(a, b) @@ -269,13 +255,6 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[i16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, i16) { // 0. Find the minimum value @@ -333,11 +312,18 @@ mod sse { } } + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[i16]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -360,17 +346,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::i16::MIN, - 6, - 9, - 9, - 22, - std::i16::MAX, - 4, - std::i16::MAX, - ]; + let data = [10, i16::MIN, 6, 9, 9, 22, i16::MAX, 4, i16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i16] = &data; @@ -416,7 +392,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_16; - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -424,7 +400,9 @@ mod avx512 { 25i16, 26i16, 27i16, 28i16, 29i16, 30i16, 31i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m512i) -> [i16; LANE_SIZE] { @@ -436,11 +414,6 @@ mod avx512 { _mm512_loadu_epi16(data as *const i16) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { _mm512_add_epi16(a, b) @@ -461,13 +434,6 @@ mod avx512 { _mm512_mask_blend_epi16(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512bw")] - unsafe fn argminmax(data: &[i16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, i16) { // 0. Find the minimum value @@ -533,11 +499,18 @@ mod avx512 { } } + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512bw")] + unsafe fn argminmax(data: &[i16]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -568,17 +541,7 @@ mod avx512 { return; } - let data = [ - 10, - std::i16::MIN, - 6, - 9, - 9, - 22, - std::i16::MAX, - 4, - std::i16::MAX, - ]; + let data = [10, i16::MIN, 6, 9, 9, 22, i16::MAX, 4, i16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i16] = &data; @@ -632,10 +595,12 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: int16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: int16x8_t = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: int16x8_t) -> [i16; LANE_SIZE] { @@ -647,11 +612,6 @@ mod neon { vld1q_s16(data as *const i16) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int16x8_t { - vdupq_n_s16(a as i16) - } - #[inline(always)] unsafe fn _mm_add(a: int16x8_t, b: int16x8_t) -> int16x8_t { vaddq_s16(a, b) @@ -672,13 +632,6 @@ mod neon { vbslq_s16(mask, b, a) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[i16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: int16x8_t, value: int16x8_t) -> (usize, i16) { // 0. Find the minimum value @@ -736,11 +689,18 @@ mod neon { } } + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[i16]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -763,17 +723,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::i16::MIN, - 6, - 9, - 9, - 22, - std::i16::MAX, - 4, - std::i16::MAX, - ]; + let data = [10, i16::MIN, 6, 9, 9, 22, i16::MAX, 4, i16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i16] = &data; diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 6ec4512..812e776 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -9,6 +9,8 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; +const MAX_INDEX: usize = i32::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -18,10 +20,12 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_32; - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m256i) -> [i32; LANE_SIZE] { @@ -33,11 +37,6 @@ mod avx2 { _mm256_loadu_si256(data as *const __m256i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi32(a as i32) - } - #[inline(always)] unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { _mm256_add_epi32(a, b) @@ -57,9 +56,9 @@ mod avx2 { unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for AVX2 { #[target_feature(enable = "avx2")] unsafe fn argminmax(data: &[i32]) -> (usize, usize) { Self::_argminmax(data) @@ -70,7 +69,7 @@ mod avx2 { #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -101,16 +100,7 @@ mod avx2 { return; } - let data = [ - std::i32::MIN, - std::i32::MIN, - 4, - 6, - 9, - std::i32::MAX, - 22, - std::i32::MAX, - ]; + let data = [i32::MIN, i32::MIN, 4, 6, 9, i32::MAX, 22, i32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i32] = &data; @@ -149,9 +139,11 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_32; - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m128i) -> [i32; LANE_SIZE] { @@ -163,11 +155,6 @@ mod sse { _mm_loadu_si128(data as *const __m128i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi32(a as i32) - } - #[inline(always)] unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { _mm_add_epi32(a, b) @@ -187,9 +174,9 @@ mod sse { unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for SSE { #[target_feature(enable = "sse4.1")] unsafe fn argminmax(data: &[i32]) -> (usize, usize) { Self::_argminmax(data) @@ -200,7 +187,7 @@ mod sse { #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -223,16 +210,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - std::i32::MIN, - std::i32::MIN, - 4, - 6, - 9, - std::i32::MAX, - 22, - std::i32::MAX, - ]; + let data = [i32::MIN, i32::MIN, 4, 6, 9, i32::MAX, 22, i32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i32] = &data; @@ -267,14 +245,16 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, 13i32, 14i32, 15i32, ]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m512i) -> [i32; LANE_SIZE] { @@ -286,11 +266,6 @@ mod avx512 { _mm512_loadu_si512(data as *const i32) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi32(a as i32) - } - #[inline(always)] unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { _mm512_add_epi32(a, b) @@ -310,9 +285,9 @@ mod avx512 { unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u16) -> __m512i { _mm512_mask_blend_epi32(mask, a, b) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for AVX512 { #[target_feature(enable = "avx512f")] unsafe fn argminmax(data: &[i32]) -> (usize, usize) { Self::_argminmax(data) @@ -323,7 +298,7 @@ mod avx512 { #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -354,16 +329,7 @@ mod avx512 { return; } - let data = [ - std::i32::MIN, - std::i32::MIN, - 4, - 6, - 9, - std::i32::MAX, - 22, - std::i32::MAX, - ]; + let data = [i32::MIN, i32::MIN, 4, 6, 9, i32::MAX, 22, i32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i32] = &data; @@ -402,9 +368,11 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: int32x4_t = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: int32x4_t = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: int32x4_t) -> [i32; LANE_SIZE] { @@ -416,11 +384,6 @@ mod neon { vld1q_s32(data) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int32x4_t { - vdupq_n_s32(a as i32) - } - #[inline(always)] unsafe fn _mm_add(a: int32x4_t, b: int32x4_t) -> int32x4_t { vaddq_s32(a, b) @@ -440,9 +403,9 @@ mod neon { unsafe fn _mm_blendv(a: int32x4_t, b: int32x4_t, mask: uint32x4_t) -> int32x4_t { vbslq_s32(mask, b, a) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for NEON { #[target_feature(enable = "neon")] unsafe fn argminmax(data: &[i32]) -> (usize, usize) { Self::_argminmax(data) @@ -453,7 +416,7 @@ mod neon { #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -476,16 +439,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - std::i32::MIN, - std::i32::MIN, - 4, - 6, - 9, - std::i32::MAX, - 22, - std::i32::MAX, - ]; + let data = [i32::MIN, i32::MIN, 4, 6, 9, i32::MAX, 22, i32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i32] = &data; diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index ff563a2..c47a708 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -1,11 +1,14 @@ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "x86")] use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const MAX_INDEX: usize = i64::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -15,9 +18,11 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_64; - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m256i) -> [i64; LANE_SIZE] { @@ -29,11 +34,6 @@ mod avx2 { _mm256_loadu_si256(data as *const __m256i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi64x(a as i64) - } - #[inline(always)] unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { _mm256_add_epi64(a, b) @@ -53,9 +53,9 @@ mod avx2 { unsafe fn _mm_blendv(a: __m256i, b: __m256i, mask: __m256i) -> __m256i { _mm256_blendv_epi8(a, b, mask) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for AVX2 { #[target_feature(enable = "avx2")] unsafe fn argminmax(data: &[i64]) -> (usize, usize) { Self::_argminmax(data) @@ -66,7 +66,7 @@ mod avx2 { #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -97,16 +97,7 @@ mod avx2 { return; } - let data = [ - std::i64::MIN, - std::i64::MIN, - 4, - 6, - 9, - std::i64::MAX, - 22, - std::i64::MAX, - ]; + let data = [i64::MIN, i64::MIN, 4, 6, 9, i64::MAX, 22, i64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i64] = &data; @@ -145,9 +136,11 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_64; - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m128i) -> [i64; LANE_SIZE] { @@ -159,11 +152,6 @@ mod sse { _mm_loadu_si128(data as *const __m128i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi64x(a as i64) - } - #[inline(always)] unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { _mm_add_epi64(a, b) @@ -183,9 +171,9 @@ mod sse { unsafe fn _mm_blendv(a: __m128i, b: __m128i, mask: __m128i) -> __m128i { _mm_blendv_epi8(a, b, mask) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for SSE { #[target_feature(enable = "sse4.2")] unsafe fn argminmax(data: &[i64]) -> (usize, usize) { Self::_argminmax(data) @@ -196,7 +184,7 @@ mod sse { #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -219,16 +207,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - std::i64::MIN, - std::i64::MIN, - 4, - 6, - 9, - std::i64::MAX, - 22, - std::i64::MAX, - ]; + let data = [i64::MIN, i64::MIN, 4, 6, 9, i64::MAX, 22, i64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i64] = &data; @@ -263,10 +242,12 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m512i) -> [i64; LANE_SIZE] { @@ -278,11 +259,6 @@ mod avx512 { _mm512_loadu_epi64(data as *const i64) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi64(a as i64) - } - #[inline(always)] unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { _mm512_add_epi64(a, b) @@ -302,9 +278,9 @@ mod avx512 { unsafe fn _mm_blendv(a: __m512i, b: __m512i, mask: u8) -> __m512i { _mm512_mask_blend_epi64(mask, a, b) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for AVX512 { #[target_feature(enable = "avx512f")] unsafe fn argminmax(data: &[i64]) -> (usize, usize) { Self::_argminmax(data) @@ -315,7 +291,7 @@ mod avx512 { #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -346,16 +322,7 @@ mod avx512 { return; } - let data = [ - std::i64::MIN, - std::i64::MIN, - 4, - 6, - 9, - std::i64::MAX, - 22, - std::i64::MAX, - ]; + let data = [i64::MIN, i64::MIN, 4, 6, 9, i64::MAX, 22, i64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i64] = &data; @@ -395,12 +362,13 @@ mod avx512 { #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] mod neon { use super::super::config::NEON; - use super::super::generic::unimplement_simd; + use super::super::generic::{unimpl_SIMDArgMinMax, unimpl_SIMDOps}; use super::*; // We need to (un)implement the SIMD trait for the NEON struct as otherwise the // compiler will complain that the trait is not implemented for the struct - // even though we are not using the trait for the NEON struct when dealing with // > 64 bit data types. - unimplement_simd!(i64, usize, NEON); + unimpl_SIMDOps!(i64, usize, NEON); + unimpl_SIMDArgMinMax!(i64, usize, NEON); } diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index d6baac8..ae2d1f5 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -9,6 +9,8 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; +const MAX_INDEX: usize = i8::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -18,7 +20,7 @@ mod avx2 { const LANE_SIZE: usize = AVX2::LANE_SIZE_8; - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -26,7 +28,9 @@ mod avx2 { 29i8, 30i8, 31i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m256i) -> [i8; LANE_SIZE] { @@ -38,11 +42,6 @@ mod avx2 { _mm256_loadu_si256(data as *const __m256i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi8(a as i8) - } - #[inline(always)] unsafe fn _mm_add(a: __m256i, b: __m256i) -> __m256i { _mm256_add_epi8(a, b) @@ -63,13 +62,6 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx2")] - unsafe fn argminmax(data: &[i8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, i8) { // 0. Find the minimum value @@ -135,11 +127,18 @@ mod avx2 { } } + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[i8]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -170,7 +169,7 @@ mod avx2 { return; } - let data = [10, std::i8::MIN, 6, 9, 9, 22, std::i8::MAX, 4, std::i8::MAX]; + let data = [10, i8::MIN, 6, 9, 9, 22, i8::MAX, 4, i8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i8] = &data; @@ -224,14 +223,16 @@ mod sse { const LANE_SIZE: usize = SSE::LANE_SIZE_8; - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, 15i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m128i) -> [i8; LANE_SIZE] { @@ -243,11 +244,6 @@ mod sse { _mm_loadu_si128(data as *const __m128i) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi8(a as i8) - } - #[inline(always)] unsafe fn _mm_add(a: __m128i, b: __m128i) -> __m128i { _mm_add_epi8(a, b) @@ -268,13 +264,6 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[i8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, i8) { // 0. Find the minimum value @@ -336,11 +325,18 @@ mod sse { } } + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[i8]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -363,7 +359,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [10, std::i8::MIN, 6, 9, 9, 22, std::i8::MAX, 4, std::i8::MAX]; + let data = [10, i8::MIN, 6, 9, 9, 22, i8::MAX, 4, i8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i8] = &data; @@ -409,7 +405,7 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_8; - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -419,7 +415,9 @@ mod avx512 { 57i8, 58i8, 59i8, 60i8, 61i8, 62i8, 63i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: __m512i) -> [i8; LANE_SIZE] { @@ -431,11 +429,6 @@ mod avx512 { _mm512_loadu_epi8(data as *const i8) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi8(a as i8) - } - #[inline(always)] unsafe fn _mm_add(a: __m512i, b: __m512i) -> __m512i { _mm512_add_epi8(a, b) @@ -456,13 +449,6 @@ mod avx512 { _mm512_mask_blend_epi8(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512bw")] - unsafe fn argminmax(data: &[i8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, i8) { // 0. Find the minimum value @@ -532,11 +518,18 @@ mod avx512 { } } + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512bw")] + unsafe fn argminmax(data: &[i8]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -567,7 +560,7 @@ mod avx512 { return; } - let data = [10, std::i8::MIN, 6, 9, 9, 22, std::i8::MAX, 4, std::i8::MAX]; + let data = [10, i8::MIN, 6, 9, 9, 22, i8::MAX, 4, i8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i8] = &data; @@ -621,14 +614,16 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: int8x16_t = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, 15i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: int8x16_t = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: int8x16_t) -> [i8; LANE_SIZE] { @@ -641,11 +636,6 @@ mod neon { vld1q_s8(data as *const i8) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> int8x16_t { - vdupq_n_s8(a as i8) - } - #[inline(always)] unsafe fn _mm_add(a: int8x16_t, b: int8x16_t) -> int8x16_t { vaddq_s8(a, b) @@ -666,13 +656,6 @@ mod neon { vbslq_s8(mask, b, a) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[i8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: int8x16_t, value: int8x16_t) -> (usize, i8) { // 0. Find the minimum value @@ -734,11 +717,18 @@ mod neon { } } + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[i8]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -761,7 +751,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [10, std::i8::MIN, 6, 9, 9, 22, std::i8::MAX, 4, std::i8::MAX]; + let data = [10, i8::MIN, 6, 9, 9, 22, i8::MAX, 4, i8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[i8] = &data; diff --git a/src/simd/simd_u16.rs b/src/simd/simd_u16.rs index 3df6c2d..98dd416 100644 --- a/src/simd/simd_u16.rs +++ b/src/simd/simd_u16.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -10,15 +10,17 @@ use std::arch::x86::*; use std::arch::x86_64::*; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -const XOR_VALUE: i16 = 0x7FFF; +const XOR_VALUE: i16 = -0x8000; // i16::MIN #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[inline(always)] -fn _i16decrord_to_u16(decrord_i16: i16) -> u16 { - // let v = ord_i16 ^ 0x7FFF; - unsafe { std::mem::transmute::(decrord_i16 ^ XOR_VALUE) } +fn _i16ord_to_u16(ord_i16: i16) -> u16 { + // let v = ord_i16 ^ -0x8000; + unsafe { std::mem::transmute::(ord_i16 ^ XOR_VALUE) } } +const MAX_INDEX: usize = i16::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -30,10 +32,10 @@ mod avx2 { const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u16_to_i16decrord(u16: __m256i) -> __m256i { - // on a scalar: v^ 0x7FFF - // transforms to monotonically **decreasing** order - _mm256_xor_si256(u16, XOR_MASK) // Only 1 operation + unsafe fn _u16_as_m256i_to_i16ord(u16_as_m256i: __m256i) -> __m256i { + // on a scalar: v ^ -0x8000 + // transforms to monotonically increasing order + _mm256_xor_si256(u16_as_m256i, XOR_MASK) } #[inline(always)] @@ -41,29 +43,28 @@ mod avx2 { std::mem::transmute::<__m256i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, 13i16, 14i16, 15i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m256i) -> [u16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u16) -> __m256i { - _u16_to_i16decrord(_mm256_loadu_si256(data as *const __m256i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi16(a as i16) + _u16_as_m256i_to_i16ord(_mm256_loadu_si256(data as *const __m256i)) } #[inline(always)] @@ -86,13 +87,6 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx2")] - unsafe fn argminmax(data: &[u16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, u16) { // 0. Find the minimum value @@ -120,7 +114,7 @@ mod avx2 { imin = _mm256_min_epi16(imin, _mm256_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm256_extract_epi16(imin, 0) as usize; - (min_index, _i16decrord_to_u16(min_value)) + (min_index, _i16ord_to_u16(min_value)) } #[inline(always)] @@ -150,20 +144,14 @@ mod avx2 { imin = _mm256_min_epi16(imin, _mm256_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm256_extract_epi16(imin, 0) as usize; - (max_index, _i16decrord_to_u16(max_value)) + (max_index, _i16ord_to_u16(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m256i, - values_low: __m256i, - index_high: __m256i, - values_high: __m256i, - ) -> (usize, u16, usize, u16) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i16ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[u16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -171,7 +159,7 @@ mod avx2 { #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -202,17 +190,7 @@ mod avx2 { return; } - let data = [ - 10, - std::u16::MIN, - 6, - 9, - 9, - 22, - std::u16::MAX, - 4, - std::u16::MAX, - ]; + let data = [10, u16::MIN, 6, 9, 9, 22, u16::MAX, 4, u16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u16] = &data; @@ -268,10 +246,10 @@ mod sse { const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u16_to_i16decrord(u16: __m128i) -> __m128i { - // on a scalar: v^ 0x7F - // transforms to monotonically **decreasing** order - _mm_xor_si128(u16, XOR_MASK) + unsafe fn _u16_as_m128i_to_i16ord(u16_as_m128i: __m128i) -> __m128i { + // on a scalar: v ^ -0x8000 + // transforms to monotonically increasing order + _mm_xor_si128(u16_as_m128i, XOR_MASK) } #[inline(always)] @@ -279,25 +257,24 @@ mod sse { std::mem::transmute::<__m128i, [i16; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m128i) -> [u16; LANE_SIZE] { - // Not used because we work with i16ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u16) -> __m128i { - _u16_to_i16decrord(_mm_loadu_si128(data as *const __m128i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi16(a as i16) + _u16_as_m128i_to_i16ord(_mm_loadu_si128(data as *const __m128i)) } #[inline(always)] @@ -320,13 +297,6 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[u16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, u16) { // 0. Find the minimum value @@ -352,7 +322,7 @@ mod sse { imin = _mm_min_epi16(imin, _mm_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm_extract_epi16(imin, 0) as usize; - (min_index, _i16decrord_to_u16(min_value)) + (min_index, _i16ord_to_u16(min_value)) } #[inline(always)] @@ -380,20 +350,14 @@ mod sse { imin = _mm_min_epi16(imin, _mm_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm_extract_epi16(imin, 0) as usize; - (max_index, _i16decrord_to_u16(max_value)) + (max_index, _i16ord_to_u16(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m128i, - values_low: __m128i, - index_high: __m128i, - values_high: __m128i, - ) -> (usize, u16, usize, u16) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i16ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[u16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -401,7 +365,7 @@ mod sse { #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -424,17 +388,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::u16::MIN, - 6, - 9, - 9, - 22, - std::u16::MAX, - 4, - std::u16::MAX, - ]; + let data = [10, u16::MIN, 6, 9, 9, 22, u16::MAX, 4, u16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u16] = &data; @@ -481,13 +435,11 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_16; const XOR_MASK: __m512i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; - // TODO - comparison swappen => dan moeten we opt einde niet meer swappen? - #[inline(always)] - unsafe fn _u16_to_i16decrord(u16: __m512i) -> __m512i { - // on a scalar: v^ 0x7FFF - // transforms to monotonically **decreasing** order - _mm512_xor_si512(u16, XOR_MASK) + unsafe fn _u16_as_m512i_to_i16ord(u16_as_m512i: __m512i) -> __m512i { + // on a scalar: v ^ -0x8000 + // transforms to monotonically increasing order + _mm512_xor_si512(u16_as_m512i, XOR_MASK) } #[inline(always)] @@ -495,7 +447,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i16; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16, 8i16, 9i16, 10i16, 11i16, 12i16, @@ -503,21 +455,21 @@ mod avx512 { 25i16, 26i16, 27i16, 28i16, 29i16, 30i16, 31i16, ]) }; - const MAX_INDEX: usize = i16::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m512i) -> [u16; LANE_SIZE] { - unimplemented!("We work with decrordi16 and override _get_min_index_value and _get_max_index_value") + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. + unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u16) -> __m512i { - _u16_to_i16decrord(_mm512_loadu_epi16(data as *const i16)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi16(a as i16) + _u16_as_m512i_to_i16ord(_mm512_loadu_epi16(data as *const i16)) } #[inline(always)] @@ -540,13 +492,6 @@ mod avx512 { _mm512_mask_blend_epi16(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512bw")] - unsafe fn argminmax(data: &[u16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, u16) { // 0. Find the minimum value @@ -576,7 +521,7 @@ mod avx512 { imin = _mm512_min_epi16(imin, _mm512_alignr_epi8(imin, imin, 2)); let min_index: usize = _mm_extract_epi16(_mm512_castsi512_si128(imin), 0) as usize; - (min_index, _i16decrord_to_u16(min_value)) + (min_index, _i16ord_to_u16(min_value)) } #[inline(always)] @@ -608,20 +553,14 @@ mod avx512 { imin = _mm512_min_epi16(imin, _mm512_alignr_epi8(imin, imin, 2)); let max_index: usize = _mm_extract_epi16(_mm512_castsi512_si128(imin), 0) as usize; - (max_index, _i16decrord_to_u16(max_value)) + (max_index, _i16ord_to_u16(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m512i, - values_low: __m512i, - index_high: __m512i, - values_high: __m512i, - ) -> (usize, u16, usize, u16) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i16ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512bw")] + unsafe fn argminmax(data: &[u16]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -630,7 +569,7 @@ mod avx512 { #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -642,7 +581,7 @@ mod avx512 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -657,21 +596,11 @@ mod avx512 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } - let data = [ - 10, - std::u16::MIN, - 6, - 9, - 9, - 22, - std::u16::MAX, - 4, - std::u16::MAX, - ]; + let data = [10, u16::MIN, 6, 9, 9, 22, u16::MAX, 4, u16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u16] = &data; @@ -686,7 +615,7 @@ mod avx512 { #[test] fn test_no_overflow() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -701,7 +630,7 @@ mod avx512 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -725,10 +654,12 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_16; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: uint16x8_t = unsafe { std::mem::transmute([0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16]) }; - const MAX_INDEX: usize = u16::MAX as usize; + const INDEX_INCREMENT: uint16x8_t = + unsafe { std::mem::transmute([LANE_SIZE as i16; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: uint16x8_t) -> [u16; LANE_SIZE] { @@ -740,11 +671,6 @@ mod neon { vld1q_u16(data as *const u16) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> uint16x8_t { - vdupq_n_u16(a as u16) - } - #[inline(always)] unsafe fn _mm_add(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t { vaddq_u16(a, b) @@ -765,13 +691,6 @@ mod neon { vbslq_u16(mask, b, a) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[u16]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: uint16x8_t, value: uint16x8_t) -> (usize, u16) { // 0. Find the minimum value @@ -829,11 +748,18 @@ mod neon { } } + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[u16]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ----------------------------------------- TESTS ----------------------------------------- #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -856,17 +782,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::u16::MIN, - 6, - 9, - 9, - 22, - std::u16::MAX, - 4, - std::u16::MAX, - ]; + let data = [10, u16::MIN, 6, 9, 9, 22, u16::MAX, 4, u16::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u16] = &data; diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 4991164..a671342 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -10,15 +10,20 @@ use std::arch::x86::*; use std::arch::x86_64::*; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -const XOR_VALUE: i32 = 0x7FFFFFFF; +use super::task::{max_index_value, min_index_value}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const XOR_VALUE: i32 = -0x80000000; // i32::MIN #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[inline(always)] -fn _i32decrord_to_u32(ord_i32: i32) -> u32 { - // let v = ord_i32 ^ 0x7FFFFFFF; +fn _i32ord_to_u32(ord_i32: i32) -> u32 { + // let v = ord_i32 ^ -0x80000000; unsafe { std::mem::transmute::(ord_i32 ^ XOR_VALUE) } } +const MAX_INDEX: usize = i32::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -30,10 +35,10 @@ mod avx2 { const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u32_to_i32decrord(u32: __m256i) -> __m256i { - // on a scalar: v^ 0x7FFFFFFF - // transforms to monotonically **decreasing** order - _mm256_xor_si256(u32, XOR_MASK) + unsafe fn _u32_as_m256i_to_i32ord(u32_as_m256i: __m256i) -> __m256i { + // on a scalar: v ^ -0x80000000 + // transforms to monotonically increasing order + _mm256_xor_si256(u32_as_m256i, XOR_MASK) } #[inline(always)] @@ -41,25 +46,24 @@ mod avx2 { std::mem::transmute::<__m256i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m256i) -> [u32; LANE_SIZE] { - // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u32) -> __m256i { - _u32_to_i32decrord(_mm256_loadu_si256(data as *const __m256i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi32(a as i32) + _u32_as_m256i_to_i32ord(_mm256_loadu_si256(data as *const __m256i)) } #[inline(always)] @@ -82,37 +86,35 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_u32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m256i, value: __m256i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_u32(max_value)) + } + } + impl SIMDArgMinMax for AVX2 { #[target_feature(enable = "avx2")] unsafe fn argminmax(data: &[u32]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m256i, - values_low: __m256i, - index_high: __m256i, - values_high: __m256i, - ) -> (usize, u32, usize, u32) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i32decrord_to_u32(max_value), - min_index, - _i32decrord_to_u32(min_value), - ) - } } // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -143,17 +145,7 @@ mod avx2 { return; } - let data = [ - 10, - std::u32::MIN, - 6, - 9, - 9, - 22, - std::u32::MAX, - 4, - std::u32::MAX, - ]; + let data = [10, u32::MIN, 6, 9, 9, 22, u32::MAX, 4, u32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u32] = &data; @@ -194,10 +186,10 @@ mod sse { const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u32_to_i32decrord(u32: __m128i) -> __m128i { - // on a scalar: v^ 0x7FFFFFFF - // transforms to monotonically **decreasing** order - _mm_xor_si128(u32, XOR_MASK) + unsafe fn _u32_as_m128i_to_i32ord(u32_as_m128i: __m128i) -> __m128i { + // on a scalar: v ^ -0x80000000 + // transforms to monotonically increasing order + _mm_xor_si128(u32_as_m128i, XOR_MASK) } #[inline(always)] @@ -205,24 +197,23 @@ mod sse { std::mem::transmute::<__m128i, [i32; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i32, 1i32, 2i32, 3i32]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m128i) -> [u32; LANE_SIZE] { - // Not used because we work with i32ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u32) -> __m128i { - _u32_to_i32decrord(_mm_loadu_si128(data as *const __m128i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi32(a as i32) + _u32_as_m128i_to_i32ord(_mm_loadu_si128(data as *const __m128i)) } #[inline(always)] @@ -245,37 +236,35 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_u32(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m128i, value: __m128i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_u32(max_value)) + } + } + impl SIMDArgMinMax for SSE { #[target_feature(enable = "sse4.1")] unsafe fn argminmax(data: &[u32]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m128i, - values_low: __m128i, - index_high: __m128i, - values_high: __m128i, - ) -> (usize, u32, usize, u32) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i32decrord_to_u32(max_value), - min_index, - _i32decrord_to_u32(min_value), - ) - } } // ----------------------------------------- TESTS ----------------------------------------- #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -298,17 +287,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::u32::MIN, - 6, - 9, - 9, - 22, - std::u32::MAX, - 4, - std::u32::MAX, - ]; + let data = [10, u32::MIN, 6, 9, 9, 22, u32::MAX, 4, u32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u32] = &data; @@ -344,13 +323,11 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_32; const XOR_MASK: __m512i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; - // TODO - comparison swappen => dan moeten we opt einde niet meer swappen? - #[inline(always)] - unsafe fn _u32_to_i32decrord(u32: __m512i) -> __m512i { - // on scalar: v ^ 0x7FFFFFFF - // transforms to monotonically **decreasing** order - _mm512_xor_si512(u32, XOR_MASK) + unsafe fn _u32_as_m512i_to_i32ord(u32_as_m512i: __m512i) -> __m512i { + // on scalar: v ^ -0x80000000 + // transforms to monotonically increasing order + _mm512_xor_si512(u32_as_m512i, XOR_MASK) } #[inline(always)] @@ -358,28 +335,28 @@ mod avx512 { std::mem::transmute::<__m512i, [i32; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, 12i32, 13i32, 14i32, 15i32, ]) }; - const MAX_INDEX: usize = i32::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m512i) -> [u32; LANE_SIZE] { - unimplemented!("We work with decrordu32 and override _get_min_index_value and _get_max_index_value") + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. + unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u32) -> __m512i { - _u32_to_i32decrord(_mm512_loadu_epi32(data as *const i32)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi32(a as i32) + _u32_as_m512i_to_i32ord(_mm512_loadu_epi32(data as *const i32)) } #[inline(always)] @@ -402,37 +379,35 @@ mod avx512 { _mm512_mask_blend_epi32(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i32ord_to_u32(min_value)) + } + #[inline(always)] + unsafe fn _horiz_max(index: __m512i, value: __m512i) -> (usize, u32) { + let index_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(index); + let value_arr: [i32; LANE_SIZE] = _reg_to_i32_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i32ord_to_u32(max_value)) + } + } + + impl SIMDArgMinMax for AVX512 { #[target_feature(enable = "avx512f")] unsafe fn argminmax(data: &[u32]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m512i, - values_low: __m512i, - index_high: __m512i, - values_high: __m512i, - ) -> (usize, u32, usize, u32) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i32decrord_to_u32(max_value), - min_index, - _i32decrord_to_u32(min_value), - ) - } } // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -463,17 +438,7 @@ mod avx512 { return; } - let data = [ - 10, - std::u32::MIN, - 6, - 9, - 9, - 22, - std::u32::MAX, - 4, - std::u32::MAX, - ]; + let data = [10, u32::MIN, 6, 9, 9, 22, u32::MAX, 4, u32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u32] = &data; @@ -512,9 +477,11 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_32; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: uint32x4_t = unsafe { std::mem::transmute([0u32, 1u32, 2u32, 3u32]) }; - const MAX_INDEX: usize = u32::MAX as usize; + const INDEX_INCREMENT: uint32x4_t = + unsafe { std::mem::transmute([LANE_SIZE as i32; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: uint32x4_t) -> [u32; LANE_SIZE] { @@ -526,11 +493,6 @@ mod neon { vld1q_u32(data) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> uint32x4_t { - vdupq_n_u32(a as u32) - } - #[inline(always)] unsafe fn _mm_add(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { vaddq_u32(a, b) @@ -550,9 +512,9 @@ mod neon { unsafe fn _mm_blendv(a: uint32x4_t, b: uint32x4_t, mask: uint32x4_t) -> uint32x4_t { vbslq_u32(mask, b, a) } + } - // ------------------------------------ ARGMINMAX -------------------------------------- - + impl SIMDArgMinMax for NEON { #[target_feature(enable = "neon")] unsafe fn argminmax(data: &[u32]) -> (usize, usize) { Self::_argminmax(data) @@ -563,7 +525,7 @@ mod neon { #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -586,17 +548,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::u32::MIN, - 6, - 9, - 9, - 22, - std::u32::MAX, - 4, - std::u32::MAX, - ]; + let data = [10, u32::MIN, 6, 9, 9, 22, u32::MAX, 4, u32::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u32] = &data; diff --git a/src/simd/simd_u64.rs b/src/simd/simd_u64.rs index b3b9ead..edbee5a 100644 --- a/src/simd/simd_u64.rs +++ b/src/simd/simd_u64.rs @@ -1,21 +1,27 @@ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "x86")] use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -const XOR_VALUE: i64 = 0x7FFFFFFFFFFFFFFF; +use super::task::{max_index_value, min_index_value}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const XOR_VALUE: i64 = -0x8000000000000000; // i64::MIN #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[inline(always)] -fn _i64decrord_to_u64(ord_i64: i64) -> u64 { - // let v = ord_i64 ^ 0x7FFFFFFFFFFFFFFF; +fn _i64ord_to_u64(ord_i64: i64) -> u64 { + // let v = ord_i64 ^ -0x8000000000000000; unsafe { std::mem::transmute::(ord_i64 ^ XOR_VALUE) } } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const MAX_INDEX: usize = i64::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -27,10 +33,10 @@ mod avx2 { const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u64_to_i64decrord(u64: __m256i) -> __m256i { - // on a scalar: v^ 0x7FFFFFFFFFFFFFFF - // transforms to monotonically **decreasing** order - _mm256_xor_si256(u64, XOR_MASK) + unsafe fn _u64_as_m256i_to_i64ord(u64_as_m256i: __m256i) -> __m256i { + // on a scalar: v ^ -0x8000000000000000 + // transforms to monotonically increasing order + _mm256_xor_si256(u64_as_m256i, XOR_MASK) } #[inline(always)] @@ -38,24 +44,23 @@ mod avx2 { std::mem::transmute::<__m256i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m256i) -> [u64; LANE_SIZE] { - // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u64) -> __m256i { - _u64_to_i64decrord(_mm256_loadu_si256(data as *const __m256i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi64x(a as i64) + _u64_as_m256i_to_i64ord(_mm256_loadu_si256(data as *const __m256i)) } #[inline(always)] @@ -78,37 +83,35 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_u64(min_value)) + } + #[inline(always)] + unsafe fn _horiz_max(index: __m256i, value: __m256i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_u64(max_value)) + } + } + + impl SIMDArgMinMax for AVX2 { #[target_feature(enable = "avx2")] unsafe fn argminmax(data: &[u64]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m256i, - values_low: __m256i, - index_high: __m256i, - values_high: __m256i, - ) -> (usize, u64, usize, u64) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i64decrord_to_u64(max_value), - min_index, - _i64decrord_to_u64(min_value), - ) - } } // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -139,17 +142,7 @@ mod avx2 { return; } - let data = [ - 10, - std::u64::MIN, - 6, - 9, - 9, - 22, - std::u64::MAX, - 4, - std::u64::MAX, - ]; + let data = [10, u64::MIN, 6, 9, 9, 22, u64::MAX, 4, u64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u64] = &data; @@ -190,10 +183,10 @@ mod sse { const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u64_to_i64decrord(u64: __m128i) -> __m128i { - // on a scalar: v^ 0x7FFFFFFF - // transforms to monotonically **decreasing** order - _mm_xor_si128(u64, XOR_MASK) + unsafe fn _u64_as_m128i_to_i64ord(u64_as_m128i: __m128i) -> __m128i { + // on a scalar: v ^ -0x8000000000000000 + // transforms to monotonically increasing order + _mm_xor_si128(u64_as_m128i, XOR_MASK) } #[inline(always)] @@ -201,24 +194,23 @@ mod sse { std::mem::transmute::<__m128i, [i64; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([0i64, 1i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m128i) -> [u64; LANE_SIZE] { - // Not used because we work with i64ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u64) -> __m128i { - _u64_to_i64decrord(_mm_loadu_si128(data as *const __m128i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi64x(a as i64) + _u64_as_m128i_to_i64ord(_mm_loadu_si128(data as *const __m128i)) } #[inline(always)] @@ -241,37 +233,35 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_u64(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m128i, value: __m128i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_u64(max_value)) + } + } + impl SIMDArgMinMax for SSE { #[target_feature(enable = "sse4.2")] unsafe fn argminmax(data: &[u64]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m128i, - values_low: __m128i, - index_high: __m128i, - values_high: __m128i, - ) -> (usize, u64, usize, u64) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i64decrord_to_u64(max_value), - min_index, - _i64decrord_to_u64(min_value), - ) - } } // ----------------------------------------- TESTS ----------------------------------------- #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -294,17 +284,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [ - 10, - std::u64::MIN, - 6, - 9, - 9, - 22, - std::u64::MAX, - 4, - std::u64::MAX, - ]; + let data = [10, u64::MIN, 6, 9, 9, 22, u64::MAX, 4, u64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u64] = &data; @@ -340,13 +320,11 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_64; const XOR_MASK: __m512i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; - // TODO - comparison swappen => dan moeten we opt einde niet meer swappen? - #[inline(always)] - unsafe fn _u64_to_i64decrord(u64: __m512i) -> __m512i { - // on a scalar: v^ 0x7FFFFFFFFFFFFFFF - // transforms to monotonically **decreasing** order - _mm512_xor_si512(u64, XOR_MASK) + unsafe fn _u64_as_m512i_to_i64ord(u64_as_m512i: __m512i) -> __m512i { + // on a scalar: v ^ -0x8000000000000000 + // transforms to monotonically increasing order + _mm512_xor_si512(u64_as_m512i, XOR_MASK) } #[inline(always)] @@ -354,24 +332,24 @@ mod avx512 { std::mem::transmute::<__m512i, [i64; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64]) }; - const MAX_INDEX: usize = i64::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i64; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m512i) -> [u64; LANE_SIZE] { - unimplemented!("We work with decrordi64 and override _get_min_index_value and _get_max_index_value") + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. + unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u64) -> __m512i { - _u64_to_i64decrord(_mm512_loadu_epi64(data as *const i64)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi64(a as i64) + _u64_as_m512i_to_i64ord(_mm512_loadu_epi64(data as *const i64)) } #[inline(always)] @@ -394,37 +372,35 @@ mod avx512 { _mm512_mask_blend_epi64(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- + #[inline(always)] + unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (min_index, min_value) = min_index_value(&index_arr, &value_arr); + (min_index as usize, _i64ord_to_u64(min_value)) + } + + #[inline(always)] + unsafe fn _horiz_max(index: __m512i, value: __m512i) -> (usize, u64) { + let index_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(index); + let value_arr: [i64; LANE_SIZE] = _reg_to_i64_arr(value); + let (max_index, max_value) = max_index_value(&index_arr, &value_arr); + (max_index as usize, _i64ord_to_u64(max_value)) + } + } + impl SIMDArgMinMax for AVX512 { #[target_feature(enable = "avx512f")] unsafe fn argminmax(data: &[u64]) -> (usize, usize) { Self::_argminmax(data) } - - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m512i, - values_low: __m512i, - index_high: __m512i, - values_high: __m512i, - ) -> (usize, u64, usize, u64) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i64ord in decreasing order (max => actual min, and vice versa) - ( - max_index, - _i64decrord_to_u64(max_value), - min_index, - _i64decrord_to_u64(min_value), - ) - } } // ------------------------------------ TESTS -------------------------------------- #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -455,17 +431,7 @@ mod avx512 { return; } - let data = [ - 10, - std::u64::MIN, - 6, - 9, - 9, - 22, - std::u64::MAX, - 4, - std::u64::MAX, - ]; + let data = [10, u64::MIN, 6, 9, 9, 22, u64::MAX, 4, u64::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u64] = &data; @@ -505,12 +471,13 @@ mod avx512 { #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] mod neon { use super::super::config::NEON; - use super::super::generic::unimplement_simd; + use super::super::generic::{unimpl_SIMDArgMinMax, unimpl_SIMDOps}; use super::*; // We need to (un)implement the SIMD trait for the NEON struct as otherwise the // compiler will complain that the trait is not implemented for the struct - // even though we are not using the trait for the NEON struct when dealing with // > 64 bit data types. - unimplement_simd!(u64, usize, NEON); + unimpl_SIMDOps!(u64, usize, NEON); + unimpl_SIMDArgMinMax!(u64, usize, NEON); } diff --git a/src/simd/simd_u8.rs b/src/simd/simd_u8.rs index ae20c0d..e562831 100644 --- a/src/simd/simd_u8.rs +++ b/src/simd/simd_u8.rs @@ -1,5 +1,5 @@ use super::config::SIMDInstructionSet; -use super::generic::SIMD; +use super::generic::{SIMDArgMinMax, SIMDOps}; #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; #[cfg(target_arch = "arm")] @@ -10,15 +10,17 @@ use std::arch::x86::*; use std::arch::x86_64::*; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -const XOR_VALUE: i8 = 0x7F; +const XOR_VALUE: i8 = -0x80; // i8::MIN #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[inline(always)] -fn _i8decrord_to_u8(decrord_i8: i8) -> u8 { - // let v = ord_i8 ^ 0x7F; - unsafe { std::mem::transmute::(decrord_i8 ^ XOR_VALUE) } +fn _i8ord_to_u8(ord_i8: i8) -> u8 { + // let v = ord_i8 ^ -0x80; + unsafe { std::mem::transmute::(ord_i8 ^ XOR_VALUE) } } +const MAX_INDEX: usize = i8::MAX as usize; + // ------------------------------------------ AVX2 ------------------------------------------ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -30,10 +32,10 @@ mod avx2 { const XOR_MASK: __m256i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u8_to_i8decrord(u8: __m256i) -> __m256i { - // on a scalar: v^ 0x7F - // transforms to monotonically **decreasing** order - _mm256_xor_si256(u8, XOR_MASK) // Only 1 operation + unsafe fn _u8_as_m256i_to_i8ord(u8_as_m256i: __m256i) -> __m256i { + // on a scalar: v ^ -0x80 + // transforms to monotonically increasing order + _mm256_xor_si256(u8_as_m256i, XOR_MASK) } #[inline(always)] @@ -41,7 +43,7 @@ mod avx2 { std::mem::transmute::<__m256i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX2 { + impl SIMDOps for AVX2 { const INITIAL_INDEX: __m256i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -49,22 +51,21 @@ mod avx2 { 29i8, 30i8, 31i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m256i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m256i) -> [u8; LANE_SIZE] { - // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u8) -> __m256i { - _u8_to_i8decrord(_mm256_loadu_si256(data as *const __m256i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m256i { - _mm256_set1_epi8(a as i8) + _u8_as_m256i_to_i8ord(_mm256_loadu_si256(data as *const __m256i)) } #[inline(always)] @@ -87,13 +88,6 @@ mod avx2 { _mm256_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx2")] - unsafe fn argminmax(data: &[u8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m256i, value: __m256i) -> (usize, u8) { // 0. Find the minimum value @@ -123,7 +117,7 @@ mod avx2 { imin = _mm256_min_epi8(imin, _mm256_alignr_epi8(imin, imin, 1)); let min_index: usize = _mm256_extract_epi8(imin, 0) as usize; - (min_index, _i8decrord_to_u8(min_value)) + (min_index, _i8ord_to_u8(min_value)) } #[inline(always)] @@ -155,20 +149,14 @@ mod avx2 { imin = _mm256_min_epi8(imin, _mm256_alignr_epi8(imin, imin, 1)); let max_index: usize = _mm256_extract_epi8(imin, 0) as usize; - (max_index, _i8decrord_to_u8(max_value)) + (max_index, _i8ord_to_u8(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m256i, - values_low: __m256i, - index_high: __m256i, - values_high: __m256i, - ) -> (usize, u8, usize, u8) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i8ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for AVX2 { + #[target_feature(enable = "avx2")] + unsafe fn argminmax(data: &[u8]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -176,7 +164,7 @@ mod avx2 { #[cfg(test)] mod tests { - use super::{AVX2, SIMD}; + use super::{SIMDArgMinMax, AVX2}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -207,7 +195,7 @@ mod avx2 { return; } - let data = [10, std::u8::MIN, 6, 9, 9, 22, std::u8::MAX, 4, std::u8::MAX]; + let data = [10, u8::MIN, 6, 9, 9, 22, u8::MAX, 4, u8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u8] = &data; @@ -263,10 +251,10 @@ mod sse { const XOR_MASK: __m128i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; #[inline(always)] - unsafe fn _u8_to_i8decrord(u8: __m128i) -> __m128i { - // on a scalar: v^ 0x7F - // transforms to monotonically **decreasing** order - _mm_xor_si128(u8, XOR_MASK) + unsafe fn _u8_as_m128i_to_i8ord(u8_as_m128i: __m128i) -> __m128i { + // on a scalar: v ^ -0x80 + // transforms to monotonically increasing order + _mm_xor_si128(u8_as_m128i, XOR_MASK) } #[inline(always)] @@ -274,29 +262,28 @@ mod sse { std::mem::transmute::<__m128i, [i8; LANE_SIZE]>(reg) } - impl SIMD for SSE { + impl SIMDOps for SSE { const INITIAL_INDEX: __m128i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, 15i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m128i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m128i) -> [u8; LANE_SIZE] { - // Not used because we work with i8ord and override _get_min_index_value and _get_max_index_value + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u8) -> __m128i { - _u8_to_i8decrord(_mm_loadu_si128(data as *const __m128i)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m128i { - _mm_set1_epi8(a as i8) + _u8_as_m128i_to_i8ord(_mm_loadu_si128(data as *const __m128i)) } #[inline(always)] @@ -319,13 +306,6 @@ mod sse { _mm_blendv_epi8(a, b, mask) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "sse4.1")] - unsafe fn argminmax(data: &[u8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m128i, value: __m128i) -> (usize, u8) { // 0. Find the minimum value @@ -353,7 +333,7 @@ mod sse { imin = _mm_min_epi8(imin, _mm_alignr_epi8(imin, imin, 1)); let min_index: usize = _mm_extract_epi8(imin, 0) as usize; - (min_index, _i8decrord_to_u8(min_value)) + (min_index, _i8ord_to_u8(min_value)) } #[inline(always)] @@ -383,20 +363,14 @@ mod sse { imin = _mm_min_epi8(imin, _mm_alignr_epi8(imin, imin, 1)); let max_index: usize = _mm_extract_epi8(imin, 0) as usize; - (max_index, _i8decrord_to_u8(max_value)) + (max_index, _i8ord_to_u8(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m128i, - values_low: __m128i, - index_high: __m128i, - values_high: __m128i, - ) -> (usize, u8, usize, u8) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i8ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for SSE { + #[target_feature(enable = "sse4.1")] + unsafe fn argminmax(data: &[u8]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -404,7 +378,7 @@ mod sse { #[cfg(test)] mod tests { - use super::{SIMD, SSE}; + use super::{SIMDArgMinMax, SSE}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -427,7 +401,7 @@ mod sse { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [10, std::u8::MIN, 6, 9, 9, 22, std::u8::MAX, 4, std::u8::MAX]; + let data = [10, u8::MIN, 6, 9, 9, 22, u8::MAX, 4, u8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u8] = &data; @@ -474,13 +448,11 @@ mod avx512 { const LANE_SIZE: usize = AVX512::LANE_SIZE_8; const XOR_MASK: __m512i = unsafe { std::mem::transmute([XOR_VALUE; LANE_SIZE]) }; - // TODO - comparison swappen => dan moeten we opt einde niet meer swappen? - #[inline(always)] - unsafe fn _u8_to_i8decrord(u8: __m512i) -> __m512i { - // on a scalar: v^ 0x7F - // transforms to monotonically **decreasing** order - _mm512_xor_si512(u8, XOR_MASK) + unsafe fn _u8_as_m512i_to_i8ord(u8_as_m512i: __m512i) -> __m512i { + // on a scalar: v ^ -0x80 + // transforms to monotonically increasing order + _mm512_xor_si512(u8_as_m512i, XOR_MASK) } #[inline(always)] @@ -488,7 +460,7 @@ mod avx512 { std::mem::transmute::<__m512i, [i8; LANE_SIZE]>(reg) } - impl SIMD for AVX512 { + impl SIMDOps for AVX512 { const INITIAL_INDEX: __m512i = unsafe { std::mem::transmute([ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8, 8i8, 9i8, 10i8, 11i8, 12i8, 13i8, 14i8, @@ -498,23 +470,21 @@ mod avx512 { 57i8, 58i8, 59i8, 60i8, 61i8, 62i8, 63i8, ]) }; - const MAX_INDEX: usize = i8::MAX as usize; + const INDEX_INCREMENT: __m512i = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(_: __m512i) -> [u8; LANE_SIZE] { - unimplemented!( - "We work with decrordi8 and override _get_min_index_value and _get_max_index_value" - ) + // Not implemented because we will perform the horizontal operations on the + // signed integer values instead of trying to retransform **only** the values + // (and thus not the indices) to signed integers. + unimplemented!() } #[inline(always)] unsafe fn _mm_loadu(data: *const u8) -> __m512i { - _u8_to_i8decrord(_mm512_loadu_epi8(data as *const i8)) - } - - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> __m512i { - _mm512_set1_epi8(a as i8) + _u8_as_m512i_to_i8ord(_mm512_loadu_epi8(data as *const i8)) } #[inline(always)] @@ -537,13 +507,6 @@ mod avx512 { _mm512_mask_blend_epi8(mask, a, b) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "avx512bw")] - unsafe fn argminmax(data: &[u8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: __m512i, value: __m512i) -> (usize, u8) { // 0. Find the minimum value @@ -575,7 +538,7 @@ mod avx512 { imin = _mm512_min_epi8(imin, _mm512_alignr_epi8(imin, imin, 1)); let min_index: usize = _mm_extract_epi8(_mm512_castsi512_si128(imin), 0) as usize; - (min_index, _i8decrord_to_u8(min_value)) + (min_index, _i8ord_to_u8(min_value)) } #[inline(always)] @@ -609,20 +572,14 @@ mod avx512 { imin = _mm512_min_epi8(imin, _mm512_alignr_epi8(imin, imin, 1)); let max_index: usize = _mm_extract_epi8(_mm512_castsi512_si128(imin), 0) as usize; - (max_index, _i8decrord_to_u8(max_value)) + (max_index, _i8ord_to_u8(max_value)) } + } - #[inline(always)] - unsafe fn _get_min_max_index_value( - index_low: __m512i, - values_low: __m512i, - index_high: __m512i, - values_high: __m512i, - ) -> (usize, u8, usize, u8) { - let (min_index, min_value) = Self::_horiz_min(index_low, values_low); - let (max_index, max_value) = Self::_horiz_max(index_high, values_high); - // Swap min and max here because we worked with i8ord in decreasing order (max => actual min, and vice versa) - (max_index, max_value, min_index, min_value) + impl SIMDArgMinMax for AVX512 { + #[target_feature(enable = "avx512bw")] + unsafe fn argminmax(data: &[u8]) -> (usize, usize) { + Self::_argminmax(data) } } @@ -631,7 +588,7 @@ mod avx512 { #[cfg(test)] mod tests { - use super::{AVX512, SIMD}; + use super::{SIMDArgMinMax, AVX512}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -643,7 +600,7 @@ mod avx512 { #[test] fn test_both_versions_return_the_same_results() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -658,11 +615,11 @@ mod avx512 { #[test] fn test_first_index_is_returned_when_identical_values_found() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } - let data = [10, std::u8::MIN, 6, 9, 9, 22, std::u8::MAX, 4, std::u8::MAX]; + let data = [10, u8::MIN, 6, 9, 9, 22, u8::MAX, 4, u8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u8] = &data; @@ -677,7 +634,7 @@ mod avx512 { #[test] fn test_no_overflow() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -692,7 +649,7 @@ mod avx512 { #[test] fn test_many_random_runs() { - if !is_x86_feature_detected!("avx512f") { + if !is_x86_feature_detected!("avx512bw") { return; } @@ -716,14 +673,16 @@ mod neon { const LANE_SIZE: usize = NEON::LANE_SIZE_8; - impl SIMD for NEON { + impl SIMDOps for NEON { const INITIAL_INDEX: uint8x16_t = unsafe { std::mem::transmute([ 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8, 11u8, 12u8, 13u8, 14u8, 15u8, ]) }; - const MAX_INDEX: usize = u8::MAX as usize; + const INDEX_INCREMENT: uint8x16_t = + unsafe { std::mem::transmute([LANE_SIZE as i8; LANE_SIZE]) }; + const MAX_INDEX: usize = MAX_INDEX; #[inline(always)] unsafe fn _reg_to_arr(reg: uint8x16_t) -> [u8; LANE_SIZE] { @@ -735,11 +694,6 @@ mod neon { vld1q_u8(data as *const u8) } - #[inline(always)] - unsafe fn _mm_set1(a: usize) -> uint8x16_t { - vdupq_n_u8(a as u8) - } - #[inline(always)] unsafe fn _mm_add(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t { vaddq_u8(a, b) @@ -760,13 +714,6 @@ mod neon { vbslq_u8(mask, b, a) } - // ------------------------------------ ARGMINMAX -------------------------------------- - - #[target_feature(enable = "neon")] - unsafe fn argminmax(data: &[u8]) -> (usize, usize) { - Self::_argminmax(data) - } - #[inline(always)] unsafe fn _horiz_min(index: uint8x16_t, value: uint8x16_t) -> (usize, u8) { // 0. Find the minimum value @@ -828,11 +775,18 @@ mod neon { } } + impl SIMDArgMinMax for NEON { + #[target_feature(enable = "neon")] + unsafe fn argminmax(data: &[u8]) -> (usize, usize) { + Self::_argminmax(data) + } + } + // ----------------------------------------- TESTS ----------------------------------------- #[cfg(test)] mod tests { - use super::{NEON, SIMD}; + use super::{SIMDArgMinMax, NEON}; use crate::scalar::generic::scalar_argminmax; extern crate dev_utils; @@ -855,7 +809,7 @@ mod neon { #[test] fn test_first_index_is_returned_when_identical_values_found() { - let data = [10, std::u8::MIN, 6, 9, 9, 22, std::u8::MAX, 4, std::u8::MAX]; + let data = [10, u8::MIN, 6, 9, 9, 22, u8::MAX, 4, u8::MAX]; let data: Vec = data.iter().map(|x| *x).collect(); let data: &[u8] = &data; diff --git a/src/simd/task.rs b/src/simd/task.rs index 7ccc0b8..2db9a45 100644 --- a/src/simd/task.rs +++ b/src/simd/task.rs @@ -1,5 +1,3 @@ -use crate::scalar::{ScalarArgMinMax, SCALAR}; - use std::cmp::Ordering; #[inline(always)] @@ -7,31 +5,44 @@ pub(crate) fn argminmax_generic( arr: &[T], lane_size: usize, core_argminmax: unsafe fn(&[T]) -> (usize, T, usize, T), -) -> (usize, usize) -where - SCALAR: ScalarArgMinMax, -{ + ignore_nan: bool, // if false, NaNs will be returned + scalar_argminmax: fn(&[T]) -> (usize, usize), +) -> (usize, usize) { assert!(!arr.is_empty()); // split_array should never return (None, None) match split_array(arr, lane_size) { - (Some(sim), Some(rem)) => { - let (rem_min_index, rem_max_index) = SCALAR::argminmax(rem); + (Some(simd_arr), Some(rem)) => { + // Perform SIMD operation on the first part of the array + let simd_result = unsafe { core_argminmax(simd_arr) }; + // Perform scalar operation on the remainder of the array + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); + // let (rem_min_index, rem_max_index) = SCALAR::argminmax(rem); let rem_result = ( - rem_min_index + sim.len(), + rem_min_index + simd_arr.len(), rem[rem_min_index], - rem_max_index + sim.len(), + rem_max_index + simd_arr.len(), rem[rem_max_index], ); - let sim_result = unsafe { core_argminmax(sim) }; - find_final_index_minmax(rem_result, sim_result) + // Find the final min and max values + let (min_index, min_value) = find_final_index_min( + (simd_result.0, simd_result.1), + (rem_result.0, rem_result.1), + ignore_nan, + ); + let (max_index, max_value) = find_final_index_max( + (simd_result.2, simd_result.3), + (rem_result.2, rem_result.3), + ignore_nan, + ); + get_correct_argminmax_result(min_index, min_value, max_index, max_value, ignore_nan) + } + (Some(simd_arr), None) => { + let (min_index, min_value, max_index, max_value) = unsafe { core_argminmax(simd_arr) }; + get_correct_argminmax_result(min_index, min_value, max_index, max_value, ignore_nan) } (None, Some(rem)) => { - let (rem_min_index, rem_max_index) = SCALAR::argminmax(rem); + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); (rem_min_index, rem_max_index) } - (Some(sim), None) => { - let sim_result = unsafe { core_argminmax(sim) }; - (sim_result.0, sim_result.2) - } (None, None) => panic!("Array is empty"), // Should never occur because of assert } } @@ -55,24 +66,131 @@ fn split_array(arr: &[T], lane_size: usize) -> (Option<&[T]>, Option<&[ } } +/// Get the final index of the min value when both a SIMD and scalar result is available +/// If not ignoring NaNs (thus returning NaN index if any present): +/// - If both values are NaN, returns the index of the simd result (as the first part +/// of the array is passed to the SIMD function) +/// - If one value is NaN, returns the index of the non-NaN value +/// - If neither value is NaN, returns the index of the min value +/// If ignoring NaNs: returns the index of the min value +/// +/// Note: when the values are equal, the index of the simd result is returned (as the +/// first part of the array is passed to the SIMD function) #[inline(always)] -fn find_final_index_minmax( - remainder_result: (usize, T, usize, T), - simd_result: (usize, T, usize, T), -) -> (usize, usize) { - let min_result = match remainder_result.1.partial_cmp(&simd_result.1).unwrap() { - Ordering::Less => remainder_result.0, - Ordering::Equal => std::cmp::min(remainder_result.0, simd_result.0), - Ordering::Greater => simd_result.0, +fn find_final_index_min( + simd_result: (usize, T), + remainder_result: (usize, T), + ignore_nan: bool, +) -> (usize, T) { + let (min_index, min_value) = match simd_result.1.partial_cmp(&remainder_result.1) { + Some(Ordering::Less) => simd_result, + Some(Ordering::Equal) => simd_result, + Some(Ordering::Greater) => remainder_result, + None => { + if !ignore_nan { + // --- Return NaNs + // Should prefer the simd result over the remainder result if both are + // NaN + if simd_result.1 != simd_result.1 { + // because NaN != NaN + simd_result + } else { + remainder_result + } + } else { + // --- Ignore NaNs + // If both are NaN raise panic, otherwise return the index of the + // non-NaN value + if simd_result.1 != simd_result.1 && remainder_result.1 != remainder_result.1 { + panic!("Data contains only NaNs (or +/- inf)") + } else if remainder_result.1 != remainder_result.1 { + simd_result + } else { + remainder_result + } + } + } }; + (min_index, min_value) +} - let max_result = match simd_result.3.partial_cmp(&remainder_result.3).unwrap() { - Ordering::Less => remainder_result.2, - Ordering::Equal => std::cmp::min(remainder_result.2, simd_result.2), - Ordering::Greater => simd_result.2, +/// Get the final index of the max value when both a SIMD and scalar result is available +/// If not ignoring NaNs (thus returning NaN index if any present): +/// - If both values are NaN, returns the index of the simd result (as the first part +/// of the array is passed to the SIMD function) +/// - If one value is NaN, returns the index of the non-NaN value +/// - If neither value is NaN, returns the index of the max value +/// If ignoring NaNs: returns the index of the max value +/// +/// Note: when the values are equal, the index of the simd result is returned (as the +/// first part of the array is passed to the SIMD function) +#[inline(always)] +fn find_final_index_max( + simd_result: (usize, T), + remainder_result: (usize, T), + ignore_nan: bool, +) -> (usize, T) { + let (max_index, max_value) = match simd_result.1.partial_cmp(&remainder_result.1) { + Some(Ordering::Greater) => simd_result, + Some(Ordering::Equal) => simd_result, + Some(Ordering::Less) => remainder_result, + None => { + if !ignore_nan { + // --- Return NaNs + // Should prefer the simd result over the remainder result if both are + // NaN + if simd_result.1 != simd_result.1 { + // because NaN != NaN + simd_result + } else { + remainder_result + } + } else { + // --- Ignore NaNs + // If both are NaN raise panic, otherwise return the index of the + // non-NaN value + if simd_result.1 != simd_result.1 && remainder_result.1 != remainder_result.1 { + panic!("Data contains only NaNs (or +/- inf)") + } else if remainder_result.1 != remainder_result.1 { + simd_result + } else { + remainder_result + } + } + } }; + (max_index, max_value) +} - (min_result, max_result) +/// Get the correct index(es) for the argmin and argmax functions +/// If not ignoring NaNs (thus returning NaN index if any present): +/// - If both values are NaN, returns the lowest index twice +/// - If one value is NaN, returns the index of the non-NaN value twice +/// - If neither value is NaN, returns the min_index and max_index +/// If ignoring NaNs: returns the min_index and max_index +fn get_correct_argminmax_result( + min_index: usize, + min_value: T, + max_index: usize, + max_value: T, + ignore_nan: bool, +) -> (usize, usize) { + if !ignore_nan && (min_value != min_value || max_value != max_value) { + // --- Return NaNs + // -> at least one of the values is NaN + if min_value != min_value && max_value != max_value { + // If both are NaN, return lowest index + let lowest_index = std::cmp::min(min_index, max_index); + return (lowest_index, lowest_index); + } else if min_value != min_value { + // If min is the only NaN, return min index + return (min_index, min_index); + } else { + // If max is the only NaN, return max index + return (max_index, max_index); + } + } + (min_index, max_index) } // ------------ Other helper functions