Skip to content

Commit 7184170

Browse files
committed
Merge branch 'cholesky'
2 parents 7d54e16 + ea0ebc3 commit 7184170

File tree

4 files changed

+89
-1
lines changed

4 files changed

+89
-1
lines changed

src/cholesky.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
use lapack::c::*;
3+
use error::LapackError;
4+
5+
pub trait ImplCholesky: Sized {
6+
fn cholesky(layout: Layout, n: usize, a: Vec<Self>) -> Result<Vec<Self>, LapackError>;
7+
}
8+
9+
macro_rules! impl_cholesky {
10+
($scalar:ty, $potrf:path) => {
11+
impl ImplCholesky for $scalar {
12+
fn cholesky(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError> {
13+
let info = $potrf(layout, b'U', n as i32, &mut a, n as i32);
14+
if info == 0 {
15+
Ok(a)
16+
} else {
17+
Err(From::from(info))
18+
}
19+
}
20+
}
21+
}} // end macro_rules
22+
23+
impl_cholesky!(f64, dpotrf);
24+
impl_cholesky!(f32, spotrf);

src/hermite.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::fmt::Debug;
44
use ndarray::prelude::*;
55
use ndarray::LinalgScalar;
66
use num_traits::float::Float;
7+
use lapack::c::Layout;
78

89
use matrix::Matrix;
910
use square::SquareMatrix;
@@ -13,17 +14,20 @@ use qr::ImplQR;
1314
use svd::ImplSVD;
1415
use norm::ImplNorm;
1516
use solve::ImplSolve;
17+
use cholesky::ImplCholesky;
1618

1719
/// Methods for Hermite matrix
1820
pub trait HermiteMatrix: SquareMatrix + Matrix {
1921
/// eigenvalue decomposition
2022
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError>;
2123
/// symmetric square root of Hermite matrix
2224
fn ssqrt(self) -> Result<Self, LinalgError>;
25+
/// Cholesky factorization
26+
fn cholesky(self) -> Result<Self, LinalgError>;
2327
}
2428

2529
impl<A> HermiteMatrix for Array<A, (Ix, Ix)>
26-
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float + Debug
30+
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + ImplCholesky + LinalgScalar + Float + Debug
2731
{
2832
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
2933
try!(self.check_square());
@@ -44,4 +48,21 @@ impl<A> HermiteMatrix for Array<A, (Ix, Ix)>
4448
}
4549
Ok(v.dot(&res))
4650
}
51+
fn cholesky(self) -> Result<Self, LinalgError> {
52+
try!(self.check_square());
53+
println!("layout = {:?}", self.layout());
54+
let (n, _) = self.size();
55+
let layout = self.layout();
56+
let a = try!(ImplCholesky::cholesky(layout, n, self.into_raw_vec()));
57+
let mut c = match layout {
58+
Layout::RowMajor => Array::from_vec(a).into_shape((n, n)).unwrap(),
59+
Layout::ColumnMajor => Array::from_vec(a).into_shape((n, n)).unwrap().reversed_axes(),
60+
};
61+
for ((i, j), val) in c.indexed_iter_mut() {
62+
if i > j {
63+
*val = A::zero();
64+
}
65+
}
66+
Ok(c)
67+
}
4768
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ pub mod svd;
4141
pub mod eigh;
4242
pub mod norm;
4343
pub mod solve;
44+
pub mod cholesky;

tests/cholesky.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
extern crate rand;
3+
extern crate ndarray;
4+
extern crate ndarray_rand;
5+
extern crate ndarray_linalg;
6+
7+
use rand::distributions::*;
8+
use ndarray::prelude::*;
9+
use ndarray_linalg::prelude::*;
10+
use ndarray_rand::RandomExt;
11+
12+
fn all_close(a: Array<f64, (Ix, Ix)>, b: Array<f64, (Ix, Ix)>) {
13+
if !a.all_close(&b, 1.0e-7) {
14+
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15+
a,
16+
b);
17+
}
18+
}
19+
20+
#[test]
21+
fn cholesky() {
22+
let r_dist = Range::new(0., 1.);
23+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
24+
a = a.dot(&a.t());
25+
println!("a = \n{:?}", a);
26+
let c = a.clone().cholesky().unwrap();
27+
println!("c = \n{:?}", c);
28+
println!("cc = \n{:?}", c.t().dot(&c));
29+
all_close(c.t().dot(&c), a);
30+
}
31+
32+
#[test]
33+
fn cholesky_t() {
34+
let r_dist = Range::new(0., 1.);
35+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
36+
a = a.dot(&a.t()).reversed_axes();
37+
println!("a = \n{:?}", a);
38+
let c = a.clone().cholesky().unwrap();
39+
println!("c = \n{:?}", c);
40+
println!("cc = \n{:?}", c.t().dot(&c));
41+
all_close(c.t().dot(&c), a);
42+
}

0 commit comments

Comments
 (0)