Skip to content

Commit 8b55efc

Browse files
authored
Merge pull request #186 from doraneko94/eig_lapacke
Add EigenValue/Vector for general matrix.
2 parents b311f5f + c837076 commit 8b55efc

File tree

7 files changed

+271
-2
lines changed

7 files changed

+271
-2
lines changed

examples/eig.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
fn main() {
5+
let a = arr2(&[[2.0, 1.0, 2.0], [-2.0, 2.0, 1.0], [1.0, 2.0, -2.0]]);
6+
let (e, vecs) = a.clone().eig().unwrap();
7+
println!("eigenvalues = \n{:?}", e);
8+
println!("V = \n{:?}", vecs);
9+
let a_c: Array2<c64> = a.map(|f| c64::new(*f, 0.0));
10+
let av = a_c.dot(&vecs);
11+
println!("AV = \n{:?}", av);
12+
}

examples/solveh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ fn solve() -> Result<(), error::LinalgError> {
1010
let b: Array1<c64> = random(3);
1111
println!("b = {:?}", &b);
1212
let x = a.solveh(&b)?;
13-
println!("Ax = {:?}", a.dot(&x));;
13+
println!("Ax = {:?}", a.dot(&x));
1414
Ok(())
1515
}
1616

src/eig.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//! Eigenvalue decomposition for non-symmetric square matrices
2+
3+
use ndarray::*;
4+
use crate::error::*;
5+
use crate::layout::*;
6+
use crate::types::*;
7+
8+
/// Eigenvalue decomposition of general matrix reference
9+
pub trait Eig {
10+
/// EigVec is the right eivenvector
11+
type EigVal;
12+
type EigVec;
13+
/// Calculate eigenvalues with the right eigenvector
14+
fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)>;
15+
}
16+
17+
impl<A, S> Eig for ArrayBase<S, Ix2>
18+
where
19+
A: Scalar + Lapack,
20+
S: Data<Elem = A>,
21+
{
22+
type EigVal = Array1<A::Complex>;
23+
type EigVec = Array2<A::Complex>;
24+
25+
fn eig(&self) -> Result<(Self::EigVal, Self::EigVec)> {
26+
let mut a = self.to_owned();
27+
let layout = a.square_layout()?;
28+
let (s, t) = unsafe { A::eig(true, layout, a.as_allocated_mut()?)? };
29+
let (n, _) = layout.size();
30+
Ok((ArrayBase::from(s), ArrayBase::from(t).into_shape((n as usize, n as usize)).unwrap()))
31+
}
32+
}
33+
34+
/// Calculate eigenvalues without eigenvectors
35+
pub trait EigVals {
36+
type EigVal;
37+
fn eigvals(&self) -> Result<Self::EigVal>;
38+
}
39+
40+
impl<A, S> EigVals for ArrayBase<S, Ix2>
41+
where
42+
A: Scalar + Lapack,
43+
S: DataMut<Elem = A>,
44+
{
45+
type EigVal = Array1<A::Complex>;
46+
47+
fn eigvals(&self) -> Result<Self::EigVal> {
48+
let mut a = self.to_owned();
49+
let (s, _) = unsafe { A::eig(true, a.square_layout()?, a.as_allocated_mut()?)? };
50+
Ok(ArrayBase::from(s))
51+
}
52+
}

