Skip to content

Commit f31add8

Browse files
committed
FEAT: Use Baseiter optimizations in some places where it's possible
1 parent be87fe7 commit f31add8

File tree

10 files changed

+58
-68
lines changed

10 files changed

+58
-68
lines changed

src/array_serde.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use alloc::vec::Vec;
1717
use crate::imp_prelude::*;
1818

1919
use super::arraytraits::ARRAY_FORMAT_VERSION;
20-
use super::Iter;
20+
use super::iter::Iter;
2121
use crate::IntoDimension;
2222

2323
/// Verifies that the version of the deserialized array matches the current

src/dimension/mod.rs

+4-31
Original file line numberDiff line numberDiff line change
@@ -784,36 +784,6 @@ where
784784
}
785785
}
786786

787-
/// Move the axis which has the smallest absolute stride and a length
788-
/// greater than one to be the last axis.
789-
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
790-
where
791-
D: Dimension,
792-
{
793-
debug_assert_eq!(dim.ndim(), strides.ndim());
794-
match dim.ndim() {
795-
0 | 1 => {}
796-
2 => {
797-
if dim[1] <= 1
798-
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
799-
{
800-
dim.slice_mut().swap(0, 1);
801-
strides.slice_mut().swap(0, 1);
802-
}
803-
}
804-
n => {
805-
if let Some(min_stride_axis) = (0..n)
806-
.filter(|&ax| dim[ax] > 1)
807-
.min_by_key(|&ax| (strides[ax] as isize).abs())
808-
{
809-
let last = n - 1;
810-
dim.slice_mut().swap(last, min_stride_axis);
811-
strides.slice_mut().swap(last, min_stride_axis);
812-
}
813-
}
814-
}
815-
}
816-
817787
/// Remove axes with length one, except never removing the last axis.
818788
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
819789
where
@@ -857,14 +827,17 @@ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
857827
where
858828
D: Dimension,
859829
{
860-
debug_assert!(dim.ndim() > 1);
830+
if dim.ndim() <= 1 {
831+
return;
832+
}
861833
debug_assert_eq!(dim.ndim(), strides.ndim());
862834
// bubble sort axes
863835
let mut changed = true;
864836
while changed {
865837
changed = false;
866838
for i in 0..dim.ndim() - 1 {
867839
// make sure higher stride axes sort before.
840+
debug_assert!(strides.get_stride(Axis(i)) >= 0);
868841
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
869842
changed = true;
870843
dim.slice_mut().swap(i, i + 1);

src/impl_methods.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem;
1919
use crate::dimension;
2020
use crate::dimension::IntoDimension;
2121
use crate::dimension::{
22-
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
22+
abs_index, axes_of, do_slice, merge_axes,
2323
offset_from_low_addr_ptr_to_logical_ptr, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
@@ -433,7 +433,7 @@ where
433433
where
434434
S: Data,
435435
{
436-
IndexedIter::new(self.view().into_elements_base())
436+
IndexedIter::new(self.view().into_elements_base_keep_dims())
437437
}
438438

439439
/// Return an iterator of indexes and mutable references to the elements of the array.
@@ -446,7 +446,7 @@ where
446446
where
447447
S: DataMut,
448448
{
449-
IndexedIterMut::new(self.view_mut().into_elements_base())
449+
IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims())
450450
}
451451

452452
/// Return a sliced view of the array.
@@ -2441,9 +2441,7 @@ where
24412441
if let Some(slc) = self.as_slice_memory_order() {
24422442
slc.iter().fold(init, f)
24432443
} else {
2444-
let mut v = self.view();
2445-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2446-
v.into_elements_base().fold(init, f)
2444+
self.view().into_elements_base_any_order().fold(init, f)
24472445
}
24482446
}
24492447

@@ -2599,9 +2597,7 @@ where
25992597
match self.try_as_slice_memory_order_mut() {
26002598
Ok(slc) => slc.iter_mut().for_each(f),
26012599
Err(arr) => {
2602-
let mut v = arr.view_mut();
2603-
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
2604-
v.into_elements_base().for_each(f);
2600+
arr.view_mut().into_elements_base_any_order().for_each(f);
26052601
}
26062602
}
26072603
}

src/impl_views/conversions.rs

+34-11
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ use std::mem::MaybeUninit;
1212

1313
use crate::imp_prelude::*;
1414

15-
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};
16-
1715
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
18-
use crate::iter::{self, AxisIter, AxisIterMut};
16+
use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut};
17+
use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder,
18+
ArbitraryOrder, NoOptimization};
1919
use crate::math_cell::MathCell;
2020
use crate::IndexLonger;
2121

@@ -188,14 +188,25 @@ impl<'a, A, D> ArrayView<'a, A, D>
188188
where
189189
D: Dimension,
190190
{
191+
/// Create a base iter fromt the view with the given order option
192+
#[inline]
193+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
194+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
195+
}
196+
197+
#[inline]
198+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> {
199+
ElementsBase::new::<NoOptimization>(self)
200+
}
201+
191202
#[inline]
192-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
193-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
203+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> {
204+
ElementsBase::new::<PreserveOrder>(self)
194205
}
195206

196207
#[inline]
197-
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
198-
ElementsBase::new(self)
208+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> {
209+
ElementsBase::new::<ArbitraryOrder>(self)
199210
}
200211

201212
pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
@@ -227,16 +238,28 @@ where
227238
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
228239
}
229240

241+
/// Create a base iter fromt the view with the given order option
230242
#[inline]
231-
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
232-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
243+
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
244+
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
233245
}
234246

