@@ -9,7 +9,7 @@ extern crate ndarray;
9
9
10
10
use ndarray:: prelude:: * ;
11
11
use ndarray:: Ix ;
12
- use ndarray:: { arr2, arr3, aview1, indices, s, Axis , Data , Dimension , Slice } ;
12
+ use ndarray:: { arr2, arr3, aview1, indices, s, Axis , Data , Dimension , Slice , Zip } ;
13
13
14
14
use itertools:: assert_equal;
15
15
use itertools:: { enumerate, rev} ;
@@ -262,6 +262,68 @@ fn axis_iter() {
262
262
) ;
263
263
}
264
264
265
+ #[ test]
266
+ fn axis_iter_split_at ( ) {
267
+ let a = Array :: from_iter ( 0 ..5 ) ;
268
+ let iter = a. axis_iter ( Axis ( 0 ) ) ;
269
+ let all: Vec < _ > = iter. clone ( ) . collect ( ) ;
270
+ for mid in 0 ..=all. len ( ) {
271
+ let ( left, right) = iter. clone ( ) . split_at ( mid) ;
272
+ assert_eq ! ( & all[ ..mid] , & left. collect:: <Vec <_>>( ) [ ..] ) ;
273
+ assert_eq ! ( & all[ mid..] , & right. collect:: <Vec <_>>( ) [ ..] ) ;
274
+ }
275
+ }
276
+
277
+ #[ test]
278
+ fn axis_iter_split_at_partially_consumed ( ) {
279
+ let a = Array :: from_iter ( 0 ..5 ) ;
280
+ let mut iter = a. axis_iter ( Axis ( 0 ) ) ;
281
+ while iter. next ( ) . is_some ( ) {
282
+ let remaining: Vec < _ > = iter. clone ( ) . collect ( ) ;
283
+ for mid in 0 ..=remaining. len ( ) {
284
+ let ( left, right) = iter. clone ( ) . split_at ( mid) ;
285
+ assert_eq ! ( & remaining[ ..mid] , & left. collect:: <Vec <_>>( ) [ ..] ) ;
286
+ assert_eq ! ( & remaining[ mid..] , & right. collect:: <Vec <_>>( ) [ ..] ) ;
287
+ }
288
+ }
289
+ }
290
+
291
+ #[ test]
292
+ fn axis_iter_zip ( ) {
293
+ let a = Array :: from_iter ( 0 ..5 ) ;
294
+ let iter = a. axis_iter ( Axis ( 0 ) ) ;
295
+ let mut b = Array :: zeros ( 5 ) ;
296
+ Zip :: from ( & mut b) . and ( iter) . apply ( |b, a| * b = a[ ( ) ] ) ;
297
+ assert_eq ! ( a, b) ;
298
+ }
299
+
300
+ #[ test]
301
+ fn axis_iter_zip_partially_consumed ( ) {
302
+ let a = Array :: from_iter ( 0 ..5 ) ;
303
+ let mut iter = a. axis_iter ( Axis ( 0 ) ) ;
304
+ let mut consumed = 0 ;
305
+ while iter. next ( ) . is_some ( ) {
306
+ consumed += 1 ;
307
+ let mut b = Array :: zeros ( a. len ( ) - consumed) ;
308
+ Zip :: from ( & mut b) . and ( iter. clone ( ) ) . apply ( |b, a| * b = a[ ( ) ] ) ;
309
+ assert_eq ! ( a. slice( s![ consumed..] ) , b) ;
310
+ }
311
+ }
312
+
313
+ #[ test]
314
+ fn axis_iter_zip_partially_consumed_discontiguous ( ) {
315
+ let a = Array :: from_iter ( 0 ..5 ) ;
316
+ let mut iter = a. axis_iter ( Axis ( 0 ) ) ;
317
+ let mut consumed = 0 ;
318
+ while iter. next ( ) . is_some ( ) {
319
+ consumed += 1 ;
320
+ let mut b = Array :: zeros ( ( a. len ( ) - consumed) * 2 ) ;
321
+ b. slice_collapse ( s ! [ ..; 2 ] ) ;
322
+ Zip :: from ( & mut b) . and ( iter. clone ( ) ) . apply ( |b, a| * b = a[ ( ) ] ) ;
323
+ assert_eq ! ( a. slice( s![ consumed..] ) , b) ;
324
+ }
325
+ }
326
+
265
327
#[ test]
266
328
fn outer_iter_corner_cases ( ) {
267
329
let a2 = ArcArray :: < i32 , _ > :: zeros ( ( 0 , 3 ) ) ;
@@ -366,6 +428,89 @@ fn axis_chunks_iter() {
366
428
assert_equal ( it, vec ! [ a. view( ) ] ) ;
367
429
}
368
430
431
+ #[ test]
432
+ fn axis_iter_mut_split_at ( ) {
433
+ let mut a = Array :: from_iter ( 0 ..5 ) ;
434
+ let mut a_clone = a. clone ( ) ;
435
+ let all: Vec < _ > = a_clone. axis_iter_mut ( Axis ( 0 ) ) . collect ( ) ;
436
+ for mid in 0 ..=all. len ( ) {
437
+ let ( left, right) = a. axis_iter_mut ( Axis ( 0 ) ) . split_at ( mid) ;
438
+ assert_eq ! ( & all[ ..mid] , & left. collect:: <Vec <_>>( ) [ ..] ) ;
439
+ assert_eq ! ( & all[ mid..] , & right. collect:: <Vec <_>>( ) [ ..] ) ;
440
+ }
441
+ }
442
+
443
+ #[ test]
444
+ fn axis_iter_mut_split_at_partially_consumed ( ) {
445
+ let mut a = Array :: from_iter ( 0 ..5 ) ;
446
+ for consumed in 1 ..=a. len ( ) {
447
+ for mid in 0 ..=( a. len ( ) - consumed) {
448
+ let mut a_clone = a. clone ( ) ;
449
+ let remaining: Vec < _ > = {
450
+ let mut iter = a_clone. axis_iter_mut ( Axis ( 0 ) ) ;
451
+ for _ in 0 ..consumed {
452
+ iter. next ( ) ;
453
+ }
454
+ iter. collect ( )
455
+ } ;
456
+ let ( left, right) = {
457
+ let mut iter = a. axis_iter_mut ( Axis ( 0 ) ) ;
458
+ for _ in 0 ..consumed {
459
+ iter. next ( ) ;
460
+ }
461
+ iter. split_at ( mid)
462
+ } ;
463
+ assert_eq ! ( & remaining[ ..mid] , & left. collect:: <Vec <_>>( ) [ ..] ) ;
464
+ assert_eq ! ( & remaining[ mid..] , & right. collect:: <Vec <_>>( ) [ ..] ) ;
465
+ }
466
+ }
467
+ }
468
+
469
+ #[ test]
470
+ fn axis_iter_mut_zip ( ) {
471
+ let orig = Array :: from_iter ( 0 ..5 ) ;
472
+ let mut cloned = orig. clone ( ) ;
473
+ let iter = cloned. axis_iter_mut ( Axis ( 0 ) ) ;
474
+ let mut b = Array :: zeros ( 5 ) ;
475
+ Zip :: from ( & mut b) . and ( iter) . apply ( |b, mut a| {
476
+ a[ ( ) ] += 1 ;
477
+ * b = a[ ( ) ] ;
478
+ } ) ;
479
+ assert_eq ! ( cloned, b) ;
480
+ assert_eq ! ( cloned, orig + 1 ) ;
481
+ }
482
+
483
+ #[ test]
484
+ fn axis_iter_mut_zip_partially_consumed ( ) {
485
+ let mut a = Array :: from_iter ( 0 ..5 ) ;
486
+ for consumed in 1 ..=a. len ( ) {
487
+ let remaining = a. len ( ) - consumed;
488
+ let mut iter = a. axis_iter_mut ( Axis ( 0 ) ) ;
489
+ for _ in 0 ..consumed {
490
+ iter. next ( ) ;
491
+ }
492
+ let mut b = Array :: zeros ( remaining) ;
493
+ Zip :: from ( & mut b) . and ( iter) . apply ( |b, a| * b = a[ ( ) ] ) ;
494
+ assert_eq ! ( a. slice( s![ consumed..] ) , b) ;
495
+ }
496
+ }
497
+
498
+ #[ test]
499
+ fn axis_iter_mut_zip_partially_consumed_discontiguous ( ) {
500
+ let mut a = Array :: from_iter ( 0 ..5 ) ;
501
+ for consumed in 1 ..=a. len ( ) {
502
+ let remaining = a. len ( ) - consumed;
503
+ let mut iter = a. axis_iter_mut ( Axis ( 0 ) ) ;
504
+ for _ in 0 ..consumed {
505
+ iter. next ( ) ;
506
+ }
507
+ let mut b = Array :: zeros ( remaining * 2 ) ;
508
+ b. slice_collapse ( s ! [ ..; 2 ] ) ;
509
+ Zip :: from ( & mut b) . and ( iter) . apply ( |b, a| * b = a[ ( ) ] ) ;
510
+ assert_eq ! ( a. slice( s![ consumed..] ) , b) ;
511
+ }
512
+ }
513
+
369
514
#[ test]
370
515
fn axis_chunks_iter_corner_cases ( ) {
371
516
// examples provided by @bluss in PR #65
0 commit comments