Skip to content

Commit 25a7bb0

Browse files
jturner314bluss
authored andcommitted
Make slice_collapse return Err(_) for NewAxis
1 parent 815e708 commit 25a7bb0

File tree

5 files changed

+52
-42
lines changed

5 files changed

+52
-42
lines changed

examples/axis_ops.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ fn main() {
5151
}
5252
a.swap_axes(0, 1);
5353
a.swap_axes(0, 2);
54-
a.slice_collapse(s![.., ..;-1, ..]);
54+
a.slice_collapse(s![.., ..;-1, ..]).unwrap();
5555
regularize(&mut a).ok();
5656

5757
let mut b = Array::<u8, _>::zeros((2, 3, 4));
@@ -68,6 +68,6 @@ fn main() {
6868
for (i, elt) in (0..).zip(&mut a) {
6969
*elt = i;
7070
}
71-
a.slice_collapse(s![..;-1, ..;2, ..]);
71+
a.slice_collapse(s![..;-1, ..;2, ..]).unwrap();
7272
regularize(&mut a).ok();
7373
}

serialization-tests/tests/serialize.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ fn serial_many_dim_serde() {
4646
{
4747
// Test a sliced array.
4848
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
49-
a.slice_collapse(s![..;-1, .., .., ..2]);
49+
a.slice_collapse(s![..;-1, .., .., ..2]).unwrap();
5050
let serial = serde_json::to_string(&a).unwrap();
5151
println!("Encode {:?} => {:?}", a, serial);
5252
let res = serde_json::from_str::<ArcArray<f32, _>>(&serial);
@@ -156,7 +156,7 @@ fn serial_many_dim_serde_msgpack() {
156156
{
157157
// Test a sliced array.
158158
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
159-
a.slice_collapse(s![..;-1, .., .., ..2]);
159+
a.slice_collapse(s![..;-1, .., .., ..2]).unwrap();
160160

161161
let mut buf = Vec::new();
162162
serde::Serialize::serialize(&a, &mut rmp_serde::Serializer::new(&mut buf))
@@ -209,7 +209,7 @@ fn serial_many_dim_ron() {
209209
{
210210
// Test a sliced array.
211211
let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4));
212-
a.slice_collapse(s![..;-1, .., .., ..2]);
212+
a.slice_collapse(s![..;-1, .., .., ..2]).unwrap();
213213

214214
let a_s = ron_serialize(&a).unwrap();
215215

src/impl_methods.rs

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,15 @@ where
461461

462462
/// Slice the array in place without changing the number of dimensions.
463463
///
464-
/// Note that `NewAxis` elements in `info` are ignored.
464+
/// If there are any `NewAxis` elements in `info`, slicing is performed
465+
/// using the other elements in `info` (i.e. ignoring the `NewAxis`
466+
/// elements), and `Err(_)` is returned to notify the caller.
465467
///
466468
/// See [*Slicing*](#slicing) for full documentation.
467469
///
468470
/// **Panics** if an index is out of bounds or step size is zero.<br>
469471
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
470-
pub fn slice_collapse<I>(&mut self, info: &I)
472+
pub fn slice_collapse<I>(&mut self, info: &I) -> Result<(), ShapeError>
471473
where
472474
I: CanSlice<D> + ?Sized,
473475
{
@@ -476,20 +478,28 @@ where
476478
self.ndim(),
477479
"The input dimension of `info` must match the array to be sliced.",
478480
);
481+
let mut new_axis_in_info = false;
479482
let mut axis = 0;
480483
info.as_ref().iter().for_each(|&ax_info| match ax_info {
481-
AxisSliceInfo::Slice { start, end, step } => {
482-
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
483-
axis += 1;
484-
}
485-
AxisSliceInfo::Index(index) => {
486-
let i_usize = abs_index(self.len_of(Axis(axis)), index);
487-
self.collapse_axis(Axis(axis), i_usize);
488-
axis += 1;
489-
}
490-
AxisSliceInfo::NewAxis => {}
491-
});
484+
AxisSliceInfo::Slice { start, end, step } => {
485+
self.slice_axis_inplace(Axis(axis), Slice { start, end, step });
486+
axis += 1;
487+
}
488+
AxisSliceInfo::Index(index) => {
489+
let i_usize = abs_index(self.len_of(Axis(axis)), index);
490+
self.collapse_axis(Axis(axis), i_usize);
491+
axis += 1;
492+
}
493+
AxisSliceInfo::NewAxis => {
494+
new_axis_in_info = true;
495+
}
496+
});
492497
debug_assert_eq!(axis, self.ndim());
498+
if new_axis_in_info {
499+
Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
500+
} else {
501+
Ok(())
502+
}
493503
}
494504

495505
/// Return a view of the array, sliced along the specified axis.

tests/array.rs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ fn test_slice_ix0() {
103103
#[test]
104104
fn test_slice_edge_cases() {
105105
let mut arr = Array3::<u8>::zeros((3, 4, 5));
106-
arr.slice_collapse(s![0..0;-1, .., ..]);
106+
arr.slice_collapse(s![0..0;-1, .., ..]).unwrap();
107107
assert_eq!(arr.shape(), &[0, 4, 5]);
108108
let mut arr = Array2::<u8>::from_shape_vec((1, 1).strides((10, 1)), vec![5]).unwrap();
109-
arr.slice_collapse(s![1..1, ..]);
109+
arr.slice_collapse(s![1..1, ..]).unwrap();
110110
assert_eq!(arr.shape(), &[0, 1]);
111111
}
112112

@@ -201,7 +201,7 @@ fn test_slice_array_fixed() {
201201
arr.slice(info);
202202
arr.slice_mut(info);
203203
arr.view().slice_move(info);
204-
arr.view().slice_collapse(info);
204+
arr.view().slice_collapse(info).unwrap_err();
205205
}
206206

207207
#[test]
@@ -211,7 +211,7 @@ fn test_slice_dyninput_array_fixed() {
211211
arr.slice(info);
212212
arr.slice_mut(info);
213213
arr.view().slice_move(info);
214-
arr.view().slice_collapse(info);
214+
arr.view().slice_collapse(info).unwrap_err();
215215
}
216216

217217
#[test]
@@ -227,7 +227,7 @@ fn test_slice_array_dyn() {
227227
arr.slice(info);
228228
arr.slice_mut(info);
229229
arr.view().slice_move(info);
230-
arr.view().slice_collapse(info);
230+
arr.view().slice_collapse(info).unwrap_err();
231231
}
232232

233233
#[test]
@@ -243,7 +243,7 @@ fn test_slice_dyninput_array_dyn() {
243243
arr.slice(info);
244244
arr.slice_mut(info);
245245
arr.view().slice_move(info);
246-
arr.view().slice_collapse(info);
246+
arr.view().slice_collapse(info).unwrap_err();
247247
}
248248

249249
#[test]
@@ -259,7 +259,7 @@ fn test_slice_dyninput_vec_fixed() {
259259
arr.slice(info);
260260
arr.slice_mut(info);
261261
arr.view().slice_move(info);
262-
arr.view().slice_collapse(info);
262+
arr.view().slice_collapse(info).unwrap_err();
263263
}
264264

265265
#[test]
@@ -275,7 +275,7 @@ fn test_slice_dyninput_vec_dyn() {
275275
arr.slice(info);
276276
arr.slice_mut(info);
277277
arr.view().slice_move(info);
278-
arr.view().slice_collapse(info);
278+
arr.view().slice_collapse(info).unwrap_err();
279279
}
280280

281281
#[test]
@@ -324,31 +324,31 @@ fn test_slice_collapse_with_indices() {
324324

325325
{
326326
let mut vi = arr.view();
327-
vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]);
327+
vi.slice_collapse(s![NewAxis, 1.., 2, ..;2]).unwrap_err();
328328
assert_eq!(vi.shape(), &[2, 1, 2]);
329329
assert!(vi
330330
.iter()
331331
.zip(arr.slice(s![1.., 2..3, ..;2]).iter())
332332
.all(|(a, b)| a == b));
333333

334334
let mut vi = arr.view();
335-
vi.slice_collapse(s![1, NewAxis, 2, ..;2]);
335+
vi.slice_collapse(s![1, NewAxis, 2, ..;2]).unwrap_err();
336336
assert_eq!(vi.shape(), &[1, 1, 2]);
337337
assert!(vi
338338
.iter()
339339
.zip(arr.slice(s![1..2, 2..3, ..;2]).iter())
340340
.all(|(a, b)| a == b));
341341

342342
let mut vi = arr.view();
343-
vi.slice_collapse(s![1, 2, NewAxis, 3]);
343+
vi.slice_collapse(s![1, 2, 3]).unwrap();
344344
assert_eq!(vi.shape(), &[1, 1, 1]);
345345
assert_eq!(vi, Array3::from_elem((1, 1, 1), arr[(1, 2, 3)]));
346346
}
347347

348348
// Do it to the ArcArray itself
349349
let elem = arr[(1, 2, 3)];
350350
let mut vi = arr;
351-
vi.slice_collapse(s![1, 2, 3, NewAxis]);
351+
vi.slice_collapse(s![1, 2, 3, NewAxis]).unwrap_err();
352352
assert_eq!(vi.shape(), &[1, 1, 1]);
353353
assert_eq!(vi, Array3::from_elem((1, 1, 1), elem));
354354
}
@@ -567,7 +567,7 @@ fn test_cow() {
567567
assert_eq!(n[[0, 1]], 0);
568568
assert_eq!(n.get((0, 1)), Some(&0));
569569
let mut rev = mat.reshape(4);
570-
rev.slice_collapse(s![..;-1]);
570+
rev.slice_collapse(s![..;-1]).unwrap();
571571
assert_eq!(rev[0], 4);
572572
assert_eq!(rev[1], 3);
573573
assert_eq!(rev[2], 2);
@@ -591,7 +591,7 @@ fn test_cow_shrink() {
591591
// mutation shrinks the array and gives it different strides
592592
//
593593
let mut mat = ArcArray::zeros((2, 3));
594-
//mat.slice_collapse(s![.., ..;2]);
594+
//mat.slice_collapse(s![.., ..;2]).unwrap();
595595
mat[[0, 0]] = 1;
596596
let n = mat.clone();
597597
mat[[0, 1]] = 2;
@@ -606,7 +606,7 @@ fn test_cow_shrink() {
606606
assert_eq!(n.get((0, 1)), Some(&0));
607607
// small has non-C strides this way
608608
let mut small = mat.reshape(6);
609-
small.slice_collapse(s![4..;-1]);
609+
small.slice_collapse(s![4..;-1]).unwrap();
610610
assert_eq!(small[0], 6);
611611
assert_eq!(small[1], 5);
612612
let before = small.clone();
@@ -886,7 +886,7 @@ fn assign() {
886886
let mut a = arr2(&[[1, 2], [3, 4]]);
887887
{
888888
let mut v = a.view_mut();
889-
v.slice_collapse(s![..1, ..]);
889+
v.slice_collapse(s![..1, ..]).unwrap();
890890
v.fill(0);
891891
}
892892
assert_eq!(a, arr2(&[[0, 0], [3, 4]]));
@@ -1093,7 +1093,7 @@ fn owned_array_discontiguous_drop() {
10931093
.collect();
10941094
let mut a = Array::from_shape_vec((2, 6), v).unwrap();
10951095
// discontiguous and non-zero offset
1096-
a.slice_collapse(s![.., 1..]);
1096+
a.slice_collapse(s![.., 1..]).unwrap();
10971097
}
10981098
// each item was dropped exactly once
10991099
itertools::assert_equal(set.borrow().iter().cloned(), 0..12);
@@ -1792,7 +1792,7 @@ fn to_owned_memory_order() {
17921792
#[test]
17931793
fn to_owned_neg_stride() {
17941794
let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]);
1795-
c.slice_collapse(s![.., ..;-1]);
1795+
c.slice_collapse(s![.., ..;-1]).unwrap();
17961796
let co = c.to_owned();
17971797
assert_eq!(c, co);
17981798
assert_eq!(c.strides(), co.strides());
@@ -1801,7 +1801,7 @@ fn to_owned_neg_stride() {
18011801
#[test]
18021802
fn discontiguous_owned_to_owned() {
18031803
let mut c = arr2(&[[1, 2, 3], [4, 5, 6]]);
1804-
c.slice_collapse(s![.., ..;2]);
1804+
c.slice_collapse(s![.., ..;2]).unwrap();
18051805

18061806
let co = c.to_owned();
18071807
assert_eq!(c.strides(), &[3, 2]);
@@ -2062,10 +2062,10 @@ fn test_accumulate_axis_inplace_nonstandard_layout() {
20622062
fn test_to_vec() {
20632063
let mut a = arr2(&[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
20642064

2065-
a.slice_collapse(s![..;-1, ..]);
2065+
a.slice_collapse(s![..;-1, ..]).unwrap();
20662066
assert_eq!(a.row(3).to_vec(), vec![1, 2, 3]);
20672067
assert_eq!(a.column(2).to_vec(), vec![12, 9, 6, 3]);
2068-
a.slice_collapse(s![.., ..;-1]);
2068+
a.slice_collapse(s![.., ..;-1]).unwrap();
20692069
assert_eq!(a.row(3).to_vec(), vec![3, 2, 1]);
20702070
}
20712071

@@ -2081,7 +2081,7 @@ fn test_array_clone_unalias() {
20812081
#[test]
20822082
fn test_array_clone_same_view() {
20832083
let mut a = Array::from_iter(0..9).into_shape((3, 3)).unwrap();
2084-
a.slice_collapse(s![..;-1, ..;-1]);
2084+
a.slice_collapse(s![..;-1, ..;-1]).unwrap();
20852085
let b = a.clone();
20862086
assert_eq!(a, b);
20872087
}

tests/iterators.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ fn axis_iter_zip_partially_consumed_discontiguous() {
332332
while iter.next().is_some() {
333333
consumed += 1;
334334
let mut b = Array::zeros((a.len() - consumed) * 2);
335-
b.slice_collapse(s![..;2]);
335+
b.slice_collapse(s![..;2]).unwrap();
336336
Zip::from(&mut b).and(iter.clone()).for_each(|b, a| *b = a[()]);
337337
assert_eq!(a.slice(s![consumed..]), b);
338338
}
@@ -519,7 +519,7 @@ fn axis_iter_mut_zip_partially_consumed_discontiguous() {
519519
iter.next();
520520
}
521521
let mut b = Array::zeros(remaining * 2);
522-
b.slice_collapse(s![..;2]);
522+
b.slice_collapse(s![..;2]).unwrap();
523523
Zip::from(&mut b).and(iter).for_each(|b, a| *b = a[()]);
524524
assert_eq!(a.slice(s![consumed..]), b);
525525
}

0 commit comments

Comments
 (0)