Skip to content

Commit f7aac04

Browse files
authored
Merge pull request #211 from rust-ndarray/rewrite-layout
Use named struct for MatrixLayout
2 parents cb4f764 + 9a3e9ce commit f7aac04

File tree

8 files changed

+87
-49
lines changed

8 files changed

+87
-49
lines changed

lax/src/layout.rs

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,76 @@
11
//! Memory layout of matrices
2+
//!
3+
//! Different from ndarray format which consists of shape and strides,
4+
//! matrix format in LAPACK consists of row or column size and leading dimension.
5+
//!
6+
//! ndarray format and stride
7+
//! --------------------------
8+
//!
9+
//! Let us consider 3-dimensional array for explaining ndarray structure.
10+
//! The address of `(x,y,z)`-element in ndarray satisfies following relation:
11+
//!
12+
//! ```text
13+
//! shape = [Nx, Ny, Nz]
14+
//! where Nx > 0, Ny > 0, Nz > 0
15+
//! stride = [Sx, Sy, Sz]
16+
//!
17+
//! &data[(x, y, z)] = &data[(0, 0, 0)] + Sx*x + Sy*y + Sz*z
18+
//! for x < Nx, y < Ny, z < Nz
19+
//! ```
20+
//!
21+
//! The array is called
22+
//!
23+
//! - C-continuous if `[Sx, Sy, Sz] = [Nz*Ny, Nz, 1]`
24+
//! - F(Fortran)-continuous if `[Sx, Sy, Sz] = [1, Nx, Nx*Ny]`
25+
//!
26+
//! Strides of ndarray `[Sx, Sy, Sz]` take arbitrary value,
27+
//! e.g. it can be non-ordered `Sy > Sx > Sz`, or can be negative `Sx < 0`.
28+
//! If the minimum of `[Sx, Sy, Sz]` equals to `1`,
29+
//! the value of elements fills `data` memory region and called "continuous".
30+
//! Non-continuous ndarray is useful to get sub-array without copying data.
31+
//!
32+
//! Matrix layout for LAPACK
33+
//! -------------------------
34+
//!
35+
//! LAPACK interface focuses on the linear algebra operations for F-continuous 2-dimensional array.
36+
//! Under this restriction, stride becomes far simpler; we only have to consider the case `[1, S]`
37+
//! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`.
38+
//!
239
3-
pub type LDA = i32;
4-
pub type LEN = i32;
5-
pub type Col = i32;
6-
pub type Row = i32;
7-
8-
#[derive(Debug, Clone, Copy, PartialEq)]
40+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
941
pub enum MatrixLayout {
10-
C((Row, LDA)),
11-
F((Col, LDA)),
42+
C { row: i32, lda: i32 },
43+
F { col: i32, lda: i32 },
1244
}
1345

