Skip to content

Commit 86f38a7

Browse files
authored
Merge pull request #20 from termoshtt/memory_layout
unify memory layout management
2 parents 8ffd77c + 43bc590 commit 86f38a7

File tree

10 files changed

+131
-223
lines changed

10 files changed

+131
-223
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
66
# More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock
77
Cargo.lock
8+
9+
# cargo fmt
10+
*.bk

rustfmt.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
max_width = 120
2+
use_try_shorthand = true

src/eigh.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
11
//! Implement eigenvalue decomposition of Hermite matrix
22
3-
use lapack::fortran::*;
3+
use lapack::c::*;
44
use num_traits::Zero;
55

66
use error::LapackError;
77

88
pub trait ImplEigh: Sized {
9-
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
9+
fn eigh(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
1010
}
1111

1212
macro_rules! impl_eigh {
1313
($scalar:ty, $syev:path) => {
1414
impl ImplEigh for $scalar {
15-
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
15+
fn eigh(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
1616
let mut w = vec![Self::zero(); n];
17-
let mut work = vec![Self::zero(); 4 * n];
18-
let mut info = 0;
19-
$syev(b'V',
20-
b'U',
21-
n as i32,
22-
&mut a,
23-
n as i32,
24-
&mut w,
25-
&mut work,
26-
4 * n as i32,
27-
&mut info);
17+
let info = $syev(layout, b'V', b'U', n as i32, &mut a, n as i32, &mut w);
2818
if info == 0 {
2919
Ok((w, a))
3020
} else {

src/error.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::error;
44
use std::fmt;
5+
use ndarray::Ixs;
56

67
#[derive(Debug)]
78
pub struct LapackError {
@@ -44,17 +45,37 @@ impl error::Error for NotSquareError {
4445
}
4546
}
4647

48+
#[derive(Debug)]
49+
pub struct StrideError {
50+
pub s0: Ixs,
51+
pub s1: Ixs,
52+
}
53+
54+
impl fmt::Display for StrideError {
55+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56+
write!(f, "invalid stride: s0={}, s1={}", self.s0, self.s1)
57+
}
58+
}
59+
60+
impl error::Error for StrideError {
61+
fn description(&self) -> &str {
62+
"invalid stride"
63+
}
64+
}
65+
4766
#[derive(Debug)]
4867
pub enum LinalgError {
4968
NotSquare(NotSquareError),
5069
Lapack(LapackError),
70+
Stride(StrideError),
5171
}
5272

5373
impl fmt::Display for LinalgError {
5474
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
5575
match *self {
5676
LinalgError::NotSquare(ref err) => err.fmt(f),
5777
LinalgError::Lapack(ref err) => err.fmt(f),
78+
LinalgError::Stride(ref err) => err.fmt(f),
5879
}
5980
}
6081
}
@@ -64,6 +85,7 @@ impl error::Error for LinalgError {
6485
match *self {
6586
LinalgError::NotSquare(ref err) => err.description(),
6687
LinalgError::Lapack(ref err) => err.description(),
88+
LinalgError::Stride(ref err) => err.description(),
6789
}
6890
}
6991
}
@@ -79,3 +101,9 @@ impl From<LapackError> for LinalgError {
79101
LinalgError::Lapack(err)
80102
}
81103
}
104+
105+
impl From<StrideError> for LinalgError {
106+
fn from(err: StrideError) -> LinalgError {
107+
LinalgError::Stride(err)
108+
}
109+
}

src/hermite.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,20 @@ impl<A> HermiteMatrix for Array<A, Ix2>
2929
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + ImplCholesky + LinalgScalar + Float + Debug
3030
{
3131
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
32-
try!(self.check_square());
32+
self.check_square()?;
33+
let layout = self.layout()?;
3334
let (rows, cols) = self.size();
34-
let (w, a) = try!(ImplEigh::eigh(rows, self.into_raw_vec()));
35+
let (w, a) = ImplEigh::eigh(layout, rows, self.into_raw_vec())?;
3536
let ea = Array::from_vec(w);
36-
let va = Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes();
37+
let va = match layout {
38+
Layout::ColumnMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes(),
39+
Layout::RowMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap(),
40+
};
3741
Ok((ea, va))
3842
}
3943
fn ssqrt(self) -> Result<Self, LinalgError> {
4044
let (n, _) = self.size();
41-
let (e, v) = try!(self.eigh());
45+
let (e, v) = self.eigh()?;
4246
let mut res = Array::zeros((n, n));
4347
for i in 0..n {
4448
for j in 0..n {
@@ -48,11 +52,10 @@ impl<A> HermiteMatrix for Array<A, Ix2>
4852
Ok(v.dot(&res))
4953
}
5054
fn cholesky(self) -> Result<Self, LinalgError> {
51-
try!(self.check_square());
52-
println!("layout = {:?}", self.layout());
55+
self.check_square()?;
5356
let (n, _) = self.size();
54-
let layout = self.layout();
55-
let a = try!(ImplCholesky::cholesky(layout, n, self.into_raw_vec()));
57+
let layout = self.layout()?;
58+
let a = ImplCholesky::cholesky(layout, n, self.into_raw_vec())?;
5659
let mut c = match layout {
5760
Layout::RowMajor => Array::from_vec(a).into_shape((n, n)).unwrap(),
5861
Layout::ColumnMajor => Array::from_vec(a).into_shape((n, n)).unwrap().reversed_axes(),

src/matrix.rs

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use ndarray::prelude::*;
66
use ndarray::LinalgScalar;
77
use lapack::c::Layout;
88

9-
use error::LapackError;
9+
use error::{LinalgError, StrideError};
1010
use qr::ImplQR;
1111
use svd::ImplSVD;
1212
use norm::ImplNorm;
@@ -20,19 +20,19 @@ pub trait Matrix: Sized {
2020
/// number of (rows, columns)
2121
fn size(&self) -> (usize, usize);
2222
/// Layout (C/Fortran) of matrix
23-
fn layout(&self) -> Layout;
23+
fn layout(&self) -> Result<Layout, StrideError>;
2424
/// Operator norm for L-1 norm
2525
fn norm_1(&self) -> Self::Scalar;
2626
/// Operator norm for L-inf norm
2727
fn norm_i(&self) -> Self::Scalar;
2828
/// Frobenius norm
2929
fn norm_f(&self) -> Self::Scalar;
3030
/// singular-value decomposition (SVD)
31-
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError>;
31+
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
3232
/// QR decomposition
33-
fn qr(self) -> Result<(Self, Self), LapackError>;
33+
fn qr(self) -> Result<(Self, Self), LinalgError>;
3434
/// LU decomposition
35-
fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError>;
35+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>;
3636
/// permutate matrix (inplace)
3737
fn permutate(&mut self, p: &Self::Permutator);
3838
/// permutate matrix (outplace)
@@ -52,12 +52,18 @@ impl<A> Matrix for Array<A, Ix2>
5252
fn size(&self) -> (usize, usize) {
5353
(self.rows(), self.cols())
5454
}
55-
fn layout(&self) -> Layout {
55+
fn layout(&self) -> Result<Layout, StrideError> {
5656
let strides = self.strides();
57+
if min(strides[0], strides[1]) != 1 {
58+
return Err(StrideError {
59+
s0: strides[0],
60+
s1: strides[1],
61+
});;
62+
}
5763
if strides[0] < strides[1] {
58-
Layout::ColumnMajor
64+
Ok(Layout::ColumnMajor)
5965
} else {
60-
Layout::RowMajor
66+
Ok(Layout::RowMajor)
6167
}
6268
}
6369
fn norm_1(&self) -> Self::Scalar {
@@ -82,35 +88,24 @@ impl<A> Matrix for Array<A, Ix2>
8288
let (m, n) = self.size();
8389
ImplNorm::norm_f(m, n, self.clone().into_raw_vec())
8490
}
85-
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError> {
86-
let strides = self.strides();
87-
let (m, n) = if strides[0] > strides[1] {
88-
self.size()
89-
} else {
90-
let (n, m) = self.size();
91-
(m, n)
92-
};
93-
let (u, s, vt) = try!(ImplSVD::svd(m, n, self.clone().into_raw_vec()));
91+
fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
92+
let (n, m) = self.size();
93+
let layout = self.layout()?;
94+
let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
9495
let sv = Array::from_vec(s);
95-
if strides[0] > strides[1] {
96-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
97-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
98-
Ok((va, sv, ua))
99-
} else {
100-
let ua = Array::from_vec(u).into_shape((n, n)).unwrap().reversed_axes();
101-
let va = Array::from_vec(vt).into_shape((m, m)).unwrap().reversed_axes();
102-
Ok((ua, sv, va))
96+
let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
97+
let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
98+
match layout {
99+
Layout::RowMajor => Ok((ua, sv, va)),
100+
Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
103101
}
104102
}
105-
fn qr(self) -> Result<(Self, Self), LapackError> {
103+
fn qr(self) -> Result<(Self, Self), LinalgError> {
106104
let (n, m) = self.size();
107105
let strides = self.strides();
108106
let k = min(n, m);
109-
let (q, r) = if strides[0] < strides[1] {
110-
try!(ImplQR::qr(m, n, self.clone().into_raw_vec()))
111-
} else {
112-
try!(ImplQR::lq(n, m, self.clone().into_raw_vec()))
113-
};
107+
let layout = self.layout()?;
108+
let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?;
114109
let (qa, ra) = if strides[0] < strides[1] {
115110
(Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
116111
Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())
@@ -136,23 +131,14 @@ impl<A> Matrix for Array<A, Ix2>
136131
}
137132
Ok((qm, rm))
138133
}
139-
fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError> {
134+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
140135
let (n, m) = self.size();
141-
println!("n={}, m={}", n, m);
142136
let k = min(n, m);
143-
let (p, mut a) = match self.layout() {
144-
Layout::ColumnMajor => {
145-
println!("ColumnMajor");
146-
let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?;
147-
(p, Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes())
148-
}
149-
Layout::RowMajor => {
150-
println!("RowMajor");
151-
let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?;
152-
(p, Array::from_vec(l).into_shape((n, m)).unwrap())
153-
}
137+
let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?;
138+
let mut a = match self.layout()? {
139+
Layout::ColumnMajor => Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes(),
140+
Layout::RowMajor => Array::from_vec(l).into_shape((n, m)).unwrap(),
154141
};
155-
println!("a (after LU) = \n{:?}", &a);
156142
let mut lm = Array::zeros((n, k));
157143
for ((i, j), val) in lm.indexed_iter_mut() {
158144
if i > j {
@@ -171,7 +157,6 @@ impl<A> Matrix for Array<A, Ix2>
171157
} else {
172158
a
173159
};
174-
println!("am = \n{:?}", am);
175160
Ok((p, lm, am))
176161
}
177162
fn permutate(&mut self, ipiv: &Self::Permutator) {

src/norm.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
//! Implement Norms for matrices
22
3-
use lapack::fortran::*;
4-
use num_traits::Zero;
3+
use lapack::c::*;
54

65
pub trait ImplNorm: Sized {
76
fn norm_1(m: usize, n: usize, mut a: Vec<Self>) -> Self;
@@ -13,16 +12,13 @@ macro_rules! impl_norm {
1312
($scalar:ty, $lange:path) => {
1413
impl ImplNorm for $scalar {
1514
fn norm_1(m: usize, n: usize, mut a: Vec<Self>) -> Self {
16-
let mut work = Vec::<Self>::new();
17-
$lange(b'o', m as i32, n as i32, &mut a, m as i32, &mut work)
15+
$lange(Layout::ColumnMajor, b'o', m as i32, n as i32, &mut a, m as i32)
1816
}
1917
fn norm_i(m: usize, n: usize, mut a: Vec<Self>) -> Self {
20-
let mut work = vec![Self::zero(); m];
21-
$lange(b'i', m as i32, n as i32, &mut a, m as i32, &mut work)
18+
$lange(Layout::ColumnMajor, b'i', m as i32, n as i32, &mut a, m as i32)
2219
}
2320
fn norm_f(m: usize, n: usize, mut a: Vec<Self>) -> Self {
24-
let mut work = Vec::<Self>::new();
25-
$lange(b'f', m as i32, n as i32, &mut a, m as i32, &mut work)
21+
$lange(Layout::ColumnMajor, b'f', m as i32, n as i32, &mut a, m as i32)
2622
}
2723
}
2824
}} // end macro_rules

0 commit comments

Comments
 (0)