Skip to content

Commit 992e3a2

Browse files
committed
Add more tests for AxisIter/Mut
1 parent be3b74f commit 992e3a2

File tree

1 file changed

+146
-1
lines changed

1 file changed

+146
-1
lines changed

tests/iterators.rs

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ extern crate ndarray;
99

1010
use ndarray::prelude::*;
1111
use ndarray::Ix;
12-
use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice};
12+
use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice, Zip};
1313

1414
use itertools::assert_equal;
1515
use itertools::{enumerate, rev};
@@ -262,6 +262,68 @@ fn axis_iter() {
262262
);
263263
}
264264

265+
#[test]
266+
fn axis_iter_split_at() {
267+
let a = Array::from_iter(0..5);
268+
let iter = a.axis_iter(Axis(0));
269+
let all: Vec<_> = iter.clone().collect();
270+
for mid in 0..=all.len() {
271+
let (left, right) = iter.clone().split_at(mid);
272+
assert_eq!(&all[..mid], &left.collect::<Vec<_>>()[..]);
273+
assert_eq!(&all[mid..], &right.collect::<Vec<_>>()[..]);
274+
}
275+
}
276+
277+
#[test]
278+
fn axis_iter_split_at_partially_consumed() {
279+
let a = Array::from_iter(0..5);
280+
let mut iter = a.axis_iter(Axis(0));
281+
while iter.next().is_some() {
282+
let remaining: Vec<_> = iter.clone().collect();
283+
for mid in 0..=remaining.len() {
284+
let (left, right) = iter.clone().split_at(mid);
285+
assert_eq!(&remaining[..mid], &left.collect::<Vec<_>>()[..]);
286+
assert_eq!(&remaining[mid..], &right.collect::<Vec<_>>()[..]);
287+
}
288+
}
289+
}
290+
291+
#[test]
292+
fn axis_iter_zip() {
293+
let a = Array::from_iter(0..5);
294+
let iter = a.axis_iter(Axis(0));
295+
let mut b = Array::zeros(5);
296+
Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
297+
assert_eq!(a, b);
298+
}
299+
300+
#[test]
301+
fn axis_iter_zip_partially_consumed() {
302+
let a = Array::from_iter(0..5);
303+
let mut iter = a.axis_iter(Axis(0));
304+
let mut consumed = 0;
305+
while iter.next().is_some() {
306+
consumed += 1;
307+
let mut b = Array::zeros(a.len() - consumed);
308+
Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]);
309+
assert_eq!(a.slice(s![consumed..]), b);
310+
}
311+
}
312+
313+
#[test]
314+
fn axis_iter_zip_partially_consumed_discontiguous() {
315+
let a = Array::from_iter(0..5);
316+
let mut iter = a.axis_iter(Axis(0));
317+
let mut consumed = 0;
318+
while iter.next().is_some() {
319+
consumed += 1;
320+
let mut b = Array::zeros((a.len() - consumed) * 2);
321+
b.slice_collapse(s![..;2]);
322+
Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]);
323+
assert_eq!(a.slice(s![consumed..]), b);
324+
}
325+
}
326+
265327
#[test]
266328
fn outer_iter_corner_cases() {
267329
let a2 = ArcArray::<i32, _>::zeros((0, 3));
@@ -366,6 +428,89 @@ fn axis_chunks_iter() {
366428
assert_equal(it, vec![a.view()]);
367429
}
368430

431+
#[test]
432+
fn axis_iter_mut_split_at() {
433+
let mut a = Array::from_iter(0..5);
434+
let mut a_clone = a.clone();
435+
let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect();
436+
for mid in 0..=all.len() {
437+
let (left, right) = a.axis_iter_mut(Axis(0)).split_at(mid);
438+
assert_eq!(&all[..mid], &left.collect::<Vec<_>>()[..]);
439+
assert_eq!(&all[mid..], &right.collect::<Vec<_>>()[..]);
440+
}
441+
}
442+
443+
#[test]
444+
fn axis_iter_mut_split_at_partially_consumed() {
445+
let mut a = Array::from_iter(0..5);
446+
for consumed in 1..=a.len() {
447+
for mid in 0..=(a.len() - consumed) {
448+
let mut a_clone = a.clone();
449+
let remaining: Vec<_> = {
450+
let mut iter = a_clone.axis_iter_mut(Axis(0));
451+
for _ in 0..consumed {
452+
iter.next();
453+
}
454+
iter.collect()
455+
};
456+
let (left, right) = {
457+
let mut iter = a.axis_iter_mut(Axis(0));
458+
for _ in 0..consumed {
459+
iter.next();
460+
}
461+
iter.split_at(mid)
462+
};
463+
assert_eq!(&remaining[..mid], &left.collect::<Vec<_>>()[..]);
464+
assert_eq!(&remaining[mid..], &right.collect::<Vec<_>>()[..]);
465+
}
466+
}
467+
}
468+
469+
#[test]
470+
fn axis_iter_mut_zip() {
471+
let orig = Array::from_iter(0..5);
472+
let mut cloned = orig.clone();
473+
let iter = cloned.axis_iter_mut(Axis(0));
474+
let mut b = Array::zeros(5);
475+
Zip::from(&mut b).and(iter).apply(|b, mut a| {
476+
a[()] += 1;
477+
*b = a[()];
478+
});
479+
assert_eq!(cloned, b);
480+
assert_eq!(cloned, orig + 1);
481+
}
482+
483+
#[test]
484+
fn axis_iter_mut_zip_partially_consumed() {
485+
let mut a = Array::from_iter(0..5);
486+
for consumed in 1..=a.len() {
487+
let remaining = a.len() - consumed;
488+
let mut iter = a.axis_iter_mut(Axis(0));
489+
for _ in 0..consumed {
490+
iter.next();
491+
}
492+
let mut b = Array::zeros(remaining);
493+
Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
494+
assert_eq!(a.slice(s![consumed..]), b);
495+
}
496+
}
497+
498+
#[test]
499+
fn axis_iter_mut_zip_partially_consumed_discontiguous() {
500+
let mut a = Array::from_iter(0..5);
501+
for consumed in 1..=a.len() {
502+
let remaining = a.len() - consumed;
503+
let mut iter = a.axis_iter_mut(Axis(0));
504+
for _ in 0..consumed {
505+
iter.next();
506+
}
507+
let mut b = Array::zeros(remaining * 2);
508+
b.slice_collapse(s![..;2]);
509+
Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
510+
assert_eq!(a.slice(s![consumed..]), b);
511+
}
512+
}
513+
369514
#[test]
370515
fn axis_chunks_iter_corner_cases() {
371516
// examples provided by @bluss in PR #65

0 commit comments

Comments
 (0)