Skip to content

Commit a10a6b4

Browse files
committed
Add .deth*() for determinants of Hermitian matrices
1 parent 9d9ee9d commit a10a6b4

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

src/solveh.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
//! ```
5151
5252
use ndarray::*;
53+
use num_traits::{Float, One, Zero};
5354

5455
use super::convert::*;
5556
use super::error::*;
@@ -258,3 +259,121 @@ where
258259
f.invh_into()
259260
}
260261
}
262+
263+
/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
264+
pub trait DeterminantH {
265+
type Output;
266+
267+
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
268+
fn deth(&self) -> Self::Output;
269+
}
270+
271+
/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
272+
pub trait DeterminantHInto {
273+
type Output;
274+
275+
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
276+
fn deth_into(self) -> Self::Output;
277+
}
278+
279+
fn bk_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> A::Real
280+
where
281+
P: Iterator<Item = i32>,
282+
S: Data<Elem = A>,
283+
A: Scalar,
284+
{
285+
let mut sign = A::Real::one();
286+
let mut ln_det = A::Real::zero();
287+
let mut ipiv_enum = ipiv_iter.enumerate();
288+
while let Some((k, ipiv_k)) = ipiv_enum.next() {
289+
debug_assert!(k < a.rows() && k < a.cols());
290+
if ipiv_k > 0 {
291+
// 1x1 block at k, must be real.
292+
let elem = unsafe { a.uget((k, k)) }.real();
293+
debug_assert_eq!(elem.imag(), Zero::zero());
294+
sign = sign * elem.signum();
295+
ln_det = ln_det + elem.abs().ln();
296+
} else {
297+
// 2x2 block at k..k+2.
298+
299+
// Upper left diagonal elem, must be real.
300+
let upper_diag = unsafe { a.uget((k, k)) }.real();
301+
debug_assert_eq!(upper_diag.imag(), Zero::zero());
302+
303+
// Lower right diagonal elem, must be real.
304+
let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.real();
305+
debug_assert_eq!(lower_diag.imag(), Zero::zero());
306+
307+
// Off-diagonal elements, can be complex.
308+
let off_diag = match uplo {
309+
UPLO::Upper => unsafe { a.uget((k, k + 1)) },
310+
UPLO::Lower => unsafe { a.uget((k + 1, k)) },
311+
};
312+
313+
// Determinant of 2x2 block.
314+
let block_det = upper_diag * lower_diag - off_diag.abs_sqr();
315+
sign = sign * block_det.signum();
316+
ln_det = ln_det + block_det.abs().ln();
317+
318+
// Skip the k+1 ipiv value.
319+
ipiv_enum.next();
320+
}
321+
}
322+
sign * ln_det.exp()
323+
}
324+
325+
impl<A, S> DeterminantH for BKFactorized<S>
326+
where
327+
A: Scalar,
328+
S: Data<Elem = A>,
329+
{
330+
type Output = A::Real;
331+
332+
fn deth(&self) -> A::Real {
333+
bk_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
334+
}
335+
}
336+
337+
impl<A, S> DeterminantHInto for BKFactorized<S>
338+
where
339+
A: Scalar,
340+
S: Data<Elem = A>,
341+
{
342+
type Output = A::Real;
343+
344+
fn deth_into(self) -> A::Real {
345+
bk_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
346+
}
347+
}
348+
349+
impl<A, S> DeterminantH for ArrayBase<S, Ix2>
350+
where
351+
A: Scalar,
352+
S: Data<Elem = A>,
353+
{
354+
type Output = Result<A::Real>;
355+
356+
fn deth(&self) -> Result<A::Real> {
357+
match self.factorizeh() {
358+
Ok(fac) => Ok(fac.deth()),
359+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
360+
Err(err) => Err(err),
361+
}
362+
}
363+
}
364+
365+
impl<A, S> DeterminantHInto for ArrayBase<S, Ix2>
366+
where
367+
A: Scalar,
368+
S: DataMut<Elem = A>,
369+
{
370+
type Output = Result<A::Real>;
371+
372+
fn deth_into(self) -> Result<A::Real> {
373+
match self.factorizeh_into() {
374+
Ok(fac) => Ok(fac.deth_into()),
375+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
376+
Err(err) => Err(err),
377+
}
378+
}
379+
}