src/lapack/eig.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//! Eigenvalue decomposition for general matrices
2+
3+
use lapacke;
4+
use num_traits::Zero;
5+
6+
use crate::error::*;
7+
use crate::layout::MatrixLayout;
8+
use crate::types::*;
9+
10+
use super::into_result;
11+
12+
/// Wraps `*geev` for real/complex
13+
pub trait Eig_: Scalar {
14+
unsafe fn eig(calc_v: bool, l: MatrixLayout, a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)>;
15+
}
16+
17+
macro_rules! impl_eig_complex {
18+
($scalar:ty, $ev:path) => {
19+
impl Eig_ for $scalar {
20+
unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
21+
let (n, _) = l.size();
22+
let jobvr = if calc_v { b'V' } else { b'N' };
23+
let mut w = vec![Self::Complex::zero(); n as usize];
24+
let mut vl = Vec::new();
25+
let mut vr = vec![Self::Complex::zero(); (n * n) as usize];
26+
let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut w, &mut vl, n, &mut vr, n);
27+
into_result(info, (w, vr))
28+
}
29+
}
30+
};
31+
}
32+
33+
macro_rules! impl_eig_real {
34+
($scalar:ty, $ev:path) => {
35+
impl Eig_ for $scalar {
36+
unsafe fn eig(calc_v: bool, l: MatrixLayout, mut a: &mut [Self]) -> Result<(Vec<Self::Complex>, Vec<Self::Complex>)> {
37+
let (n, _) = l.size();
38+
let jobvr = if calc_v { b'V' } else { b'N' };
39+
let mut wr = vec![Self::Real::zero(); n as usize];
40+
let mut wi = vec![Self::Real::zero(); n as usize];
41+
let mut vl = Vec::new();
42+
let mut vr = vec![Self::Real::zero(); (n * n) as usize];
43+
let info = $ev(l.lapacke_layout(), b'N', jobvr, n, &mut a, n, &mut wr, &mut wi, &mut vl, n, &mut vr, n);
44+
let w: Vec<Self::Complex> = wr.iter().zip(wi.iter()).map(|(&r, &i)| Self::Complex::new(r, i)).collect();
45+
// If the j-th eigenvalue is real, then
46+
// eigenvector = [ vr[j], vr[j+n], vr[j+2*n], ... ].
47+
//
48+
// If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
49+
// eigenvector(j) = [ vr[j] + i*vr[j+1], vr[j+n] + i*vr[j+n+1], vr[j+2*n] + i*vr[j+2*n+1], ... ] and
50+
// eigenvector(j+1) = [ vr[j] - i*vr[j+1], vr[j+n] - i*vr[j+n+1], vr[j+2*n] - i*vr[j+2*n+1], ... ].
51+
//
52+
// Therefore, if eigenvector(j) is written as [ v_{j0}, v_{j1}, v_{j2}, ... ],
53+
// you have to make
54+
// v = vec![ v_{00}, v_{10}, v_{20}, ..., v_{jk}, v_{(j+1)k}, v_{(j+2)k}, ... ] (v.len() = n*n)
55+
// based on wi and vr.
56+
// After that, v is converted to Array2 (see ../eig.rs).
57+
let n = n as usize;
58+
let mut flg = false;
59+
let conj: Vec<i8> = wi.iter().map(|&i| {
60+
if flg {
61+
flg = false;
62+
-1
63+
} else if i != 0.0 {
64+
flg = true;
65+
1
66+
} else {
67+
0
68+
}
69+
}).collect();
70+
let v: Vec<Self::Complex> = (0..n*n).map(|i| {
71+
let j = i % n;
72+
match conj[j] {
73+
1 => Self::Complex::new(vr[i], vr[i+1]),
74+
-1 => Self::Complex::new(vr[i-1], -vr[i]),
75+
_ => Self::Complex::new(vr[i], 0.0),
76+
}
77+
}).collect();
78+
79+
into_result(info, (w, v))
80+
}
81+
}
82+
};
83+
}
84+
85+
impl_eig_real!(f64, lapacke::dgeev);
86+
impl_eig_real!(f32, lapacke::sgeev);
87+
impl_eig_complex!(c64, lapacke::zgeev);
88+
impl_eig_complex!(c32, lapacke::cgeev);

src/lapack/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Define traits wrapping LAPACK routines
22
33
pub mod cholesky;
4+
pub mod eig;
45
pub mod eigh;
56
pub mod opnorm;
67
pub mod qr;
@@ -11,6 +12,7 @@ pub mod svddc;
1112
pub mod triangular;
1213

1314
pub use self::cholesky::*;
15+
pub use self::eig::*;
1416
pub use self::eigh::*;
1517
pub use self::opnorm::*;
1618
pub use self::qr::*;
@@ -26,7 +28,7 @@ use super::types::*;
2628
pub type Pivot = Vec<i32>;
2729

2830
/// Trait for primitive types which implements LAPACK subroutines
29-
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {}
31+
pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_ {}
3032

3133
impl Lapack for f32 {}
3234
impl Lapack for f64 {}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//! - Decomposition methods:
88
//! - [QR decomposition](qr/index.html)
99
//! - [Cholesky/LU decomposition](cholesky/index.html)
10+
//! - [Eigenvalue decomposition](eig/index.html)
1011
//! - [Eigenvalue decomposition for Hermite matrices](eigh/index.html)
1112
//! - [**S**ingular **V**alue **D**ecomposition](svd/index.html)
1213
//! - Solution of linear systems:
@@ -43,6 +44,7 @@ pub mod assert;
4344
pub mod cholesky;
4445
pub mod convert;
4546
pub mod diagonal;
47+
pub mod eig;
4648
pub mod eigh;
4749
pub mod error;
4850
pub mod generate;
@@ -66,6 +68,7 @@ pub use assert::*;
6668
pub use cholesky::*;
6769
pub use convert::*;
6870
pub use diagonal::*;
71+
pub use eig::*;
6972
pub use eigh::*;
7073
pub use generate::*;
7174
pub use inner::*;

