Skip to content

Commit 546b69c

Browse files
jturner314bluss
authored andcommitted
Add support for inserting new axes while slicing
1 parent 6a16b88 commit 546b69c

File tree

7 files changed

+173
-69
lines changed

7 files changed

+173
-69
lines changed

src/dimension/mod.rs

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,11 @@ pub fn slices_intersect<D: Dimension>(
601601
indices2: &impl CanSlice<D>,
602602
) -> bool {
603603
debug_assert_eq!(indices1.in_ndim(), indices2.in_ndim());
604-
for (&axis_len, &si1, &si2) in izip!(dim.slice(), indices1.as_ref(), indices2.as_ref()) {
604+
for (&axis_len, &si1, &si2) in izip!(
605+
dim.slice(),
606+
indices1.as_ref().iter().filter(|si| !si.is_new_axis()),
607+
indices2.as_ref().iter().filter(|si| !si.is_new_axis()),
608+
) {
605609
// The slices do not intersect iff any pair of `AxisSliceInfo` does not intersect.
606610
match (si1, si2) {
607611
(
@@ -649,6 +653,7 @@ pub fn slices_intersect<D: Dimension>(
649653
return false;
650654
}
651655
}
656+
(AxisSliceInfo::NewAxis, _) | (_, AxisSliceInfo::NewAxis) => unreachable!(),
652657
}
653658
}
654659
true
@@ -720,7 +725,7 @@ mod test {
720725
};
721726
use crate::error::{from_kind, ErrorKind};
722727
use crate::slice::Slice;
723-
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn};
728+
use crate::{Dim, Dimension, Ix0, Ix1, Ix2, Ix3, IxDyn, NewAxis};
724729
use num_integer::gcd;
725730
use quickcheck::{quickcheck, TestResult};
726731

@@ -994,17 +999,45 @@ mod test {
994999

9951000
#[test]
9961001
fn slices_intersect_true() {
997-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..], s![.., ..]));
998-
assert!(slices_intersect(&Dim([4, 5]), s![0, ..], s![0, ..]));
999-
assert!(slices_intersect(&Dim([4, 5]), s![..;2, ..], s![..;3, ..]));
1000-
assert!(slices_intersect(&Dim([4, 5]), s![.., ..;2], s![.., 1..;3]));
1002+
assert!(slices_intersect(
1003+
&Dim([4, 5]),
1004+
s![NewAxis, .., NewAxis, ..],
1005+
s![.., NewAxis, .., NewAxis]
1006+
));
1007+
assert!(slices_intersect(
1008+
&Dim([4, 5]),
1009+
s![NewAxis, 0, ..],
1010+
s![0, ..]
1011+
));
1012+
assert!(slices_intersect(
1013+
&Dim([4, 5]),
1014+
s![..;2, ..],
1015+
s![..;3, NewAxis, ..]
1016+
));
1017+
assert!(slices_intersect(
1018+
&Dim([4, 5]),
1019+
s![.., ..;2],
1020+
s![.., 1..;3, NewAxis]
1021+
));
10011022
assert!(slices_intersect(&Dim([4, 10]), s![.., ..;9], s![.., 3..;6]));
10021023
}
10031024

10041025
#[test]
10051026
fn slices_intersect_false() {
1006-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;2, ..]));
1007-
assert!(!slices_intersect(&Dim([4, 5]), s![..;2, ..], s![1..;3, ..]));
1008-
assert!(!slices_intersect(&Dim([4, 5]), s![.., ..;9], s![.., 3..;6]));
1027+
assert!(!slices_intersect(
1028+
&Dim([4, 5]),
1029+
s![..;2, ..],
1030+
s![NewAxis, 1..;2, ..]
1031+
));
1032+
assert!(!slices_intersect(
1033+
&Dim([4, 5]),
1034+
s![..;2, NewAxis, ..],
1035+
s![1..;3, ..]
1036+
));
1037+
assert!(!slices_intersect(
1038+
&Dim([4, 5]),
1039+
s![.., ..;9],
1040+
s![.., 3..;6, NewAxis]
1041+
));
10091042
}
10101043
}

