Skip to content

Commit e808e2b

Browse files
authored
Merge pull request #844 from andrei-papou/ap/stack-and-concatenate
Stack and concatenate
2 parents f69248e + 06e6145 commit e808e2b

File tree

5 files changed

+226
-11
lines changed

5 files changed

+226
-11
lines changed

src/doc/ndarray_for_numpy_users/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@
531531
//! `a[:] = 3.` | [`a.fill(3.)`][.fill()] | set all array elements to the same scalar value
532532
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
533533
//! `np.concatenate((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), &[a.view(), b.view()])`][stack()] | concatenate arrays `a` and `b` along axis 1
534+
//! `np.stack((a,b), axis=1)` | [`stack_new_axis![Axis(1), a, b]`][stack_new_axis!] or [`stack_new_axis(Axis(1), vec![a.view(), b.view()])`][stack_new_axis()] | stack arrays `a` and `b` along axis 1
534535
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
535536
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
536537
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
@@ -640,6 +641,8 @@
640641
//! [.shape()]: ../../struct.ArrayBase.html#method.shape
641642
//! [stack!]: ../../macro.stack.html
642643
//! [stack()]: ../../fn.stack.html
644+
//! [stack_new_axis!]: ../../macro.stack_new_axis.html
645+
//! [stack_new_axis()]: ../../fn.stack_new_axis.html
643646
//! [.strides()]: ../../struct.ArrayBase.html#method.strides
644647
//! [.index_axis()]: ../../struct.ArrayBase.html#method.index_axis
645648
//! [.sum_axis()]: ../../struct.ArrayBase.html#method.sum_axis

src/impl_methods.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::iter::{
2828
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
2929
};
3030
use crate::slice::MultiSlice;
31-
use crate::stacking::stack;
31+
use crate::stacking::concatenate;
3232
use crate::{NdIndex, Slice, SliceInfo, SliceOrIndex};
3333

3434
/// # Methods For All Array Types
@@ -840,7 +840,7 @@ where
840840
dim.set_axis(axis, 0);
841841
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
842842
} else {
843-
stack(axis, &subs).unwrap()
843+
concatenate(axis, &subs).unwrap()
844844
}
845845
}
846846

src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes, Lane
131131

132132
pub use crate::arraytraits::AsArray;
133133
pub use crate::linalg_traits::{LinalgScalar, NdFloat};
134-
pub use crate::stacking::stack;
134+
135+
#[allow(deprecated)]
136+
pub use crate::stacking::{concatenate, stack, stack_new_axis};
135137

136138
pub use crate::impl_views::IndexLonger;
137139
pub use crate::shape_builder::ShapeBuilder;

src/stacking.rs

+175-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use crate::error::{from_kind, ErrorKind, ShapeError};
1010
use crate::imp_prelude::*;
1111

12-
/// Stack arrays along the given axis.
12+
/// Concatenate arrays along the given axis.
1313
///
1414
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
1515
/// (may be made more flexible in the future).<br>
@@ -29,10 +29,11 @@ use crate::imp_prelude::*;
2929
/// [3., 3.]]))
3030
/// );
3131
/// ```
32-
pub fn stack<'a, A, D>(
33-
axis: Axis,
34-
arrays: &[ArrayView<'a, A, D>],
35-
) -> Result<Array<A, D>, ShapeError>
32+
#[deprecated(
33+
since = "0.13.2",
34+
note = "Please use the `concatenate` function instead"
35+
)]
36+
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
3637
where
3738
A: Copy,
3839
D: RemoveAxis,
@@ -76,7 +77,103 @@ where
7677
Ok(res)
7778
}
7879

