Skip to content

Commit 18208ea

Browse files
authored
Merge pull request #104 from jturner314/add-ln-det
Add methods to get natural log of determinant directly
2 parents b453f8e + 636ae56 commit 18208ea

File tree

6 files changed

+280
-54
lines changed

6 files changed

+280
-54
lines changed

src/cholesky.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,15 @@ where
102102
type Output = <A as AssociatedReal>::Real;
103103

104104
fn detc(&self) -> Self::Output {
105+
self.ln_detc().exp()
106+
}
107+
108+
fn ln_detc(&self) -> Self::Output {
105109
self.factor
106110
.diag()
107111
.iter()
108112
.map(|elem| elem.abs_sqr().ln())
109113
.sum::<Self::Output>()
110-
.exp()
111114
}
112115
}
113116

@@ -121,6 +124,10 @@ where
121124
fn detc_into(self) -> Self::Output {
122125
self.detc()
123126
}
127+
128+
fn ln_detc_into(self) -> Self::Output {
129+
self.ln_detc()
130+
}
124131
}
125132

126133
impl<A, S> InverseC for CholeskyFactorized<S>
@@ -391,6 +398,14 @@ pub trait DeterminantC {
391398
/// Computes the determinant of the Hermitian (or real symmetric) positive
392399
/// definite matrix.
393400
fn detc(&self) -> Self::Output;
401+
402+
/// Computes the natural log of the determinant of the Hermitian (or real
403+
/// symmetric) positive definite matrix.
404+
///
405+
/// This method is more robust than `.detc()` to very small or very large
406+
/// determinants since it returns the natural logarithm of the determinant
407+
/// rather than the determinant itself.
408+
fn ln_detc(&self) -> Self::Output;
394409
}
395410

396411

@@ -401,6 +416,14 @@ pub trait DeterminantCInto {
401416
/// Computes the determinant of the Hermitian (or real symmetric) positive
402417
/// definite matrix.
403418
fn detc_into(self) -> Self::Output;
419+
420+
/// Computes the natural log of the determinant of the Hermitian (or real
421+
/// symmetric) positive definite matrix.
422+
///
423+
/// This method is more robust than `.detc_into()` to very small or very
424+
/// large determinants since it returns the natural logarithm of the
425+
/// determinant rather than the determinant itself.
426+
fn ln_detc_into(self) -> Self::Output;
404427
}
405428

406429
impl<A, S> DeterminantC for ArrayBase<S, Ix2>
@@ -411,7 +434,11 @@ where
411434
type Output = Result<<A as AssociatedReal>::Real>;
412435

413436
fn detc(&self) -> Self::Output {
414-
Ok(self.factorizec(UPLO::Upper)?.detc())
437+
Ok(self.ln_detc()?.exp())
438+
}
439+
440+
fn ln_detc(&self) -> Self::Output {
441+
Ok(self.factorizec(UPLO::Upper)?.ln_detc())
415442
}
416443
}
417444

@@ -423,6 +450,10 @@ where
423450
type Output = Result<<A as AssociatedReal>::Real>;
424451

425452
fn detc_into(self) -> Self::Output {
426-
Ok(self.factorizec_into(UPLO::Upper)?.detc_into())
453+
Ok(self.ln_detc_into()?.exp())
454+
}
455+
456+
fn ln_detc_into(self) -> Self::Output {
457+
Ok(self.factorizec_into(UPLO::Upper)?.ln_detc_into())
427458
}
428459
}

src/solve.rs

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
//! ```
4848
4949
use ndarray::*;
50+
use num_traits::{Float, Zero};
5051

5152
use super::convert::*;
5253
use super::error::*;
@@ -336,16 +337,54 @@ where
336337
/// An interface for calculating determinants of matrix refs.
337338
pub trait Determinant<A: Scalar> {
338339
/// Computes the determinant of the matrix.
339-
fn det(&self) -> Result<A>;
340+
fn det(&self) -> Result<A> {
341+
let (sign, ln_det) = self.sln_det()?;
342+
Ok(sign.mul_real(ln_det.exp()))
343+
}
344+
345+
/// Computes the `(sign, natural_log)` of the determinant of the matrix.
346+
///
347+
/// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
348+
/// `sign` is `0` or a complex number with absolute value 1. The
349+
/// `natural_log` is the natural logarithm of the absolute value of the
350+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
351+
/// is negative infinity.
352+
///
353+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
354+
/// or just call `.det()` instead.
355+
///
356+
/// This method is more robust than `.det()` to very small or very large
357+
/// determinants since it returns the natural logarithm of the determinant
358+
/// rather than the determinant itself.
359+
fn sln_det(&self) -> Result<(A, A::Real)>;
340360
}
341361

