Skip to content

Commit 2a9ba60

Browse files
committed
FEAT: Add method squeeze_into
This method can squeeze into a particular dimensionality. Squeezing means removing axes of length 1. When squeezing to a particular dimensionality, we may have to still pad out the shape with extra 1-shape axes to fill the dimensionality.
1 parent f31add8 commit 2a9ba60

File tree

1 file changed

+134
-9
lines changed

1 file changed

+134
-9
lines changed

src/dimension/mod.rs

+134-9
Original file line numberDiff line numberDiff line change
@@ -785,37 +785,77 @@ where
785785
}
786786

787787
/// Remove axes with length one, except never removing the last axis.
788+
///
789+
/// This function is a no-op for const dim.
788790
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
789791
where
790792
D: Dimension,
791793
{
792794
if let Some(_) = D::NDIM {
793795
return;
794796
}
797+
798+
// infallible for dyn dim
799+
let (d, s) = squeeze_into(dim, strides).unwrap();
800+
*dim = d;
801+
*strides = s;
802+
}
803+
804+
/// Remove axes with length one, except never removing the last axis.
805+
///
806+
/// Return an error if there are more non-unitary dimensions than can be stored
807+
/// in `E`. Infallible for dyn dim.
808+
///
809+
/// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
810+
/// dynamic 0D, the output can be too.
811+
///
812+
/// For const dim, this may instead pad the dimensionality with ones if it needs
813+
/// to grow to fill the target dimensionality; the dimension is padded in the
814+
/// start.
815+
pub(crate) fn squeeze_into<D, E>(dim: &D, strides: &D) -> Result<(E, E), ShapeError>
816+
where
817+
D: Dimension,
818+
E: Dimension,
819+
{
795820
debug_assert_eq!(dim.ndim(), strides.ndim());
796821

797822
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
798823
let mut ndim_new = 0;
799824
for &d in dim.slice() {
800825
if d != 1 { ndim_new += 1; }
801826
}
802-
ndim_new = Ord::max(1, ndim_new);
803-
let mut new_dim = D::zeros(ndim_new);
804-
let mut new_strides = D::zeros(ndim_new);
827+
let mut fill_ones = 0;
828+
if let Some(e_ndim) = E::NDIM {
829+
if e_ndim < ndim_new {
830+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
831+
}
832+
fill_ones = e_ndim - ndim_new;
833+
ndim_new = e_ndim;
834+
} else {
835+
// dynamic-dimensional
836+
// use minimum one dimension unless input has less than one dim
837+
if dim.ndim() > 0 && ndim_new == 0 {
838+
ndim_new = 1;
839+
fill_ones = 1;
840+
}
841+
}
842+
843+
let mut new_dim = E::zeros(ndim_new);
844+
let mut new_strides = E::zeros(ndim_new);
805845
let mut i = 0;
846+
while i < fill_ones {
847+
new_dim[i] = 1;
848+
new_strides[i] = 1;
849+
i += 1;
850+
}
806851
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
807852
if d != 1 {
808853
new_dim[i] = d;
809854
new_strides[i] = s;
810855
i += 1;
811856
}
812857
}
813-
if i == 0 {
814-
new_dim[i] = 1;
815-
new_strides[i] = 1;
816-
}
817-
*dim = new_dim;
818-
*strides = new_strides;
858+
Ok((new_dim, new_strides))
819859
}
820860

821861

@@ -1220,6 +1260,91 @@ mod test {
12201260
assert_eq!(s, sans);
12211261
}
12221262

1263+
#[test]
1264+
#[cfg(feature = "std")]
1265+
fn test_squeeze_into() {
1266+
use super::squeeze_into;
1267+
1268+
let dyndim = Dim::<&[usize]>;
1269+
1270+
// squeeze to ixdyn
1271+
let d = dyndim(&[1, 2, 1, 1, 3, 1]);
1272+
let s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1273+
let dans = dyndim(&[2, 3]);
1274+
let sans = dyndim(&[!0, 10]);
1275+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1276+
assert_eq!(d2, dans);
1277+
assert_eq!(s2, sans);
1278+
1279+
// squeeze to ixdyn does not go below 1D
1280+
let d = dyndim(&[1, 1]);
1281+
let s = dyndim(&[3, 4]);
1282+
let dans = dyndim(&[1]);
1283+
let sans = dyndim(&[1]);
1284+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1285+
assert_eq!(d2, dans);
1286+
assert_eq!(s2, sans);
1287+
1288+
let d = Dim([1, 1]);
1289+
let s = Dim([3, 4]);
1290+
let dans = Dim([1]);
1291+
let sans = Dim([1]);
1292+
let (d2, s2) = squeeze_into::<_, Ix1>(&d, &s).unwrap();
1293+
assert_eq!(d2, dans);
1294+
assert_eq!(s2, sans);
1295+
1296+
// squeeze to zero-dim
1297+
let (d2, s2) = squeeze_into::<_, Ix0>(&d, &s).unwrap();
1298+
assert_eq!(d2, Ix0());
1299+
assert_eq!(s2, Ix0());
1300+
1301+
let d = Dim([0, 1, 3, 4]);
1302+
let s = Dim([2, 3, 4, 5]);
1303+
let dans = Dim([0, 3, 4]);
1304+
let sans = Dim([2, 4, 5]);
1305+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1306+
assert_eq!(d2, dans);
1307+
assert_eq!(s2, sans);
1308+
1309+
// Pad with ones
1310+
let d = Dim([0, 1, 3, 1]);
1311+
let s = Dim([2, 3, 4, 5]);
1312+
let dans = Dim([1, 0, 3]);
1313+
let sans = Dim([1, 2, 4]);
1314+
let (d2, s2) = squeeze_into::<_, Ix3>(&d, &s).unwrap();
1315+
assert_eq!(d2, dans);
1316+
assert_eq!(s2, sans);
1317+
1318+
// Try something that doesn't fit
1319+
let d = Dim([0, 1, 3, 1]);
1320+
let s = Dim([2, 3, 4, 5]);
1321+
let res = squeeze_into::<_, Ix1>(&d, &s);
1322+
assert!(res.is_err());
1323+
let res = squeeze_into::<_, Ix0>(&d, &s);
1324+
assert!(res.is_err());
1325+
1326+
// Squeeze 0d to 0d
1327+
let d = Dim([]);
1328+
let s = Dim([]);
1329+
let res = squeeze_into::<_, Ix0>(&d, &s);
1330+
assert!(res.is_ok());
1331+
// grow 0d to 2d
1332+
let dans = Dim([1, 1]);
1333+
let sans = Dim([1, 1]);
1334+
let (d2, s2) = squeeze_into::<_, Ix2>(&d, &s).unwrap();
1335+
assert_eq!(d2, dans);
1336+
assert_eq!(s2, sans);
1337+
1338+
// Squeeze 0d to 0d dynamic
1339+
let d = dyndim(&[]);
1340+
let s = dyndim(&[]);
1341+
let (d2, s2) = squeeze_into::<_, IxDyn>(&d, &s).unwrap();
1342+
let dans = d;
1343+
let sans = s;
1344+
assert_eq!(d2, dans);
1345+
assert_eq!(s2, sans);
1346+
}
1347+
12231348
#[test]
12241349
fn test_merge_axes_from_the_back() {
12251350
let dyndim = Dim::<&[usize]>;

0 commit comments

Comments
 (0)