79-
/// Stack arrays along the given axis.
80+
/// Concatenate arrays along the given axis.
81+
///
82+
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
83+
/// (may be made more flexible in the future).<br>
84+
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
85+
/// if the result is larger than is possible to represent.
86+
///
87+
/// ```
88+
/// use ndarray::{arr2, Axis, concatenate};
89+
///
90+
/// let a = arr2(&[[2., 2.],
91+
/// [3., 3.]]);
92+
/// assert!(
93+
/// concatenate(Axis(0), &[a.view(), a.view()])
94+
/// == Ok(arr2(&[[2., 2.],
95+
/// [3., 3.],
96+
/// [2., 2.],
97+
/// [3., 3.]]))
98+
/// );
99+
/// ```
100+
#[allow(deprecated)]
101+
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
102+
where
103+
A: Copy,
104+
D: RemoveAxis,
105+
{
106+
stack(axis, arrays)
107+
}
108+
109+
/// Stack arrays along the new axis.
110+
///
111+
/// ***Errors*** if the arrays have mismatching shapes.
112+
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
113+
/// if the result is larger than is possible to represent.
114+
///
115+
/// ```
116+
/// extern crate ndarray;
117+
///
118+
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
119+
///
120+
/// # fn main() {
121+
///
122+
/// let a = arr2(&[[2., 2.],
123+
/// [3., 3.]]);
124+
/// assert!(
125+
/// stack_new_axis(Axis(0), &[a.view(), a.view()])
126+
/// == Ok(arr3(&[[[2., 2.],
127+
/// [3., 3.]],
128+
/// [[2., 2.],
129+
/// [3., 3.]]]))
130+
/// );
131+
/// # }
132+
/// ```
133+
pub fn stack_new_axis<A, D>(
134+
axis: Axis,
135+
arrays: &[ArrayView<A, D>],
136+
) -> Result<Array<A, D::Larger>, ShapeError>
137+
where
138+
A: Copy,
139+
D: Dimension,
140+
D::Larger: RemoveAxis,
141+
{
142+
if arrays.is_empty() {
143+
return Err(from_kind(ErrorKind::Unsupported));
144+
}
145+
let common_dim = arrays[0].raw_dim();
146+
// Avoid panic on `insert_axis` call, return an Err instead of it.
147+
if axis.index() > common_dim.ndim() {
148+
return Err(from_kind(ErrorKind::OutOfBounds));
149+
}
150+
let mut res_dim = common_dim.insert_axis(axis);
151+
152+
if arrays.iter().any(|a| a.raw_dim() != common_dim) {
153+
return Err(from_kind(ErrorKind::IncompatibleShape));
154+
}
155+
156+
res_dim.set_axis(axis, arrays.len());
157+
158+
// we can safely use uninitialized values here because they are Copy
159+
// and we will only ever write to them
160+
let size = res_dim.size();
161+
let mut v = Vec::with_capacity(size);
162+
unsafe {
163+
v.set_len(size);
164+
}
165+
let mut res = Array::from_shape_vec(res_dim, v)?;
166+
167+
res.axis_iter_mut(axis)
168+
.zip(arrays.into_iter())
169+
.for_each(|(mut assign_view, array)| {
170+
assign_view.assign(&array);
171+
});
172+
173+
Ok(res)
174+
}
175+
176+
/// Concatenate arrays along the given axis.
80177
///
81178
/// Uses the [`stack`][1] function, calling `ArrayView::from(&a)` on each
82179
/// argument `a`.
@@ -101,9 +198,81 @@ where
101198
/// );
102199
/// # }
103200
/// ```
201+
#[deprecated(
202+
since = "0.13.2",
203+
note = "Please use the `concatenate!` macro instead"
204+
)]
104205
#[macro_export]
105206
macro_rules! stack {
106207
($axis:expr, $( $array:expr ),+ ) => {
107208
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
108209
}
109210
}
211+
212+
/// Concatenate arrays along the given axis.
213+
///
214+
/// Uses the [`concatenate`][1] function, calling `ArrayView::from(&a)` on each
215+
/// argument `a`.
216+
///
217+
/// [1]: fn.concatenate.html
218+
///
219+
/// ***Panics*** if the `concatenate` function would return an error.
220+
///
221+
/// ```
222+
/// extern crate ndarray;
223+
///
224+
/// use ndarray::{arr2, concatenate, Axis};
225+
///
226+
/// # fn main() {
227+
///
228+
/// let a = arr2(&[[2., 2.],
229+
/// [3., 3.]]);
230+
/// assert!(
231+
/// concatenate![Axis(0), a, a]
232+
/// == arr2(&[[2., 2.],
233+
/// [3., 3.],
234+
/// [2., 2.],
235+
/// [3., 3.]])
236+
/// );
237+
/// # }
238+
/// ```
239+
#[macro_export]
240+
macro_rules! concatenate {
241+
($axis:expr, $( $array:expr ),+ ) => {
242+
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
243+
}
244+
}
245+
246+
/// Stack arrays along the new axis.
247+
///
248+
/// Uses the [`stack_new_axis`][1] function, calling `ArrayView::from(&a)` on each
249+
/// argument `a`.
250+
///
251+
/// [1]: fn.stack_new_axis.html
252+
///
253+
/// ***Panics*** if the `stack` function would return an error.
254+
///
255+
/// ```
256+
/// extern crate ndarray;
257+
///
258+
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
259+
///
260+
/// # fn main() {
261+
///
262+
/// let a = arr2(&[[2., 2.],
263+
/// [3., 3.]]);
264+
/// assert!(
265+
/// stack_new_axis![Axis(0), a, a]
266+
/// == arr3(&[[[2., 2.],
267+
/// [3., 3.]],
268+
/// [[2., 2.],
269+
/// [3., 3.]]])
270+
/// );
271+
/// # }
272+
/// ```
273+
#[macro_export]
274+
macro_rules! stack_new_axis {
275+
($axis:expr, $( $array:expr ),+ ) => {
276+
$crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
277+
}
278+
}

