Skip to content

Commit e53a328

Browse files
committed
Revise layout func
1 parent f50a6b8 commit e53a328

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

src/layout.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub type LEN = i32;
99
pub type Col = i32;
1010
pub type Row = i32;
1111

12-
#[derive(Debug, Clone, Copy)]
12+
#[derive(Debug, Clone, Copy, PartialEq)]
1313
pub enum Layout {
1414
C((Row, LDA)),
1515
F((Col, LDA)),
@@ -87,15 +87,15 @@ impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
8787
type Elem = A;
8888

8989
fn layout(&self) -> Result<Layout> {
90+
let shape = self.shape();
9091
let strides = self.strides();
91-
if ::std::cmp::min(strides[0], strides[1]) != 1 {
92-
return Err(StrideError::new(strides[0], strides[1]).into());
92+
if shape[0] == strides[1] as usize {
93+
return Ok(Layout::F((self.cols() as i32, self.rows() as i32)));
9394
}
94-
if strides[0] > strides[1] {
95-
Ok(Layout::C((self.rows() as i32, self.cols() as i32)))
96-
} else {
97-
Ok(Layout::F((self.cols() as i32, self.rows() as i32)))
95+
if shape[1] == strides[0] as usize {
96+
return Ok(Layout::C((self.rows() as i32, self.cols() as i32)));
9897
}
98+
Err(StrideError::new(strides[0], strides[1]).into())
9999
}
100100

101101
fn square_layout(&self) -> Result<Layout> {

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
pub use assert::*;
33
pub use generate::*;
44
pub use types::*;
5+
pub use layout::*;
56

67
pub use cholesky::*;
78
pub use eigh::*;

tests/layout.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
extern crate ndarray;
3+
extern crate ndarray_linalg;
4+
5+
use ndarray::*;
6+
use ndarray_linalg::prelude::*;
7+
use ndarray_linalg::layout::Layout;
8+
9+
#[test]
10+
fn layout_c_3x1() {
11+
let a: Array2<f64> = Array::zeros((3, 1));
12+
println!("a = {:?}", &a);
13+
assert_eq!(a.layout().unwrap(), Layout::C((3, 1)));
14+
}
15+
16+
#[test]
17+
fn layout_f_3x1() {
18+
let a: Array2<f64> = Array::zeros((3, 1).f());
19+
println!("a = {:?}", &a);
20+
assert_eq!(a.layout().unwrap(), Layout::F((1, 3)));
21+
}
22+
23+
#[test]
24+
fn layout_c_3x2() {
25+
let a: Array2<f64> = Array::zeros((3, 2));
26+
println!("a = {:?}", &a);
27+
assert_eq!(a.layout().unwrap(), Layout::C((3, 2)));
28+
}
29+
30+
#[test]
31+
fn layout_f_3x2() {
32+
let a: Array2<f64> = Array::zeros((3, 2).f());
33+
println!("a = {:?}", &a);
34+
assert_eq!(a.layout().unwrap(), Layout::F((2, 3)));
35+
}

0 commit comments

Comments
 (0)