@@ -13,7 +13,7 @@ pub struct Shape<D> {
13
13
}
14
14
15
15
#[ derive( Copy , Clone , Debug ) ]
16
- pub ( crate ) enum Contiguous { }
16
+ pub enum Contiguous { }
17
17
18
18
impl < D > Shape < D > {
19
19
pub ( crate ) fn is_c ( & self ) -> bool {
44
44
45
45
/// Stride description
46
46
#[ derive( Copy , Clone , Debug ) ]
47
- pub ( crate ) enum Strides < D > {
47
+ pub enum Strides < D > {
48
48
/// Row-major ("C"-order)
49
49
C ,
50
50
/// Column-major ("F"-order)
@@ -184,3 +184,63 @@ where
184
184
self . dim . size ( )
185
185
}
186
186
}
187
+
188
+
189
+ use crate :: order:: Order ;
190
+
191
+ pub trait ShapeArg {
192
+ type Dim : Dimension ;
193
+ type StrideType ;
194
+
195
+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) ;
196
+ fn into_shape_and_strides ( self , default : Order ) -> ( Self :: Dim , Strides < Self :: StrideType > ) ;
197
+ }
198
+
199
+ impl < T > ShapeArg for T where T : IntoDimension {
200
+ type Dim = T :: Dim ;
201
+ type StrideType = Contiguous ;
202
+
203
+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) {
204
+ ( self . into_dimension ( ) , default)
205
+ }
206
+
207
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
208
+ unimplemented ! ( )
209
+ }
210
+ }
211
+
212
+ impl < T > ShapeArg for ( T , Order ) where T : IntoDimension {
213
+ type Dim = T :: Dim ;
214
+ type StrideType = Contiguous ;
215
+
216
+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
217
+ ( self . 0 . into_dimension ( ) , self . 1 )
218
+ }
219
+
220
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
221
+ unimplemented ! ( )
222
+ }
223
+ }
224
+
225
+ /// Custom strides
226
+ #[ derive( Copy , Clone , Debug ) ]
227
+ pub struct CustomStrides < D > { strides : D }
228
+
229
+ // newtype constructor without public field
230
+ #[ allow( non_snake_case) ]
231
+ pub fn CustomStrides < T > ( strides : T ) -> CustomStrides < T > {
232
+ CustomStrides { strides }
233
+ }
234
+
235
+ impl < T > ShapeArg for ( T , CustomStrides < T > ) where T : IntoDimension {
236
+ type Dim = T :: Dim ;
237
+ type StrideType = T :: Dim ;
238
+
239
+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
240
+ ( self . 0 . into_dimension ( ) , _default)
241
+ }
242
+
243
+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < T :: Dim > ) {
244
+ ( self . 0 . into_dimension ( ) , Strides :: Custom ( self . 1 . strides . into_dimension ( ) ) )
245
+ }
246
+ }
0 commit comments