Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Basic support for axis simplification and arbitrary order in iterators #979

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
72 changes: 72 additions & 0 deletions benches/iter.rs
Original file line number Diff line number Diff line change
@@ -87,6 +87,78 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher)
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
}

#[bench]
fn iter_sum_2d_row_matrix(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_row_matrix_for_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_row_matrix_sum_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(1));
bench.iter(|| v.iter().sum::<i32>());
}

#[bench]
fn iter_sum_2d_col_matrix(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_col_matrix_for_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| {
let mut s = 0;
for &elt in v.iter() {
s += elt;
}
s
});
}

#[bench]
fn iter_sum_2d_col_matrix_sum_strided(bench: &mut Bencher)
{
let a = Array::from_iter(0i32..64 * 64).slice_move(s![..;2]);
let v = a.view().insert_axis(Axis(0));
bench.iter(|| v.iter().sum::<i32>());
}

#[bench]
fn iter_rev_step_by_contiguous(bench: &mut Bencher)
{
157 changes: 157 additions & 0 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
@@ -313,6 +313,20 @@ pub trait DimensionExt
/// *Panics* if `axis` is out of bounds.
#[track_caller]
fn set_axis(&mut self, axis: Axis, value: Ix);

/// Get as stride
#[inline]
fn get_stride(&self, axis: Axis) -> isize
{
self.axis(axis) as isize
}

/// Set as stride
#[inline]
fn set_stride(&mut self, axis: Axis, value: isize)
{
self.set_axis(axis, value as usize)
}
}

impl<D> DimensionExt for D
@@ -745,6 +759,32 @@ where D: Dimension
}
}

/// Attempt to merge axes if possible, starting from the back
///
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
/// to merge all axes one by one into Axis(3); when/if this fails,
/// it attempts to merge the rest of the axes together into the next
/// axis in line, for example a result could be:
///
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
/// mean axes were merged.
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
n => {
let mut last = n - 1;
for i in (0..last).rev() {
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
last = i;
}
}
}
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -771,6 +811,67 @@ where D: Dimension
}
}

/// Remove axes with length one, except never removing the last axis.
///
/// This only has effect on dynamic dimensions.
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
if let Some(_) = D::NDIM {
return;
}
debug_assert_eq!(dim.ndim(), strides.ndim());

// Count axes with dim == 1; we keep axes with d == 0 or d > 1
let mut ndim_new = 0;
for &d in dim.slice() {
if d != 1 {
ndim_new += 1;
}
}
ndim_new = Ord::max(1, ndim_new);
let mut new_dim = D::zeros(ndim_new);
let mut new_strides = D::zeros(ndim_new);
let mut i = 0;
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
if d != 1 {
new_dim[i] = d;
new_strides[i] = s;
i += 1;
}
}
if i == 0 {
new_dim[i] = 1;
new_strides[i] = 1;
}
*dim = new_dim;
*strides = new_strides;
}

/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
/// stride
///
/// The axes are sorted according to the .abs() of their stride.
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
where D: Dimension
{
debug_assert!(dim.ndim() > 1);
debug_assert_eq!(dim.ndim(), strides.ndim());
// bubble sort axes
let mut changed = true;
while changed {
changed = false;
for i in 0..dim.ndim() - 1 {
// make sure higher stride axes sort before.
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
changed = true;
dim.slice_mut().swap(i, i + 1);
strides.slice_mut().swap(i, i + 1);
}
}
}
}

#[cfg(test)]
mod test
{
@@ -780,9 +881,11 @@ mod test
can_index_slice_not_custom,
extended_gcd,
max_abs_offset_check_overflow,
merge_axes_from_the_back,
slice_min_max,
slices_intersect,
solve_linear_diophantine_eq,
squeeze,
IntoDimension,
};
use crate::error::{from_kind, ErrorKind};
@@ -1132,4 +1235,58 @@ mod test
s![.., 3..;6, NewAxis]
));
}

#[test]
#[cfg(feature = "std")]
fn test_squeeze()
{
let dyndim = Dim::<&[usize]>;

let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
let dans = dyndim(&[2, 3]);
let sans = dyndim(&[!0, 10]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);

let mut d = dyndim(&[1, 1]);
let mut s = dyndim(&[3, 4]);
let dans = dyndim(&[1]);
let sans = dyndim(&[1]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);

let mut d = dyndim(&[0, 1, 3, 4]);
let mut s = dyndim(&[2, 3, 4, 5]);
let dans = dyndim(&[0, 3, 4]);
let sans = dyndim(&[2, 4, 5]);
squeeze(&mut d, &mut s);
assert_eq!(d, dans);
assert_eq!(s, sans);
}

#[test]
fn test_merge_axes_from_the_back()
{
let dyndim = Dim::<&[usize]>;

let mut d = Dim([3, 4, 5]);
let mut s = Dim([20, 5, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 1, 60]));
assert_eq!(s, Dim([20, 5, 1]));

let mut d = Dim([3, 4, 5, 2]);
let mut s = Dim([80, 20, 2, 1]);
merge_axes_from_the_back(&mut d, &mut s);
assert_eq!(d, Dim([1, 12, 1, 10]));
assert_eq!(s, Dim([80, 20, 2, 1]));
let mut d = d.into_dyn();
let mut s = s.into_dyn();
squeeze(&mut d, &mut s);
assert_eq!(d, dyndim(&[12, 10]));
assert_eq!(s, dyndim(&[20, 1]));
}
}
Loading