Skip to content

Commit 52e7aa1

Browse files
authored
Merge pull request #85 from jturner314/add-determinant
Add determinant methods in solve module
2 parents 3d0dfa8 + 2ce0be8 commit 52e7aa1

File tree

5 files changed

+267
-0
lines changed

5 files changed

+267
-0
lines changed

src/lapack_traits/solve.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ use super::{Pivot, Transpose, into_result};
1010

1111
/// Wraps `*getrf`, `*getri`, and `*getrs`
1212
pub trait Solve_: Sized {
13+
/// Computes the LU factorization of a general `m x n` matrix `a` using
14+
/// partial pivoting with row interchanges.
15+
///
16+
/// If the result matches `Err(LinalgError::Lapack(LapackError {
17+
/// return_code )) if return_code > 0`, then `U[(return_code-1,
18+
/// return_code-1)]` is exactly zero. The factorization has been completed,
19+
/// but the factor `U` is exactly singular, and division by zero will occur
20+
/// if it is used to solve a system of equations.
1321
unsafe fn lu(MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
1422
unsafe fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>;
1523
unsafe fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;

src/layout.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ pub trait AllocatedArray {
7575
type Elem;
7676
fn layout(&self) -> Result<MatrixLayout>;
7777
fn square_layout(&self) -> Result<MatrixLayout>;
78+
/// Returns Ok iff the matrix is square (without computing the layout).
79+
fn ensure_square(&self) -> Result<()>;
7880
fn as_allocated(&self) -> Result<&[Self::Elem]>;
7981
}
8082

@@ -110,6 +112,14 @@ where
110112
}
111113
}
112114

115+
fn ensure_square(&self) -> Result<()> {
116+
if self.is_square() {
117+
Ok(())
118+
} else {
119+
Err(NotSquareError::new(self.rows() as i32, self.cols() as i32).into())
120+
}
121+
}
122+
113123
fn as_allocated(&self) -> Result<&[A]> {
114124
Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?)
115125
}

