@@ -785,37 +785,77 @@ where
785
785
}
786
786
787
787
/// Remove axes with length one, except never removing the last axis.
788
+ ///
789
+ /// This function is a no-op for const dim.
788
790
pub ( crate ) fn squeeze < D > ( dim : & mut D , strides : & mut D )
789
791
where
790
792
D : Dimension ,
791
793
{
792
794
if let Some ( _) = D :: NDIM {
793
795
return ;
794
796
}
797
+
798
+ // infallible for dyn dim
799
+ let ( d, s) = squeeze_into ( dim, strides) . unwrap ( ) ;
800
+ * dim = d;
801
+ * strides = s;
802
+ }
803
+
804
+ /// Remove axes with length one, except never removing the last axis.
805
+ ///
806
+ /// Return an error if there are more non-unitary dimensions than can be stored
807
+ /// in `E`. Infallible for dyn dim.
808
+ ///
809
+ /// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
810
+ /// dynamic 0D, the output can be too.
811
+ ///
812
+ /// For const dim, this may instead pad the dimensionality with ones if it needs
813
+ /// to grow to fill the target dimensionality; the dimension is padded in the
814
+ /// start.
815
+ pub ( crate ) fn squeeze_into < D , E > ( dim : & D , strides : & D ) -> Result < ( E , E ) , ShapeError >
816
+ where
817
+ D : Dimension ,
818
+ E : Dimension ,
819
+ {
795
820
debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
796
821
797
822
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
798
823
let mut ndim_new = 0 ;
799
824
for & d in dim. slice ( ) {
800
825
if d != 1 { ndim_new += 1 ; }
801
826
}
802
- ndim_new = Ord :: max ( 1 , ndim_new) ;
803
- let mut new_dim = D :: zeros ( ndim_new) ;
804
- let mut new_strides = D :: zeros ( ndim_new) ;
827
+ let mut fill_ones = 0 ;
828
+ if let Some ( e_ndim) = E :: NDIM {
829
+ if e_ndim < ndim_new {
830
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
831
+ }
832
+ fill_ones = e_ndim - ndim_new;
833
+ ndim_new = e_ndim;
834
+ } else {
835
+ // dynamic-dimensional
836
+ // use minimum one dimension unless input has less than one dim
837
+ if dim. ndim ( ) > 0 && ndim_new == 0 {
838
+ ndim_new = 1 ;
839
+ fill_ones = 1 ;
840
+ }
841
+ }
842
+
843
+ let mut new_dim = E :: zeros ( ndim_new) ;
844
+ let mut new_strides = E :: zeros ( ndim_new) ;
805
845
let mut i = 0 ;
846
+ while i < fill_ones {
847
+ new_dim[ i] = 1 ;
848
+ new_strides[ i] = 1 ;
849
+ i += 1 ;
850
+ }
806
851
for ( & d, & s) in izip ! ( dim. slice( ) , strides. slice( ) ) {
807
852
if d != 1 {
808
853
new_dim[ i] = d;
809
854
new_strides[ i] = s;
810
855
i += 1 ;
811
856
}
812
857
}
813
- if i == 0 {
814
- new_dim[ i] = 1 ;
815
- new_strides[ i] = 1 ;
816
- }
817
- * dim = new_dim;
818
- * strides = new_strides;
858
+ Ok ( ( new_dim, new_strides) )
819
859
}
820
860
821
861
@@ -1220,6 +1260,91 @@ mod test {
1220
1260
assert_eq ! ( s, sans) ;
1221
1261
}
1222
1262
1263
+ #[ test]
1264
+ #[ cfg( feature = "std" ) ]
1265
+ fn test_squeeze_into ( ) {
1266
+ use super :: squeeze_into;
1267
+
1268
+ let dyndim = Dim :: < & [ usize ] > ;
1269
+
1270
+ // squeeze to ixdyn
1271
+ let d = dyndim ( & [ 1 , 2 , 1 , 1 , 3 , 1 ] ) ;
1272
+ let s = dyndim ( & [ !0 , !0 , !0 , 9 , 10 , !0 ] ) ;
1273
+ let dans = dyndim ( & [ 2 , 3 ] ) ;
1274
+ let sans = dyndim ( & [ !0 , 10 ] ) ;
1275
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1276
+ assert_eq ! ( d2, dans) ;
1277
+ assert_eq ! ( s2, sans) ;
1278
+
1279
+ // squeeze to ixdyn does not go below 1D
1280
+ let d = dyndim ( & [ 1 , 1 ] ) ;
1281
+ let s = dyndim ( & [ 3 , 4 ] ) ;
1282
+ let dans = dyndim ( & [ 1 ] ) ;
1283
+ let sans = dyndim ( & [ 1 ] ) ;
1284
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1285
+ assert_eq ! ( d2, dans) ;
1286
+ assert_eq ! ( s2, sans) ;
1287
+
1288
+ let d = Dim ( [ 1 , 1 ] ) ;
1289
+ let s = Dim ( [ 3 , 4 ] ) ;
1290
+ let dans = Dim ( [ 1 ] ) ;
1291
+ let sans = Dim ( [ 1 ] ) ;
1292
+ let ( d2, s2) = squeeze_into :: < _ , Ix1 > ( & d, & s) . unwrap ( ) ;
1293
+ assert_eq ! ( d2, dans) ;
1294
+ assert_eq ! ( s2, sans) ;
1295
+
1296
+ // squeeze to zero-dim
1297
+ let ( d2, s2) = squeeze_into :: < _ , Ix0 > ( & d, & s) . unwrap ( ) ;
1298
+ assert_eq ! ( d2, Ix0 ( ) ) ;
1299
+ assert_eq ! ( s2, Ix0 ( ) ) ;
1300
+
1301
+ let d = Dim ( [ 0 , 1 , 3 , 4 ] ) ;
1302
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1303
+ let dans = Dim ( [ 0 , 3 , 4 ] ) ;
1304
+ let sans = Dim ( [ 2 , 4 , 5 ] ) ;
1305
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1306
+ assert_eq ! ( d2, dans) ;
1307
+ assert_eq ! ( s2, sans) ;
1308
+
1309
+ // Pad with ones
1310
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1311
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1312
+ let dans = Dim ( [ 1 , 0 , 3 ] ) ;
1313
+ let sans = Dim ( [ 1 , 2 , 4 ] ) ;
1314
+ let ( d2, s2) = squeeze_into :: < _ , Ix3 > ( & d, & s) . unwrap ( ) ;
1315
+ assert_eq ! ( d2, dans) ;
1316
+ assert_eq ! ( s2, sans) ;
1317
+
1318
+ // Try something that doesn't fit
1319
+ let d = Dim ( [ 0 , 1 , 3 , 1 ] ) ;
1320
+ let s = Dim ( [ 2 , 3 , 4 , 5 ] ) ;
1321
+ let res = squeeze_into :: < _ , Ix1 > ( & d, & s) ;
1322
+ assert ! ( res. is_err( ) ) ;
1323
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1324
+ assert ! ( res. is_err( ) ) ;
1325
+
1326
+ // Squeeze 0d to 0d
1327
+ let d = Dim ( [ ] ) ;
1328
+ let s = Dim ( [ ] ) ;
1329
+ let res = squeeze_into :: < _ , Ix0 > ( & d, & s) ;
1330
+ assert ! ( res. is_ok( ) ) ;
1331
+ // grow 0d to 2d
1332
+ let dans = Dim ( [ 1 , 1 ] ) ;
1333
+ let sans = Dim ( [ 1 , 1 ] ) ;
1334
+ let ( d2, s2) = squeeze_into :: < _ , Ix2 > ( & d, & s) . unwrap ( ) ;
1335
+ assert_eq ! ( d2, dans) ;
1336
+ assert_eq ! ( s2, sans) ;
1337
+
1338
+ // Squeeze 0d to 0d dynamic
1339
+ let d = dyndim ( & [ ] ) ;
1340
+ let s = dyndim ( & [ ] ) ;
1341
+ let ( d2, s2) = squeeze_into :: < _ , IxDyn > ( & d, & s) . unwrap ( ) ;
1342
+ let dans = d;
1343
+ let sans = s;
1344
+ assert_eq ! ( d2, dans) ;
1345
+ assert_eq ! ( s2, sans) ;
1346
+ }
1347
+
1223
1348
#[ test]
1224
1349
fn test_merge_axes_from_the_back ( ) {
1225
1350
let dyndim = Dim :: < & [ usize ] > ;
0 commit comments