Skip to content

Commit 12851b7

Browse files
committed
Order constructor for arrays
1 parent 9544e0b commit 12851b7

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

src/shape_builder.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,30 @@ impl<T> ShapeArg for (T, Order) where T: IntoDimension {
222222
}
223223
}
224224

225+
impl<T> ShapeBuilder for (T, Order) where T: IntoDimension {
226+
type Dim = T::Dim;
227+
type Strides = T;
228+
fn into_shape(self) -> Shape<Self::Dim> {
229+
let strides = match self.1 {
230+
Order::C | Order::A => Strides::C,
231+
Order::F => Strides::F,
232+
};
233+
Shape {
234+
dim: self.0.into_dimension(),
235+
strides,
236+
}
237+
}
238+
fn f(self) -> Shape<Self::Dim> {
239+
self.set_f(true)
240+
}
241+
fn set_f(self, is_f: bool) -> Shape<Self::Dim> {
242+
self.into_shape().set_f(is_f)
243+
}
244+
fn strides(self, st: T) -> StrideShape<Self::Dim> {
245+
self.into_shape().strides(st.into_dimension())
246+
}
247+
}
248+
225249
/// Custom strides
226250
#[derive(Copy, Clone, Debug)]
227251
pub struct CustomStrides<D> { strides: D }

tests/array-construct.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use defmac::defmac;
99
use ndarray::prelude::*;
1010
use ndarray::Zip;
11+
use ndarray::Order;
1112

1213
#[test]
1314
fn test_from_shape_fn() {
@@ -248,3 +249,13 @@ fn maybe_uninit_1() {
248249

249250
}
250251
}
252+
253+
#[test]
254+
fn test_order_constructor() {
255+
type Mat<D> = Array<f32, D>;
256+
let a = Mat::zeros(((3, 3), Order::C));
257+
let b = Mat::zeros(((3, 3), Order::F));
258+
let c = Mat::zeros(((3, 3), Order::A));
259+
assert_eq!(a.strides(), b.t().strides());
260+
assert_eq!(a.strides(), c.strides());
261+
}

0 commit comments

Comments
 (0)