Skip to content

Commit 49bf28a

Browse files
committedApr 12, 2021
FEAT: Use Baseiter optimizations and arbitrary order where possible
1 parent 33506cf commit 49bf28a

File tree

11 files changed

+81
-82
lines changed

11 files changed

+81
-82
lines changed
 

‎src/array_serde.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use alloc::vec::Vec;
1717
use crate::imp_prelude::*;
1818

1919
use super::arraytraits::ARRAY_FORMAT_VERSION;
20-
use super::Iter;
20+
use super::iter::Iter;
2121
use crate::IntoDimension;
2222

2323
/// Verifies that the version of the deserialized array matches the current

‎src/dimension/mod.rs

+4-31
Original file line numberDiff line numberDiff line change
@@ -728,36 +728,6 @@ where
728728
}
729729
}
730730

731-
/// Move the axis which has the smallest absolute stride and a length
732-
/// greater than one to be the last axis.
733-
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
734-
where
735-
D: Dimension,
736-
{
737-
debug_assert_eq!(dim.ndim(), strides.ndim());
738-
match dim.ndim() {
739-
0 | 1 => {}
740-
2 => {
741-
if dim[1] <= 1
742-
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
743-
{
744-
dim.slice_mut().swap(0, 1);
745-
strides.slice_mut().swap(0, 1);
746-
}
747-
}
748-
n => {
749-
if let Some(min_stride_axis) = (0..n)
750-
.filter(|&ax| dim[ax] > 1)
751-
.min_by_key(|&ax| (strides[ax] as isize).abs())
752-
{
753-
let last = n - 1;
754-
dim.slice_mut().swap(last, min_stride_axis);
755-
strides.slice_mut().swap(last, min_stride_axis);
756-
}
757-
}
758-
}
759-
}
760-
761731
/// Remove axes with length one, except never removing the last axis.
762732
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
763733
where
@@ -801,14 +771,17 @@ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
801771
where
802772
D: Dimension,
803773
{
804-
debug_assert!(dim.ndim() > 1);
774+
if dim.ndim() <= 1 {
775+
return;
776+
}
805777
debug_assert_eq!(dim.ndim(), strides.ndim());
806778
// bubble sort axes
807779
let mut changed = true;
808780
while changed {
809781
changed = false;
810782
for i in 0..dim.ndim() - 1 {
811783
// make sure higher stride axes sort before.
784+
debug_assert!(strides.get_stride(Axis(i)) >= 0);
812785
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
813786
changed = true;
814787
dim.slice_mut().swap(i, i + 1);

‎src/impl_methods.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem;
1919
use crate::dimension;
2020
use crate::dimension::IntoDimension;
2121
use crate::dimension::{
22-
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
22+
abs_index, axes_of, do_slice, merge_axes,
2323
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
@@ -316,7 +316,7 @@ where
316316
where
317317
S: Data,
318318
{
319-
IndexedIter::new(self.view().into_elements_base())
319+
IndexedIter::new(self.view().into_elements_base_keep_dims())
320320
}
321321

322322
/// Return an iterator of indexes and mutable references to the elements of the array.
@@ -329,7 +329,7 @@ where
329329
where
330330
S: DataMut,
331331
{
332-
IndexedIterMut::new(self.view_mut().into_elements_base())
332+
IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims())
333333
}
334334

335335
/// Return a sliced view of the array.
@@ -2175,9 +2175,7 @@ where
21752175
if let Some(slc) = self.as_slice_memory_order() {
21762176
slc.iter().fold(init, f)
21772177
} else {
2178-
let mut v = self.view();
2179-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2180-
v.into_elements_base().fold(init, f)
2178+
self.view().into_elements_base_any_order().fold(init, f)
21812179
}
21822180
}
21832181

@@ -2295,9 +2293,7 @@ where
22952293
match self.try_as_slice_memory_order_mut() {
22962294
Ok(slc) => slc.iter_mut().for_each(f),
22972295
Err(arr) => {
2298-
let mut v = arr.view_mut();
2299-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2300-
v.into_elements_base().for_each(f);
2296+
arr.view_mut().into_elements_base_any_order().for_each(f);
23012297
}
23022298
}
23032299
}

‎src/impl_views/conversions.rs

+34-11
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use alloc::slice;
1010

1111
use crate::imp_prelude::*;
1212

13-
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};
14-
15-
use crate::iter::{self, AxisIter, AxisIterMut};
13+
use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut};
14+
use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder,
15+
ArbitraryOrder, NoOptimization};
1616
use crate::math_cell::MathCell;
1717
use crate::IndexLonger;
1818

