1
1
use crate :: dimension:: IntoDimension ;
2
2
use crate :: Dimension ;
3
- use crate :: { Shape , StrideShape } ;
3
+
4
+ /// A contiguous array shape of n dimensions.
5
+ ///
6
+ /// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
7
+ #[ derive( Copy , Clone , Debug ) ]
8
+ pub struct Shape < D > {
9
+ /// Shape (axis lengths)
10
+ pub ( crate ) dim : D ,
11
+ /// Strides can only be C or F here
12
+ pub ( crate ) strides : Strides < Contiguous > ,
13
+ }
14
+
15
+ #[ derive( Copy , Clone , Debug ) ]
16
+ pub ( crate ) enum Contiguous { }
17
+
18
+ impl < D > Shape < D > {
19
+ pub ( crate ) fn is_c ( & self ) -> bool {
20
+ matches ! ( self . strides, Strides :: C )
21
+ }
22
+ }
23
+
24
+
25
+ /// An array shape of n dimensions in c-order, f-order or custom strides.
26
+ #[ derive( Copy , Clone , Debug ) ]
27
+ pub struct StrideShape < D > {
28
+ pub ( crate ) dim : D ,
29
+ pub ( crate ) strides : Strides < D > ,
30
+ }
31
+
32
+ /// Stride description
33
+ #[ derive( Copy , Clone , Debug ) ]
34
+ pub ( crate ) enum Strides < D > {
35
+ /// Row-major ("C"-order)
36
+ C ,
37
+ /// Column-major ("F"-order)
38
+ F ,
39
+ /// Custom strides
40
+ Custom ( D )
41
+ }
42
+
43
+ impl < D > Strides < D > {
44
+ /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45
+ pub ( crate ) fn strides_for_dim ( self , dim : & D ) -> D
46
+ where D : Dimension
47
+ {
48
+ match self {
49
+ Strides :: C => dim. default_strides ( ) ,
50
+ Strides :: F => dim. fortran_strides ( ) ,
51
+ Strides :: Custom ( c) => {
52
+ debug_assert_eq ! ( c. ndim( ) , dim. ndim( ) ,
53
+ "Custom strides given with {} dimensions, expected {}" ,
54
+ c. ndim( ) , dim. ndim( ) ) ;
55
+ c
56
+ }
57
+ }
58
+ }
59
+
60
+ pub ( crate ) fn is_custom ( & self ) -> bool {
61
+ matches ! ( * self , Strides :: Custom ( _) )
62
+ }
63
+ }
4
64
5
65
/// A trait for `Shape` and `D where D: Dimension` that allows
6
66
/// customizing the memory layout (strides) of an array shape.
@@ -34,36 +94,18 @@ where
34
94
{
35
95
fn from ( value : T ) -> Self {
36
96
let shape = value. into_shape ( ) ;
37
- let d = shape. dim ;
38
- let st = if shape. is_c {
39
- d. default_strides ( )
97
+ let st = if shape. is_c ( ) {
98
+ Strides :: C
40
99
} else {
41
- d . fortran_strides ( )
100
+ Strides :: F
42
101
} ;
43
102
StrideShape {
44
103
strides : st,
45
- dim : d,
46
- custom : false ,
104
+ dim : shape. dim ,
47
105
}
48
106
}
49
107
}
50
108
51
- /*
52
- impl<D> From<Shape<D>> for StrideShape<D>
53
- where D: Dimension
54
- {
55
- fn from(shape: Shape<D>) -> Self {
56
- let d = shape.dim;
57
- let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() };
58
- StrideShape {
59
- strides: st,
60
- dim: d,
61
- custom: false,
62
- }
63
- }
64
- }
65
- */
66
-
67
109
impl < T > ShapeBuilder for T
68
110
where
69
111
T : IntoDimension ,
73
115
fn into_shape ( self ) -> Shape < Self :: Dim > {
74
116
Shape {
75
117
dim : self . into_dimension ( ) ,
76
- is_c : true ,
118
+ strides : Strides :: C ,
77
119
}
78
120
}
79
121
fn f ( self ) -> Shape < Self :: Dim > {
@@ -93,21 +135,24 @@ where
93
135
{
94
136
type Dim = D ;
95
137
type Strides = D ;
138
+
96
139
fn into_shape ( self ) -> Shape < D > {
97
140
self
98
141
}
142
+
99
143
fn f ( self ) -> Self {
100
144
self . set_f ( true )
101
145
}
146
+
102
147
fn set_f ( mut self , is_f : bool ) -> Self {
103
- self . is_c = !is_f;
148
+ self . strides = if !is_f { Strides :: C } else { Strides :: F } ;
104
149
self
105
150
}
151
+
106
152
fn strides ( self , st : D ) -> StrideShape < D > {
107
153
StrideShape {
108
154
dim : self . dim ,
109
- strides : st,
110
- custom : true ,
155
+ strides : Strides :: Custom ( st) ,
111
156
}
112
157
}
113
158
}
0 commit comments