Skip to content

Commit 70fbc80

Browse files
committed
Impl a lifetime-relaxed broadcast for ArrayView
ArrayView::broadcast has a lifetime that depends on &self instead of its internal buffer. This prevents writing some types of functions in an allocation-free way. For instance, take the numpy `meshgrid` function: It could be implemented like so: ```rust fn meshgrid_2d<'a, 'b>(coords_x: ArrayView1<'a, X>, coords_y: ArrayView1<'b, X>) -> (ArrayView2<'a, X>, ArrayView2<'b, X>) { let x_len = coords_x.shape()[0]; let y_len = coords_y.shape()[0]; let coords_x_s = coords_x.into_shape((1, y_len)).unwrap(); let coords_x_b = coords_x_s.broadcast((x_len, y_len)).unwrap(); let coords_y_s = coords_y.into_shape((x_len, 1)).unwrap(); let coords_y_b = coords_y_s.broadcast((x_len, y_len)).unwrap(); (coords_x_b, coords_y_b) } ``` Unfortunately, this doesn't work, because `coords_x_b` is bound to the lifetime of `coord_x_s`, instead of being bound to 'a. This commit introduces a new function, broadcast_ref, that does just that.
1 parent e080d62 commit 70fbc80

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

src/impl_views/methods.rs

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright 2014-2016 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use crate::imp_prelude::*;
10+
use crate::dimension::IntoDimension;
11+
use crate::dimension::size_of_shape_checked;
12+
13+
impl<'a, A, D> ArrayView<'a, A, D>
14+
where
15+
D: Dimension,
16+
{
17+
/// Broadcasts an `ArrayView`. See [`ArrayBase::broadcast`].
18+
///
19+
/// This is a specialized version of [`ArrayBase::broadcast`] that transfers
20+
/// the view's lifetime to the output.
21+
pub fn broadcast_ref<E>(&self, dim: E) -> Option<ArrayView<'a, A, E::Dim>>
22+
where
23+
E: IntoDimension,
24+
{
25+
/// Return new stride when trying to grow `from` into shape `to`
26+
///
27+
/// Broadcasting works by returning a "fake stride" where elements
28+
/// to repeat are in axes with 0 stride, so that several indexes point
29+
/// to the same element.
30+
///
31+
/// **Note:** Cannot be used for mutable iterators, since repeating
32+
/// elements would create aliasing pointers.
33+
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
34+
// Make sure the product of non-zero axis lengths does not exceed
35+
// `isize::MAX`. This is the only safety check we need to perform
36+
// because all the other constraints of `ArrayBase` are guaranteed
37+
// to be met since we're starting from a valid `ArrayBase`.
38+
let _ = size_of_shape_checked(to).ok()?;
39+
40+
let mut new_stride = to.clone();
41+
// begin at the back (the least significant dimension)
42+
// size of the axis has to either agree or `from` has to be 1
43+
if to.ndim() < from.ndim() {
44+
return None;
45+
}
46+
47+
{
48+
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
49+
for ((er, es), dr) in from
50+
.slice()
51+
.iter()
52+
.rev()
53+
.zip(stride.slice().iter().rev())
54+
.zip(new_stride_iter.by_ref())
55+
{
56+
/* update strides */
57+
if *dr == *er {
58+
/* keep stride */
59+
*dr = *es;
60+
} else if *er == 1 {
61+
/* dead dimension, zero stride */
62+
*dr = 0
63+
} else {
64+
return None;
65+
}
66+
}
67+
68+
/* set remaining strides to zero */
69+
for dr in new_stride_iter {
70+
*dr = 0;
71+
}
72+
}
73+
Some(new_stride)
74+
}
75+
let dim = dim.into_dimension();
76+
77+
// Note: zero strides are safe precisely because we return an read-only view
78+
let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) {
79+
Some(st) => st,
80+
None => return None,
81+
};
82+
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
83+
}
84+
}

src/impl_views/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod constructors;
22
mod conversions;
33
mod indexing;
4+
mod methods;
45
mod splitting;
56

67
pub use constructors::*;

0 commit comments

Comments
 (0)