Skip to content

Commit 81f08e3

Browse files
committed
Use struct instead of tuple for MatrixLayout::{C, F}
Fix for layout change
1 parent cb4f764 commit 81f08e3

File tree

8 files changed

+45
-39
lines changed

8 files changed

+45
-39
lines changed

lax/src/layout.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,40 @@ pub type LEN = i32;
55
pub type Col = i32;
66
pub type Row = i32;
77

8-
#[derive(Debug, Clone, Copy, PartialEq)]
8+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99
pub enum MatrixLayout {
10-
C((Row, LDA)),
11-
F((Col, LDA)),
10+
C { row: i32, lda: i32 },
11+
F { col: i32, lda: i32 },
1212
}
1313

1414
impl MatrixLayout {
1515
pub fn size(&self) -> (Row, Col) {
1616
match *self {
17-
MatrixLayout::C((row, lda)) => (row, lda),
18-
MatrixLayout::F((col, lda)) => (lda, col),
17+
MatrixLayout::C { row, lda } => (row, lda),
18+
MatrixLayout::F { col, lda } => (lda, col),
1919
}
2020
}
2121

2222
pub fn resized(&self, row: Row, col: Col) -> MatrixLayout {
2323
match *self {
24-
MatrixLayout::C(_) => MatrixLayout::C((row, col)),
25-
MatrixLayout::F(_) => MatrixLayout::F((col, row)),
24+
MatrixLayout::C { .. } => MatrixLayout::C { row, lda: col },
25+
MatrixLayout::F { .. } => MatrixLayout::F { col, lda: row },
2626
}
2727
}
2828

2929
pub fn lda(&self) -> LDA {
3030
std::cmp::max(
3131
1,
3232
match *self {
33-
MatrixLayout::C((_, lda)) | MatrixLayout::F((_, lda)) => lda,
33+
MatrixLayout::C { lda, .. } | MatrixLayout::F { lda, .. } => lda,
3434
},
3535
)
3636
}
3737

3838
pub fn len(&self) -> LEN {
3939
match *self {
40-
MatrixLayout::C((row, _)) => row,
41-
MatrixLayout::F((col, _)) => col,
40+
MatrixLayout::C { row, .. } => row,
41+
MatrixLayout::F { col, .. } => col,
4242
}
4343
}
4444

@@ -48,8 +48,8 @@ impl MatrixLayout {
4848

4949
pub fn lapacke_layout(&self) -> lapacke::Layout {
5050
match *self {
51-
MatrixLayout::C(_) => lapacke::Layout::RowMajor,
52-
MatrixLayout::F(_) => lapacke::Layout::ColumnMajor,
51+
MatrixLayout::C { .. } => lapacke::Layout::RowMajor,
52+
MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor,
5353
}
5454
}
5555

@@ -59,8 +59,8 @@ impl MatrixLayout {
5959

6060
pub fn toggle_order(&self) -> Self {
6161
match *self {
62-
MatrixLayout::C((row, col)) => MatrixLayout::F((col, row)),
63-
MatrixLayout::F((col, row)) => MatrixLayout::C((row, col)),
62+
MatrixLayout::C { row, lda } => MatrixLayout::F { lda: row, col: lda },
63+
MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col },
6464
}
6565
}
6666
}

lax/src/opnorm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ macro_rules! impl_opnorm {
1515
impl OperatorNorm_ for $scalar {
1616
unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real {
1717
match l {
18-
MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda),
19-
MatrixLayout::C((row, lda)) => {
18+
MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda),
19+
MatrixLayout::C { row, lda } => {
2020
$lange(cm, t.transpose() as u8, lda, row, a, lda)
2121
}
2222
}

lax/src/solveh.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ macro_rules! impl_solveh {
5858
let (n, _) = l.size();
5959
let nrhs = 1;
6060
let ldb = match l {
61-
MatrixLayout::C(_) => 1,
62-
MatrixLayout::F(_) => n,
61+
MatrixLayout::C { .. } => 1,
62+
MatrixLayout::F { .. } => n,
6363
};
6464
$trs(
6565
l.lapacke_layout(),

ndarray-linalg/src/convert.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ where
3636
S: DataOwned<Elem = A>,
3737
{
3838
match l {
39-
MatrixLayout::C((row, col)) => {
40-
Ok(ArrayBase::from_shape_vec((row as usize, col as usize), a)?)
39+
MatrixLayout::C { row, lda } => {
40+
Ok(ArrayBase::from_shape_vec((row as usize, lda as usize), a)?)
4141
}
42-
MatrixLayout::F((col, row)) => Ok(ArrayBase::from_shape_vec(
43-
(row as usize, col as usize).f(),
42+
MatrixLayout::F { col, lda } => Ok(ArrayBase::from_shape_vec(
43+
(lda as usize, col as usize).f(),
4444
a,
4545
)?),
4646
}
@@ -52,11 +52,11 @@ where
5252
S: DataOwned<Elem = A>,
5353
{
5454
match l {
55-
MatrixLayout::C((row, col)) => unsafe {
56-
ArrayBase::uninitialized((row as usize, col as usize))
55+
MatrixLayout::C { row, lda } => unsafe {
56+
ArrayBase::uninitialized((row as usize, lda as usize))
5757
},
58-
MatrixLayout::F((col, row)) => unsafe {
59-
ArrayBase::uninitialized((row as usize, col as usize).f())
58+
MatrixLayout::F { col, lda } => unsafe {
59+
ArrayBase::uninitialized((lda as usize, col as usize).f())
6060
},
6161
}
6262
}

ndarray-linalg/src/eigh.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ where
9696
let layout = self.square_layout()?;
9797
// XXX Force layout to be Fortran (see #146)
9898
match layout {
99-
MatrixLayout::C(_) => self.swap_axes(0, 1),
100-
MatrixLayout::F(_) => {}
99+
MatrixLayout::C { .. } => self.swap_axes(0, 1),
100+
MatrixLayout::F { .. } => {}
101101
}
102102
let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? };
103103
Ok((ArrayBase::from(s), self))
@@ -116,14 +116,14 @@ where
116116
let layout = self.0.square_layout()?;
117117
// XXX Force layout to be Fortran (see #146)
118118
match layout {
119-
MatrixLayout::C(_) => self.0.swap_axes(0, 1),
120-
MatrixLayout::F(_) => {}
119+
MatrixLayout::C { .. } => self.0.swap_axes(0, 1),
120+
MatrixLayout::F { .. } => {}
121121
}
122122

123123
let layout = self.1.square_layout()?;
124124
match layout {
125-
MatrixLayout::C(_) => self.1.swap_axes(0, 1),
126-
MatrixLayout::F(_) => {}
125+
MatrixLayout::C { .. } => self.1.swap_axes(0, 1),
126+
MatrixLayout::F { .. } => {}
127127
}
128128

129129
let s = unsafe {

ndarray-linalg/src/layout.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,16 @@ where
2828
let shape = self.shape();
2929
let strides = self.strides();
3030
if shape[0] == strides[1] as usize {
31-
return Ok(MatrixLayout::F((self.ncols() as i32, self.nrows() as i32)));
31+
return Ok(MatrixLayout::F {
32+
col: self.ncols() as i32,
33+
lda: self.nrows() as i32,
34+
});
3235
}
3336
if shape[1] == strides[0] as usize {
34-
return Ok(MatrixLayout::C((self.nrows() as i32, self.ncols() as i32)));
37+
return Ok(MatrixLayout::C {
38+
row: self.nrows() as i32,
39+
lda: self.ncols() as i32,
40+
});
3541
}
3642
Err(LinalgError::InvalidStride {
3743
s0: strides[0],

ndarray-linalg/src/least_squares.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ mod tests {
735735
fn test_incompatible_shape_error_on_mismatching_layout() {
736736
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
737737
let b = array![[1.], [2.]].t().to_owned();
738-
assert_eq!(b.layout().unwrap(), MatrixLayout::F((2, 1)));
738+
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });
739739

740740
let res = a.least_squares(&b);
741741
match res {

ndarray-linalg/tests/layout.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@ use ndarray_linalg::*;
66
fn layout_c_3x1() {
77
let a: Array2<f64> = Array::zeros((3, 1));
88
println!("a = {:?}", &a);
9-
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 1)));
9+
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 1 });
1010
}
1111

1212
#[test]
1313
fn layout_f_3x1() {
1414
let a: Array2<f64> = Array::zeros((3, 1).f());
1515
println!("a = {:?}", &a);
16-
assert_eq!(a.layout().unwrap(), MatrixLayout::F((1, 3)));
16+
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 1, lda: 3 });
1717
}
1818

1919
#[test]
2020
fn layout_c_3x2() {
2121
let a: Array2<f64> = Array::zeros((3, 2));
2222
println!("a = {:?}", &a);
23-
assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 2)));
23+
assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 2 });
2424
}
2525

2626
#[test]
2727
fn layout_f_3x2() {
2828
let a: Array2<f64> = Array::zeros((3, 2).f());
2929
println!("a = {:?}", &a);
30-
assert_eq!(a.layout().unwrap(), MatrixLayout::F((2, 3)));
30+
assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 2, lda: 3 });
3131
}

0 commit comments

Comments
 (0)