Skip to content

Commit 206c19e

Browse files
committed
shape: Add trait ShapeArg
1 parent 4b2b0d6 commit 206c19e

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

src/shape_builder.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub struct Shape<D> {
1313
}
1414

1515
#[derive(Copy, Clone, Debug)]
16-
pub(crate) enum Contiguous {}
16+
pub enum Contiguous {}
1717

1818
impl<D> Shape<D> {
1919
pub(crate) fn is_c(&self) -> bool {
@@ -44,7 +44,7 @@ where
4444

4545
/// Stride description
4646
#[derive(Copy, Clone, Debug)]
47-
pub(crate) enum Strides<D> {
47+
pub enum Strides<D> {
4848
/// Row-major ("C"-order)
4949
C,
5050
/// Column-major ("F"-order)
@@ -184,3 +184,63 @@ where
184184
self.dim.size()
185185
}
186186
}
187+
188+
189+
use crate::order::Order;
190+
191+
pub trait ShapeArg {
192+
type Dim: Dimension;
193+
type StrideType;
194+
195+
fn into_shape_and_order(self, default: Order) -> (Self::Dim, Order);
196+
fn into_shape_and_strides(self, default: Order) -> (Self::Dim, Strides<Self::StrideType>);
197+
}
198+
199+
impl<T> ShapeArg for T where T: IntoDimension {
200+
type Dim = T::Dim;
201+
type StrideType = Contiguous;
202+
203+
fn into_shape_and_order(self, default: Order) -> (Self::Dim, Order) {
204+
(self.into_dimension(), default)
205+
}
206+
207+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<Contiguous>) {
208+
unimplemented!()
209+
}
210+
}
211+
212+
impl<T> ShapeArg for (T, Order) where T: IntoDimension {
213+
type Dim = T::Dim;
214+
type StrideType = Contiguous;
215+
216+
fn into_shape_and_order(self, _default: Order) -> (Self::Dim, Order) {
217+
(self.0.into_dimension(), self.1)
218+
}
219+
220+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<Contiguous>) {
221+
unimplemented!()
222+
}
223+
}
224+
225+
/// Custom strides
226+
#[derive(Copy, Clone, Debug)]
227+
pub struct CustomStrides<D> { strides: D }
228+
229+
// newtype constructor without public field
230+
#[allow(non_snake_case)]
231+
pub fn CustomStrides<T>(strides: T) -> CustomStrides<T> {
232+
CustomStrides { strides }
233+
}
234+
235+
impl<T> ShapeArg for (T, CustomStrides<T>) where T: IntoDimension {
236+
type Dim = T::Dim;
237+
type StrideType = T::Dim;
238+
239+
fn into_shape_and_order(self, _default: Order) -> (Self::Dim, Order) {
240+
(self.0.into_dimension(), _default)
241+
}
242+
243+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<T::Dim>) {
244+
(self.0.into_dimension(), Strides::Custom(self.1.strides.into_dimension()))
245+
}
246+
}

0 commit comments

Comments
 (0)