342362
/// An interface for calculating determinants of matrices.
343-
pub trait DeterminantInto<A: Scalar> {
363+
pub trait DeterminantInto<A: Scalar>: Sized {
344364
/// Computes the determinant of the matrix.
345-
fn det_into(self) -> Result<A>;
365+
fn det_into(self) -> Result<A> {
366+
let (sign, ln_det) = self.sln_det_into()?;
367+
Ok(sign.mul_real(ln_det.exp()))
368+
}
369+
370+
/// Computes the `(sign, natural_log)` of the determinant of the matrix.
371+
///
372+
/// For real matrices, `sign` is `1`, `0`, or `-1`. For complex matrices,
373+
/// `sign` is `0` or a complex number with absolute value 1. The
374+
/// `natural_log` is the natural logarithm of the absolute value of the
375+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
376+
/// is negative infinity.
377+
///
378+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
379+
/// or just call `.det_into()` instead.
380+
///
381+
/// This method is more robust than `.det()` to very small or very large
382+
/// determinants since it returns the natural logarithm of the determinant
383+
/// rather than the determinant itself.
384+
fn sln_det_into(self) -> Result<(A, A::Real)>;
346385
}
347386

348-
fn lu_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> A
387+
fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real)
349388
where
350389
A: Scalar,
351390
P: Iterator<Item = i32>,
@@ -360,24 +399,27 @@ where
360399
} else {
361400
-A::one()
362401
};
363-
let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::zero()), |(upper_sign, ln_det), &elem| {
364-
let abs_elem = elem.abs();
365-
(
366-
upper_sign * elem.div_real(abs_elem),
367-
ln_det.add_real(abs_elem.ln()),
368-
)
369-
});
370-
pivot_sign * upper_sign * ln_det.exp()
402+
let (upper_sign, ln_det) = u_diag_iter.fold(
403+
(A::one(), A::Real::zero()),
404+
|(upper_sign, ln_det), &elem| {
405+
let abs_elem: A::Real = elem.abs();
406+
(upper_sign * elem.div_real(abs_elem), ln_det + abs_elem.ln())
407+
},
408+
);
409+
(pivot_sign * upper_sign, ln_det)
371410
}
372411

