Skip to content

Commit 9b9481e

Browse files
committed
Add layout module
1 parent 184d615 commit 9b9481e

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/layout.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
use ndarray::*;
3+
4+
use super::error::*;
5+
6+
pub enum Layout {
7+
C((usize, usize)),
8+
F((usize, usize)),
9+
}
10+
11+
impl Layout {
12+
pub fn size(&self) -> (usize, usize) {
13+
match self {
14+
&Layout::C(s) => s,
15+
&Layout::F(s) => s,
16+
}
17+
}
18+
}
19+
20+
pub trait AllocatedArray2D {
21+
type Scalar;
22+
fn layout(&self) -> LResult<Layout>;
23+
fn square_layout(&self) -> LResult<Layout>;
24+
fn as_allocated(&self) -> LResult<&[Self::Scalar]>;
25+
}
26+
27+
impl<A, S> AllocatedArray2D for ArrayBase<S, Ix2>
28+
where S: Data<Elem = A>
29+
{
30+
type Scalar = A;
31+
32+
fn layout(&self) -> LResult<Layout> {
33+
let strides = self.strides();
34+
if ::std::cmp::min(strides[0], strides[1]) != 1 {
35+
return Err(StrideError::new(strides[0], strides[1]).into());
36+
}
37+
if strides[0] < strides[1] {
38+
Ok(Layout::C((self.rows(), self.cols())))
39+
} else {
40+
Ok(Layout::F((self.rows(), self.cols())))
41+
}
42+
}
43+
44+
fn square_layout(&self) -> LResult<Layout> {
45+
let l = self.layout()?;
46+
let (n, m) = l.size();
47+
if n == m {
48+
Ok(l)
49+
} else {
50+
Err(NotSquareError::new(n, m).into())
51+
}
52+
}
53+
54+
fn as_allocated(&self) -> LResult<&[A]> {
55+
let slice = self.as_slice_memory_order().ok_or(MemoryContError::new())?;
56+
Ok(slice)
57+
}
58+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ extern crate derive_new;
4444

4545
pub mod impls;
4646
pub mod error;
47+
pub mod layout;
4748

4849
pub mod vector;
4950
pub mod matrix;

0 commit comments

Comments
 (0)