@@ -13,7 +13,7 @@ pub struct Shape<D> {
13
13
}
14
14
15
15
#[ derive( Copy , Clone , Debug ) ]
16
- pub ( crate ) enum Contiguous { }
16
+ pub enum Contiguous { }
17
17
18
18
impl < D > Shape < D > {
19
19
pub ( crate ) fn is_c ( & self ) -> bool {
@@ -31,7 +31,7 @@ pub struct StrideShape<D> {
31
31
32
32
/// Stride description
33
33
#[ derive( Copy , Clone , Debug ) ]
34
- pub ( crate ) enum Strides < D > {
34
+ pub enum Strides < D > {
35
35
/// Row-major ("C"-order)
36
36
C ,
37
37
/// Column-major ("F"-order)
@@ -168,3 +168,63 @@ where
168
168
self . dim . size ( )
169
169
}
170
170
}
171
+
172
+
173
+ use crate :: order:: Order ;
174
+
175
+ pub trait ShapeArg {
176
+ type Dim : Dimension ;
177
+ type StrideType ;
178
+
179
+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) ;
180
+ fn into_shape_and_strides ( self , default : Order ) -> ( Self :: Dim , Strides < Self :: StrideType > ) ;
181
+ }
182
+
183
+ impl < T > ShapeArg for T where T : IntoDimension {
184
+ type Dim = T :: Dim ;
185
+ type StrideType = Contiguous ;
186
+
187
+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) {
188
+ ( self . into_dimension ( ) , default)
189
+ }
190
+
191
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
192
+ unimplemented ! ( )
193
+ }
194
+ }
195
+
196
+ impl < T > ShapeArg for ( T , Order ) where T : IntoDimension {
197
+ type Dim = T :: Dim ;
198
+ type StrideType = Contiguous ;
199
+
200
+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
201
+ ( self . 0 . into_dimension ( ) , self . 1 )
202
+ }
203
+
204
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
205
+ unimplemented ! ( )
206
+ }
207
+ }
208
+
209
+ /// Custom strides
210
+ #[ derive( Copy , Clone , Debug ) ]
211
+ pub struct CustomStrides < D > { strides : D }
212
+
213
+ // newtype constructor without public field
214
+ #[ allow( non_snake_case) ]
215
+ pub fn CustomStrides < T > ( strides : T ) -> CustomStrides < T > {
216
+ CustomStrides { strides }
217
+ }
218
+
219
+ impl < T > ShapeArg for ( T , CustomStrides < T > ) where T : IntoDimension {
220
+ type Dim = T :: Dim ;
221
+ type StrideType = T :: Dim ;
222
+
223
+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
224
+ ( self . 0 . into_dimension ( ) , _default)
225
+ }
226
+
227
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < T :: Dim > ) {
228
+ ( self . 0 . into_dimension ( ) , Strides :: Custom ( self . 1 . strides . into_dimension ( ) ) )
229
+ }
230
+ }
0 commit comments