235247
#[inline]
236-
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
237-
ElementsBaseMut::new(self)
248+
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> {
249+
ElementsBaseMut::new::<NoOptimization>(self)
238250
}
239251

252+
#[inline]
253+
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> {
254+
ElementsBaseMut::new::<PreserveOrder>(self)
255+
}
256+
257+
#[inline]
258+
pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> {
259+
ElementsBaseMut::new::<ArbitraryOrder>(self)
260+
}
261+
262+
240263
/// Return the array’s data as a slice, if it is contiguous and in standard order.
241264
/// Otherwise return self in the Err branch of the result.
242265
pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> {

src/iterators/base.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ impl<A, D: Dimension> Baseiter<A, D> {
7171

7272
/// Return the iter strides
7373
pub(crate) fn raw_strides(&self) -> D { self.strides.clone() }
74-
}
7574

76-
impl<A, D: Dimension> Baseiter<A, D> {
7775
/// Creating a Baseiter is unsafe because shape and stride parameters need
7876
/// to be correct to avoid performing an unsafe pointer offset while
7977
/// iterating.
@@ -252,9 +250,9 @@ clone_bounds!(
252250
);
253251

254252
impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
255-
pub fn new(v: ArrayView<'a, A, D>) -> Self {
253+
pub fn new<F: OrderOption>(v: ArrayView<'a, A, D>) -> Self {
256254
ElementsBase {
257-
inner: v.into_base_iter(),
255+
inner: v.into_base_iter::<F>(),
258256
life: PhantomData,
259257
}
260258
}
@@ -338,7 +336,7 @@ where
338336
inner: if let Some(slc) = self_.to_slice() {
339337
ElementsRepr::Slice(slc.iter())
340338
} else {
341-
ElementsRepr::Counted(self_.into_elements_base())
339+
ElementsRepr::Counted(self_.into_elements_base_preserve_order())
342340
},
343341
}
344342
}
@@ -352,7 +350,7 @@ where
352350
IterMut {
353351
inner: match self_.try_into_slice() {
354352
Ok(x) => ElementsRepr::Slice(x.iter_mut()),
355-
Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
353+
Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()),
356354
},
357355
}
358356
}
@@ -397,9 +395,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> {
397395
}
398396

399397
impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
400-
pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
398+
pub fn new<F: OrderOption>(v: ArrayViewMut<'a, A, D>) -> Self {
401399
ElementsBaseMut {
402-
inner: v.into_base_iter(),
400+
inner: v.into_base_iter::<F>(),
403401
life: PhantomData,
404402
}
405403
}

src/iterators/chunks.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ where
7979
type IntoIter = ExactChunksIter<'a, A, D>;
8080
fn into_iter(self) -> Self::IntoIter {
8181
ExactChunksIter {
82-
iter: self.base.into_elements_base(),
82+
iter: self.base.into_elements_base_any_order(),
8383
chunk: self.chunk,
8484
inner_strides: self.inner_strides,
8585
}
@@ -169,7 +169,7 @@ where
169169
type IntoIter = ExactChunksIterMut<'a, A, D>;
170170
fn into_iter(self) -> Self::IntoIter {
171171
ExactChunksIterMut {
172-
iter: self.base.into_elements_base(),
172+
iter: self.base.into_elements_base_any_order(),
173173
chunk: self.chunk,
174174
inner_strides: self.inner_strides,
175175
}

src/iterators/lanes.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::marker::PhantomData;
33
use crate::imp_prelude::*;
44
use crate::{Layout, NdProducer};
55
use crate::iterators::Baseiter;
6+
use crate::iterators::base::NoOptimization;
67

78
impl_ndproducer! {
89
['a, A, D: Dimension]
@@ -83,7 +84,7 @@ where
8384
type IntoIter = LanesIter<'a, A, D>;
8485
fn into_iter(self) -> Self::IntoIter {
8586
LanesIter {
86-
iter: self.base.into_base_iter(),
87+
iter: self.base.into_base_iter::<NoOptimization>(),
8788
inner_len: self.inner_len,
8889
inner_stride: self.inner_stride,
8990
life: PhantomData,
@@ -134,7 +135,7 @@ where
134135
type IntoIter = LanesIterMut<'a, A, D>;
135136
fn into_iter(self) -> Self::IntoIter {
136137
LanesIterMut {
137-
iter: self.base.into_base_iter(),
138+
iter: self.base.into_base_iter::<NoOptimization>(),
138139
inner_len: self.inner_len,
139140
inner_stride: self.inner_stride,
140141
life: PhantomData,

src/iterators/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mod macros;
1111

1212
mod axis;
13-
mod base;
13+
pub(crate) mod base;
1414
mod chunks;
1515
mod into_iter;
1616
pub mod iter;

src/iterators/windows.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ where
7777
type IntoIter = WindowsIter<'a, A, D>;
7878
fn into_iter(self) -> Self::IntoIter {
7979
WindowsIter {
80-
iter: self.base.into_elements_base(),
80+
iter: self.base.into_elements_base_preserve_order(),
8181
window: self.window,
8282
strides: self.strides,
8383
}

src/lib.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ pub use crate::slice::{
142142
MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim,
143143
};
144144

145-
use crate::iterators::Baseiter;
146-
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};
145+
use crate::iterators::{ElementsBase, ElementsBaseMut};
147146

148147
pub use crate::arraytraits::AsArray;
149148
#[cfg(feature = "std")]

0 commit comments

Comments
 (0)