Skip to content

Commit 346d43c

Browse files
committed
Add .sign_ln_deth*() methods to DeterminantH*
1 parent b9ed2c9 commit 346d43c

File tree

2 files changed

+116
-16
lines changed

2 files changed

+116
-16
lines changed

src/solveh.rs

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,20 +254,53 @@ where
254254
/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
255255
pub trait DeterminantH {
256256
type Output;
257+
type SignLnOutput;
257258

258259
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
259260
fn deth(&self) -> Self::Output;
261+
262+
/// Computes the `(sign, natural_log)` of the determinant of the Hermitian
263+
/// (or real symmetric) matrix.
264+
///
265+
/// The `natural_log` is the natural logarithm of the absolute value of the
266+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
267+
/// is negative infinity.
268+
///
269+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
270+
/// or just call `.deth()` instead.
271+
///
272+
/// This method is more robust than `.deth()` to very small or very large
273+
/// determinants since it returns the natural logarithm of the determinant
274+
/// rather than the determinant itself.
275+
fn sln_deth(&self) -> Self::SignLnOutput;
260276
}
261277

262278
/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
263279
pub trait DeterminantHInto {
264280
type Output;
281+
type SignLnOutput;
265282

266283
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
267284
fn deth_into(self) -> Self::Output;
285+
286+
/// Computes the `(sign, natural_log)` of the determinant of the Hermitian
287+
/// (or real symmetric) matrix.
288+
///
289+
/// The `natural_log` is the natural logarithm of the absolute value of the
290+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
291+
/// is negative infinity.
292+
///
293+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
294+
/// or just call `.deth_into()` instead.
295+
///
296+
/// This method is more robust than `.deth_into()` to very small or very
297+
/// large determinants since it returns the natural logarithm of the
298+
/// determinant rather than the determinant itself.
299+
fn sln_deth_into(self) -> Self::SignLnOutput;
268300
}
269301

270-
fn bk_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> A::Real
302+
/// Returns the sign and natural log of the determinant.
303+
fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
271304
where
272305
P: Iterator<Item = i32>,
273306
S: Data<Elem = A>,
@@ -310,7 +343,7 @@ where
310343
ipiv_enum.next();
311344
}
312345
}
313-
sign * ln_det.exp()
346+
(sign, ln_det)
314347
}
315348

316349
impl<A, S> DeterminantH for BKFactorized<S>
@@ -319,9 +352,15 @@ where
319352
S: Data<Elem = A>,
320353
{
321354
type Output = A::Real;
355+
type SignLnOutput = (A::Real, A::Real);
322356

323357
fn deth(&self) -> A::Real {
324-
bk_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
358+
let (sign, ln_det) = self.sln_deth();
359+
sign * ln_det.exp()
360+
}
361+
362+
fn sln_deth(&self) -> (A::Real, A::Real) {
363+
bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
325364
}
326365
}
327366

@@ -331,9 +370,15 @@ where
331370
S: Data<Elem = A>,
332371
{
333372
type Output = A::Real;
373+
type SignLnOutput = (A::Real, A::Real);
334374

335375
fn deth_into(self) -> A::Real {
336-
bk_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
376+
let (sign, ln_det) = self.sln_deth_into();
377+
sign * ln_det.exp()
378+
}
379+
380+
fn sln_deth_into(self) -> (A::Real, A::Real) {
381+
bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
337382
}
338383
}
339384

@@ -343,11 +388,20 @@ where
343388
S: Data<Elem = A>,
344389
{
345390
type Output = Result<A::Real>;
391+
type SignLnOutput = Result<(A::Real, A::Real)>;
346392

347393
fn deth(&self) -> Result<A::Real> {
394+
let (sign, ln_det) = self.sln_deth()?;
395+
Ok(sign * ln_det.exp())
396+
}
397+
398+
fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
348399
match self.factorizeh() {
349-
Ok(fac) => Ok(fac.deth()),
350-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
400+
Ok(fac) => Ok(fac.sln_deth()),
401+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
402+
// Determinant is zero.
403+
Ok((A::Real::zero(), A::Real::neg_infinity()))
404+
}
351405
Err(err) => Err(err),
352406
}
353407
}
@@ -359,11 +413,20 @@ where
359413
S: DataMut<Elem = A>,
360414
{
361415
type Output = Result<A::Real>;
416+
type SignLnOutput = Result<(A::Real, A::Real)>;
362417

363418
fn deth_into(self) -> Result<A::Real> {
419+
let (sign, ln_det) = self.sln_deth_into()?;
420+
Ok(sign * ln_det.exp())
421+
}
422+
423+
fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
364424
match self.factorizeh_into() {
365-
Ok(fac) => Ok(fac.deth_into()),
366-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
425+
Ok(fac) => Ok(fac.sln_deth_into()),
426+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
427+
// Determinant is zero.
428+
Ok((A::Real::zero(), A::Real::neg_infinity()))
429+
}
367430
Err(err) => Err(err),
368431
}
369432
}

tests/deth.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@ extern crate num_traits;
55

66
use ndarray::*;
77
use ndarray_linalg::*;
8-
use num_traits::{One, Zero};
8+
use num_traits::{Float, One, Zero};
99