@@ -140,14 +140,25 @@ impl<'a, A, D> ArrayView<'a, A, D>
140140
where
141141
D: Dimension,
142142
{
143+
/// Create a base iter fromt the view with the given order option
144+
#[inline]
145+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
146+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
147+
}
148+
149+
#[inline]
150+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> {
151+
ElementsBase::new::<NoOptimization>(self)
152+
}
153+
143154
#[inline]
144-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
145-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
155+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> {
156+
ElementsBase::new::<PreserveOrder>(self)
146157
}
147158

148159
#[inline]
149-
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
150-
ElementsBase::new(self)
160+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> {
161+
ElementsBase::new::<ArbitraryOrder>(self)
151162
}
152163

153164
pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
@@ -179,16 +190,28 @@ where
179190
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
180191
}
181192

193+
/// Create a base iter fromt the view with the given order option
182194
#[inline]
183-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
184-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
195+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
196+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
185197
}
186198

187199
#[inline]
188-
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
189-
ElementsBaseMut::new(self)
200+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> {
201+
ElementsBaseMut::new::<NoOptimization>(self)
190202
}
191203

204+
#[inline]
205+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> {
206+
ElementsBaseMut::new::<PreserveOrder>(self)
207+
}
208+
209+
#[inline]
210+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> {
211+
ElementsBaseMut::new::<ArbitraryOrder>(self)
212+
}
213+
214+
192215
/// Return the array’s data as a slice, if it is contiguous and in standard order.
193216
/// Otherwise return self in the Err branch of the result.
194217
pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> {

‎src/iterators/base.rs

+6-16
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,6 @@ pub(crate) struct Baseiter<A, D> {
5757
index: Option<D>,
5858
}
5959

60-
impl<A, D: Dimension> Baseiter<A, D> {
61-
/// Creating a Baseiter is unsafe because shape and stride parameters need
62-
/// to be correct to avoid performing an unsafe pointer offset while
63-
/// iterating.
64-
#[inline]
65-
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D> {
66-
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
67-
}
68-
}
69-
7060
impl<A, D: Dimension> Baseiter<A, D> {
7161
/// Creating a Baseiter is unsafe because shape and stride parameters need
7262
/// to be correct to avoid performing an unsafe pointer offset while
@@ -246,9 +236,9 @@ clone_bounds!(
246236
);
247237

248238
impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
249-
pub fn new(v: ArrayView<'a, A, D>) -> Self {
239+
pub fn new<F: OrderOption>(v: ArrayView<'a, A, D>) -> Self {
250240
ElementsBase {
251-
inner: v.into_base_iter(),
241+
inner: v.into_base_iter::<F>(),
252242
life: PhantomData,
253243
}
254244
}
@@ -332,7 +322,7 @@ where
332322
inner: if let Some(slc) = self_.to_slice() {
333323
ElementsRepr::Slice(slc.iter())
334324
} else {
335-
ElementsRepr::Counted(self_.into_elements_base())
325+
ElementsRepr::Counted(self_.into_elements_base_preserve_order())
336326
},
337327
}
338328
}
@@ -346,7 +336,7 @@ where
346336
IterMut {
347337
inner: match self_.try_into_slice() {
348338
Ok(x) => ElementsRepr::Slice(x.iter_mut()),
349-
Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
339+
Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()),
350340
},
351341
}
352342
}
@@ -391,9 +381,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> {
391381
}
392382

393383
impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
394-
pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
384+
pub fn new<F: OrderOption>(v: ArrayViewMut<'a, A, D>) -> Self {
395385
ElementsBaseMut {
396-
inner: v.into_base_iter(),
386+
inner: v.into_base_iter::<F>(),
397387
life: PhantomData,
398388
}
399389
}

