Skip to content

Commit e27eed7

Browse files
committed
Merge branch 'determinant'
2 parents 38e1ca3 + cf4adff commit e27eed7

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ features = ["blas"]
2020
[dev-dependencies]
2121
rand = "0.3.14"
2222
ndarray-rand = "0.3"
23+
float-cmp = "0.2.3"

src/hermite.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use ndarray::{Ix2, Array, LinalgScalar};
44
use std::fmt::Debug;
55
use num_traits::float::Float;
6+
use num_traits::One;
67
use lapack::c::Layout;
78

89
use matrix::Matrix;
@@ -23,10 +24,12 @@ pub trait HermiteMatrix: SquareMatrix + Matrix {
2324
fn ssqrt(self) -> Result<Self, LinalgError>;
2425
/// Cholesky factorization
2526
fn cholesky(self) -> Result<Self, LinalgError>;
27+
/// calc determinant using Cholesky factorization
28+
fn deth(self) -> Result<Self::Scalar, LinalgError>;
2629
}
2730

2831
impl<A> HermiteMatrix for Array<A, Ix2>
29-
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + ImplCholesky + LinalgScalar + Float + Debug
32+
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + ImplCholesky + LinalgScalar + Float + Debug + One
3033
{
3134
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
3235
self.check_square()?;
@@ -67,4 +70,10 @@ impl<A> HermiteMatrix for Array<A, Ix2>
6770
}
6871
Ok(c)
6972
}
73+
fn deth(self) -> Result<Self::Scalar, LinalgError> {
74+
let (n, _) = self.size();
75+
let c = self.cholesky()?;
76+
let rt = (0..n).map(|i| c[(i, i)]).fold(A::one(), |det, c| det * c);
77+
Ok(rt * rt)
78+
}
7079
}

tests/det.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
extern crate ndarray;
3+
extern crate ndarray_linalg;
4+
extern crate ndarray_rand;
5+
extern crate rand;
6+
extern crate float_cmp;
7+
8+
use ndarray::prelude::*;
9+
use ndarray_linalg::prelude::*;
10+
use rand::distributions::*;
11+
use ndarray_rand::RandomExt;
12+
use float_cmp::ApproxEqRatio;
13+
14+
fn approx_eq(val: f64, truth: f64, ratio: f64) {
15+
if !val.approx_eq_ratio(&truth, ratio) {
16+
panic!("Not almost equal! val={:?}, truth={:?}, ratio={:?}",
17+
val,
18+
truth,
19+
ratio);
20+
}
21+
}
22+
23+
fn random_hermite(n: usize) -> Array<f64, Ix2> {
24+
let r_dist = Range::new(0., 1.);
25+
let a = Array::<f64, _>::random((n, n), r_dist);
26+
a.dot(&a.t())
27+
}
28+
29+
#[test]
30+
fn deth() {
31+
let a = random_hermite(3);
32+
let (e, _) = a.clone().eigh().unwrap();
33+
let deth = a.clone().deth().unwrap();
34+
let det_eig = e.iter().fold(1.0, |x, y| x * y);
35+
approx_eq(deth, det_eig, 1.0e-7);
36+
}

0 commit comments

Comments
 (0)