src/solve.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,92 @@ where
330330
f.inv_into()
331331
}
332332
}
333+
334+
/// An interface for calculating determinants of matrix refs.
335+
pub trait Determinant<A: Scalar> {
336+
/// Computes the determinant of the matrix.
337+
fn det(&self) -> Result<A>;
338+
}
339+
340+
/// An interface for calculating determinants of matrices.
341+
pub trait DeterminantInto<A: Scalar> {
342+
/// Computes the determinant of the matrix.
343+
fn det_into(self) -> Result<A>;
344+
}
345+
346+
fn lu_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> A
347+
where
348+
A: Scalar,
349+
P: Iterator<Item = i32>,
350+
U: Iterator<Item = &'a A>,
351+
{
352+
let pivot_sign = if ipiv_iter
353+
.enumerate()
354+
.filter(|&(i, pivot)| pivot != i as i32 + 1)
355+
.count() % 2 == 0
356+
{
357+
A::one()
358+
} else {
359+
-A::one()
360+
};
361+
let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::zero()), |(upper_sign, ln_det), &elem| {
362+
let abs_elem = elem.abs();
363+
(
364+
upper_sign * elem.div_real(abs_elem),
365+
ln_det.add_real(abs_elem.ln()),
366+
)
367+
});
368+
pivot_sign * upper_sign * ln_det.exp()
369+
}
370+
371+
impl<A, S> Determinant<A> for LUFactorized<S>
372+
where
373+
A: Scalar,
374+
S: Data<Elem = A>,
375+
{
376+
fn det(&self) -> Result<A> {
377+
self.a.ensure_square()?;
378+
Ok(lu_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
379+
}
380+
}
381+
382+
impl<A, S> DeterminantInto<A> for LUFactorized<S>
383+
where
384+
A: Scalar,
385+
S: Data<Elem = A>,
386+
{
387+
fn det_into(self) -> Result<A> {
388+
self.a.ensure_square()?;
389+
Ok(lu_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
390+
}
391+
}
392+
393+
impl<A, S> Determinant<A> for ArrayBase<S, Ix2>
394+
where
395+
A: Scalar,
396+
S: Data<Elem = A>,
397+
{
398+
fn det(&self) -> Result<A> {
399+
self.ensure_square()?;
400+
match self.factorize() {
401+
Ok(fac) => fac.det(),
402+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
403+
Err(err) => Err(err),
404+
}
405+
}
406+
}
407+
408+
impl<A, S> DeterminantInto<A> for ArrayBase<S, Ix2>
409+
where
410+
A: Scalar,
411+
S: DataMut<Elem = A>,
412+
{
413+
fn det_into(self) -> Result<A> {
414+
self.ensure_square()?;
415+
match self.factorize_into() {
416+
Ok(fac) => fac.det_into(),
417+
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
418+
Err(err) => Err(err),
419+
}
420+
}
421+
}

src/types.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use num_complex::Complex64 as c64;
2222
/// - [abs_sqr](trait.Absolute.html#tymethod.abs_sqr)
2323
/// - [sqrt](trait.SquareRoot.html#tymethod.sqrt)
2424
/// - [exp](trait.Exponential.html#tymethod.exp)
25+
/// - [ln](trait.NaturalLogarithm.html#tymethod.ln)
2526
/// - [conj](trait.Conjugate.html#tymethod.conj)
2627
/// - [randn](trait.RandNormal.html#tymethod.randn)
2728
///
@@ -33,6 +34,7 @@ pub trait Scalar
3334
+ Absolute
3435
+ SquareRoot
3536
+ Exponential
37+
+ NaturalLogarithm
3638
+ Conjugate
3739
+ RandNormal
3840
+ Neg<Output = Self>
@@ -118,6 +120,11 @@ pub trait Exponential {
118120
fn exp(&self) -> Self;
119121
}
120122

123+
/// Define `ln()` more generally
124+
pub trait NaturalLogarithm {
125+
fn ln(&self) -> Self;
126+
}
127+
121128
/// Complex conjugate value
122129
pub trait Conjugate: Copy {
123130
fn conj(self) -> Self;
@@ -207,6 +214,18 @@ impl Exponential for $complex {
207214
}
208215
}
209216

217+
impl NaturalLogarithm for $real {
218+
fn ln(&self) -> Self {
219+
Float::ln(*self)
220+
}
221+
}
222+
223+
impl NaturalLogarithm for $complex {
224+
fn ln(&self) -> Self {
225+
Complex::ln(self)
226+
}
227+
}
228+
210229
impl Conjugate for $real {
211230
fn conj(self) -> Self {
212231
self

tests/solve.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
extern crate ndarray;
2+
#[macro_use]
3+
extern crate ndarray_linalg;
4+
extern crate num_traits;
5+
6+
use ndarray::*;
7+
use ndarray_linalg::*;
8+
use num_traits::{One, Zero};
9+
10+
/// Returns the matrix with the specified `row` and `col` removed.
11+
fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
12+
where
13+
A: Scalar,
14+
S: Data<Elem = A>,
15+
{
16+
let mut select_rows = (0..a.rows()).collect::<Vec<_>>();
17+
select_rows.remove(row);
18+
let mut select_cols = (0..a.cols()).collect::<Vec<_>>();
19+
select_cols.remove(col);
20+
a.select(Axis(0), &select_rows).select(
21+
Axis(1),
22+
&select_cols,
23+
)
24+
}
25+
26+
/// Computes the determinant of matrix `a`.
27+
///
28+
/// Note: This implementation is written to be clearly correct so that it's
29+
/// useful for verification, but it's very inefficient.
30+
fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A
31+
where
32+
A: Scalar,
33+
S: Data<Elem = A>,
34+
{
35+
assert_eq!(a.rows(), a.cols());
36+
match a.cols() {
37+
0 => A::one(),
38+
1 => a[(0, 0)],
39+
cols => {
40+
(0..cols)
41+
.map(|col| {
42+
let sign = if col % 2 == 0 { A::one() } else { -A::one() };
43+
sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col)))
44+
})
45+
.fold(A::zero(), |sum, subdet| sum + subdet)
46+
}
47+
}
48+
}
49+
50+
#[test]
51+
fn det_empty() {
52+
macro_rules! det_empty {
53+
($elem:ty) => {
54+
let a: Array2<$elem> = Array2::zeros((0, 0));
55+
assert_eq!(a.factorize().unwrap().det().unwrap(), One::one());
56+
assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one());
57+
assert_eq!(a.det().unwrap(), One::one());
58+
assert_eq!(a.det_into().unwrap(), One::one());
59+
}
60+
}
61+
det_empty!(f64);
62+
det_empty!(f32);
63+
det_empty!(c64);
64+
det_empty!(c32);
65+
}
66+
67+
#[test]
68+
fn det_zero() {
69+
macro_rules! det_zero {
70+
($elem:ty) => {
71+
let a: Array2<$elem> = Array2::zeros((1, 1));
72+
assert_eq!(a.det().unwrap(), Zero::zero());
73+
assert_eq!(a.det_into().unwrap(), Zero::zero());
74+
}
75+
}
76+
det_zero!(f64);
77+
det_zero!(f32);
78+
det_zero!(c64);
79+
det_zero!(c32);
80+
}
81+
82+
#[test]
83+
fn det_zero_nonsquare() {
84+
macro_rules! det_zero_nonsquare {
85+
($elem:ty, $shape:expr) => {
86+
let a: Array2<$elem> = Array2::zeros($shape);
87+
assert!(a.det().is_err());
88+
assert!(a.det_into().is_err());
89+
}
90+
}
91+
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
92+
det_zero_nonsquare!(f64, shape);
93+
det_zero_nonsquare!(f32, shape);
94+
det_zero_nonsquare!(c64, shape);
95+
det_zero_nonsquare!(c32, shape);
96+
}
97+
}
98+
99+
#[test]
100+
fn det() {
101+
macro_rules! det {
102+
($elem:ty, $shape:expr, $rtol:expr) => {
103+
let a: Array2<$elem> = random($shape);
104+
println!("a = \n{:?}", a);
105+
let det = det_naive(a.view());
106+
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
107+
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
108+
assert_rclose!(a.det().unwrap(), det, $rtol);
109+
assert_rclose!(a.det_into().unwrap(), det, $rtol);
110+
}
111+
}
112+
for rows in 1..5 {
113+
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
114+
det!(f64, shape, 1e-9);
115+
det!(f32, shape, 1e-4);
116+
det!(c64, shape, 1e-9);
117+
det!(c32, shape, 1e-4);
118+
}
119+
}
120+
}
121+
122+
#[test]
123+
fn det_nonsquare() {
124+
macro_rules! det_nonsquare {
125+
($elem:ty, $shape:expr) => {
126+
let a: Array2<$elem> = random($shape);
127+
assert!(a.factorize().unwrap().det().is_err());
128+
assert!(a.factorize().unwrap().det_into().is_err());
129+
assert!(a.det().is_err());
130+
assert!(a.det_into().is_err());
131+
}
132+
}
133+
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
134+
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
135+
det_nonsquare!(f64, shape);
136+
det_nonsquare!(f32, shape);
137+
det_nonsquare!(c64, shape);
138+
det_nonsquare!(c32, shape);
139+
}
140+
}
141+
}

0 commit comments

Comments
 (0)