1446
impl MatrixLayout {
15-
pub fn size(&self) -> (Row, Col) {
47+
pub fn size(&self) -> (i32, i32) {
1648
match *self {
17-
MatrixLayout::C((row, lda)) => (row, lda),
18-
MatrixLayout::F((col, lda)) => (lda, col),
49+
MatrixLayout::C { row, lda } => (row, lda),
50+
MatrixLayout::F { col, lda } => (lda, col),
1951
}
2052
}
2153

22-
pub fn resized(&self, row: Row, col: Col) -> MatrixLayout {
54+
pub fn resized(&self, row: i32, col: i32) -> MatrixLayout {
2355
match *self {
24-
MatrixLayout::C(_) => MatrixLayout::C((row, col)),
25-
MatrixLayout::F(_) => MatrixLayout::F((col, row)),
56+
MatrixLayout::C { .. } => MatrixLayout::C { row, lda: col },
57+
MatrixLayout::F { .. } => MatrixLayout::F { col, lda: row },
2658
}
2759
}
2860

29-
pub fn lda(&self) -> LDA {
61+
pub fn lda(&self) -> i32 {
3062
std::cmp::max(
3163
1,
3264
match *self {
33-
MatrixLayout::C((_, lda)) | MatrixLayout::F((_, lda)) => lda,
65+
MatrixLayout::C { lda, .. } | MatrixLayout::F { lda, .. } => lda,
3466
},
3567
)
3668
}
3769

38-
pub fn len(&self) -> LEN {
70+
pub fn len(&self) -> i32 {
3971
match *self {
40-
MatrixLayout::C((row, _)) => row,
41-
MatrixLayout::F((col, _)) => col,
72+
MatrixLayout::C { row, .. } => row,
73+
MatrixLayout::F { col, .. } => col,
4274
}
4375
}
4476

@@ -48,8 +80,8 @@ impl MatrixLayout {
4880

4981
pub fn lapacke_layout(&self) -> lapacke::Layout {
5082
match *self {
51-
MatrixLayout::C(_) => lapacke::Layout::RowMajor,
52-
MatrixLayout::F(_) => lapacke::Layout::ColumnMajor,
83+
MatrixLayout::C { .. } => lapacke::Layout::RowMajor,
84+
MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor,
5385
}
5486
}
5587

@@ -59,8 +91,8 @@ impl MatrixLayout {
5991

6092
pub fn toggle_order(&self) -> Self {
6193
match *self {
62-
MatrixLayout::C((row, col)) => MatrixLayout::F((col, row)),
63-
MatrixLayout::F((col, row)) => MatrixLayout::C((row, col)),
94+
MatrixLayout::C { row, lda } => MatrixLayout::F { lda: row, col: lda },
95+
MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col },
6496
}
6597
}
6698
}

lax/src/opnorm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ macro_rules! impl_opnorm {
1515
impl OperatorNorm_ for $scalar {
1616
unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real {
1717
match l {
18-
MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
19-
MatrixLayout::C((row, lda)) => {
18+
MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda),
19+
MatrixLayout::C { row, lda } => {
2020
$lange(cm, t.transpose() as u8, lda, row, a, lda)
2121
}
2222
}

lax/src/solveh.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ macro_rules! impl_solveh {
5858
let (n, _) = l.size();
5959
let nrhs = 1;
6060
let ldb = match l {
61-
MatrixLayout::C(_) => 1,
62-
MatrixLayout::F(_) => n,
61+
MatrixLayout::C { .. } => 1,
62+
MatrixLayout::F { .. } => n,
6363
};
6464
$trs(
6565
l.lapacke_layout(),

ndarray-linalg/src/convert.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ where
3636
S: DataOwned<Elem = A>,
3737
{
3838
match l {
39-
MatrixLayout::C((row, col)) => {
40-
Ok(ArrayBase::from_shape_vec((row as usize, col as usize), a)?)
39+
MatrixLayout::C { row, lda } => {
40+
Ok(ArrayBase::from_shape_vec((row as usize, lda as usize), a)?)
4141
}
42-
MatrixLayout::F((col, row)) => Ok(ArrayBase::from_shape_vec(
43-
(row as usize, col as usize).f(),
42+
MatrixLayout::F { col, lda } => Ok(ArrayBase::from_shape_vec(
43+
(lda as usize, col as usize).f(),
4444
a,
4545
)?),
4646
}
@@ -52,11 +52,11 @@ where
5252
S: DataOwned<Elem = A>,
5353
{
5454
match l {
55-
MatrixLayout::C((row, col)) => unsafe {
56-
ArrayBase::uninitialized((row as usize, col as usize))
55+
MatrixLayout::C { row, lda } => unsafe {
56+
ArrayBase::uninitialized((row as usize, lda as usize))
5757
},
58-
MatrixLayout::F((col, row)) => unsafe {
59-
ArrayBase::uninitialized((row as usize, col as usize).f())
58+
MatrixLayout::F { col, lda } => unsafe {
59+
ArrayBase::uninitialized((lda as usize, col as usize).f())
6060
},
6161
}
6262
}

ndarray-linalg/src/eigh.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ where
9696
let layout = self.square_layout()?;
9797
// XXX Force layout to be Fortran (see #146)
9898
match layout {
99-
MatrixLayout::C(_) => self.swap_axes(0, 1),
100-
MatrixLayout::F(_) => {}
99+
MatrixLayout::C { .. } => self.swap_axes(0, 1),
100+
MatrixLayout::F { .. } => {}
101101
}
102102
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
103103
Ok((ArrayBase::from(s), self))
@@ -116,14 +116,14 @@ where
116116
let layout = self.0.square_layout()?;
117117
// XXX Force layout to be Fortran (see #146)
118118
match layout {
119-
MatrixLayout::C(_) => self.0.swap_axes(0, 1),
120-
MatrixLayout::F(_) => {}
119+
MatrixLayout::C { .. } => self.0.swap_axes(0, 1),
120+
MatrixLayout::F { .. } => {}
121121
}
122122

123123
let layout = self.1.square_layout()?;
124124
match layout {
125-
MatrixLayout::C(_) => self.1.swap_axes(0, 1),
126-
MatrixLayout::F(_) => {}
125+
MatrixLayout::C { .. } => self.1.swap_axes(0, 1),
126+
MatrixLayout::F { .. } => {}
127127
}
128128

129129
let s = unsafe {

ndarray-linalg/src/layout.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Memory layout of matrices
1+
//! Convert ndarray into LAPACK-compatible matrix format
22
33
use super::error::*;
44
use ndarray::*;
@@ -28,10 +28,16 @@ where
2828
let shape = self.shape();
2929
let strides = self.strides();
3030
if shape[0] == strides[1] as usize {
31-
return Ok(MatrixLayout::F((self.ncols() as i32, self.nrows() as i32)));
31+
return Ok(MatrixLayout::F {
32+
col: self.ncols() as i32,
33+
lda: self.nrows() as i32,
34+
});
3235
}
3336
if shape[1] == strides[0] as usize {
34-
return Ok(MatrixLayout::C((self.nrows() as i32, self.ncols() as i32)));
37+
return Ok(MatrixLayout::C {
38+
row: self.nrows() as i32,
39+
lda: self.ncols() as i32,
40+
});
3541
}
3642
Err(LinalgError::InvalidStride {
3743
s0: strides[0],

ndarray-linalg/src/least_squares.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ mod tests {
735735
fn test_incompatible_shape_error_on_mismatching_layout() {
736736
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
737737
let b = array![[1.], [2.]].t().to_owned();
738-
assert_eq!(b.layout().unwrap(), MatrixLayout::F((2, 1)));
738+
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });
739739

740740
let res = a.least_squares(&b);
741741
match res {

ndarray-linalg/tests/layout.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@ use ndarray_linalg::*;
66
fn layout_c_3x1() {
77
let a: Array2<f64> = Array::zeros((3, 1));
88
println!("a = {:?}", &a);
9-
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 1)));
9+
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 1 });
1010
}
1111

1212
#[test]
1313
fn layout_f_3x1() {
1414
let a: Array2<f64> = Array::zeros((3, 1).f());
1515
println!("a = {:?}", &a);
16-
assert_eq!(a.layout().unwrap(), MatrixLayout::F((1, 3)));
16+
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 1, lda: 3 });
1717
}
1818

1919
#[test]
2020
fn layout_c_3x2() {
2121
let a: Array2<f64> = Array::zeros((3, 2));
2222
println!("a = {:?}", &a);
23-
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 2)));
23+
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 2 });
2424
}
2525

2626
#[test]
2727
fn layout_f_3x2() {
2828
let a: Array2<f64> = Array::zeros((3, 2).f());
2929
println!("a = {:?}", &a);
30-
assert_eq!(a.layout().unwrap(), MatrixLayout::F((2, 3)));
30+
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 2, lda: 3 });
3131
}

0 commit comments

Comments
 (0)