tests/stacking.rs

+43-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
use ndarray::{arr2, aview1, stack, Array2, Axis, ErrorKind};
1+
#![allow(deprecated)]
2+
3+
use ndarray::{arr2, arr3, aview1, concatenate, stack, Array2, Axis, ErrorKind, Ix1};
24

35
#[test]
4-
fn stacking() {
6+
fn concatenating() {
57
let a = arr2(&[[2., 2.], [3., 3.]]);
68
let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap();
79
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
@@ -23,4 +25,43 @@ fn stacking() {
2325

2426
let res: Result<Array2<f64>, _> = ndarray::stack(Axis(0), &[]);
2527
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
28+
29+
let a = arr2(&[[2., 2.], [3., 3.]]);
30+
let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap();
31+
assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]]));
32+
33+
let c = concatenate![Axis(0), a, b];
34+
assert_eq!(
35+
c,
36+
arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])
37+
);
38+
39+
let d = concatenate![Axis(0), a.row(0), &[9., 9.]];
40+
assert_eq!(d, aview1(&[2., 2., 9., 9.]));
41+
42+
let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]);
43+
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
44+
45+
let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]);
46+
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
47+
48+
let res: Result<Array2<f64>, _> = ndarray::concatenate(Axis(0), &[]);
49+
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
50+
}
51+
52+
#[test]
53+
fn stacking() {
54+
let a = arr2(&[[2., 2.], [3., 3.]]);
55+
let b = ndarray::stack_new_axis(Axis(0), &[a.view(), a.view()]).unwrap();
56+
assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]]));
57+
58+
let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]);
59+
let res = ndarray::stack_new_axis(Axis(1), &[a.view(), c.view()]);
60+
assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape);
61+
62+
let res = ndarray::stack_new_axis(Axis(3), &[a.view(), a.view()]);
63+
assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds);
64+
65+
let res: Result<Array2<f64>, _> = ndarray::stack_new_axis::<_, Ix1>(Axis(0), &[]);
66+
assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported);
2667
}

0 commit comments

Comments
 (0)