Skip to content

Commit 87da884

Browse files
committed
shape: Add trait ShapeArg
1 parent 68eb5c5 commit 87da884

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

src/shape_builder.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub struct Shape<D> {
1313
}
1414

1515
#[derive(Copy, Clone, Debug)]
16-
pub(crate) enum Contiguous { }
16+
pub enum Contiguous { }
1717

1818
impl<D> Shape<D> {
1919
pub(crate) fn is_c(&self) -> bool {
@@ -31,7 +31,7 @@ pub struct StrideShape<D> {
3131

3232
/// Stride description
3333
#[derive(Copy, Clone, Debug)]
34-
pub(crate) enum Strides<D> {
34+
pub enum Strides<D> {
3535
/// Row-major ("C"-order)
3636
C,
3737
/// Column-major ("F"-order)
@@ -168,3 +168,63 @@ where
168168
self.dim.size()
169169
}
170170
}
171+
172+
173+
use crate::order::Order;
174+
175+
pub trait ShapeArg {
176+
type Dim: Dimension;
177+
type StrideType;
178+
179+
fn into_shape_and_order(self, default: Order) -> (Self::Dim, Order);
180+
fn into_shape_and_strides(self, default: Order) -> (Self::Dim, Strides<Self::StrideType>);
181+
}
182+
183+
impl<T> ShapeArg for T where T: IntoDimension {
184+
type Dim = T::Dim;
185+
type StrideType = Contiguous;
186+
187+
fn into_shape_and_order(self, default: Order) -> (Self::Dim, Order) {
188+
(self.into_dimension(), default)
189+
}
190+
191+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<Contiguous>) {
192+
unimplemented!()
193+
}
194+
}
195+
196+
impl<T> ShapeArg for (T, Order) where T: IntoDimension {
197+
type Dim = T::Dim;
198+
type StrideType = Contiguous;
199+
200+
fn into_shape_and_order(self, _default: Order) -> (Self::Dim, Order) {
201+
(self.0.into_dimension(), self.1)
202+
}
203+
204+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<Contiguous>) {
205+
unimplemented!()
206+
}
207+
}
208+
209+
/// Custom strides
210+
#[derive(Copy, Clone, Debug)]
211+
pub struct CustomStrides<D> { strides: D }
212+
213+
// newtype constructor without public field
214+
#[allow(non_snake_case)]
215+
pub fn CustomStrides<T>(strides: T) -> CustomStrides<T> {
216+
CustomStrides { strides }
217+
}
218+
219+
impl<T> ShapeArg for (T, CustomStrides<T>) where T: IntoDimension {
220+
type Dim = T::Dim;
221+
type StrideType = T::Dim;
222+
223+
fn into_shape_and_order(self, _default: Order) -> (Self::Dim, Order) {
224+
(self.0.into_dimension(), _default)
225+
}
226+
227+
fn into_shape_and_strides(self, _default: Order) -> (Self::Dim, Strides<T::Dim>) {
228+
(self.0.into_dimension(), Strides::Custom(self.1.strides.into_dimension()))
229+
}
230+
}

0 commit comments

Comments
 (0)