Skip to content

Commit 2e86888

Browse files
committed
CustomStrides array constructor
1 parent 12851b7 commit 2e86888

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ pub use crate::stacking::{concatenate, stack, stack_new_axis};
163163

164164
pub use crate::math_cell::MathCell;
165165
pub use crate::impl_views::IndexLonger;
166-
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};
166+
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape, CustomStrides};
167167

168168
#[macro_use]
169169
mod macro_utils;

src/shape_builder.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,17 @@ impl<T> ShapeArg for (T, CustomStrides<T>) where T: IntoDimension {
268268
(self.0.into_dimension(), Strides::Custom(self.1.strides.into_dimension()))
269269
}
270270
}
271+
272+
impl<T, D> From<(T, CustomStrides<T>)> for StrideShape<D>
273+
where
274+
D: Dimension,
275+
T: IntoDimension<Dim = D>,
276+
{
277+
fn from(value: (T, CustomStrides<T>)) -> Self {
278+
let shape = value.0.into_shape();
279+
StrideShape {
280+
strides: Strides::Custom(value.1.strides.into_dimension()),
281+
dim: shape.dim,
282+
}
283+
}
284+
}

tests/array-construct.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use defmac::defmac;
99
use ndarray::prelude::*;
1010
use ndarray::Zip;
11-
use ndarray::Order;
11+
use ndarray::{Order, CustomStrides};
1212

1313
#[test]
1414
fn test_from_shape_fn() {
@@ -259,3 +259,16 @@ fn test_order_constructor() {
259259
assert_eq!(a.strides(), b.t().strides());
260260
assert_eq!(a.strides(), c.strides());
261261
}
262+
263+
#[test]
264+
fn test_custom_strides() {
265+
type Mat<D> = Array<f32, D>;
266+
let v = vec![0.; 3 * 3 * 2];
267+
let a = Mat::from_shape_vec(((3, 3), CustomStrides((3, 1))), v.clone()).unwrap();
268+
let b = Mat::from_shape_vec(((3, 3), CustomStrides((1, 6))), v.clone()).unwrap();
269+
let c = Mat::from_shape_vec(((3, 3), CustomStrides((2, 6))), v.clone()).unwrap();
270+
271+
assert_eq!(a.strides(), &[3, 1]);
272+
assert_eq!(b.strides(), &[1, 6]);
273+
assert_eq!(c.strides(), &[2, 6]);
274+
}

0 commit comments

Comments
 (0)