tests/solveh.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
extern crate ndarray;
2+
#[macro_use]
3+
extern crate ndarray_linalg;
4+
extern crate num_traits;
5+
6+
use ndarray::*;
7+
use ndarray_linalg::*;
8+
use num_traits::{One, Zero};
9+
10+
#[test]
11+
fn deth_empty() {
12+
macro_rules! deth_empty {
13+
($elem:ty) => {
14+
let a: Array2<$elem> = Array2::zeros((0, 0));
15+
assert_eq!(a.factorizeh().unwrap().deth(), One::one());
16+
assert_eq!(a.factorizeh().unwrap().deth_into(), One::one());
17+
assert_eq!(a.deth().unwrap(), One::one());
18+
assert_eq!(a.deth_into().unwrap(), One::one());
19+
}
20+
}
21+
deth_empty!(f64);
22+
deth_empty!(f32);
23+
deth_empty!(c64);
24+
deth_empty!(c32);
25+
}
26+
27+
#[test]
28+
fn deth_zero() {
29+
macro_rules! deth_zero {
30+
($elem:ty) => {
31+
let a: Array2<$elem> = Array2::zeros((1, 1));
32+
assert_eq!(a.deth().unwrap(), Zero::zero());
33+
assert_eq!(a.deth_into().unwrap(), Zero::zero());
34+
}
35+
}
36+
deth_zero!(f64);
37+
deth_zero!(f32);
38+
deth_zero!(c64);
39+
deth_zero!(c32);
40+
}
41+
42+
#[test]
43+
fn deth_zero_nonsquare() {
44+
macro_rules! deth_zero_nonsquare {
45+
($elem:ty, $shape:expr) => {
46+
let a: Array2<$elem> = Array2::zeros($shape);
47+
assert!(a.deth().is_err());
48+
assert!(a.deth_into().is_err());
49+
}
50+
}
51+
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
52+
deth_zero_nonsquare!(f64, shape);
53+
deth_zero_nonsquare!(f32, shape);
54+
deth_zero_nonsquare!(c64, shape);
55+
deth_zero_nonsquare!(c32, shape);
56+
}
57+
}
58+
59+
#[test]
60+
fn deth() {
61+
macro_rules! deth {
62+
($elem:ty, $rows:expr, $atol:expr) => {
63+
let a: Array2<$elem> = random_hermite($rows);
64+
println!("a = \n{:?}", a);
65+
let det = a.eigvalsh(UPLO::Upper).unwrap().iter().product();
66+
assert_aclose!(a.factorizeh().unwrap().deth(), det, $atol);
67+
assert_aclose!(a.factorizeh().unwrap().deth_into(), det, $atol);
68+
assert_aclose!(a.deth().unwrap(), det, $atol);
69+
assert_aclose!(a.deth_into().unwrap(), det, $atol);
70+
}
71+
}
72+
for rows in 1..6 {
73+
deth!(f64, rows, 1e-9);
74+
deth!(f32, rows, 1e-3);
75+
deth!(c64, rows, 1e-9);
76+
deth!(c32, rows, 1e-3);
77+
}
78+
}
79+
80+
#[test]
81+
fn deth_nonsquare() {
82+
macro_rules! deth_nonsquare {
83+
($elem:ty, $shape:expr) => {
84+
let a: Array2<$elem> = Array2::zeros($shape);
85+
assert!(a.factorizeh().is_err());
86+
assert!(a.factorizeh().is_err());
87+
assert!(a.deth().is_err());
88+
assert!(a.deth_into().is_err());
89+
}
90+
}
91+
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
92+
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
93+
deth_nonsquare!(f64, shape);
94+
deth_nonsquare!(f32, shape);
95+
deth_nonsquare!(c64, shape);
96+
deth_nonsquare!(c32, shape);
97+
}
98+
}
99+
}

0 commit comments

Comments
 (0)