@@ -738,11 +738,19 @@ where
738
738
739
739
#[ derive( Debug ) ]
740
740
pub struct AxisIterCore < A , D > {
741
+ /// Index along the axis of the value of `.next()`, relative to the start
742
+ /// of the axis.
741
743
index : Ix ,
742
- len : Ix ,
744
+ /// (Exclusive) upper bound on `index`. Initially, this is equal to the
745
+ /// length of the axis.
746
+ end : Ix ,
747
+ /// Stride along the axis (offset between consecutive pointers).
743
748
stride : Ixs ,
749
+ /// Shape of the iterator's items.
744
750
inner_dim : D ,
751
+ /// Strides of the iterator's items.
745
752
inner_strides : D ,
753
+ /// Pointer corresponding to `index == 0`.
746
754
ptr : * mut A ,
747
755
}
748
756
@@ -751,7 +759,7 @@ clone_bounds!(
751
759
AxisIterCore [ A , D ] {
752
760
@copy {
753
761
index,
754
- len ,
762
+ end ,
755
763
stride,
756
764
ptr,
757
765
}
@@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
767
775
Di : RemoveAxis < Smaller = D > ,
768
776
S : Data < Elem = A > ,
769
777
{
770
- let shape = v. shape ( ) [ axis. index ( ) ] ;
771
- let stride = v. strides ( ) [ axis. index ( ) ] ;
772
778
AxisIterCore {
773
779
index : 0 ,
774
- len : shape ,
775
- stride,
780
+ end : v . len_of ( axis ) ,
781
+ stride : v . stride_of ( axis ) ,
776
782
inner_dim : v. dim . remove_axis ( axis) ,
777
783
inner_strides : v. strides . remove_axis ( axis) ,
778
784
ptr : v. ptr ,
779
785
}
780
786
}
781
787
788
+ #[ inline]
782
789
unsafe fn offset ( & self , index : usize ) -> * mut A {
783
790
debug_assert ! (
784
- index <= self . len ,
785
- "index={}, len ={}, stride={}" ,
791
+ index <= self . end ,
792
+ "index={}, end ={}, stride={}" ,
786
793
index,
787
- self . len ,
794
+ self . end ,
788
795
self . stride
789
796
) ;
790
797
self . ptr . offset ( index as isize * self . stride )
791
798
}
792
799
793
- /// Split the iterator at index, yielding two disjoint iterators.
800
+ /// Splits the iterator at `index`, yielding two disjoint iterators.
801
+ ///
802
+ /// `index` is relative to the current state of the iterator (which is not
803
+ /// necessarily the start of the axis).
794
804
///
795
- /// **Panics** if `index` is strictly greater than the iterator's length
805
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
806
+ /// length.
796
807
fn split_at ( self , index : usize ) -> ( Self , Self ) {
797
- assert ! ( index <= self . len) ;
798
- let right_ptr = if index != self . len {
799
- unsafe { self . offset ( index) }
800
- } else {
801
- self . ptr
802
- } ;
808
+ assert ! ( index <= self . len( ) ) ;
809
+ let mid = self . index + index;
803
810
let left = AxisIterCore {
804
- index : 0 ,
805
- len : index ,
811
+ index : self . index ,
812
+ end : mid ,
806
813
stride : self . stride ,
807
814
inner_dim : self . inner_dim . clone ( ) ,
808
815
inner_strides : self . inner_strides . clone ( ) ,
809
816
ptr : self . ptr ,
810
817
} ;
811
818
let right = AxisIterCore {
812
- index : 0 ,
813
- len : self . len - index ,
819
+ index : mid ,
820
+ end : self . end ,
814
821
stride : self . stride ,
815
822
inner_dim : self . inner_dim ,
816
823
inner_strides : self . inner_strides ,
817
- ptr : right_ptr ,
824
+ ptr : self . ptr ,
818
825
} ;
819
826
( left, right)
820
827
}
@@ -827,7 +834,7 @@ where
827
834
type Item = * mut A ;
828
835
829
836
fn next ( & mut self ) -> Option < Self :: Item > {
830
- if self . index >= self . len {
837
+ if self . index >= self . end {
831
838
None
832
839
} else {
833
840
let ptr = unsafe { self . offset ( self . index ) } ;
@@ -837,7 +844,7 @@ where
837
844
}
838
845
839
846
fn size_hint ( & self ) -> ( usize , Option < usize > ) {
840
- let len = self . len - self . index ;
847
+ let len = self . len ( ) ;
841
848
( len, Some ( len) )
842
849
}
843
850
}
@@ -847,16 +854,25 @@ where
847
854
D : Dimension ,
848
855
{
849
856
fn next_back ( & mut self ) -> Option < Self :: Item > {
850
- if self . index >= self . len {
857
+ if self . index >= self . end {
851
858
None
852
859
} else {
853
- self . len -= 1 ;
854
- let ptr = unsafe { self . offset ( self . len ) } ;
860
+ self . end -= 1 ;
861
+ let ptr = unsafe { self . offset ( self . end ) } ;
855
862
Some ( ptr)
856
863
}
857
864
}
858
865
}
859
866
867
+ impl < A , D > ExactSizeIterator for AxisIterCore < A , D >
868
+ where
869
+ D : Dimension ,
870
+ {
871
+ fn len ( & self ) -> usize {
872
+ self . end - self . index
873
+ }
874
+ }
875
+
860
876
/// An iterator that traverses over an axis and
861
877
/// and yields each subview.
862
878
///
@@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
899
915
}
900
916
}
901
917
902
- /// Split the iterator at index, yielding two disjoint iterators.
918
+ /// Splits the iterator at `index`, yielding two disjoint iterators.
919
+ ///
920
+ /// `index` is relative to the current state of the iterator (which is not
921
+ /// necessarily the start of the axis).
903
922
///
904
- /// **Panics** if `index` is strictly greater than the iterator's length
923
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
924
+ /// length.
905
925
pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
906
926
let ( left, right) = self . iter . split_at ( index) ;
907
927
(
@@ -946,7 +966,7 @@ where
946
966
D : Dimension ,
947
967
{
948
968
fn len ( & self ) -> usize {
949
- self . size_hint ( ) . 0
969
+ self . iter . len ( )
950
970
}
951
971
}
952
972
@@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
981
1001
}
982
1002
}
983
1003
984
- /// Split the iterator at index, yielding two disjoint iterators.
1004
+ /// Splits the iterator at ` index` , yielding two disjoint iterators.
985
1005
///
986
- /// **Panics** if `index` is strictly greater than the iterator's length
1006
+ /// `index` is relative to the current state of the iterator (which is not
1007
+ /// necessarily the start of the axis).
1008
+ ///
1009
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
1010
+ /// length.
987
1011
pub fn split_at ( self , index : usize ) -> ( Self , Self ) {
988
1012
let ( left, right) = self . iter . split_at ( index) ;
989
1013
(
@@ -1028,7 +1052,7 @@ where
1028
1052
D : Dimension ,
1029
1053
{
1030
1054
fn len ( & self ) -> usize {
1031
- self . size_hint ( ) . 0
1055
+ self . iter . len ( )
1032
1056
}
1033
1057
}
1034
1058
@@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
1048
1072
}
1049
1073
#[ doc( hidden) ]
1050
1074
fn as_ptr ( & self ) -> Self :: Ptr {
1051
- self . iter . ptr
1075
+ if self . len ( ) > 0 {
1076
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1077
+ // iterator remains (i.e. if `self.len() > 0`).
1078
+ unsafe { self . iter . offset ( self . iter . index ) }
1079
+ } else {
1080
+ // In this case, `self.iter.index` may be past the end, so we must
1081
+ // not call `.offset()`. It's okay to return a dangling pointer
1082
+ // because it will never be used in the length 0 case.
1083
+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1084
+ }
1052
1085
}
1053
1086
1054
1087
fn contiguous_stride ( & self ) -> isize {
@@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
1065
1098
}
1066
1099
#[ doc( hidden) ]
1067
1100
unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1068
- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1101
+ self . iter . offset ( self . iter . index + i[ 0 ] )
1069
1102
}
1070
1103
1071
1104
#[ doc( hidden) ]
@@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1096
1129
}
1097
1130
#[ doc( hidden) ]
1098
1131
fn as_ptr ( & self ) -> Self :: Ptr {
1099
- self . iter . ptr
1132
+ if self . len ( ) > 0 {
1133
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
1134
+ // iterator remains (i.e. if `self.len() > 0`).
1135
+ unsafe { self . iter . offset ( self . iter . index ) }
1136
+ } else {
1137
+ // In this case, `self.iter.index` may be past the end, so we must
1138
+ // not call `.offset()`. It's okay to return a dangling pointer
1139
+ // because it will never be used in the length 0 case.
1140
+ std:: ptr:: NonNull :: dangling ( ) . as_ptr ( )
1141
+ }
1100
1142
}
1101
1143
1102
1144
fn contiguous_stride ( & self ) -> isize {
@@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
1113
1155
}
1114
1156
#[ doc( hidden) ]
1115
1157
unsafe fn uget_ptr ( & self , i : & Self :: Dim ) -> Self :: Ptr {
1116
- self . iter . ptr . offset ( self . iter . stride * i[ 0 ] as isize )
1158
+ self . iter . offset ( self . iter . index + i[ 0 ] )
1117
1159
}
1118
1160
1119
1161
#[ doc( hidden) ]
@@ -1193,7 +1235,7 @@ fn chunk_iter_parts<A, D: Dimension>(
1193
1235
1194
1236
let iter = AxisIterCore {
1195
1237
index : 0 ,
1196
- len : iter_len,
1238
+ end : iter_len,
1197
1239
stride,
1198
1240
inner_dim,
1199
1241
inner_strides : v. strides ,
@@ -1270,7 +1312,7 @@ macro_rules! chunk_iter_impl {
1270
1312
D : Dimension ,
1271
1313
{
1272
1314
fn next_back( & mut self ) -> Option <Self :: Item > {
1273
- let is_uneven = self . iter. len > self . n_whole_chunks;
1315
+ let is_uneven = self . iter. end > self . n_whole_chunks;
1274
1316
let res = self . iter. next_back( ) ;
1275
1317
self . get_subview( res, is_uneven)
1276
1318
}
0 commit comments