From 633793fb2d502bf3f954e25434ccc5d16fbe2d03 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Thu, 6 Oct 2022 11:14:05 +0200 Subject: [PATCH 1/6] :recycle: faster scalar implementation + no Option output :zap: --- README.md | 2 +- src/generic.rs | 94 ++++++++++++++++++++++++++------------------ src/lib.rs | 16 ++++---- src/simd/simd_f32.rs | 43 +++++++------------- src/simd/simd_f64.rs | 43 +++++++------------- src/simd/simd_i16.rs | 43 +++++++------------- src/simd/simd_i32.rs | 43 +++++++------------- src/simd/simd_i64.rs | 43 +++++++------------- src/simd/simd_i8.rs | 43 +++++++------------- src/simd/simd_u32.rs | 46 ++++++++-------------- src/task.rs | 42 ++++++++++++++++++-- 11 files changed, 202 insertions(+), 256 deletions(-) diff --git a/README.md b/README.md index 69979b5..4434aff 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ use ndarray::Array1; let arr: Vec = (0..200_000).collect(); let arr: Array1 = Array1::from(arr); -let (min, max) = arr.view().argminmax().unwrap(); // apply extension +let (min, max) = arr.view().argminmax(); // apply extension println!("min: {}, max: {}", min, max); println!("arr[min]: {}, arr[max]: {}", arr[min], arr[max]); diff --git a/src/generic.rs b/src/generic.rs index 5948e60..882358b 100644 --- a/src/generic.rs +++ b/src/generic.rs @@ -2,48 +2,66 @@ use ndarray::ArrayView1; // ------ On ArrayView1 -#[inline] -pub fn simple_argmin(arr: ArrayView1) -> usize { - let mut low_index = 0usize; - let mut low = arr[low_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } - } - low_index -} +// #[inline] +// pub(crate) fn simple_argmin(arr: ArrayView1) -> usize { +// let mut low_index = 0usize; +// let mut low = arr[low_index]; +// for (i, item) in arr.iter().enumerate() { +// if *item < low { +// low = *item; +// low_index = i; +// } +// } +// low_index +// } -#[inline] -pub fn simple_argmax(arr: ArrayView1) -> usize { - let mut high_index = 0usize; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item > high { - high = *item; - high_index = i; - } - } - high_index -} +// #[inline] +// pub(crate) fn simple_argmax(arr: ArrayView1) -> usize { +// let mut high_index = 0usize; +// let mut high = arr[high_index]; +// for (i, item) in arr.iter().enumerate() { +// if *item > high { +// high = *item; +// high_index = i; +// } +// } +// high_index +// } + +// #[inline] +// pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { +// let mut low_index: usize = 0; +// let mut high_index: usize = 0; +// let mut low = arr[low_index]; +// let mut high = arr[high_index]; +// for (i, item) in arr.iter().enumerate() { +// if *item < low { +// low = *item; +// low_index = i; +// } else if *item > high { +// high = *item; +// high_index = i; +// } +// } +// (low_index, high_index) +// } +// Note: 5-7% faster than the above implementation #[inline] pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { - let mut low_index: usize = 0; - let mut high_index: usize = 0; - let mut low = arr[low_index]; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } else if *item > high { - high = *item; - high_index = i; - } - } - (low_index, high_index) + let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( + (0usize, arr[0], 0usize, arr[0]), + |(min_idx, min, max_idx, max), (idx, item)| { + if *item < min { + (idx, *item, max_idx, max) + } else if *item > max { + (min_idx, min, idx, *item) + } else { + (min_idx, min, max_idx, max) + } + }, + ); + (minmax_tuple.0, minmax_tuple.2) } // ------ On &[T] diff --git a/src/lib.rs b/src/lib.rs index 4722e2d..e5218ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,20 +1,18 @@ pub mod generic; -// #[cfg(target_feature = "sse")] mod simd; -// #[cfg(target_feature = "sse")] mod task; -pub use generic::{simple_argminmax}; +pub use generic::simple_argminmax; pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; use ndarray::ArrayView1; pub trait ArgMinMax { - fn argminmax(self) -> Option<(usize, usize)>; + fn argminmax(self) -> (usize, usize); } impl ArgMinMax for ArrayView1<'_, f64> { - fn argminmax(self) -> Option<(usize, usize)> { + fn argminmax(self) -> (usize, usize) { // #[cfg(not(target_feature = "sse"))] // return Some(simple_argminmax(self)); // #[cfg(target_feature = "sse")] @@ -23,7 +21,7 @@ impl ArgMinMax for ArrayView1<'_, f64> { } impl ArgMinMax for ArrayView1<'_, i64> { - fn argminmax(self) -> Option<(usize, usize)> { + fn argminmax(self) -> (usize, usize) { // #[cfg(not(target_feature = "sse"))] // return Some(simple_argminmax(self)); // #[cfg(target_feature = "sse")] @@ -32,7 +30,7 @@ impl ArgMinMax for ArrayView1<'_, i64> { } impl ArgMinMax for ArrayView1<'_, f32> { - fn argminmax(self) -> Option<(usize, usize)> { + fn argminmax(self) -> (usize, usize) { // #[cfg(not(target_feature = "sse"))] // return Some(simple_argminmax(self)); // #[cfg(target_feature = "sse")] @@ -41,7 +39,7 @@ impl ArgMinMax for ArrayView1<'_, f32> { } impl ArgMinMax for ArrayView1<'_, i32> { - fn argminmax(self) -> Option<(usize, usize)> { + fn argminmax(self) -> (usize, usize) { // #[cfg(not(target_feature = "sse"))] // return Some(simple_argminmax(self)); // #[cfg(target_feature = "sse")] @@ -50,7 +48,7 @@ impl ArgMinMax for ArrayView1<'_, i32> { } impl ArgMinMax for ArrayView1<'_, i16> { - fn argminmax(self) -> Option<(usize, usize)> { + fn argminmax(self) -> (usize, usize) { // #[cfg(not(target_feature = "sse"))] // return Some(simple_argminmax(self)); // #[cfg(target_feature = "sse")] diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 46f5592..2d53f19 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_f32(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_f32(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_f32_arr(reg: __m256) -> [f32; 8] { unsafe { std::mem::transmute::<__m256, [f32; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f32, usize, f32, usize) { // Efficient calculation of argmin and argmax together @@ -92,7 +74,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f32, u #[cfg(test)] mod tests { - use super::{argminmax_f32, simple_argminmax}; + use super::argminmax_f32; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -108,7 +93,7 @@ mod tests { assert_eq!(data.len() % 8, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -131,7 +116,7 @@ mod tests { assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -141,7 +126,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_f32(32 * 8 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index 6672785..55377ee 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 4; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_f64(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 4) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_f64(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_f64_arr(reg: __m256d) -> [f64; 4] { unsafe { std::mem::transmute::<__m256d, [f64; 4]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f64, usize, f64, usize) { // Efficient calculation of argmin and argmax together @@ -84,7 +66,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f64, u #[cfg(test)] mod tests { - use super::{argminmax_f64, simple_argminmax}; + use super::argminmax_f64; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -100,7 +85,7 @@ mod tests { assert_eq!(data.len() % 4, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -123,7 +108,7 @@ mod tests { assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -133,7 +118,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_f64(32 * 8 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index e4561fb..3f3c03e 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 16; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i16(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 16) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i16(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i16_arr(reg: __m256i) -> [i16; 16] { unsafe { std::mem::transmute::<__m256i, [i16; 16]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i16, usize, i16, usize) { // Efficient calculation of argmin and argmax together @@ -90,7 +72,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i16, u #[cfg(test)] mod tests { - use super::{argminmax_i16, simple_argminmax}; + use super::argminmax_i16; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -106,7 +91,7 @@ mod tests { assert_eq!(data.len() % 16, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -131,7 +116,7 @@ mod tests { assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_simd_index, 1); assert_eq!(argmax_simd_index, 6); } @@ -141,7 +126,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_i16(32 * 2 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 15563ba..72fb7d0 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i32(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i32(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i32_arr(reg: __m256i) -> [i32; 8] { unsafe { std::mem::transmute::<__m256i, [i32; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i32, usize, i32, usize) { // Efficient calculation of argmin and argmax together @@ -87,7 +69,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i32, u #[cfg(test)] mod tests { - use super::{argminmax_i32, simple_argminmax}; + use super::argminmax_i32; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -103,7 +88,7 @@ mod tests { assert_eq!(data.len() % 8, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -127,7 +112,7 @@ mod tests { assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_simd_index, 0); assert_eq!(argmax_simd_index, 5); } @@ -137,7 +122,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_i32(32 * 8 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index a4d6115..37925e3 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 4; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i64(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 4) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i64(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i64_arr(reg: __m256i) -> [i64; 4] { unsafe { std::mem::transmute::<__m256i, [i64; 4]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i64, usize, i64, usize) { // Efficient calculation of argmin and argmax together @@ -107,7 +89,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i64, u #[cfg(test)] mod tests { - use super::{argminmax_i64, simple_argminmax}; + use super::argminmax_i64; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -123,7 +108,7 @@ mod tests { assert_eq!(data.len() % 4, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -147,7 +132,7 @@ mod tests { assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_simd_index, 0); assert_eq!(argmax_simd_index, 5); } @@ -157,7 +142,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_i64(32 * 8 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index c2d2f76..8de252f 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 32; + // ------------------------------------ ARGMINMAX -------------------------------------- -pub fn argminmax_i8(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 32) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } +pub fn argminmax_i8(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_i8_arr(reg: __m256i) -> [i8; 32] { unsafe { std::mem::transmute::<__m256i, [i8; 32]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i8, usize, i8, usize) { // Efficient calculation of argmin and argmax together @@ -106,7 +88,10 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i8, usi #[cfg(test)] mod tests { - use super::{argminmax_i8, simple_argminmax}; + use super::argminmax_i8; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; extern crate dev_utils; @@ -122,7 +107,7 @@ mod tests { assert_eq!(data.len() % 32, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -137,7 +122,7 @@ mod tests { assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_simd_index, 1); assert_eq!(argmax_simd_index, 6); } @@ -147,7 +132,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_i8(32 * 2 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 29d672f..34a3c0f 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -1,33 +1,14 @@ -use crate::generic::{max_index_value, min_index_value, simple_argminmax}; -use crate::task::{find_final_index_minmax, split_array}; +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; +const LANE_SIZE: usize = 8; + // ------------------------------------ ARGMINMAX -------------------------------------- pub fn argminmax_u16(arr: ArrayView1) -> Option<(usize, usize)> { - match split_array(arr, 8) { - (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - let rem_result = ( - rem[rem_min_index], - rem_min_index, - rem[rem_max_index], - rem_max_index, - ); - let sim_result = unsafe { core_argminmax_256(sim, rem.len()) }; - find_final_index_minmax(rem_result, sim_result) - } - (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); - Some((rem_min_index, rem_max_index)) - } - (None, Some(sim)) => { - let sim_result = unsafe { core_argminmax_256(sim, 0) }; - Some((sim_result.1, sim_result.3)) - } - (None, None) => None, - } + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) } #[inline] @@ -35,6 +16,7 @@ fn reg_to_u16_arr(reg: __m256i) -> [u16; 8] { unsafe { std::mem::transmute::<__m256i, [u16; 8]>(reg) } } +#[inline] #[target_feature(enable = "avx2")] unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (u16, usize, u16, usize) { // Efficient calculation of argmin and argmax together @@ -87,10 +69,14 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (u16, u #[cfg(test)] mod tests { - use super::{argminmax_u16, simple_argminmax}; + use super::argminmax_u32; + use crate::generic; + use generic::simple_argminmax; + use ndarray::Array1; - use rand::{thread_rng, Rng}; - use rand_distr::Uniform; + + extern crate dev_utils; + use dev_utils::utils; // TODO: duplicate code in bench config fn get_array_u16(n: usize) -> Array1 { @@ -106,7 +92,7 @@ mod tests { assert_eq!(data.len() % 8, 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } @@ -127,7 +113,7 @@ mod tests { assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_simd_index, 3); assert_eq!(argmax_simd_index, 1); } @@ -137,7 +123,7 @@ mod tests { for _ in 0..10_000 { let data = get_array_u16(32 * 8 + 1); let (argmin_index, argmax_index) = simple_argminmax(data.view()); - let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()).unwrap(); + let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); } diff --git a/src/task.rs b/src/task.rs index 825c214..1e07f55 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,8 +1,42 @@ +use crate::generic; +use generic::simple_argminmax; + use ndarray::{ArrayView1, Axis}; use std::cmp::Ordering; #[inline] -pub(crate) fn split_array( +pub(crate) fn argminmax_generic( + arr: ArrayView1, + lane_size: usize, + core_argminmax: unsafe fn(ArrayView1, usize) -> (T, usize, T, usize), +) -> (usize, usize) { + assert!(!arr.is_empty()); // split_array should never return (None, None) + match split_array(arr, lane_size) { + (Some(rem), Some(sim)) => { + let (rem_min_index, rem_max_index) = simple_argminmax(rem); + let rem_result = ( + rem[rem_min_index], + rem_min_index, + rem[rem_max_index], + rem_max_index, + ); + let sim_result = unsafe { core_argminmax(sim, rem.len()) }; + find_final_index_minmax(rem_result, sim_result) + } + (Some(rem), None) => { + let (rem_min_index, rem_max_index) = simple_argminmax(rem); + (rem_min_index, rem_max_index) + } + (None, Some(sim)) => { + let sim_result = unsafe { core_argminmax(sim, 0) }; + (sim_result.1, sim_result.3) + } + (None, None) => panic!("Array is empty"), // Should never occur because of assert + } +} + +#[inline] +fn split_array( arr: ArrayView1, lane_size: usize, ) -> (Option>, Option>) { @@ -23,10 +57,10 @@ pub(crate) fn split_array( } #[inline] -pub fn find_final_index_minmax( +fn find_final_index_minmax( remainder_result: (T, usize, T, usize), simd_result: (T, usize, T, usize), -) -> Option<(usize, usize)> { +) -> (usize, usize) { let min_result = match remainder_result.0.partial_cmp(&simd_result.0).unwrap() { Ordering::Less => remainder_result.1, Ordering::Equal => std::cmp::min(remainder_result.1, simd_result.1), @@ -39,5 +73,5 @@ pub fn find_final_index_minmax( Ordering::Greater => simd_result.3, }; - Some((min_result, max_result)) + (min_result, max_result) } From 6e1a3f699fc6b24de4c8a383136514d30a4ece83 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Sat, 5 Nov 2022 10:14:16 +0100 Subject: [PATCH 2/6] :fire: support f16 efficiently --- Cargo.toml | 9 ++- benches/bench_f16.rs | 83 ++++++++++++++++++++++ src/generic.rs | 112 ++++++++++++++--------------- src/lib.rs | 86 ++++++++++------------- src/scalar_f16.rs | 36 ++++++++++ src/simd/mod.rs | 2 + src/simd/simd_f16.rs | 163 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 382 insertions(+), 109 deletions(-) create mode 100644 benches/bench_f16.rs create mode 100644 src/scalar_f16.rs create mode 100644 src/simd/simd_f16.rs diff --git a/Cargo.toml b/Cargo.toml index 070d9c5..b336541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,13 +12,18 @@ categories = ["algorithms", "mathematics", "science"] [dependencies] -ndarray = "0.15.6" +ndarray = {version = "0.15.6", default-features = false} +half = {version = "2.1.0", default-features = false, optional = true} [dev-dependencies] criterion = "0.3.0" dev_utils = { path = "dev_utils" } +[[bench]] +name = "bench_f16" +harness = false + [[bench]] name = "bench_f32" harness = false @@ -37,4 +42,4 @@ harness = false [[bench]] name = "bench_i64" -harness = false +harness = false \ No newline at end of file diff --git a/benches/bench_f16.rs b/benches/bench_f16.rs new file mode 100644 index 0000000..a663acb --- /dev/null +++ b/benches/bench_f16.rs @@ -0,0 +1,83 @@ +#[macro_use] +extern crate criterion; +extern crate dev_utils; + +use argminmax::ArgMinMax; +use criterion::{black_box, Criterion}; +use dev_utils::{config, utils}; + +#[cfg(feature = "half")] +use half::f16; +use ndarray::Array1; + +#[cfg(feature = "half")] +fn get_random_f16_array(n: usize) -> Array1 { + let data = utils::get_random_array::(n, u16::MIN, u16::MAX); + let data: Vec = data.iter().map(|&x| f16::from_bits(x)).collect(); + // Replace NaNs and Infs with 0 + let data: Vec = data + .iter() + .map(|&x| if x.is_nan() || x.is_infinite() { f16::from_bits(0) } else { x }) + .collect(); + let arr:Array1 = Array1::from(data); + arr +} + +#[cfg(feature = "half")] +fn minmax_f16_random_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data = get_random_f16_array(n); + c.bench_function("simp_random_long_f16", |b| { + b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_random_long_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_random_array_short(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_SHORT; + let data = get_random_f16_array(n); + c.bench_function("simple_random_short_f16", |b| { + b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_random_short_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_worst_case_array_long(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_LONG; + let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); + c.bench_function("simple_worst_long_f16", |b| { + b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_worst_long_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +fn minmax_f16_worst_case_array_short(c: &mut Criterion) { + let n = config::ARRAY_LENGTH_SHORT; + let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); + c.bench_function("simple_worst_short_f16", |b| { + b.iter(|| argminmax:::simple_argminmax_f16(black_box(data.view()))) + }); + c.bench_function("simd_worst_short_f16", |b| { + b.iter(|| black_box(data.view().argminmax())) + }); +} + +#[cfg(feature = "half")] +criterion_group!( + benches, + minmax_f16_random_array_long, + minmax_f16_random_array_short, + minmax_f16_worst_case_array_long, + minmax_f16_worst_case_array_short +); +#[cfg(feature = "half")] +criterion_main!(benches); diff --git a/src/generic.rs b/src/generic.rs index 882358b..35c0955 100644 --- a/src/generic.rs +++ b/src/generic.rs @@ -2,68 +2,68 @@ use ndarray::ArrayView1; // ------ On ArrayView1 -// #[inline] -// pub(crate) fn simple_argmin(arr: ArrayView1) -> usize { -// let mut low_index = 0usize; -// let mut low = arr[low_index]; -// for (i, item) in arr.iter().enumerate() { -// if *item < low { -// low = *item; -// low_index = i; -// } -// } -// low_index -// } - -// #[inline] -// pub(crate) fn simple_argmax(arr: ArrayView1) -> usize { -// let mut high_index = 0usize; -// let mut high = arr[high_index]; -// for (i, item) in arr.iter().enumerate() { -// if *item > high { -// high = *item; -// high_index = i; -// } -// } -// high_index -// } +#[inline] +pub fn simple_argmin(arr: ArrayView1) -> usize { + let mut low_index = 0usize; + let mut low = arr[low_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } + } + low_index +} -// #[inline] -// pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { -// let mut low_index: usize = 0; -// let mut high_index: usize = 0; -// let mut low = arr[low_index]; -// let mut high = arr[high_index]; -// for (i, item) in arr.iter().enumerate() { -// if *item < low { -// low = *item; -// low_index = i; -// } else if *item > high { -// high = *item; -// high_index = i; -// } -// } -// (low_index, high_index) -// } +#[inline] +pub fn simple_argmax(arr: ArrayView1) -> usize { + let mut high_index = 0usize; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item > high { + high = *item; + high_index = i; + } + } + high_index +} -// Note: 5-7% faster than the above implementation #[inline] pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { - let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( - (0usize, arr[0], 0usize, arr[0]), - |(min_idx, min, max_idx, max), (idx, item)| { - if *item < min { - (idx, *item, max_idx, max) - } else if *item > max { - (min_idx, min, idx, *item) - } else { - (min_idx, min, max_idx, max) - } - }, - ); - (minmax_tuple.0, minmax_tuple.2) + let mut low_index: usize = 0; + let mut high_index: usize = 0; + let mut low = arr[low_index]; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } else if *item > high { + high = *item; + high_index = i; + } + } + (low_index, high_index) } +// Note: 5-7% faster than the above implementation (for floats) +// #[inline] +// pub fn simple_argminmax_fold(arr: ArrayView1) -> (usize, usize) { +// let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( +// (0usize, arr[0], 0usize, arr[0]), +// |(min_idx, min, max_idx, max), (idx, item)| { +// if *item < min { +// (idx, *item, max_idx, max) +// } else if *item > max { +// (min_idx, min, idx, *item) +// } else { +// (min_idx, min, max_idx, max) +// } +// }, +// ); +// (minmax_tuple.0, minmax_tuple.2) +// } + // ------ On &[T] // Note: these two functions are necessary because in the final SIMD registers the diff --git a/src/lib.rs b/src/lib.rs index e5218ef..e5634d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod generic; mod simd; mod task; +mod scalar_f16; pub use generic::simple_argminmax; pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; @@ -8,59 +9,42 @@ pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; use ndarray::ArrayView1; pub trait ArgMinMax { - fn argminmax(self) -> (usize, usize); -} + // TODO: future work implement these other functions + // fn min(self) -> Self::Item; + // fn max(self) -> Self::Item; + // fn minmax(self) -> (T, T); -impl ArgMinMax for ArrayView1<'_, f64> { - fn argminmax(self) -> (usize, usize) { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_f64::argminmax_f64(self); - } -} - -impl ArgMinMax for ArrayView1<'_, i64> { - fn argminmax(self) -> (usize, usize) { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i64::argminmax_i64(self); - } -} - -impl ArgMinMax for ArrayView1<'_, f32> { - fn argminmax(self) -> (usize, usize) { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_f32::argminmax_f32(self); - } -} - -impl ArgMinMax for ArrayView1<'_, i32> { - fn argminmax(self) -> (usize, usize) { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i32::argminmax_i32(self); - } + // fn argmin(self) -> usize; + // fn argmax(self) -> usize; + fn argminmax(self) -> (usize, usize); } -impl ArgMinMax for ArrayView1<'_, i16> { - fn argminmax(self) -> (usize, usize) { - // #[cfg(not(target_feature = "sse"))] - // return Some(simple_argminmax(self)); - // #[cfg(target_feature = "sse")] - return simd_i16::argminmax_i16(self); - } +macro_rules! impl_argminmax { + ($t:ty, $scalar_func:ident, $simd_mod:ident, $simd_func:ident) => { + impl ArgMinMax for ArrayView1<'_, $t> { + fn argminmax(self) -> (usize, usize) { + // TODO: what to do with cfg target_feature? + #[cfg(not(target_feature = "sse"))] + return $scalar_func(self); + #[cfg(target_feature = "sse")] + return $simd_mod::$simd_func(self); + } + } + }; } -// impl ArgMinMax for ArrayView1<'_, i8> { -// fn argminmax(self) -> Option<(usize, usize)> { -// #[cfg(not(target_feature = "sse"))] -// return Some(simple_argminmax(self)); -// #[cfg(target_feature = "sse")] -// return simd_i8::argminmax_i8(self); -// } -// } +// Implement ArgMinMax for the rust primitive types +impl_argminmax!(f32, simple_argminmax, simd_f32, argminmax_f32); +impl_argminmax!(f64, simple_argminmax, simd_f64, argminmax_f64); +impl_argminmax!(i16, simple_argminmax, simd_i16, argminmax_i16); +impl_argminmax!(i32, simple_argminmax, simd_i32, argminmax_i32); +impl_argminmax!(i64, simple_argminmax, simd_i64, argminmax_i64); + +#[cfg(feature = "half")] +use half::f16; +#[cfg(feature = "half")] +pub use scalar_f16::simple_argminmax_f16; +#[cfg(feature = "half")] +pub use simd::simd_f16; +#[cfg(feature = "half")] +impl_argminmax!(f16, simple_argminmax_f16, simd_f16, argminmax_f16); diff --git a/src/scalar_f16.rs b/src/scalar_f16.rs new file mode 100644 index 0000000..d986351 --- /dev/null +++ b/src/scalar_f16.rs @@ -0,0 +1,36 @@ +use ndarray::ArrayView1; +#[cfg(feature = "half")] +use half::f16; + +// ------ On ArrayView1 + +#[cfg(feature = "half")] +#[inline] +fn f16_to_i16ord(x: f16) -> i16 { + let x = unsafe { std::mem::transmute::(x) }; + ((x >> 15) & 0x7FFF) ^ x +} + +#[cfg(feature = "half")] +#[inline] +pub fn simple_argminmax_f16(arr: ArrayView1) -> (usize, usize) { + // f16 is transformed to i16ord + // benchmarks show: + // 1. this is 7-10x faster than using raw f16 + // 2. this is 3x faster than transforming to f32 or f64 + let mut low_index: usize = 0; + let mut high_index: usize = 0; + let mut low = f16_to_i16ord(arr[low_index]); + let mut high = f16_to_i16ord(arr[high_index]); + for (i, item) in arr.iter().enumerate() { + let item = f16_to_i16ord(*item); + if item < low { + low = item; + low_index = i; + } else if item > high { + high = item; + high_index = i; + } + } + (low_index, high_index) +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 9508e4f..ca3e2d1 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -3,6 +3,8 @@ pub mod simd_f64; pub use simd_f64::*; pub mod simd_f32; pub use simd_f32::*; +pub mod simd_f16; +pub use simd_f16::*; // SIGNED INT pub mod simd_i64; pub use simd_i64::*; diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs new file mode 100644 index 0000000..26bc739 --- /dev/null +++ b/src/simd/simd_f16.rs @@ -0,0 +1,163 @@ +use crate::generic::{max_index_value, min_index_value}; +use crate::task::argminmax_generic; +use ndarray::ArrayView1; +use std::arch::x86_64::*; + +#[cfg(feature = "half")] +use half::f16; + +const LANE_SIZE: usize = 16; + +// ------------------------------------ ARGMINMAX -------------------------------------- + +#[cfg(feature = "half")] +pub fn argminmax_f16(arr: ArrayView1) -> (usize, usize) { + argminmax_generic(arr, LANE_SIZE, core_argminmax_256) +} + +#[inline] +fn reg_to_i16_arr(reg: __m256i) -> [i16; 16] { + unsafe { std::mem::transmute::<__m256i, [i16; 16]>(reg) } +} + +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn f16_as_m256i_to_ord_i16(f16_as_m256i: __m256i) -> __m256i { + // on a scalar: ((v >> 15) & 0x7FFF) ^ v + let sign_bit_shifted = _mm256_srai_epi16(f16_as_m256i, 15); + let sign_bit_masked = _mm256_and_si256(sign_bit_shifted, _mm256_set1_epi16(0x7FFF)); + _mm256_xor_si256(sign_bit_masked, f16_as_m256i) +} + +#[cfg(feature = "half")] +#[inline] +fn ord_i16_to_f16(ord_i16: i16) -> f16 { + let v = ((ord_i16 >> 15) & 0x7FFF) ^ ord_i16; + unsafe { std::mem::transmute::(v) } +} + +#[cfg(feature = "half")] +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f16, usize, f16, usize) { + // Efficient calculation of argmin and argmax together + let offset = _mm256_set1_epi16(offset as i16); + let mut new_index = _mm256_add_epi16( + _mm256_set_epi16(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), + offset, + ); + let mut index_low = new_index; + let mut index_high = new_index; + + let increment = _mm256_set1_epi16(16); + + // println!("raw new values: {:?}", sim_arr.slice(s![0..16])); + let new_values = _mm256_loadu_si256(sim_arr.as_ptr() as *const __m256i); + // println!("new_values: {:?}", reg_to_i16_arr(new_values)); + let new_values = f16_as_m256i_to_ord_i16(new_values); + // println!("new_values: {:?}", reg_to_i16_arr(new_values)); + // println!(); + let mut values_low = new_values; + let mut values_high = new_values; + + sim_arr + .exact_chunks(16) + .into_iter() + .skip(1) + .for_each(|step| { + new_index = _mm256_add_epi16(new_index, increment); + + let new_values = _mm256_loadu_si256(step.as_ptr() as *const __m256i); + let new_values = f16_as_m256i_to_ord_i16(new_values); + let gt_mask = _mm256_cmpgt_epi16(new_values, values_high); + // Below does not work (bc instruction is not available) + // let lt_mask = _mm256_cmplt_epi16(new_values, values_low); + // Solution: swap parameters and use gt instead + let lt_mask = _mm256_cmpgt_epi16(values_low, new_values); + + index_low = _mm256_blendv_epi8(index_low, new_index, lt_mask); + index_high = _mm256_blendv_epi8(index_high, new_index, gt_mask); + + values_low = _mm256_blendv_epi8(values_low, new_values, lt_mask); + values_high = _mm256_blendv_epi8(values_high, new_values, gt_mask); + }); + + // Select max_index and max_value + let value_array = reg_to_i16_arr(values_high); + let index_array = reg_to_i16_arr(index_high); + let (index_max, value_max) = max_index_value(&index_array, &value_array); + + // Select min_index and min_value + let value_array = reg_to_i16_arr(values_low); + let index_array = reg_to_i16_arr(index_low); + let (index_min, value_min) = min_index_value(&index_array, &value_array); + + (ord_i16_to_f16(value_min), index_min as usize, ord_i16_to_f16(value_max), index_max as usize) +} + +//----- TESTS ----- + +#[cfg(feature = "half")] +#[cfg(test)] +mod tests { + use super::argminmax_f16; + use crate::generic; + use generic::simple_argminmax; + + use half::f16; + use ndarray::Array1; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f16(n: usize) -> Array1 { + let arr = utils::get_random_array(n, i16::MIN, i16::MAX); + let arr = arr.mapv(|x| f16::from_f32(x as f32)); + Array1::from(arr) + } + + #[test] + fn test_both_versions_return_the_same_results() { + let data = get_array_f16(1025); + assert_eq!(data.len() % 8, 1); + + let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + + #[test] + fn test_first_index_is_returned_when_identical_values_found() { + let data = [ + f16::from_f32(10.), + f16::MAX, + f16::from_f32(6.), + f16::NEG_INFINITY, + f16::NEG_INFINITY, + f16::MAX, + f16::from_f32(5_000.0), + ]; + let data: Vec = data.iter().map(|x| *x).collect(); + let data = Array1::from(data); + + let (argmin_index, argmax_index) = simple_argminmax(data.view()); + assert_eq!(argmin_index, 3); + assert_eq!(argmax_index, 1); + + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_simd_index, 3); + assert_eq!(argmax_simd_index, 1); + } + + #[test] + fn test_many_random_runs() { + for _ in 0..10_000 { + let data = get_array_f16(32 * 8 + 1); + let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } +} From 6b6e14d203f053287660bb4a4bed3185dbbd2ed0 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Sat, 5 Nov 2022 11:06:49 +0100 Subject: [PATCH 3/6] :broom: rename scalar --- benches/bench_f32.rs | 8 +++--- benches/bench_f64.rs | 8 +++--- benches/bench_i16.rs | 8 +++--- benches/bench_i32.rs | 8 +++--- benches/bench_i64.rs | 8 +++--- src/scalar_generic.rs | 65 +++++++++++++++++++++++++++++++++++++++++++ src/simd/simd_f32.rs | 11 ++++---- src/simd/simd_f64.rs | 11 ++++---- src/simd/simd_i16.rs | 11 ++++---- src/simd/simd_i32.rs | 11 ++++---- src/simd/simd_i64.rs | 11 ++++---- src/simd/simd_i8.rs | 11 ++++---- src/task.rs | 7 ++--- src/utils.rs | 38 +++++++++++++++++++++++++ 14 files changed, 156 insertions(+), 60 deletions(-) create mode 100644 src/scalar_generic.rs create mode 100644 src/utils.rs diff --git a/benches/bench_f32.rs b/benches/bench_f32.rs index 489cff9..eb4c14c 100644 --- a/benches/bench_f32.rs +++ b/benches/bench_f32.rs @@ -10,7 +10,7 @@ fn minmax_f32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, f32::MIN, f32::MAX); c.bench_function("simple_random_long_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_f32_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, f32::MIN, f32::MAX); c.bench_function("simple_random_short_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_f32_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_long_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_f32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_f32_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_short_f32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_f32", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_f64.rs b/benches/bench_f64.rs index 3dbac99..a00a158 100644 --- a/benches/bench_f64.rs +++ b/benches/bench_f64.rs @@ -10,7 +10,7 @@ fn minmax_f64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, f64::MIN, f64::MAX); c.bench_function("simple_random_long_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_f64_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, f64::MIN, f64::MAX); c.bench_function("simple_random_short_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_f64_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_long_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_f64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_f64_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1.0); c.bench_function("simple_worst_short_f64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_f64", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i16.rs b/benches/bench_i16.rs index be9ee8f..6ffd236 100644 --- a/benches/bench_i16.rs +++ b/benches/bench_i16.rs @@ -10,7 +10,7 @@ fn minmax_i16_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i16::MIN, i16::MAX); c.bench_function("simple_random_long_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i16_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i16::MIN, i16::MAX); c.bench_function("simple_random_short_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i16_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i16_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i16", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i16", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i32.rs b/benches/bench_i32.rs index f284602..482e332 100644 --- a/benches/bench_i32.rs +++ b/benches/bench_i32.rs @@ -10,7 +10,7 @@ fn minmax_i32_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i32::MIN, i32::MAX); c.bench_function("simple_random_long_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i32_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i32::MIN, i32::MAX); c.bench_function("simple_random_short_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i32_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i32", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i32_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i32", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i32", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/benches/bench_i64.rs b/benches/bench_i64.rs index 2548cdb..922644d 100644 --- a/benches/bench_i64.rs +++ b/benches/bench_i64.rs @@ -10,7 +10,7 @@ fn minmax_i64_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_random_array::(n, i64::MIN, i64::MAX); c.bench_function("simple_random_long_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_long_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -21,7 +21,7 @@ fn minmax_i64_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_random_array::(n, i64::MIN, i64::MAX); c.bench_function("simple_random_short_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_random_short_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -32,7 +32,7 @@ fn minmax_i64_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_long_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_long_i64", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -43,7 +43,7 @@ fn minmax_i64_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, 1); c.bench_function("simple_worst_short_i64", |b| { - b.iter(|| argminmax::generic::simple_argminmax(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax(black_box(data.view()))) }); c.bench_function("simd_worst_short_i64", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/src/scalar_generic.rs b/src/scalar_generic.rs new file mode 100644 index 0000000..1d82a23 --- /dev/null +++ b/src/scalar_generic.rs @@ -0,0 +1,65 @@ +use ndarray::ArrayView1; + +// ------ On ArrayView1 + +#[inline] +pub fn scalar_argmin(arr: ArrayView1) -> usize { + let mut low_index = 0usize; + let mut low = arr[low_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } + } + low_index +} + +#[inline] +pub fn scalar_argmax(arr: ArrayView1) -> usize { + let mut high_index = 0usize; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item > high { + high = *item; + high_index = i; + } + } + high_index +} + +#[inline] +pub fn scalar_argminmax(arr: ArrayView1) -> (usize, usize) { + let mut low_index: usize = 0; + let mut high_index: usize = 0; + let mut low = arr[low_index]; + let mut high = arr[high_index]; + for (i, item) in arr.iter().enumerate() { + if *item < low { + low = *item; + low_index = i; + } else if *item > high { + high = *item; + high_index = i; + } + } + (low_index, high_index) +} + +// Note: 5-7% faster than the above implementation (for floats) +// #[inline] +// pub fn scalar_argminmax_fold(arr: ArrayView1) -> (usize, usize) { +// let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( +// (0usize, arr[0], 0usize, arr[0]), +// |(min_idx, min, max_idx, max), (idx, item)| { +// if *item < min { +// (idx, *item, max_idx, max) +// } else if *item > max { +// (min_idx, min, idx, *item) +// } else { +// (min_idx, min, max_idx, max) +// } +// }, +// ); +// (minmax_tuple.0, minmax_tuple.2) +// } diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 2d53f19..68b5c84 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -75,8 +75,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f32, u #[cfg(test)] mod tests { use super::argminmax_f32; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -92,7 +91,7 @@ mod tests { let data = get_array_f32(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -112,7 +111,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); @@ -125,7 +124,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_f32(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index 55377ee..f2a4a78 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -67,8 +67,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f64, u #[cfg(test)] mod tests { use super::argminmax_f64; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -84,7 +83,7 @@ mod tests { let data = get_array_f64(1025); assert_eq!(data.len() % 4, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -104,7 +103,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); @@ -117,7 +116,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_f64(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index 3f3c03e..e5cd550 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -73,8 +73,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i16, u #[cfg(test)] mod tests { use super::argminmax_i16; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -90,7 +89,7 @@ mod tests { let data = get_array_i16(513); assert_eq!(data.len() % 16, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -112,7 +111,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); @@ -125,7 +124,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i16(32 * 2 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 72fb7d0..5b3ca8d 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -70,8 +70,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i32, u #[cfg(test)] mod tests { use super::argminmax_i32; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -87,7 +86,7 @@ mod tests { let data = get_array_i32(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -108,7 +107,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); @@ -121,7 +120,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i32(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i32(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index 37925e3..41fc10f 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -90,8 +90,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i64, u #[cfg(test)] mod tests { use super::argminmax_i64; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -107,7 +106,7 @@ mod tests { let data = get_array_i64(1025); assert_eq!(data.len() % 4, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -128,7 +127,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 0); assert_eq!(argmax_index, 5); @@ -141,7 +140,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i64(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i64(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_i8.rs b/src/simd/simd_i8.rs index 8de252f..a6319b0 100644 --- a/src/simd/simd_i8.rs +++ b/src/simd/simd_i8.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -89,8 +89,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (i8, usi #[cfg(test)] mod tests { use super::argminmax_i8; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use ndarray::Array1; @@ -106,7 +105,7 @@ mod tests { let data = get_array_i8(32 * 6 + 1); // TODO: lengte mag niet > 2^8 zijn... assert_eq!(data.len() % 32, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -118,7 +117,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 1); assert_eq!(argmax_index, 6); @@ -131,7 +130,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_i8(32 * 2 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_i8(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/task.rs b/src/task.rs index 1e07f55..75cde73 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,5 +1,4 @@ -use crate::generic; -use generic::simple_argminmax; +use crate::scalar_generic::scalar_argminmax; // TODO: dit in macro doorgeven use ndarray::{ArrayView1, Axis}; use std::cmp::Ordering; @@ -13,7 +12,7 @@ pub(crate) fn argminmax_generic( assert!(!arr.is_empty()); // split_array should never return (None, None) match split_array(arr, lane_size) { (Some(rem), Some(sim)) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); let rem_result = ( rem[rem_min_index], rem_min_index, @@ -24,7 +23,7 @@ pub(crate) fn argminmax_generic( find_final_index_minmax(rem_result, sim_result) } (Some(rem), None) => { - let (rem_min_index, rem_max_index) = simple_argminmax(rem); + let (rem_min_index, rem_max_index) = scalar_argminmax(rem); (rem_min_index, rem_max_index) } (None, Some(sim)) => { diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..0038b11 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,38 @@ +// ------ On &[T] + +// Note: these two functions are necessary because in the final SIMD registers the +// indexes are not in sorted order - this means that the first index (in the SIMD +// registers) is not necessarily the lowest min / max index when the min / max value +// occurs multiple times. + +#[inline] +pub(crate) fn min_index_value(index: &[T], values: &[T]) -> (T, T) { + assert_eq!(index.len(), values.len()); + let mut min_index: usize = 0; + let mut min_value = values[min_index]; + for (i, value) in values.iter().skip(1).enumerate() { + if *value < min_value { + min_value = *value; + min_index = i + 1; + } else if *value == min_value && index[i + 1] < index[min_index] { + min_index = i + 1; + } + } + (index[min_index], min_value) +} + +#[inline] +pub(crate) fn max_index_value(index: &[T], values: &[T]) -> (T, T) { + assert_eq!(index.len(), values.len()); + let mut max_index: usize = 0; + let mut max_value = values[max_index]; + for (i, value) in values.iter().skip(1).enumerate() { + if *value > max_value { + max_value = *value; + max_index = i + 1; + } else if *value == max_value && index[i + 1] < index[max_index] { + max_index = i + 1; + } + } + (index[max_index], max_value) +} From 768a6df6f09e0bd8106732cb38f14fecf939cb8f Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Sat, 5 Nov 2022 12:21:26 +0100 Subject: [PATCH 4/6] :broom: --- Cargo.toml | 8 ++-- README.md | 15 +++++-- benches/bench_f16.rs | 9 ++-- src/generic.rs | 104 ------------------------------------------- src/lib.rs | 21 ++++----- src/scalar_f16.rs | 32 ++++++++++++- src/simd/simd_f16.rs | 11 +++-- src/simd/simd_u32.rs | 8 ++-- 8 files changed, 72 insertions(+), 136 deletions(-) delete mode 100644 src/generic.rs diff --git a/Cargo.toml b/Cargo.toml index b336541..5fd0875 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "argminmax" -version = "0.1.1" +version = "0.2.0" authors = ["Jeroen Van Der Donckt"] edition = "2021" readme = "README.md" @@ -12,8 +12,8 @@ categories = ["algorithms", "mathematics", "science"] [dependencies] -ndarray = {version = "0.15.6", default-features = false} -half = {version = "2.1.0", default-features = false, optional = true} +ndarray = { version = "0.15.6", default-features = false } +half = { version = "2.1.0", default-features = false, optional = true } [dev-dependencies] criterion = "0.3.0" @@ -42,4 +42,4 @@ harness = false [[bench]] name = "bench_i64" -harness = false \ No newline at end of file +harness = false diff --git a/README.md b/README.md index 4434aff..bef6690 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # ArgMinMax -> Efficient argmin & argmax (in 1 function) with SIMD (avx2) for `f32`, `f64`, `i16`, `i32`, `i64` on `ndarray::ArrayView1` +> Efficient argmin & argmax (in 1 function) with SIMD (avx2) for `f16`, `f32`, `f64`, `i16`, `i32`, `i64` on `ndarray::ArrayView1` -🚀 The function is generic over the type of the array, so it can be used on an `ndarray::ArrayView1` where `T` can be `f32`, `f64`, `i16`, `i32`, `i64`. +🚀 The function is generic over the type of the array, so it can be used on an `ndarray::ArrayView1` where `T` can be `f16`*, `f32`, `f64`, `i16`, `i32`, `i64`. 👀 Note that this implementation contains no if checks, ensuring that the runtime of the function is independent of the input data its order (best-case = worst-case = average-case). +*for `f16` you should enable the 'half' feature. + ## Installing Add the following to your `Cargo.toml`: @@ -41,7 +43,14 @@ See `/benches/results`. Run the benchmarks yourself with the following command: ```bash -cargo bench --quiet --message-format=short | grep "time:" +cargo bench --quiet --message-format=short --features half | grep "time:" +``` + +## Tests + +To run the tests use the following command: +```bash +cargo test --message-format=short --features half ``` --- diff --git a/benches/bench_f16.rs b/benches/bench_f16.rs index a663acb..b6fcabd 100644 --- a/benches/bench_f16.rs +++ b/benches/bench_f16.rs @@ -2,6 +2,7 @@ extern crate criterion; extern crate dev_utils; +#[cfg(feature = "half")] use argminmax::ArgMinMax; use criterion::{black_box, Criterion}; use dev_utils::{config, utils}; @@ -28,7 +29,7 @@ fn minmax_f16_random_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = get_random_f16_array(n); c.bench_function("simp_random_long_f16", |b| { - b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) }); c.bench_function("simd_random_long_f16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -40,7 +41,7 @@ fn minmax_f16_random_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = get_random_f16_array(n); c.bench_function("simple_random_short_f16", |b| { - b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) }); c.bench_function("simd_random_short_f16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -52,7 +53,7 @@ fn minmax_f16_worst_case_array_long(c: &mut Criterion) { let n = config::ARRAY_LENGTH_LONG; let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); c.bench_function("simple_worst_long_f16", |b| { - b.iter(|| argminmax::simple_argminmax_f16(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) }); c.bench_function("simd_worst_long_f16", |b| { b.iter(|| black_box(data.view().argminmax())) @@ -64,7 +65,7 @@ fn minmax_f16_worst_case_array_short(c: &mut Criterion) { let n = config::ARRAY_LENGTH_SHORT; let data = utils::get_worst_case_array::(n, f16::from_f32(1.)); c.bench_function("simple_worst_short_f16", |b| { - b.iter(|| argminmax:::simple_argminmax_f16(black_box(data.view()))) + b.iter(|| argminmax::scalar_argminmax_f16(black_box(data.view()))) }); c.bench_function("simd_worst_short_f16", |b| { b.iter(|| black_box(data.view().argminmax())) diff --git a/src/generic.rs b/src/generic.rs deleted file mode 100644 index 35c0955..0000000 --- a/src/generic.rs +++ /dev/null @@ -1,104 +0,0 @@ -use ndarray::ArrayView1; - -// ------ On ArrayView1 - -#[inline] -pub fn simple_argmin(arr: ArrayView1) -> usize { - let mut low_index = 0usize; - let mut low = arr[low_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } - } - low_index -} - -#[inline] -pub fn simple_argmax(arr: ArrayView1) -> usize { - let mut high_index = 0usize; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item > high { - high = *item; - high_index = i; - } - } - high_index -} - -#[inline] -pub fn simple_argminmax(arr: ArrayView1) -> (usize, usize) { - let mut low_index: usize = 0; - let mut high_index: usize = 0; - let mut low = arr[low_index]; - let mut high = arr[high_index]; - for (i, item) in arr.iter().enumerate() { - if *item < low { - low = *item; - low_index = i; - } else if *item > high { - high = *item; - high_index = i; - } - } - (low_index, high_index) -} - -// Note: 5-7% faster than the above implementation (for floats) -// #[inline] -// pub fn simple_argminmax_fold(arr: ArrayView1) -> (usize, usize) { -// let minmax_tuple: (usize, T, usize, T) = arr.iter().enumerate().fold( -// (0usize, arr[0], 0usize, arr[0]), -// |(min_idx, min, max_idx, max), (idx, item)| { -// if *item < min { -// (idx, *item, max_idx, max) -// } else if *item > max { -// (min_idx, min, idx, *item) -// } else { -// (min_idx, min, max_idx, max) -// } -// }, -// ); -// (minmax_tuple.0, minmax_tuple.2) -// } - -// ------ On &[T] - -// Note: these two functions are necessary because in the final SIMD registers the -// indexes are not in sorted order - this means that the first index (in the SIMD -// registers) is not necessarily the lowest min / max index when the min / max value -// occurs multiple times. - -#[inline] -pub fn min_index_value(index: &[T], values: &[T]) -> (T, T) { - assert_eq!(index.len(), values.len()); - let mut min_index: usize = 0; - let mut min_value = values[min_index]; - for (i, value) in values.iter().skip(1).enumerate() { - if *value < min_value { - min_value = *value; - min_index = i + 1; - } else if *value == min_value && index[i + 1] < index[min_index] { - min_index = i + 1; - } - } - (index[min_index], min_value) -} - -#[inline] -pub fn max_index_value(index: &[T], values: &[T]) -> (T, T) { - assert_eq!(index.len(), values.len()); - let mut max_index: usize = 0; - let mut max_value = values[max_index]; - for (i, value) in values.iter().skip(1).enumerate() { - if *value > max_value { - max_value = *value; - max_index = i + 1; - } else if *value == max_value && index[i + 1] < index[max_index] { - max_index = i + 1; - } - } - (index[max_index], max_value) -} diff --git a/src/lib.rs b/src/lib.rs index e5634d7..268bf35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,10 @@ -pub mod generic; mod simd; mod task; +mod utils; +mod scalar_generic; mod scalar_f16; -pub use generic::simple_argminmax; +pub use scalar_generic::*; pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; use ndarray::ArrayView1; @@ -34,17 +35,17 @@ macro_rules! impl_argminmax { } // Implement ArgMinMax for the rust primitive types -impl_argminmax!(f32, simple_argminmax, simd_f32, argminmax_f32); -impl_argminmax!(f64, simple_argminmax, simd_f64, argminmax_f64); -impl_argminmax!(i16, simple_argminmax, simd_i16, argminmax_i16); -impl_argminmax!(i32, simple_argminmax, simd_i32, argminmax_i32); -impl_argminmax!(i64, simple_argminmax, simd_i64, argminmax_i64); - +impl_argminmax!(f32, scalar_argminmax, simd_f32, argminmax_f32); +impl_argminmax!(f64, scalar_argminmax, simd_f64, argminmax_f64); +impl_argminmax!(i16, scalar_argminmax, simd_i16, argminmax_i16); +impl_argminmax!(i32, scalar_argminmax, simd_i32, argminmax_i32); +impl_argminmax!(i64, scalar_argminmax, simd_i64, argminmax_i64); +// Implement ArgMinMax for other data types #[cfg(feature = "half")] use half::f16; #[cfg(feature = "half")] -pub use scalar_f16::simple_argminmax_f16; +pub use scalar_f16::scalar_argminmax_f16; #[cfg(feature = "half")] pub use simd::simd_f16; #[cfg(feature = "half")] -impl_argminmax!(f16, simple_argminmax_f16, simd_f16, argminmax_f16); +impl_argminmax!(f16, scalar_argminmax_f16, simd_f16, argminmax_f16); diff --git a/src/scalar_f16.rs b/src/scalar_f16.rs index d986351..3af8cfb 100644 --- a/src/scalar_f16.rs +++ b/src/scalar_f16.rs @@ -13,7 +13,7 @@ fn f16_to_i16ord(x: f16) -> i16 { #[cfg(feature = "half")] #[inline] -pub fn simple_argminmax_f16(arr: ArrayView1) -> (usize, usize) { +pub fn scalar_argminmax_f16(arr: ArrayView1) -> (usize, usize) { // f16 is transformed to i16ord // benchmarks show: // 1. this is 7-10x faster than using raw f16 @@ -34,3 +34,33 @@ pub fn simple_argminmax_f16(arr: ArrayView1) -> (usize, usize) { } (low_index, high_index) } + +#[cfg(feature = "half")] +#[cfg(test)] +mod tests { + use super::scalar_argminmax_f16; + use crate::scalar_generic::scalar_argminmax; + + use half::f16; + use ndarray::Array1; + + extern crate dev_utils; + use dev_utils::utils; + + fn get_array_f16(n: usize) -> Array1 { + let arr = utils::get_random_array(n, i16::MIN, i16::MAX); + let arr = arr.mapv(|x| f16::from_f32(x as f32)); + Array1::from(arr) + } + + #[test] + fn test_generic_and_specific_impl_return_the_same_results() { + for _ in 0..100 { + let data = get_array_f16(1025); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); + let (argmin_simd_index, argmax_simd_index) = scalar_argminmax_f16(data.view()); + assert_eq!(argmin_index, argmin_simd_index); + assert_eq!(argmax_index, argmax_simd_index); + } + } +} diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index 26bc739..ac45aef 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -1,4 +1,4 @@ -use crate::generic::{max_index_value, min_index_value}; +use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -101,8 +101,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f16, u #[cfg(test)] mod tests { use super::argminmax_f16; - use crate::generic; - use generic::simple_argminmax; + use crate::scalar_generic::scalar_argminmax; use half::f16; use ndarray::Array1; @@ -121,7 +120,7 @@ mod tests { let data = get_array_f16(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -141,7 +140,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); @@ -154,7 +153,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_f16(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_f16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); diff --git a/src/simd/simd_u32.rs b/src/simd/simd_u32.rs index 34a3c0f..2fc5af0 100644 --- a/src/simd/simd_u32.rs +++ b/src/simd/simd_u32.rs @@ -71,7 +71,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (u16, u mod tests { use super::argminmax_u32; use crate::generic; - use generic::simple_argminmax; + use generic::scalar_argminmax; use ndarray::Array1; @@ -91,7 +91,7 @@ mod tests { let data = get_array_u16(1025); assert_eq!(data.len() % 8, 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); @@ -109,7 +109,7 @@ mod tests { let data: Vec = data.iter().map(|x| *x).collect(); let data = Array1::from(data); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); assert_eq!(argmin_index, 3); assert_eq!(argmax_index, 1); @@ -122,7 +122,7 @@ mod tests { fn test_many_random_runs() { for _ in 0..10_000 { let data = get_array_u16(32 * 8 + 1); - let (argmin_index, argmax_index) = simple_argminmax(data.view()); + let (argmin_index, argmax_index) = scalar_argminmax(data.view()); let (argmin_simd_index, argmax_simd_index) = argminmax_u16(data.view()); assert_eq!(argmin_index, argmin_simd_index); assert_eq!(argmax_index, argmax_simd_index); From ba60c71322e52a14a88e19a436bea8e373beb0a0 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Sat, 5 Nov 2022 12:49:03 +0100 Subject: [PATCH 5/6] :recycle: format --- benches/bench_f16.rs | 10 ++++++++-- src/lib.rs | 4 ++-- src/scalar_f16.rs | 2 +- src/simd/simd_f16.rs | 11 ++++++++--- src/simd/simd_f32.rs | 2 +- src/simd/simd_f64.rs | 2 +- src/simd/simd_i16.rs | 2 +- src/simd/simd_i32.rs | 2 +- src/simd/simd_i64.rs | 2 +- src/task.rs | 2 +- 10 files changed, 25 insertions(+), 14 deletions(-) diff --git a/benches/bench_f16.rs b/benches/bench_f16.rs index b6fcabd..16a7172 100644 --- a/benches/bench_f16.rs +++ b/benches/bench_f16.rs @@ -18,9 +18,15 @@ fn get_random_f16_array(n: usize) -> Array1 { // Replace NaNs and Infs with 0 let data: Vec = data .iter() - .map(|&x| if x.is_nan() || x.is_infinite() { f16::from_bits(0) } else { x }) + .map(|&x| { + if x.is_nan() || x.is_infinite() { + f16::from_bits(0) + } else { + x + } + }) .collect(); - let arr:Array1 = Array1::from(data); + let arr: Array1 = Array1::from(data); arr } diff --git a/src/lib.rs b/src/lib.rs index 268bf35..5ec16ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,8 @@ +mod scalar_f16; +mod scalar_generic; mod simd; mod task; mod utils; -mod scalar_generic; -mod scalar_f16; pub use scalar_generic::*; pub use simd::{simd_f32, simd_f64, simd_i16, simd_i32, simd_i64}; diff --git a/src/scalar_f16.rs b/src/scalar_f16.rs index 3af8cfb..ec506c6 100644 --- a/src/scalar_f16.rs +++ b/src/scalar_f16.rs @@ -1,6 +1,6 @@ -use ndarray::ArrayView1; #[cfg(feature = "half")] use half::f16; +use ndarray::ArrayView1; // ------ On ArrayView1 diff --git a/src/simd/simd_f16.rs b/src/simd/simd_f16.rs index ac45aef..9bec83d 100644 --- a/src/simd/simd_f16.rs +++ b/src/simd/simd_f16.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; @@ -66,7 +66,7 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f16, u .skip(1) .for_each(|step| { new_index = _mm256_add_epi16(new_index, increment); - + let new_values = _mm256_loadu_si256(step.as_ptr() as *const __m256i); let new_values = f16_as_m256i_to_ord_i16(new_values); let gt_mask = _mm256_cmpgt_epi16(new_values, values_high); @@ -92,7 +92,12 @@ unsafe fn core_argminmax_256(sim_arr: ArrayView1, offset: usize) -> (f16, u let index_array = reg_to_i16_arr(index_low); let (index_min, value_min) = min_index_value(&index_array, &value_array); - (ord_i16_to_f16(value_min), index_min as usize, ord_i16_to_f16(value_max), index_max as usize) + ( + ord_i16_to_f16(value_min), + index_min as usize, + ord_i16_to_f16(value_max), + index_max as usize, + ) } //----- TESTS ----- diff --git a/src/simd/simd_f32.rs b/src/simd/simd_f32.rs index 68b5c84..ceb1dac 100644 --- a/src/simd/simd_f32.rs +++ b/src/simd/simd_f32.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; diff --git a/src/simd/simd_f64.rs b/src/simd/simd_f64.rs index f2a4a78..427c17b 100644 --- a/src/simd/simd_f64.rs +++ b/src/simd/simd_f64.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; diff --git a/src/simd/simd_i16.rs b/src/simd/simd_i16.rs index e5cd550..f4ea34e 100644 --- a/src/simd/simd_i16.rs +++ b/src/simd/simd_i16.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; diff --git a/src/simd/simd_i32.rs b/src/simd/simd_i32.rs index 5b3ca8d..31c7ccf 100644 --- a/src/simd/simd_i32.rs +++ b/src/simd/simd_i32.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; diff --git a/src/simd/simd_i64.rs b/src/simd/simd_i64.rs index 41fc10f..edee939 100644 --- a/src/simd/simd_i64.rs +++ b/src/simd/simd_i64.rs @@ -1,5 +1,5 @@ -use crate::utils::{max_index_value, min_index_value}; use crate::task::argminmax_generic; +use crate::utils::{max_index_value, min_index_value}; use ndarray::ArrayView1; use std::arch::x86_64::*; diff --git a/src/task.rs b/src/task.rs index 75cde73..573e280 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,4 @@ -use crate::scalar_generic::scalar_argminmax; // TODO: dit in macro doorgeven +use crate::scalar_generic::scalar_argminmax; // TODO: dit in macro doorgeven use ndarray::{ArrayView1, Axis}; use std::cmp::Ordering; From f7e3097773a3f3ceafd4094ab46f8181faa8f491 Mon Sep 17 00:00:00 2001 From: Jeroen Van Der Donckt Date: Sat, 5 Nov 2022 12:49:17 +0100 Subject: [PATCH 6/6] :robot: add CI-CD --- .github/workflows/test.yaml | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/test.yaml diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..1cd8f60 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,51 @@ +name: Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: argminmax test + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ['windows-latest', 'macOS-latest', 'ubuntu-latest'] + rust: ['stable', 'beta', 'nightly'] + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + components: clippy, rustfmt + + - name: Rust toolchain info + run: | + cargo --version --verbose + rustc --version + cargo clippy --version + cargo fmt --version + + - name: Linting + run: | + cargo fmt -- --check + cargo clippy --features half -- -D warnings + + - name: Cache Dependencies + uses: Swatinem/rust-cache@v1 + + - name: Run cargo-tarpaulin + uses: actions-rs/tarpaulin@v0.1 + with: + args: '--features half -- --test-threads 1' + + - name: Upload to codecov.io + uses: codecov/codecov-action@v3