src/doc/ndarray_for_numpy_users/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@
532532
//! `a[:] = b` | [`a.assign(&b)`][.assign()] | copy the data from array `b` into array `a`
533533
//! `np.concatenate((a,b), axis=1)` | [`concatenate![Axis(1), a, b]`][concatenate!] or [`concatenate(Axis(1), &[a.view(), b.view()])`][concatenate()] | concatenate arrays `a` and `b` along axis 1
534534
//! `np.stack((a,b), axis=1)` | [`stack![Axis(1), a, b]`][stack!] or [`stack(Axis(1), vec![a.view(), b.view()])`][stack()] | stack arrays `a` and `b` along axis 1
535-
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.insert_axis(Axis(1))`][.insert_axis()] | create an array from `a`, inserting a new axis 1
535+
//! `a[:,np.newaxis]` or `np.expand_dims(a, axis=1)` | [`a.slice(s![.., NewAxis])`][.slice()] or [`a.insert_axis(Axis(1))`][.insert_axis()] | create an view of 1-D array `a`, inserting a new axis 1
536536
//! `a.transpose()` or `a.T` | [`a.t()`][.t()] or [`a.reversed_axes()`][.reversed_axes()] | transpose of array `a` (view for `.t()` or by-move for `.reversed_axes()`)
537537
//! `np.diag(a)` | [`a.diag()`][.diag()] | view the diagonal of `a`
538538
//! `a.flatten()` | [`use std::iter::FromIterator; Array::from_iter(a.iter().cloned())`][::from_iter()] | create a 1-D array by flattening `a`

src/impl_methods.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ where
437437
// Skip the old axis since it should be removed.
438438
old_axis += 1;
439439
}
440+
AxisSliceInfo::NewAxis => {
441+
// Set the dim and stride of the new axis.
442+
new_dim[new_axis] = 1;
443+
new_strides[new_axis] = 0;
444+
new_axis += 1;
445+
}
440446
});
441447
debug_assert_eq!(old_axis, self.ndim());
442448
debug_assert_eq!(new_axis, out_ndim);
@@ -450,6 +456,8 @@ where
450456

451457
/// Slice the array in place without changing the number of dimensions.
452458
///
459+
/// Note that `NewAxis` elements in `info` are ignored.
460+
///
453461
/// See [*Slicing*](#slicing) for full documentation.
454462
///
455463
/// **Panics** if an index is out of bounds or step size is zero.<br>
@@ -463,18 +471,20 @@ where
463471
self.ndim(),
464472
"The input dimension of `info` must match the array to be sliced.",
465473
);
466-
info.as_ref()
467-
.iter()
468-
.enumerate()
469-
.for_each(|(axis, &ax_info)| match ax_info {
474+
let mut axis = 0;
475+
info.as_ref().iter().for_each(|&ax_info| match ax_info {
470476
AxisSliceInfo::Slice { start, end, step } => {
471-
self.slice_axis_inplace(Axis(axis), Slice { start, end, step })
477+
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
478+
axis += 1;
472479
}
473480
AxisSliceInfo::Index(index) => {
474481
let i_usize = abs_index(self.len_of(Axis(axis)), index);
475-
self.collapse_axis(Axis(axis), i_usize)
482+
self.collapse_axis(Axis(axis), i_usize);
483+
axis += 1;
476484
}
485+
AxisSliceInfo::NewAxis => {}
477486
});
487+
debug_assert_eq!(axis, self.ndim());
478488
}
479489

480490
/// Return a view of the array, sliced along the specified axis.

src/lib.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub use crate::dimension::IxDynImpl;
141141
pub use crate::dimension::NdIndex;
142142
pub use crate::error::{ErrorKind, ShapeError};
143143
pub use crate::indexes::{indices, indices_of};
144-
pub use crate::slice::{AxisSliceInfo, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim};
144+
pub use crate::slice::{AxisSliceInfo, NewAxis, Slice, SliceInfo, SliceNextInDim, SliceNextOutDim};
145145

146146
use crate::iterators::Baseiter;
147147
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut, Lanes};
@@ -496,14 +496,16 @@ pub type Ixs = isize;
496496
///
497497
/// If a range is used, the axis is preserved. If an index is used, that index
498498
/// is selected and the axis is removed; this selects a subview. See
499-
/// [*Subviews*](#subviews) for more information about subviews. Note that
500-
/// [`.slice_collapse()`] behaves like [`.collapse_axis()`] by preserving
501-
/// the number of dimensions.
499+
/// [*Subviews*](#subviews) for more information about subviews. If a
500+
/// [`NewAxis`] instance is used, a new axis is inserted. Note that
501+
/// [`.slice_collapse()`] ignores `NewAxis` elements and behaves like
502+
/// [`.collapse_axis()`] by preserving the number of dimensions.
502503
///
503504
/// [`.slice()`]: #method.slice
504505
/// [`.slice_mut()`]: #method.slice_mut
505506
/// [`.slice_move()`]: #method.slice_move
506507
/// [`.slice_collapse()`]: #method.slice_collapse
508+
/// [`NewAxis`]: struct.NewAxis.html
507509
///
508510
/// When slicing arrays with generic dimensionality, creating an instance of
509511
/// [`&SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`]
@@ -526,7 +528,7 @@ pub type Ixs = isize;
526528
/// [`.multi_slice_move()`]: type.ArrayViewMut.html#method.multi_slice_move
527529
///
528530
/// ```
529-
/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, Slice};
531+
/// use ndarray::{arr2, arr3, s, ArrayBase, DataMut, Dimension, NewAxis, Slice};
530532
///
531533
/// // 2 submatrices of 2 rows with 3 elements per row, means a shape of `[2, 2, 3]`.
532534
///
@@ -561,16 +563,17 @@ pub type Ixs = isize;
561563
/// assert_eq!(d, e);
562564
/// assert_eq!(d.shape(), &[2, 1, 3]);
563565
///
564-
/// // Let’s create a slice while selecting a subview with
566+
/// // Let’s create a slice while selecting a subview and inserting a new axis with
565567
/// //
566568
/// // - Both submatrices of the greatest dimension: `..`
567569
/// // - The last row in each submatrix, removing that axis: `-1`
568570
/// // - Row elements in reverse order: `..;-1`
569-
/// let f = a.slice(s![.., -1, ..;-1]);
570-
/// let g = arr2(&[[ 6, 5, 4],
571-
/// [12, 11, 10]]);
571+
/// // - A new axis at the end.
572+
/// let f = a.slice(s![.., -1, ..;-1, NewAxis]);
573+
/// let g = arr3(&[[ [6], [5], [4]],
574+
/// [[12], [11], [10]]]);
572575
/// assert_eq!(f, g);
573-
/// assert_eq!(f.shape(), &[2, 3]);
576+
/// assert_eq!(f.shape(), &[2, 3, 1]);
574577
///
575578
/// // Let's take two disjoint, mutable slices of a matrix with
576579
/// //

src/prelude.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ pub use crate::{array, azip, s};
4949
#[doc(no_inline)]
5050
pub use crate::ShapeBuilder;
5151

52+
#[doc(no_inline)]
53+
pub use crate::NewAxis;
54+
5255
#[doc(no_inline)]
5356
pub use crate::AsArray;
5457

0 commit comments

Comments
 (0)