1010
#[test]
1111
fn deth_empty() {
1212
macro_rules! deth_empty {
1313
($elem:ty) => {
1414
let a: Array2<$elem> = Array2::zeros((0, 0));
1515
assert_eq!(a.factorizeh().unwrap().deth(), One::one());
16+
assert_eq!(a.factorizeh().unwrap().sln_deth(), (One::one(), Zero::zero()));
1617
assert_eq!(a.factorizeh().unwrap().deth_into(), One::one());
18+
assert_eq!(a.factorizeh().unwrap().sln_deth_into(), (One::one(), Zero::zero()));
1719
assert_eq!(a.deth().unwrap(), One::one());
18-
assert_eq!(a.deth_into().unwrap(), One::one());
20+
assert_eq!(a.sln_deth().unwrap(), (One::one(), Zero::zero()));
21+
assert_eq!(a.clone().deth_into().unwrap(), One::one());
22+
assert_eq!(a.sln_deth_into().unwrap(), (One::one(), Zero::zero()));
1923
}
2024
}
2125
deth_empty!(f64);
@@ -30,7 +34,9 @@ fn deth_zero() {
3034
($elem:ty) => {
3135
let a: Array2<$elem> = Array2::zeros((1, 1));
3236
assert_eq!(a.deth().unwrap(), Zero::zero());
33-
assert_eq!(a.deth_into().unwrap(), Zero::zero());
37+
assert_eq!(a.sln_deth().unwrap(), (Zero::zero(), Float::neg_infinity()));
38+
assert_eq!(a.clone().deth_into().unwrap(), Zero::zero());
39+
assert_eq!(a.sln_deth_into().unwrap(), (Zero::zero(), Float::neg_infinity()));
3440
}
3541
}
3642
deth_zero!(f64);
@@ -45,7 +51,9 @@ fn deth_zero_nonsquare() {
4551
($elem:ty, $shape:expr) => {
4652
let a: Array2<$elem> = Array2::zeros($shape);
4753
assert!(a.deth().is_err());
48-
assert!(a.deth_into().is_err());
54+
assert!(a.sln_deth().is_err());
55+
assert!(a.clone().deth_into().is_err());
56+
assert!(a.sln_deth_into().is_err());
4957
}
5058
}
5159
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
@@ -62,11 +70,39 @@ fn deth() {
6270
($elem:ty, $rows:expr, $atol:expr) => {
6371
let a: Array2<$elem> = random_hermite($rows);
6472
println!("a = \n{:?}", a);
65-
let det = a.eigvalsh(UPLO::Upper).unwrap().iter().product();
73+
74+
// Compute determinant from eigenvalues.
75+
let (sign, ln_det) = a.eigvalsh(UPLO::Upper).unwrap().iter().fold(
76+
(<$elem as AssociatedReal>::Real::one(), <$elem as AssociatedReal>::Real::zero()),
77+
|(sign, ln_det), eigval| (sign * eigval.signum(), ln_det + eigval.abs().ln())
78+
);
79+
let det = sign * ln_det.exp();
80+
assert_aclose!(det, a.eigvalsh(UPLO::Upper).unwrap().iter().product(), $atol);
81+
6682
assert_aclose!(a.factorizeh().unwrap().deth(), det, $atol);
83+
{
84+
let result = a.factorizeh().unwrap().sln_deth();
85+
assert_aclose!(result.0, sign, $atol);
86+
assert_aclose!(result.1, ln_det, $atol);
87+
}
6788
assert_aclose!(a.factorizeh().unwrap().deth_into(), det, $atol);
89+
{
90+
let result = a.factorizeh().unwrap().sln_deth_into();
91+
assert_aclose!(result.0, sign, $atol);
92+
assert_aclose!(result.1, ln_det, $atol);
93+
}
6894
assert_aclose!(a.deth().unwrap(), det, $atol);
69-
assert_aclose!(a.deth_into().unwrap(), det, $atol);
95+
{
96+
let result = a.sln_deth().unwrap();
97+
assert_aclose!(result.0, sign, $atol);
98+
assert_aclose!(result.1, ln_det, $atol);
99+
}
100+
assert_aclose!(a.clone().deth_into().unwrap(), det, $atol);
101+
{
102+
let result = a.sln_deth_into().unwrap();
103+
assert_aclose!(result.0, sign, $atol);
104+
assert_aclose!(result.1, ln_det, $atol);
105+
}
70106
}
71107
}
72108
for rows in 1..6 {
@@ -83,9 +119,10 @@ fn deth_nonsquare() {
83119
($elem:ty, $shape:expr) => {
84120
let a: Array2<$elem> = Array2::zeros($shape);
85121
assert!(a.factorizeh().is_err());
86-
assert!(a.factorizeh().is_err());
87122
assert!(a.deth().is_err());
88-
assert!(a.deth_into().is_err());
123+
assert!(a.sln_deth().is_err());
124+
assert!(a.clone().deth_into().is_err());
125+
assert!(a.sln_deth_into().is_err());
89126
}
90127
}
91128
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {

0 commit comments

Comments
 (0)