Skip to content

Commit da4da09

Browse files
committed
Add determinant methods in solve module
1 parent d5898c4 commit da4da09

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

src/lapack_traits/solve.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ use super::{Pivot, Transpose, into_result};
1010

1111
/// Wraps `*getrf`, `*getri`, and `*getrs`
1212
pub trait Solve_: Sized {
13+
/// Computes the LU factorization of a general `m x n` matrix `a` using
14+
/// partial pivoting with row interchanges.
15+
///
16+
/// If the result matches `Err(LinalgError::Lapack(LapackError {
17+
/// return_code )) if return_code > 0`, then `U[(return_code,
18+
/// return_code)]` is exactly zero. The factorization has been completed,
19+
/// but the factor `U` is exactly singular, and division by zero will occur
20+
/// if it is used to solve a system of equations.
1321
unsafe fn lu(MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
1422
unsafe fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>;
1523
unsafe fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;

src/solve.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,100 @@ where
328328
f.inv_into()
329329
}
330330
}
331+
332+
/// An interface for calculating determinants of matrix refs.
333+
pub trait Determinant<A: Scalar> {
334+
/// Computes the determinant of the matrix.
335+
fn det(&self) -> Result<A>;
336+
}
337+
338+
/// An interface for calculating determinants of matrices.
339+
pub trait DeterminantInto<A: Scalar> {
340+
/// Computes the determinant of the matrix.
341+
fn det_into(self) -> Result<A>;
342+
}
343+
344+
fn lu_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> A
345+
where
346+
A: Scalar,
347+
P: Iterator<Item = i32>,
348+
U: Iterator<Item = &'a A>,
349+
{
350+
let pivot_sign = if ipiv_iter
351+
.enumerate()
352+
.filter(|&(i, pivot)| pivot != i as i32 + 1)
353+
.count() % 2 == 0
354+
{
355+
A::one()
356+
} else {
357+
-A::one()
358+
};
359+
let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::zero()), |(upper_sign, ln_det), &elem| {
360+
let abs_elem = elem.abs();
361+
(
362+
upper_sign * elem.div_real(abs_elem),
363+
ln_det.add_real(abs_elem.ln()),
364+
)
365+
});
366+
pivot_sign * upper_sign * ln_det.exp()
367+
}
368+
369+
fn check_square<S: Data>(a: &ArrayBase<S, Ix2>) -> Result<()> {
370+
if a.is_square() {
371+
Ok(())
372+
} else {
373+
Err(NotSquareError::new(a.rows() as i32, a.cols() as i32).into())
374+
}
375+
}
376+
377+
impl<A, S> Determinant<A> for LUFactorized<S>
378+
where
379+
A: Scalar,
380+
S: Data<Elem = A>,
381+
{
382+
fn det(&self) -> Result<A> {
383+
check_square(&self.a)?;
384+
Ok(lu_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
385+
}
386+
}
387+
388+
impl<A, S> DeterminantInto<A> for LUFactorized<S>
389+
where
390+
A: Scalar,
391+
S: Data<Elem = A>,
392+
{
393+
fn det_into(self) -> Result<A> {
394+
check_square(&self.a)?;
395+
Ok(lu_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
396+
}
397+
}
398+
399+
impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
400+
where
401+
A: Scalar,
402+
S: Data<Elem = A>,
403+
{
404+
fn det(&self) -> Result<A> {
405+
check_square(&self)?;
406+
match self.factorize() {
407+
Ok(fac) => fac.det(),
408+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
409+
Err(err) => Err(err),
410+
}
411+
}
412+
}
413+
414+
impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
415+
where
416+
A: Scalar,
417+
S: DataMut<Elem = A>,
418+
{
419+
fn det_into(self) -> Result<A> {
420+
check_square(&self)?;
421+
match self.factorize_into() {
422+
Ok(fac) => fac.det_into(),
423+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
424+
Err(err) => Err(err),
425+
}
426+
}
427+
}

src/types.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use num_complex::Complex64 as c64;
2222
/// - [abs_sqr](trait.Absolute.html#tymethod.abs_sqr)
2323
/// - [sqrt](trait.SquareRoot.html#tymethod.sqrt)
2424
/// - [exp](trait.Exponential.html#tymethod.exp)
25+
/// - [ln](trait.NaturalLogarithm.html#tymethod.ln)
2526
/// - [conj](trait.Conjugate.html#tymethod.conj)
2627
/// - [randn](trait.RandNormal.html#tymethod.randn)
2728
///
@@ -33,6 +34,7 @@ pub trait Scalar
3334
+ Absolute
3435
+ SquareRoot
3536
+ Exponential
37+
+ NaturalLogarithm
3638
+ Conjugate
3739
+ RandNormal
3840
+ Neg<Output = Self>
@@ -118,6 +120,11 @@ pub trait Exponential {
118120
fn exp(&self) -> Self;
119121
}
120122

123+
/// Define `ln()` more generally
124+
pub trait NaturalLogarithm {
125+
fn ln(&self) -> Self;
126+
}
127+
121128
/// Complex conjugate value
122129
pub trait Conjugate: Copy {
123130
fn conj(self) -> Self;
@@ -207,6 +214,18 @@ impl Exponential for $complex {
207214
}
208215
}
209216

217+
impl NaturalLogarithm for $real {
218+
fn ln(&self) -> Self {
219+
Float::ln(*self)
220+
}
221+
}
222+
223+
impl NaturalLogarithm for $complex {
224+
fn ln(&self) -> Self {
225+
Complex::ln(self)
226+
}
227+
}
228+
210229
impl Conjugate for $real {
211230
fn conj(self) -> Self {
212231
self

tests/solve.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
extern crate ndarray;
2+
#[macro_use]
3+
extern crate ndarray_linalg;
4+
extern crate num_traits;
5+
6+
use ndarray::*;
7+
use ndarray_linalg::*;
8+
use num_traits::Zero;
9+
10+
fn det_3x3<A, S>(a: ArrayBase<S, Ix2>) -> A
11+
where
12+
A: Scalar,
13+
S: Data<Elem = A>,
14+
{
15+
a[(0, 0)] * a[(1, 1)] * a[(2, 2)] + a[(0, 1)] * a[(1, 2)] * a[(2, 0)] + a[(0, 2)] * a[(1, 0)] * a[(2, 1)] -
16+
a[(0, 2)] * a[(1, 1)] * a[(2, 0)] - a[(0, 1)] * a[(1, 0)] * a[(2, 2)] - a[(0, 0)] * a[(1, 2)] * a[(2, 1)]
17+
}
18+
19+
#[test]
20+
fn det_zero() {
21+
macro_rules! det_zero {
22+
($elem:ty) => {
23+
let a: Array2<$elem> = array![[Zero::zero()]];
24+
assert_eq!(a.det().unwrap(), Zero::zero());
25+
assert_eq!(a.det_into().unwrap(), Zero::zero());
26+
}
27+
}
28+
det_zero!(f64);
29+
det_zero!(f32);
30+
det_zero!(c64);
31+
det_zero!(c32);
32+
}
33+
34+
#[test]
35+
fn det_zero_nonsquare() {
36+
macro_rules! det_zero_nonsquare {
37+
($elem:ty, $shape:expr) => {
38+
let a: Array2<$elem> = Array2::zeros($shape);
39+
assert!(a.det().is_err());
40+
assert!(a.det_into().is_err());
41+
}
42+
}
43+
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
44+
det_zero_nonsquare!(f64, shape);
45+
det_zero_nonsquare!(f32, shape);
46+
det_zero_nonsquare!(c64, shape);
47+
det_zero_nonsquare!(c32, shape);
48+
}
49+
}
50+
51+
#[test]
52+
fn det() {
53+
macro_rules! det {
54+
($elem:ty, $shape:expr, $rtol:expr) => {
55+
let a: Array2<$elem> = random($shape);
56+
println!("a = \n{:?}", a);
57+
let det = det_3x3(a.view());
58+
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
59+
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
60+
assert_rclose!(a.det().unwrap(), det, $rtol);
61+
assert_rclose!(a.det_into().unwrap(), det, $rtol);
62+
}
63+
}
64+
for &shape in &[(3, 3).into_shape(), (3, 3).f()] {
65+
det!(f64, shape, 1e-9);
66+
det!(f32, shape, 1e-4);
67+
det!(c64, shape, 1e-9);
68+
det!(c32, shape, 1e-4);
69+
}
70+
}
71+
72+
#[test]
73+
fn det_nonsquare() {
74+
macro_rules! det_nonsquare {
75+
($elem:ty, $shape:expr) => {
76+
let a: Array2<$elem> = random($shape);
77+
assert!(a.factorize().unwrap().det().is_err());
78+
assert!(a.factorize().unwrap().det_into().is_err());
79+
assert!(a.det().is_err());
80+
assert!(a.det_into().is_err());
81+
}
82+
}
83+
for &shape in &[(1, 2).into_shape(), (1, 2).f(), (2, 1).into_shape(), (2, 1).f()] {
84+
det_nonsquare!(f64, shape);
85+
det_nonsquare!(f32, shape);
86+
det_nonsquare!(c64, shape);
87+
det_nonsquare!(c32, shape);
88+
}
89+
}

0 commit comments

Comments
 (0)