@@ -16,22 +16,57 @@ mod windows;
16
16
use std:: iter:: FromIterator ;
17
17
use std:: marker:: PhantomData ;
18
18
use std:: ptr;
19
+ use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
19
20
use alloc:: vec:: Vec ;
20
21
22
+ use crate :: imp_prelude:: * ;
21
23
use crate :: Ix1 ;
22
24
23
- use super :: { ArrayBase , ArrayView , ArrayViewMut , Axis , Data , NdProducer , RemoveAxis } ;
24
- use super :: { Dimension , Ix , Ixs } ;
25
+ use super :: { NdProducer , RemoveAxis } ;
25
26
26
27
pub use self :: chunks:: { ExactChunks , ExactChunksIter , ExactChunksIterMut , ExactChunksMut } ;
27
28
pub use self :: lanes:: { Lanes , LanesMut } ;
28
29
pub use self :: windows:: Windows ;
29
30
30
- use std:: slice:: { self , Iter as SliceIter , IterMut as SliceIterMut } ;
31
+ use crate :: dimension;
32
+
33
+ /// No traversal optmizations that would change element order or axis dimensions are permitted.
34
+ ///
35
+ /// This option is suitable for example for the indexed iterator.
36
+ pub ( crate ) enum NoOptimization { }
37
+
38
+ /// Preserve element iteration order, but modify dimensions if profitable; for example we can
39
+ /// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
40
+ ///
41
+ /// This option is suitable for example for the default .iter() iterator.
42
+ pub ( crate ) enum PreserveOrder { }
43
+
44
+ /// Allow use of arbitrary element iteration order
45
+ ///
46
+ /// This option is suitable for example for an arbitrary order iterator.
47
+ pub ( crate ) enum ArbitraryOrder { }
48
+
49
+ pub ( crate ) trait OrderOption {
50
+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = false ;
51
+ const ALLOW_ARBITRARY_ORDER : bool = false ;
52
+ }
53
+
54
+ impl OrderOption for NoOptimization { }
55
+
56
+ impl OrderOption for PreserveOrder {
57
+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
58
+ }
59
+
60
+ impl OrderOption for ArbitraryOrder {
61
+ const ALLOW_REMOVE_REDUNDANT_AXES : bool = true ;
62
+ const ALLOW_ARBITRARY_ORDER : bool = true ;
63
+ }
31
64
32
65
/// Base for iterators over all axes.
33
66
///
34
67
/// Iterator element type is `*mut A`.
68
+ ///
69
+ /// `F` is for layout/iteration order flags
35
70
pub ( crate ) struct Baseiter < A , D > {
36
71
ptr : * mut A ,
37
72
dim : D ,
@@ -44,12 +79,43 @@ impl<A, D: Dimension> Baseiter<A, D> {
44
79
/// to be correct to avoid performing an unsafe pointer offset while
45
80
/// iterating.
46
81
#[ inline]
47
- pub unsafe fn new ( ptr : * mut A , len : D , stride : D ) -> Baseiter < A , D > {
82
+ pub unsafe fn new ( ptr : * mut A , dim : D , strides : D ) -> Baseiter < A , D > {
83
+ Self :: new_with_order :: < NoOptimization > ( ptr, dim, strides)
84
+ }
85
+ }
86
+
87
+ impl < A , D : Dimension > Baseiter < A , D > {
88
+ /// Creating a Baseiter is unsafe because shape and stride parameters need
89
+ /// to be correct to avoid performing an unsafe pointer offset while
90
+ /// iterating.
91
+ #[ inline]
92
+ pub unsafe fn new_with_order < Flags : OrderOption > ( mut ptr : * mut A , mut dim : D , mut strides : D )
93
+ -> Baseiter < A , D >
94
+ {
95
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
96
+ if Flags :: ALLOW_ARBITRARY_ORDER {
97
+ // iterate in memory order; merge axes if possible
98
+ // make all axes positive and put the pointer back to the first element in memory
99
+ let offset = dimension:: offset_from_ptr_to_memory ( & dim, & strides) ;
100
+ ptr = ptr. offset ( offset) ;
101
+ for i in 0 ..strides. ndim ( ) {
102
+ let s = strides. get_stride ( Axis ( i) ) ;
103
+ if s < 0 {
104
+ strides. set_stride ( Axis ( i) , -s) ;
105
+ }
106
+ }
107
+ dimension:: sort_axes_to_standard ( & mut dim, & mut strides) ;
108
+ }
109
+ if Flags :: ALLOW_REMOVE_REDUNDANT_AXES {
110
+ // preserve element order but shift dimensions
111
+ dimension:: merge_axes_from_the_back ( & mut dim, & mut strides) ;
112
+ dimension:: squeeze ( & mut dim, & mut strides) ;
113
+ }
48
114
Baseiter {
49
115
ptr,
50
- index : len . first_index ( ) ,
51
- dim : len ,
52
- strides : stride ,
116
+ index : dim . first_index ( ) ,
117
+ dim,
118
+ strides,
53
119
}
54
120
}
55
121
}
@@ -1496,3 +1562,147 @@ where
1496
1562
debug_assert_eq ! ( size, result. len( ) ) ;
1497
1563
result
1498
1564
}
1565
+
1566
+ #[ cfg( test) ]
1567
+ #[ cfg( feature = "std" ) ]
1568
+ mod tests {
1569
+ use crate :: prelude:: * ;
1570
+ use super :: Baseiter ;
1571
+ use super :: { ArbitraryOrder , PreserveOrder , NoOptimization } ;
1572
+ use itertools:: assert_equal;
1573
+ use itertools:: Itertools ;
1574
+
1575
+ // 3-d axis swaps
1576
+ fn swaps ( ) -> impl Iterator < Item =Vec < ( usize , usize ) > > {
1577
+ vec ! [
1578
+ vec![ ] ,
1579
+ vec![ ( 0 , 1 ) ] ,
1580
+ vec![ ( 0 , 2 ) ] ,
1581
+ vec![ ( 1 , 2 ) ] ,
1582
+ vec![ ( 0 , 1 ) , ( 1 , 2 ) ] ,
1583
+ vec![ ( 0 , 1 ) , ( 0 , 2 ) ] ,
1584
+ ] . into_iter ( )
1585
+ }
1586
+
1587
+ // 3-d axis inverts
1588
+ fn inverts ( ) -> impl Iterator < Item =Vec < Axis > > {
1589
+ vec ! [
1590
+ vec![ ] ,
1591
+ vec![ Axis ( 0 ) ] ,
1592
+ vec![ Axis ( 1 ) ] ,
1593
+ vec![ Axis ( 2 ) ] ,
1594
+ vec![ Axis ( 0 ) , Axis ( 1 ) ] ,
1595
+ vec![ Axis ( 0 ) , Axis ( 2 ) ] ,
1596
+ vec![ Axis ( 1 ) , Axis ( 2 ) ] ,
1597
+ vec![ Axis ( 0 ) , Axis ( 1 ) , Axis ( 2 ) ] ,
1598
+ ] . into_iter ( )
1599
+ }
1600
+
1601
+ #[ test]
1602
+ fn test_arbitrary_order ( ) {
1603
+ for swap in swaps ( ) {
1604
+ for invert in inverts ( ) {
1605
+ for & slice in & [ false , true ] {
1606
+ // pattern is 0, 1; 4, 5; 8, 9; etc..
1607
+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1608
+ if slice {
1609
+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1610
+ }
1611
+ for & ( i, j) in & swap {
1612
+ a. swap_axes ( i, j) ;
1613
+ }
1614
+ for & i in & invert {
1615
+ a. invert_axis ( i) ;
1616
+ }
1617
+ unsafe {
1618
+ // Should have in-memory order for arbitrary order
1619
+ let iter = Baseiter :: new_with_order :: < ArbitraryOrder > ( a. as_mut_ptr ( ) ,
1620
+ a. dim , a. strides ) ;
1621
+ if !slice {
1622
+ assert_equal ( iter. map ( |ptr| * ptr) , 0 ..a. len ( ) ) ;
1623
+ } else {
1624
+ assert_eq ! ( iter. map( |ptr| * ptr) . collect_vec( ) ,
1625
+ ( 0 ..a. len( ) * 2 ) . filter( |& x| ( x / 2 ) % 2 == 0 ) . collect_vec( ) ) ;
1626
+ }
1627
+ }
1628
+ }
1629
+ }
1630
+ }
1631
+ }
1632
+
1633
+ #[ test]
1634
+ fn test_logical_order ( ) {
1635
+ for swap in swaps ( ) {
1636
+ for invert in inverts ( ) {
1637
+ for & slice in & [ false , true ] {
1638
+ let mut a = Array :: from_iter ( 0 ..24 ) . into_shape ( ( 3 , 4 , 2 ) ) . unwrap ( ) ;
1639
+ for & ( i, j) in & swap {
1640
+ a. swap_axes ( i, j) ;
1641
+ }
1642
+ for & i in & invert {
1643
+ a. invert_axis ( i) ;
1644
+ }
1645
+ if slice {
1646
+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1647
+ }
1648
+
1649
+ unsafe {
1650
+ let mut iter = Baseiter :: new_with_order :: < NoOptimization > ( a. as_mut_ptr ( ) ,
1651
+ a. dim , a. strides ) ;
1652
+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1653
+ let mut elts = 0 ;
1654
+ while let Some ( elt) = iter. next ( ) {
1655
+ assert_eq ! ( * elt, a[ index] ) ;
1656
+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1657
+ index = index_;
1658
+ }
1659
+ elts += 1 ;
1660
+ }
1661
+ assert_eq ! ( elts, a. len( ) ) ;
1662
+ }
1663
+ }
1664
+ }
1665
+ }
1666
+ }
1667
+
1668
+ #[ test]
1669
+ fn test_preserve_order ( ) {
1670
+ for swap in swaps ( ) {
1671
+ for invert in inverts ( ) {
1672
+ for & slice in & [ false , true ] {
1673
+ let mut a = Array :: from_iter ( 0 ..20 ) . into_shape ( ( 2 , 10 , 1 ) ) . unwrap ( ) ;
1674
+ for & ( i, j) in & swap {
1675
+ a. swap_axes ( i, j) ;
1676
+ }
1677
+ for & i in & invert {
1678
+ a. invert_axis ( i) ;
1679
+ }
1680
+ if slice {
1681
+ a. slice_collapse ( s ! [ .., ..; 2 , ..] ) ;
1682
+ }
1683
+
1684
+ unsafe {
1685
+ let mut iter = Baseiter :: new_with_order :: < PreserveOrder > (
1686
+ a. as_mut_ptr ( ) , a. dim , a. strides ) ;
1687
+
1688
+ // check that axes have been merged (when it's easy to check)
1689
+ if a. shape ( ) == & [ 2 , 10 , 1 ] && invert. is_empty ( ) {
1690
+ assert_eq ! ( iter. dim, Dim ( [ 1 , 1 , 20 ] ) ) ;
1691
+ }
1692
+
1693
+ let mut index = Dim ( [ 0 , 0 , 0 ] ) ;
1694
+ let mut elts = 0 ;
1695
+ while let Some ( elt) = iter. next ( ) {
1696
+ assert_eq ! ( * elt, a[ index] ) ;
1697
+ if let Some ( index_) = a. raw_dim ( ) . next_for ( index) {
1698
+ index = index_;
1699
+ }
1700
+ elts += 1 ;
1701
+ }
1702
+ assert_eq ! ( elts, a. len( ) ) ;
1703
+ }
1704
+ }
1705
+ }
1706
+ }
1707
+ }
1708
+ }
0 commit comments