‎src/iterators/chunks.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ where
7979
type IntoIter = ExactChunksIter<'a, A, D>;
8080
fn into_iter(self) -> Self::IntoIter {
8181
ExactChunksIter {
82-
iter: self.base.into_elements_base(),
82+
iter: self.base.into_elements_base_any_order(),
8383
chunk: self.chunk,
8484
inner_strides: self.inner_strides,
8585
}
@@ -169,7 +169,7 @@ where
169169
type IntoIter = ExactChunksIterMut<'a, A, D>;
170170
fn into_iter(self) -> Self::IntoIter {
171171
ExactChunksIterMut {
172-
iter: self.base.into_elements_base(),
172+
iter: self.base.into_elements_base_any_order(),
173173
chunk: self.chunk,
174174
inner_strides: self.inner_strides,
175175
}

‎src/iterators/lanes.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::marker::PhantomData;
33
use crate::imp_prelude::*;
44
use crate::{Layout, NdProducer};
55
use crate::iterators::Baseiter;
6+
use crate::iterators::base::NoOptimization;
67

78
impl_ndproducer! {
89
['a, A, D: Dimension]
@@ -83,7 +84,7 @@ where
8384
type IntoIter = LanesIter<'a, A, D>;
8485
fn into_iter(self) -> Self::IntoIter {
8586
LanesIter {
86-
iter: self.base.into_base_iter(),
87+
iter: self.base.into_base_iter::<NoOptimization>(),
8788
inner_len: self.inner_len,
8889
inner_stride: self.inner_stride,
8990
life: PhantomData,
@@ -134,7 +135,7 @@ where
134135
type IntoIter = LanesIterMut<'a, A, D>;
135136
fn into_iter(self) -> Self::IntoIter {
136137
LanesIterMut {
137-
iter: self.base.into_base_iter(),
138+
iter: self.base.into_base_iter::<NoOptimization>(),
138139
inner_len: self.inner_len,
139140
inner_stride: self.inner_stride,
140141
life: PhantomData,

‎src/iterators/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mod macros;
1111

1212
mod axis;
13-
mod base;
13+
pub(crate) mod base;
1414
mod chunks;
1515
pub mod iter;
1616
mod lanes;

‎src/iterators/windows.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ where
7777
type IntoIter = WindowsIter<'a, A, D>;
7878
fn into_iter(self) -> Self::IntoIter {
7979
WindowsIter {
80-
iter: self.base.into_elements_base(),
80+
iter: self.base.into_elements_base_any_order(),
8181
window: self.window,
8282
strides: self.strides,
8383
}

‎src/lib.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ pub use crate::slice::{
147147
MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim,
148148
};
149149

150-
use crate::iterators::Baseiter;
151-
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};
150+
use crate::iterators::{ElementsBase, ElementsBaseMut};
152151

153152
pub use crate::arraytraits::AsArray;
154153
#[cfg(feature = "std")]

‎tests/windows.rs

+23-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
clippy::many_single_char_names
66
)]
77

8+
use std::collections::HashSet;
9+
use std::hash::Hash;
10+
811
use ndarray::prelude::*;
912
use ndarray::Zip;
1013

@@ -117,6 +120,20 @@ fn test_window_zip() {
117120
}
118121
}
119122

123+
fn set<T>(iter: impl IntoIterator<Item = T>) -> HashSet<T>
124+
where
125+
T: Eq + Hash
126+
{
127+
iter.into_iter().collect()
128+
}
129+
130+
/// Assert equal sets (same collection but order doesn't matter)
131+
macro_rules! assert_set_eq {
132+
($a:expr, $b:expr) => {
133+
assert_eq!(set($a), set($b))
134+
}
135+
}
136+
120137
#[test]
121138
fn test_window_neg_stride() {
122139
let array = Array::from_iter(1..10).into_shape((3, 3)).unwrap();
@@ -131,24 +148,24 @@ fn test_window_neg_stride() {
131148
answer.invert_axis(Axis(1));
132149
answer.map_inplace(|a| a.invert_axis(Axis(1)));
133150

134-
itertools::assert_equal(
151+
assert_set_eq!(
135152
array.slice(s![.., ..;-1]).windows((2, 2)),
136-
answer.iter()
153+
answer.iter().map(Array::view)
137154
);
138155

139156
answer.invert_axis(Axis(0));
140157
answer.map_inplace(|a| a.invert_axis(Axis(0)));
141158

142-
itertools::assert_equal(
159+
assert_set_eq!(
143160
array.slice(s![..;-1, ..;-1]).windows((2, 2)),
144-
answer.iter()
161+
answer.iter().map(Array::view)
145162
);
146163

147164
answer.invert_axis(Axis(1));
148165
answer.map_inplace(|a| a.invert_axis(Axis(1)));
149166

150-
itertools::assert_equal(
167+
assert_set_eq!(
151168
array.slice(s![..;-1, ..]).windows((2, 2)),
152-
answer.iter()
169+
answer.iter().map(Array::view)
153170
);
154171
}

0 commit comments

Comments
 (0)