Skip to content

Commit 31da57b

Browse files
committed
Use Layout in eigh
1 parent b12d68e commit 31da57b

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

src/eigh.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
11
//! Implement eigenvalue decomposition of Hermite matrix
22
3-
use lapack::fortran::*;
3+
use lapack::c::*;
44
use num_traits::Zero;
55

66
use error::LapackError;
77

88
pub trait ImplEigh: Sized {
9-
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
9+
fn eigh(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
1010
}
1111

1212
macro_rules! impl_eigh {
1313
($scalar:ty, $syev:path) => {
1414
impl ImplEigh for $scalar {
15-
fn eigh(n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
15+
fn eigh(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError> {
1616
let mut w = vec![Self::zero(); n];
17-
let mut work = vec![Self::zero(); 4 * n];
18-
let mut info = 0;
19-
$syev(b'V',
20-
b'U',
21-
n as i32,
22-
&mut a,
23-
n as i32,
24-
&mut w,
25-
&mut work,
26-
4 * n as i32,
27-
&mut info);
17+
let info = $syev(layout, b'V', b'U', n as i32, &mut a, n as i32, &mut w);
2818
if info == 0 {
2919
Ok((w, a))
3020
} else {

src/hermite.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ impl<A> HermiteMatrix for Array<A, Ix2>
3030
{
3131
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
3232
try!(self.check_square());
33+
let layout = self.layout()?;
3334
let (rows, cols) = self.size();
34-
let (w, a) = try!(ImplEigh::eigh(rows, self.into_raw_vec()));
35+
let (w, a) = ImplEigh::eigh(layout, rows, self.into_raw_vec())?;
3536
let ea = Array::from_vec(w);
36-
let va = Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes();
37+
let va = match layout {
38+
Layout::ColumnMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes(),
39+
Layout::RowMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap(),
40+
};
3741
Ok((ea, va))
3842
}
3943
fn ssqrt(self) -> Result<Self, LinalgError> {

0 commit comments

Comments
 (0)