373412
impl<A, S> Determinant<A> for LUFactorized<S>
374413
where
375414
A: Scalar,
376415
S: Data<Elem = A>,
377416
{
378-
fn det(&self) -> Result<A> {
417+
fn sln_det(&self) -> Result<(A, A::Real)> {
379418
self.a.ensure_square()?;
380-
Ok(lu_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
419+
Ok(lu_sln_det(
420+
self.ipiv.iter().cloned(),
421+
self.a.diag().iter(),
422+
))
381423
}
382424
}
383425

@@ -386,9 +428,12 @@ where
386428
A: Scalar,
387429
S: Data<Elem = A>,
388430
{
389-
fn det_into(self) -> Result<A> {
431+
fn sln_det_into(self) -> Result<(A, A::Real)> {
390432
self.a.ensure_square()?;
391-
Ok(lu_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
433+
Ok(lu_sln_det(
434+
self.ipiv.into_iter(),
435+
self.a.into_diag().iter(),
436+
))
392437
}
393438
}
394439

@@ -397,11 +442,14 @@ where
397442
A: Scalar,
398443
S: Data<Elem = A>,
399444
{
400-
fn det(&self) -> Result<A> {
445+
fn sln_det(&self) -> Result<(A, A::Real)> {
401446
self.ensure_square()?;
402447
match self.factorize() {
403-
Ok(fac) => fac.det(),
404-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
448+
Ok(fac) => fac.sln_det(),
449+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
450+
// The determinant is zero.
451+
Ok((A::zero(), A::Real::neg_infinity()))
452+
}
405453
Err(err) => Err(err),
406454
}
407455
}
@@ -412,11 +460,14 @@ where
412460
A: Scalar,
413461
S: DataMut<Elem = A>,
414462
{
415-
fn det_into(self) -> Result<A> {
463+
fn sln_det_into(self) -> Result<(A, A::Real)> {
416464
self.ensure_square()?;
417465
match self.factorize_into() {
418-
Ok(fac) => fac.det_into(),
419-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
466+
Ok(fac) => fac.sln_det_into(),
467+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
468+
// The determinant is zero.
469+
Ok((A::zero(), A::Real::neg_infinity()))
470+
}
420471
Err(err) => Err(err),
421472
}
422473
}

src/solveh.rs

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,20 +254,53 @@ where
254254
/// An interface for calculating determinants of Hermitian (or real symmetric) matrix refs.
255255
pub trait DeterminantH {
256256
type Output;
257+
type SignLnOutput;
257258

258259
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
259260
fn deth(&self) -> Self::Output;
261+
262+
/// Computes the `(sign, natural_log)` of the determinant of the Hermitian
263+
/// (or real symmetric) matrix.
264+
///
265+
/// The `natural_log` is the natural logarithm of the absolute value of the
266+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
267+
/// is negative infinity.
268+
///
269+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
270+
/// or just call `.deth()` instead.
271+
///
272+
/// This method is more robust than `.deth()` to very small or very large
273+
/// determinants since it returns the natural logarithm of the determinant
274+
/// rather than the determinant itself.
275+
fn sln_deth(&self) -> Self::SignLnOutput;
260276
}
261277

262278
/// An interface for calculating determinants of Hermitian (or real symmetric) matrices.
263279
pub trait DeterminantHInto {
264280
type Output;
281+
type SignLnOutput;
265282

266283
/// Computes the determinant of the Hermitian (or real symmetric) matrix.
267284
fn deth_into(self) -> Self::Output;
285+
286+
/// Computes the `(sign, natural_log)` of the determinant of the Hermitian
287+
/// (or real symmetric) matrix.
288+
///
289+
/// The `natural_log` is the natural logarithm of the absolute value of the
290+
/// determinant. If the determinant is zero, `sign` is 0 and `natural_log`
291+
/// is negative infinity.
292+
///
293+
/// To obtain the determinant, you can compute `sign * natural_log.exp()`
294+
/// or just call `.deth_into()` instead.
295+
///
296+
/// This method is more robust than `.deth_into()` to very small or very
297+
/// large determinants since it returns the natural logarithm of the
298+
/// determinant rather than the determinant itself.
299+
fn sln_deth_into(self) -> Self::SignLnOutput;
268300
}
269301

270-
fn bk_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> A::Real
302+
/// Returns the sign and natural log of the determinant.
303+
fn bk_sln_det<P, S, A>(uplo: UPLO, ipiv_iter: P, a: &ArrayBase<S, Ix2>) -> (A::Real, A::Real)
271304
where
272305
P: Iterator<Item = i32>,
273306
S: Data<Elem = A>,
@@ -310,7 +343,7 @@ where
310343
ipiv_enum.next();
311344
}
312345
}
313-
sign * ln_det.exp()
346+
(sign, ln_det)
314347
}
315348

316349
impl<A, S> DeterminantH for BKFactorized<S>
@@ -319,9 +352,15 @@ where
319352
S: Data<Elem = A>,
320353
{
321354
type Output = A::Real;
355+
type SignLnOutput = (A::Real, A::Real);
322356

323357
fn deth(&self) -> A::Real {
324-
bk_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
358+
let (sign, ln_det) = self.sln_deth();
359+
sign * ln_det.exp()
360+
}
361+
362+
fn sln_deth(&self) -> (A::Real, A::Real) {
363+
bk_sln_det(UPLO::Upper, self.ipiv.iter().cloned(), &self.a)
325364
}
326365
}
327366

@@ -331,9 +370,15 @@ where
331370
S: Data<Elem = A>,
332371
{
333372
type Output = A::Real;
373+
type SignLnOutput = (A::Real, A::Real);
334374

335375
fn deth_into(self) -> A::Real {
336-
bk_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
376+
let (sign, ln_det) = self.sln_deth_into();
377+
sign * ln_det.exp()
378+
}
379+
380+
fn sln_deth_into(self) -> (A::Real, A::Real) {
381+
bk_sln_det(UPLO::Upper, self.ipiv.into_iter(), &self.a)
337382
}
338383
}
339384

@@ -343,11 +388,20 @@ where
343388
S: Data<Elem = A>,
344389
{
345390
type Output = Result<A::Real>;
391+
type SignLnOutput = Result<(A::Real, A::Real)>;
346392

347393
fn deth(&self) -> Result<A::Real> {
394+
let (sign, ln_det) = self.sln_deth()?;
395+
Ok(sign * ln_det.exp())
396+
}
397+
398+
fn sln_deth(&self) -> Result<(A::Real, A::Real)> {
348399
match self.factorizeh() {
349-
Ok(fac) => Ok(fac.deth()),
350-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
400+
Ok(fac) => Ok(fac.sln_deth()),
401+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
402+
// Determinant is zero.
403+
Ok((A::Real::zero(), A::Real::neg_infinity()))
404+
}
351405
Err(err) => Err(err),
352406
}
353407
}
@@ -359,11 +413,20 @@ where
359413
S: DataMut<Elem = A>,
360414
{
361415
type Output = Result<A::Real>;
416+
type SignLnOutput = Result<(A::Real, A::Real)>;
362417

363418
fn deth_into(self) -> Result<A::Real> {
419+
let (sign, ln_det) = self.sln_deth_into()?;
420+
Ok(sign * ln_det.exp())
421+
}
422+
423+
fn sln_deth_into(self) -> Result<(A::Real, A::Real)> {
364424
match self.factorizeh_into() {
365-
Ok(fac) => Ok(fac.deth_into()),
366-
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::Real::zero()),
425+
Ok(fac) => Ok(fac.sln_deth_into()),
426+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => {
427+
// Determinant is zero.
428+
Ok((A::Real::zero(), A::Real::neg_infinity()))
429+
}
367430
Err(err) => Err(err),
368431
}
369432
}

0 commit comments

Comments
 (0)