From 4c19f96cab08d8349cb3d734587fa532fefd4ed9 Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Mon, 12 Apr 2021 21:19:31 -0400 Subject: [PATCH 1/4] Add IntoNdProducer impl for &[T; N] Add From impl to convert from 2D slices to 2D views Move NdIndex implementations for arrays out of macros Remove FixedInitializer trait Make bounds checks in aview2 and aview_mut2 conditional on slices of ZSTs Refactor aview2 and aview_mut2 implementations into From --- .github/workflows/ci.yml | 2 +- src/arraytraits.rs | 58 +++++++++++++++-- src/dimension/ndindex.rs | 102 ++++++++++++++++------------- src/free_functions.rs | 135 +++++++-------------------------------- src/zip/ndproducer.rs | 22 ++++++- tests/array.rs | 88 ++++++++----------------- 6 files changed, 184 insertions(+), 223 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a44568ae8..1c76bd9d1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: - stable - beta - nightly - - 1.49.0 # MSRV + - 1.51.0 # MSRV steps: - uses: actions/checkout@v2 diff --git a/src/arraytraits.rs b/src/arraytraits.rs index a7f22c1f7..0bec319e6 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -6,17 +6,17 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::hash; -use std::iter::FromIterator; +use alloc::boxed::Box; +use alloc::vec::Vec; use std::iter::IntoIterator; use std::mem; use std::ops::{Index, IndexMut}; -use alloc::boxed::Box; -use alloc::vec::Vec; +use std::{hash, mem::size_of}; +use std::{iter::FromIterator, slice}; -use crate::imp_prelude::*; use crate::iter::{Iter, IterMut}; use crate::NdIndex; +use crate::{dimension, imp_prelude::*}; use crate::numeric_util; use crate::{FoldWhile, Zip}; @@ -323,6 +323,30 @@ where } } +/// Implementation of ArrayView2::from(&S) where S is a slice to a 2D array +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). +impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> { + /// Create a two-dimensional read-only array view of the data in `slice` + fn from(xs: &'a [[A; N]]) -> Self { + let cols = N; + let rows = xs.len(); + let dim = Ix2(rows, cols); + if size_of::() == 0 { + dimension::size_of_shape_checked(&dim) + .expect("Product of non-zero axis lengths must not overflow isize."); + } + + // `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in + // `isize::MAX` + unsafe { + let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows); + ArrayView::from_shape_ptr(dim, data.as_ptr()) + } + } +} + /// Implementation of `ArrayView::from(&A)` where `A` is an array. impl<'a, A, S, D> From<&'a ArrayBase> for ArrayView<'a, A, D> where @@ -355,6 +379,30 @@ where } } +/// Implementation of ArrayViewMut2::from(&S) where S is a slice to a 2D array +/// +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). +impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> { + /// Create a two-dimensional read-write array view of the data in `slice` + fn from(xs: &'a mut [[A; N]]) -> Self { + let cols = N; + let rows = xs.len(); + let dim = Ix2(rows, cols); + if size_of::() == 0 { + dimension::size_of_shape_checked(&dim) + .expect("Product of non-zero axis lengths must not overflow isize."); + } + + // `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in + // `isize::MAX` + unsafe { + let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows); + ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr()) + } + } +} + /// Implementation of `ArrayViewMut::from(&mut A)` where `A` is an array. impl<'a, A, S, D> From<&'a mut ArrayBase> for ArrayViewMut<'a, A, D> where diff --git a/src/dimension/ndindex.rs b/src/dimension/ndindex.rs index 810a5b097..718ee059b 100644 --- a/src/dimension/ndindex.rs +++ b/src/dimension/ndindex.rs @@ -140,50 +140,6 @@ macro_rules! ndindex_with_array { 0 } } - - // implement NdIndex for Dim<[Ix; 2]> and so on - unsafe impl NdIndex for Dim<[Ix; $n]> { - #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - stride_offset_checked(dim.ix(), strides.ix(), self.ix()) - } - - #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - $( - stride_offset(get!(self, $index), get!(strides, $index)) + - )* - 0 - } - } - - // implement NdIndex for [Ix; 2] and so on - unsafe impl NdIndex for [Ix; $n] { - #[inline] - fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - stride_offset_checked(dim.ix(), strides.ix(), self) - } - - #[inline] - fn index_unchecked(&self, strides: &IxDyn) -> isize { - debug_assert_eq!(strides.ndim(), $n, - "Attempted to index with {:?} in array with {} axes", - self, strides.ndim()); - $( - stride_offset(self[$index], get!(strides, $index)) + - )* - 0 - } - } )+ }; } @@ -198,6 +154,64 @@ ndindex_with_array! { [6, Ix6 0 1 2 3 4 5] } +// implement NdIndex for Dim<[Ix; 2]> and so on +unsafe impl NdIndex for Dim<[Ix; N]> { + #[inline] + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + stride_offset_checked(dim.ix(), strides.ix(), self.ix()) + } + + #[inline] + fn index_unchecked(&self, strides: &IxDyn) -> isize { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + (0..N) + .map(|i| stride_offset(get!(self, i), get!(strides, i))) + .sum() + } +} + +// implement NdIndex for [Ix; 2] and so on +unsafe impl NdIndex for [Ix; N] { + #[inline] + fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + stride_offset_checked(dim.ix(), strides.ix(), self) + } + + #[inline] + fn index_unchecked(&self, strides: &IxDyn) -> isize { + debug_assert_eq!( + strides.ndim(), + N, + "Attempted to index with {:?} in array with {} axes", + self, + strides.ndim() + ); + (0..N) + .map(|i| stride_offset(self[i], get!(strides, i))) + .sum() + } +} + impl<'a> IntoDimension for &'a [Ix] { type Dim = IxDyn; fn into_dimension(self) -> Self::Dim { diff --git a/src/free_functions.rs b/src/free_functions.rs index 2b30e0bd3..156eee6b9 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -6,10 +6,9 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::mem::{forget, size_of}; -use alloc::slice; use alloc::vec; use alloc::vec::Vec; +use std::mem::{forget, size_of}; use crate::imp_prelude::*; use crate::{dimension, ArcArray1, ArcArray2}; @@ -87,26 +86,10 @@ pub fn aview1(xs: &[A]) -> ArrayView1<'_, A> { /// Create a two-dimensional array view with elements borrowing `xs`. /// -/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This -/// can only occur when `V` is zero-sized.) -pub fn aview2>(xs: &[V]) -> ArrayView2<'_, A> { - let cols = V::len(); - let rows = xs.len(); - let dim = Ix2(rows, cols); - if size_of::() == 0 { - dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); - } - // `rows` is guaranteed to fit in `isize` because we've checked the ZST - // case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed - // to fit in `isize` because `FixedInitializer` is not implemented for any - // array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in - // `isize` because we've checked the ZST case and slices never contain > - // `isize::MAX` bytes. - unsafe { - let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows); - ArrayView::from_shape_ptr(dim, data.as_ptr()) - } +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). +pub fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> { + ArrayView2::from(xs) } /// Create a one-dimensional read-write array view with elements borrowing `xs`. @@ -127,16 +110,15 @@ pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> { /// Create a two-dimensional read-write array view with elements borrowing `xs`. /// -/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This -/// can only occur when `V` is zero-sized.) +/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A +/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes). /// /// # Example /// /// ``` /// use ndarray::aview_mut2; /// -/// // The inner (nested) array must be of length 1 to 16, but the outer -/// // can be of any length. +/// // The inner (nested) and outer arrays can be of any length. /// let mut data = [[0.; 2]; 128]; /// { /// // Make a 128 x 2 mut array view then turn it into 2 x 128 @@ -148,57 +130,10 @@ pub fn aview_mut1(xs: &mut [A]) -> ArrayViewMut1<'_, A> { /// // look at the start of the result /// assert_eq!(&data[..3], [[1., -1.], [1., -1.], [1., -1.]]); /// ``` -pub fn aview_mut2>(xs: &mut [V]) -> ArrayViewMut2<'_, A> { - let cols = V::len(); - let rows = xs.len(); - let dim = Ix2(rows, cols); - if size_of::() == 0 { - dimension::size_of_shape_checked(&dim) - .expect("Product of non-zero axis lengths must not overflow isize."); - } - // `rows` is guaranteed to fit in `isize` because we've checked the ZST - // case and slices never contain > `isize::MAX` bytes. `cols` is guaranteed - // to fit in `isize` because `FixedInitializer` is not implemented for any - // array lengths > `isize::MAX`. `cols * rows` is guaranteed to fit in - // `isize` because we've checked the ZST case and slices never contain > - // `isize::MAX` bytes. - unsafe { - let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows); - ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr()) - } +pub fn aview_mut2(xs: &mut [[A; N]]) -> ArrayViewMut2<'_, A> { + ArrayViewMut2::from(xs) } -/// Fixed-size array used for array initialization -#[allow(clippy::missing_safety_doc)] // Should not be implemented downstream and to be deprecated. -pub unsafe trait FixedInitializer { - type Elem; - fn as_init_slice(&self) -> &[Self::Elem]; - fn len() -> usize; -} - -macro_rules! impl_arr_init { - (__impl $n: expr) => ( - unsafe impl FixedInitializer for [T; $n] { - type Elem = T; - fn as_init_slice(&self) -> &[T] { self } - fn len() -> usize { $n } - } - ); - () => (); - ($n: expr, $($m:expr,)*) => ( - impl_arr_init!(__impl $n); - impl_arr_init!($($m,)*); - ) - -} - -// For implementors: If you ever implement `FixedInitializer` for array lengths -// > `isize::MAX` (e.g. once Rust adds const generics), you must update -// `aview2` and `aview_mut2` to perform the necessary checks. In particular, -// the assumption that `cols` can never exceed `isize::MAX` would be incorrect. -// (Consider e.g. `let xs: &[[i32; ::std::usize::MAX]] = &[]`.) -impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); - /// Create a two-dimensional array with elements from `xs`. /// /// ``` @@ -210,22 +145,16 @@ impl_arr_init!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,); /// a.shape() == [2, 3] /// ); /// ``` -pub fn arr2>(xs: &[V]) -> Array2 -where - V: Clone, -{ +pub fn arr2(xs: &[[A; N]]) -> Array2 { Array2::from(xs.to_vec()) } -impl From> for Array2 -where - V: FixedInitializer, -{ +impl From> for Array2 { /// Converts the `Vec` of arrays to an owned 2-D array. /// /// **Panics** if the product of non-zero axis lengths overflows `isize`. - fn from(mut xs: Vec) -> Self { - let dim = Ix2(xs.len(), V::len()); + fn from(mut xs: Vec<[A; N]>) -> Self { + let dim = Ix2(xs.len(), N); let ptr = xs.as_mut_ptr(); let cap = xs.capacity(); let expand_len = dimension::size_of_shape_checked(&dim) @@ -234,12 +163,12 @@ where unsafe { let v = if size_of::() == 0 { Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) - } else if V::len() == 0 { + } else if N == 0 { Vec::new() } else { // Guaranteed not to overflow in this case since A is non-ZST // and Vec never allocates more than isize bytes. - let expand_cap = cap * V::len(); + let expand_cap = cap * N; Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) }; ArrayBase::from_shape_vec_unchecked(dim, v) @@ -247,16 +176,12 @@ where } } -impl From> for Array3 -where - V: FixedInitializer, - U: FixedInitializer, -{ +impl From> for Array3 { /// Converts the `Vec` of arrays to an owned 3-D array. /// /// **Panics** if the product of non-zero axis lengths overflows `isize`. - fn from(mut xs: Vec) -> Self { - let dim = Ix3(xs.len(), V::len(), U::len()); + fn from(mut xs: Vec<[[A; M]; N]>) -> Self { + let dim = Ix3(xs.len(), N, M); let ptr = xs.as_mut_ptr(); let cap = xs.capacity(); let expand_len = dimension::size_of_shape_checked(&dim) @@ -265,12 +190,12 @@ where unsafe { let v = if size_of::() == 0 { Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) - } else if V::len() == 0 || U::len() == 0 { + } else if N == 0 || M == 0 { Vec::new() } else { // Guaranteed not to overflow in this case since A is non-ZST // and Vec never allocates more than isize bytes. - let expand_cap = cap * V::len() * U::len(); + let expand_cap = cap * N * M; Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) }; ArrayBase::from_shape_vec_unchecked(dim, v) @@ -280,7 +205,7 @@ where /// Create a two-dimensional array with elements from `xs`. /// -pub fn rcarr2>(xs: &[V]) -> ArcArray2 { +pub fn rcarr2(xs: &[[A; N]]) -> ArcArray2 { arr2(xs).into_shared() } @@ -301,23 +226,11 @@ pub fn rcarr2>(xs: &[V]) -> ArcA /// a.shape() == [3, 2, 2] /// ); /// ``` -pub fn arr3, U: FixedInitializer>( - xs: &[V], -) -> Array3 -where - V: Clone, - U: Clone, -{ +pub fn arr3(xs: &[[[A; M]; N]]) -> Array3 { Array3::from(xs.to_vec()) } /// Create a three-dimensional array with elements from `xs`. -pub fn rcarr3, U: FixedInitializer>( - xs: &[V], -) -> ArcArray -where - V: Clone, - U: Clone, -{ +pub fn rcarr3(xs: &[[[A; M]; N]]) -> ArcArray { arr3(xs).into_shared() } diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs index 619fadcc3..ca7e75fd3 100644 --- a/src/zip/ndproducer.rs +++ b/src/zip/ndproducer.rs @@ -1,4 +1,3 @@ - use crate::imp_prelude::*; use crate::Layout; use crate::NdIndex; @@ -168,6 +167,26 @@ impl<'a, A: 'a> IntoNdProducer for &'a mut [A] { } } +/// A one-dimensional array is a one-dimensional producer +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a [A; N] { + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayView1<'a, A>; + fn into_producer(self) -> Self::Output { + <_>::from(self) + } +} + +/// A mutable one-dimensional array is a mutable one-dimensional producer +impl<'a, A: 'a, const N: usize> IntoNdProducer for &'a mut [A; N] { + type Item = ::Item; + type Dim = Ix1; + type Output = ArrayViewMut1<'a, A>; + fn into_producer(self) -> Self::Output { + <_>::from(self) + } +} + /// A Vec is a one-dimensional producer impl<'a, A: 'a> IntoNdProducer for &'a Vec { type Item = ::Item; @@ -399,4 +418,3 @@ impl NdProducer for RawArrayViewMut { self.split_at(axis, index) } } - diff --git a/tests/array.rs b/tests/array.rs index a16e75fd0..821246be1 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -12,7 +12,6 @@ use defmac::defmac; use itertools::{zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; -use ndarray::indices; use ndarray::{Slice, SliceInfo, SliceInfoElem}; use num_complex::Complex; use std::convert::TryFrom; @@ -731,7 +730,7 @@ fn diag() { let a = arr2(&[[1., 2., 3.0f32], [0., 0., 0.]]); let d = a.view().into_diag(); assert_eq!(d.dim(), 2); - let d = arr2::(&[[]]).into_diag(); + let d = arr2::(&[[]]).into_diag(); assert_eq!(d.dim(), 0); let d = ArcArray::::zeros(()).into_diag(); assert_eq!(d.dim(), 1); @@ -960,7 +959,7 @@ fn zero_axes() { a.map_inplace(|_| panic!()); a.for_each(|_| panic!()); println!("{:?}", a); - let b = arr2::(&[[], [], [], []]); + let b = arr2::(&[[], [], [], []]); println!("{:?}\n{:?}", b.shape(), b); // we can even get a subarray of b @@ -2071,9 +2070,8 @@ fn test_view_from_shape_ptr() { #[test] fn test_view_from_shape_ptr_deny_neg_strides() { let data = [0, 1, 2, 3, 4, 5]; - let _view = unsafe { - ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) - }; + let _view = + unsafe { ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; } #[should_panic(expected = "Unsupported")] @@ -2466,74 +2464,48 @@ mod array_cow_tests { #[test] fn test_remove_index() { - let mut a = arr2(&[[1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10,11,12]]); + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, - array![[1, 2], - [7, 8], - [10,11]]); - - let mut a = arr2(&[[1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10,11,12]]); + assert_eq!(a, array![[1, 2], [7, 8], [10, 11]]); + + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.invert_axis(Axis(0)); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, - array![[10,11], - [4, 5], - [1, 2]]); + assert_eq!(a, array![[10, 11], [4, 5], [1, 2]]); a.remove_index(Axis(1), 1); assert_eq!(a.shape(), &[3, 1]); - assert_eq!(a, - array![[10], - [4], - [1]]); + assert_eq!(a, array![[10], [4], [1]]); a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, - array![[], - [], - []]); + assert_eq!(a, array![[], [], []]); } -#[should_panic(expected="must be less")] +#[should_panic(expected = "must be less")] #[test] fn test_remove_index_oob1() { - let mut a = arr2(&[[1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10,11,12]]); + let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); a.remove_index(Axis(0), 4); } -#[should_panic(expected="must be less")] +#[should_panic(expected = "must be less")] #[test] fn test_remove_index_oob2() { let mut a = array![[10], [4], [1]]; a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, - array![[], - [], - []]); + assert_eq!(a, array![[], [], []]); a.remove_index(Axis(0), 1); // ok - assert_eq!(a, - array![[], - []]); + assert_eq!(a, array![[], []]); a.remove_index(Axis(1), 0); // oob } -#[should_panic(expected="index out of bounds")] +#[should_panic(expected = "index out of bounds")] #[test] fn test_remove_index_oob3() { let mut a = array![[10], [4], [1]]; @@ -2552,14 +2524,10 @@ fn test_split_complex_view() { #[test] fn test_split_complex_view_roundtrip() { - let a_re = Array3::from_shape_fn((3,1,5), |(i, j, _k)| { - i * j - }); - let a_im = Array3::from_shape_fn((3,1,5), |(_i, _j, k)| { - k - }); - let a = Array3::from_shape_fn((3,1,5), |(i,j,k)| { - Complex::new(a_re[[i,j,k]], a_im[[i,j,k]]) + let a_re = Array3::from_shape_fn((3, 1, 5), |(i, j, _k)| i * j); + let a_im = Array3::from_shape_fn((3, 1, 5), |(_i, _j, k)| k); + let a = Array3::from_shape_fn((3, 1, 5), |(i, j, k)| { + Complex::new(a_re[[i, j, k]], a_im[[i, j, k]]) }); let Complex { re, im } = a.view().split_complex(); assert_eq!(a_re, re); @@ -2590,18 +2558,18 @@ fn test_split_complex_zerod() { #[test] fn test_split_complex_permuted() { - let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| { - Complex::new(i * k + j, k) - }); - let permuted = a.view().permuted_axes([1,0,2]); + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::new(i * k + j, k)); + let permuted = a.view().permuted_axes([1, 0, 2]); let Complex { re, im } = permuted.split_complex(); - assert_eq!(re.get((3,2,4)).unwrap(), &11); - assert_eq!(im.get((3,2,4)).unwrap(), &4); + assert_eq!(re.get((3, 2, 4)).unwrap(), &11); + assert_eq!(im.get((3, 2, 4)).unwrap(), &4); } #[test] fn test_split_complex_invert_axis() { - let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64)); + let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| { + Complex::new(i as f64 + j as f64, i as f64 + k as f64) + }); a.invert_axis(Axis(1)); let cmplx = a.view().split_complex(); assert_eq!(cmplx.re, a.mapv(|z| z.re)); From ad66360f09874931ee19bf1356f4076c31c39b49 Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Fri, 3 Dec 2021 21:30:36 -0500 Subject: [PATCH 2/4] Fix imports --- tests/array.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/array.rs b/tests/array.rs index 821246be1..c4b590b4a 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -11,8 +11,7 @@ use approx::assert_relative_eq; use defmac::defmac; use itertools::{zip, Itertools}; use ndarray::prelude::*; -use ndarray::{arr3, rcarr2}; -use ndarray::{Slice, SliceInfo, SliceInfoElem}; +use ndarray::{arr3, indices, rcarr2, Slice, SliceInfo, SliceInfoElem}; use num_complex::Complex; use std::convert::TryFrom; From a1d268b4d243ca59e63277f055e2afaeb9d11d75 Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Sat, 4 Dec 2021 13:16:21 -0500 Subject: [PATCH 3/4] Fix arraytraits imports --- src/arraytraits.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/arraytraits.rs b/src/arraytraits.rs index 0bec319e6..6a4fd1137 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -14,12 +14,12 @@ use std::ops::{Index, IndexMut}; use std::{hash, mem::size_of}; use std::{iter::FromIterator, slice}; -use crate::iter::{Iter, IterMut}; -use crate::NdIndex; -use crate::{dimension, imp_prelude::*}; - -use crate::numeric_util; -use crate::{FoldWhile, Zip}; +use crate::imp_prelude::*; +use crate::{ + dimension, + iter::{Iter, IterMut}, + numeric_util, FoldWhile, NdIndex, Zip, +}; #[cold] #[inline(never)] From 75a27e515d494ac72b3a909bca45401b70ea18e6 Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Sat, 4 Dec 2021 13:31:01 -0500 Subject: [PATCH 4/4] Undo formatting changes in tests --- tests/array.rs | 87 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 27 deletions(-) diff --git a/tests/array.rs b/tests/array.rs index c4b590b4a..e3922ea8d 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -11,7 +11,9 @@ use approx::assert_relative_eq; use defmac::defmac; use itertools::{zip, Itertools}; use ndarray::prelude::*; -use ndarray::{arr3, indices, rcarr2, Slice, SliceInfo, SliceInfoElem}; +use ndarray::{arr3, rcarr2}; +use ndarray::indices; +use ndarray::{Slice, SliceInfo, SliceInfoElem}; use num_complex::Complex; use std::convert::TryFrom; @@ -2069,8 +2071,9 @@ fn test_view_from_shape_ptr() { #[test] fn test_view_from_shape_ptr_deny_neg_strides() { let data = [0, 1, 2, 3, 4, 5]; - let _view = - unsafe { ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) }; + let _view = unsafe { + ArrayView::from_shape_ptr((2, 3).strides((-3isize as usize, 1)), data.as_ptr()) + }; } #[should_panic(expected = "Unsupported")] @@ -2463,48 +2466,74 @@ mod array_cow_tests { #[test] fn test_remove_index() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + let mut a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, array![[1, 2], [7, 8], [10, 11]]); - - let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + assert_eq!(a, + array![[1, 2], + [7, 8], + [10,11]]); + + let mut a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); a.invert_axis(Axis(0)); a.remove_index(Axis(0), 1); a.remove_index(Axis(1), 2); assert_eq!(a.shape(), &[3, 2]); - assert_eq!(a, array![[10, 11], [4, 5], [1, 2]]); + assert_eq!(a, + array![[10,11], + [4, 5], + [1, 2]]); a.remove_index(Axis(1), 1); assert_eq!(a.shape(), &[3, 1]); - assert_eq!(a, array![[10], [4], [1]]); + assert_eq!(a, + array![[10], + [4], + [1]]); a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, array![[], [], []]); + assert_eq!(a, + array![[], + [], + []]); } -#[should_panic(expected = "must be less")] +#[should_panic(expected="must be less")] #[test] fn test_remove_index_oob1() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]); + let mut a = arr2(&[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10,11,12]]); a.remove_index(Axis(0), 4); } -#[should_panic(expected = "must be less")] +#[should_panic(expected="must be less")] #[test] fn test_remove_index_oob2() { let mut a = array![[10], [4], [1]]; a.remove_index(Axis(1), 0); assert_eq!(a.shape(), &[3, 0]); - assert_eq!(a, array![[], [], []]); + assert_eq!(a, + array![[], + [], + []]); a.remove_index(Axis(0), 1); // ok - assert_eq!(a, array![[], []]); + assert_eq!(a, + array![[], + []]); a.remove_index(Axis(1), 0); // oob } -#[should_panic(expected = "index out of bounds")] +#[should_panic(expected="index out of bounds")] #[test] fn test_remove_index_oob3() { let mut a = array![[10], [4], [1]]; @@ -2523,10 +2552,14 @@ fn test_split_complex_view() { #[test] fn test_split_complex_view_roundtrip() { - let a_re = Array3::from_shape_fn((3, 1, 5), |(i, j, _k)| i * j); - let a_im = Array3::from_shape_fn((3, 1, 5), |(_i, _j, k)| k); - let a = Array3::from_shape_fn((3, 1, 5), |(i, j, k)| { - Complex::new(a_re[[i, j, k]], a_im[[i, j, k]]) + let a_re = Array3::from_shape_fn((3,1,5), |(i, j, _k)| { + i * j + }); + let a_im = Array3::from_shape_fn((3,1,5), |(_i, _j, k)| { + k + }); + let a = Array3::from_shape_fn((3,1,5), |(i,j,k)| { + Complex::new(a_re[[i,j,k]], a_im[[i,j,k]]) }); let Complex { re, im } = a.view().split_complex(); assert_eq!(a_re, re); @@ -2557,18 +2590,18 @@ fn test_split_complex_zerod() { #[test] fn test_split_complex_permuted() { - let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| Complex::new(i * k + j, k)); - let permuted = a.view().permuted_axes([1, 0, 2]); + let a = Array3::from_shape_fn((3, 4, 5), |(i, j, k)| { + Complex::new(i * k + j, k) + }); + let permuted = a.view().permuted_axes([1,0,2]); let Complex { re, im } = permuted.split_complex(); - assert_eq!(re.get((3, 2, 4)).unwrap(), &11); - assert_eq!(im.get((3, 2, 4)).unwrap(), &4); + assert_eq!(re.get((3,2,4)).unwrap(), &11); + assert_eq!(im.get((3,2,4)).unwrap(), &4); } #[test] fn test_split_complex_invert_axis() { - let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| { - Complex::new(i as f64 + j as f64, i as f64 + k as f64) - }); + let mut a = Array::from_shape_fn((2, 3, 2), |(i, j, k)| Complex::new(i as f64 + j as f64, i as f64 + k as f64)); a.invert_axis(Axis(1)); let cmplx = a.view().split_complex(); assert_eq!(cmplx.re, a.mapv(|z| z.re));