tests/eig.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
#[test]
5+
fn dgeev() {
6+
// https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm
7+
let a: Array2<f64> = arr2(&[[-1.01, 0.86, -4.60, 3.31, -4.81],
8+
[ 3.98, 0.53, -7.04, 5.29, 3.55],
9+
[ 3.30, 8.26, -3.89, 8.20, -1.51],
10+
[ 4.43, 4.96, -7.66, -7.33, 6.18],
11+
[ 7.31, -6.43, -6.16, 2.47, 5.58]]);
12+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap();
13+
assert_close_l2!(&e,
14+
&arr1(&[c64::new( 2.86, 10.76), c64::new( 2.86,-10.76), c64::new( -0.69, 4.70), c64::new( -0.69, -4.70), c64::new(-10.46, 0.00)]),
15+
1.0e-3);
16+
17+
/*
18+
let answer = &arr2(&[[c64::new( 0.11, 0.17), c64::new( 0.11, -0.17), c64::new( 0.73, 0.00), c64::new( 0.73, 0.00), c64::new( 0.46, 0.00)],
19+
[c64::new( 0.41, -0.26), c64::new( 0.41, 0.26), c64::new( -0.03, -0.02), c64::new( -0.03, 0.02), c64::new( 0.34, 0.00)],
20+
[c64::new( 0.10, -0.51), c64::new( 0.10, 0.51), c64::new( 0.19, -0.29), c64::new( 0.19, 0.29), c64::new( 0.31, 0.00)],
21+
[c64::new( 0.40, -0.09), c64::new( 0.40, 0.09), c64::new( -0.08, -0.08), c64::new( -0.08, 0.08), c64::new( -0.74, 0.00)],
22+
[c64::new( 0.54, 0.00), c64::new( 0.54, 0.00), c64::new( -0.29, -0.49), c64::new( -0.29, 0.49), c64::new( 0.16, 0.00)]]);
23+
*/
24+
25+
let a_c: Array2<c64> = a.map(|f| c64::new(*f, 0.0));
26+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
27+
let av = a_c.dot(&v);
28+
let ev = v.mapv(|f| e[i] * f);
29+
assert_close_l2!(&av, &ev, 1.0e-7);
30+
}
31+
}
32+
33+
#[test]
34+
fn fgeev() {
35+
// https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/dgeev_ex.f.htm
36+
let a: Array2<f32> = arr2(&[[-1.01, 0.86, -4.60, 3.31, -4.81],
37+
[ 3.98, 0.53, -7.04, 5.29, 3.55],
38+
[ 3.30, 8.26, -3.89, 8.20, -1.51],
39+
[ 4.43, 4.96, -7.66, -7.33, 6.18],
40+
[ 7.31, -6.43, -6.16, 2.47, 5.58]]);
41+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap();
42+
assert_close_l2!(&e,
43+
&arr1(&[c32::new( 2.86, 10.76), c32::new( 2.86,-10.76), c32::new( -0.69, 4.70), c32::new( -0.69, -4.70), c32::new(-10.46, 0.00)]),
44+
1.0e-3);
45+
46+
/*
47+
let answer = &arr2(&[[c32::new( 0.11, 0.17), c32::new( 0.11, -0.17), c32::new( 0.73, 0.00), c32::new( 0.73, 0.00), c32::new( 0.46, 0.00)],
48+
[c32::new( 0.41, -0.26), c32::new( 0.41, 0.26), c32::new( -0.03, -0.02), c32::new( -0.03, 0.02), c32::new( 0.34, 0.00)],
49+
[c32::new( 0.10, -0.51), c32::new( 0.10, 0.51), c32::new( 0.19, -0.29), c32::new( 0.19, 0.29), c32::new( 0.31, 0.00)],
50+
[c32::new( 0.40, -0.09), c32::new( 0.40, 0.09), c32::new( -0.08, -0.08), c32::new( -0.08, 0.08), c32::new( -0.74, 0.00)],
51+
[c32::new( 0.54, 0.00), c32::new( 0.54, 0.00), c32::new( -0.29, -0.49), c32::new( -0.29, 0.49), c32::new( 0.16, 0.00)]]);
52+
*/
53+
54+
let a_c: Array2<c32> = a.map(|f| c32::new(*f, 0.0));
55+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
56+
let av = a_c.dot(&v);
57+
let ev = v.mapv(|f| e[i] * f);
58+
assert_close_l2!(&av, &ev, 1.0e-5);
59+
}
60+
}
61+
62+
#[test]
63+
fn zgeev() {
64+
// https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm
65+
let a: Array2<c64> = arr2(&[[c64::new( -3.84, 2.25), c64::new( -8.94, -4.75), c64::new( 8.95, -6.53), c64::new( -9.87, 4.82)],
66+
[c64::new( -0.66, 0.83), c64::new( -4.40, -3.82), c64::new( -3.50, -4.26), c64::new( -3.15, 7.36)],
67+
[c64::new( -3.99, -4.73), c64::new( -5.88, -6.60), c64::new( -3.36, -0.40), c64::new( -0.75, 5.23)],
68+
[c64::new( 7.74, 4.18), c64::new( 3.66, -7.53), c64::new( 2.58, 3.60), c64::new( 4.59, 5.41)],]);
69+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap();
70+
assert_close_l2!(&e,
71+
&arr1(&[c64::new( -9.43,-12.98), c64::new( -3.44, 12.69), c64::new( 0.11, -3.40), c64::new( 5.76, 7.13)]),
72+
1.0e-3);
73+
74+
/*
75+
let answer = &arr2(&[[c64::new( 0.43, 0.33), c64::new( 0.83, 0.00), c64::new( 0.60, 0.00), c64::new( -0.31, 0.03)],
76+
[c64::new( 0.51, -0.03), c64::new( 0.08, -0.25), c64::new( -0.40, -0.20), c64::new( 0.04, 0.34)],
77+
[c64::new( 0.62, 0.00), c64::new( -0.25, 0.28), c64::new( -0.09, -0.48), c64::new( 0.36, 0.06)],
78+
[c64::new( -0.23, 0.11), c64::new( -0.10, -0.32), c64::new( -0.43, 0.13), c64::new( 0.81, 0.00)]]);
79+
*/
80+
81+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
82+
let av = a.dot(&v);
83+
let ev = v.mapv(|f| e[i] * f);
84+
assert_close_l2!(&av, &ev, 1.0e-7);
85+
}
86+
}
87+
88+
#[test]
89+
fn cgeev() {
90+
// https://software.intel.com/sites/products/documentation/doclib/mkl_sa/11/mkl_lapack_examples/zgeev_ex.f.htm
91+
let a: Array2<c32> = arr2(&[[c32::new( -3.84, 2.25), c32::new( -8.94, -4.75), c32::new( 8.95, -6.53), c32::new( -9.87, 4.82)],
92+
[c32::new( -0.66, 0.83), c32::new( -4.40, -3.82), c32::new( -3.50, -4.26), c32::new( -3.15, 7.36)],
93+
[c32::new( -3.99, -4.73), c32::new( -5.88, -6.60), c32::new( -3.36, -0.40), c32::new( -0.75, 5.23)],
94+
[c32::new( 7.74, 4.18), c32::new( 3.66, -7.53), c32::new( 2.58, 3.60), c32::new( 4.59, 5.41)],]);
95+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eig().unwrap();
96+
assert_close_l2!(&e,
97+
&arr1(&[c32::new( -9.43,-12.98), c32::new( -3.44, 12.69), c32::new( 0.11, -3.40), c32::new( 5.76, 7.13)]),
98+
1.0e-3);
99+
100+
/*
101+
let answer = &arr2(&[[c32::new( 0.43, 0.33), c32::new( 0.83, 0.00), c32::new( 0.60, 0.00), c32::new( -0.31, 0.03)],
102+
[c32::new( 0.51, -0.03), c32::new( 0.08, -0.25), c32::new( -0.40, -0.20), c32::new( 0.04, 0.34)],
103+
[c32::new( 0.62, 0.00), c32::new( -0.25, 0.28), c32::new( -0.09, -0.48), c32::new( 0.36, 0.06)],
104+
[c32::new( -0.23, 0.11), c32::new( -0.10, -0.32), c32::new( -0.43, 0.13), c32::new( 0.81, 0.00)]]);
105+
*/
106+
107+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
108+
let av = a.dot(&v);
109+
let ev = v.mapv(|f| e[i] * f);
110+
assert_close_l2!(&av, &ev, 1.0e-5);
111+
}
112+
}

0 commit comments

Comments
 (0)