Skip to content

Commit 2e2f83f

Browse files
committed
FEAT: Add dimension::squeeze to remove dimensions with len == 1
1 parent 3070113 commit 2e2f83f

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

src/dimension/mod.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,12 +787,46 @@ where
787787
}
788788
}
789789

790+
/// Remove axes with length one, except never removing the last axis.
791+
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
792+
where
793+
D: Dimension,
794+
{
795+
if let Some(_) = D::NDIM {
796+
return;
797+
}
798+
debug_assert_eq!(dim.ndim(), strides.ndim());
799+
800+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
801+
let mut ndim_new = 0;
802+
for &d in dim.slice() {
803+
if d != 1 { ndim_new += 1; }
804+
}
805+
ndim_new = Ord::max(1, ndim_new);
806+
let mut new_dim = D::zeros(ndim_new);
807+
let mut new_strides = D::zeros(ndim_new);
808+
let mut i = 0;
809+
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
810+
if d != 1 {
811+
new_dim[i] = d;
812+
new_strides[i] = s;
813+
i += 1;
814+
}
815+
}
816+
if i == 0 {
817+
new_dim[i] = 1;
818+
new_strides[i] = 1;
819+
}
820+
*dim = new_dim;
821+
*strides = new_strides;
822+
}
823+
790824
#[cfg(test)]
791825
mod test {
792826
use super::{
793827
arith_seq_intersect, can_index_slice, can_index_slice_not_custom, extended_gcd,
794828
max_abs_offset_check_overflow, slice_min_max, slices_intersect,
795-
solve_linear_diophantine_eq, IntoDimension,
829+
solve_linear_diophantine_eq, IntoDimension, squeeze,
796830
};
797831
use crate::error::{from_kind, ErrorKind};
798832
use crate::slice::Slice;
@@ -1127,4 +1161,34 @@ mod test {
11271161
s![.., 3..;6, NewAxis]
11281162
));
11291163
}
1164+
1165+
#[test]
1166+
#[cfg(feature = "std")]
1167+
fn test_squeeze() {
1168+
let dyndim = Dim::<&[usize]>;
1169+
1170+
let mut d = dyndim(&[1, 2, 1, 1, 3, 1]);
1171+
let mut s = dyndim(&[!0, !0, !0, 9, 10, !0]);
1172+
let dans = dyndim(&[2, 3]);
1173+
let sans = dyndim(&[!0, 10]);
1174+
squeeze(&mut d, &mut s);
1175+
assert_eq!(d, dans);
1176+
assert_eq!(s, sans);
1177+
1178+
let mut d = dyndim(&[1, 1]);
1179+
let mut s = dyndim(&[3, 4]);
1180+
let dans = dyndim(&[1]);
1181+
let sans = dyndim(&[1]);
1182+
squeeze(&mut d, &mut s);
1183+
assert_eq!(d, dans);
1184+
assert_eq!(s, sans);
1185+
1186+
let mut d = dyndim(&[0, 1, 3, 4]);
1187+
let mut s = dyndim(&[2, 3, 4, 5]);
1188+
let dans = dyndim(&[0, 3, 4]);
1189+
let sans = dyndim(&[2, 4, 5]);
1190+
squeeze(&mut d, &mut s);
1191+
assert_eq!(d, dans);
1192+
assert_eq!(s, sans);
1193+
}
11301194
}

0 commit comments

Comments
 (0)