Skip to content

Commit 1b77e5c

Browse files
committed
Add more tests for determinant methods
1 parent da4da09 commit 1b77e5c

File tree

1 file changed

+74
-16
lines changed

1 file changed

+74
-16
lines changed

tests/solve.rs

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,70 @@ extern crate num_traits;
55

66
use ndarray::*;
77
use ndarray_linalg::*;
8-
use num_traits::Zero;
8+
use num_traits::{One, Zero};
99

10-
fn det_3x3<A, S>(a: ArrayBase<S, Ix2>) -> A
10+
/// Returns the matrix with the specified `row` and `col` removed.
11+
fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
1112
where
1213
A: Scalar,
1314
S: Data<Elem = A>,
1415
{
15-
a[(0, 0)] * a[(1, 1)] * a[(2, 2)] + a[(0, 1)] * a[(1, 2)] * a[(2, 0)] + a[(0, 2)] * a[(1, 0)] * a[(2, 1)] -
16-
a[(0, 2)] * a[(1, 1)] * a[(2, 0)] - a[(0, 1)] * a[(1, 0)] * a[(2, 2)] - a[(0, 0)] * a[(1, 2)] * a[(2, 1)]
16+
let mut select_rows = (0..a.rows()).collect::<Vec<_>>();
17+
select_rows.remove(row);
18+
let mut select_cols = (0..a.cols()).collect::<Vec<_>>();
19+
select_cols.remove(col);
20+
a.select(Axis(0), &select_rows).select(
21+
Axis(1),
22+
&select_cols,
23+
)
24+
}
25+
26+
/// Computes the determinant of matrix `a`.
27+
///
28+
/// Note: This implementation is written to be clearly correct so that it's
29+
/// useful for verification, but it's very inefficient.
30+
fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A
31+
where
32+
A: Scalar,
33+
S: Data<Elem = A>,
34+
{
35+
assert_eq!(a.rows(), a.cols());
36+
match a.cols() {
37+
0 => A::one(),
38+
1 => a[(0, 0)],
39+
cols => {
40+
(0..cols)
41+
.map(|col| {
42+
let sign = if col % 2 == 0 { A::one() } else { -A::one() };
43+
sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col)))
44+
})
45+
.fold(A::zero(), |sum, subdet| sum + subdet)
46+
}
47+
}
48+
}
49+
50+
#[test]
51+
fn det_empty() {
52+
macro_rules! det_empty {
53+
($elem:ty) => {
54+
let a: Array2<$elem> = Array2::zeros((0, 0));
55+
assert_eq!(a.factorize().unwrap().det().unwrap(), One::one());
56+
assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one());
57+
assert_eq!(a.det().unwrap(), One::one());
58+
assert_eq!(a.det_into().unwrap(), One::one());
59+
}
60+
}
61+
det_empty!(f64);
62+
det_empty!(f32);
63+
det_empty!(c64);
64+
det_empty!(c32);
1765
}
1866

1967
#[test]
2068
fn det_zero() {
2169
macro_rules! det_zero {
2270
($elem:ty) => {
23-
let a: Array2<$elem> = array![[Zero::zero()]];
71+
let a: Array2<$elem> = Array2::zeros((1, 1));
2472
assert_eq!(a.det().unwrap(), Zero::zero());
2573
assert_eq!(a.det_into().unwrap(), Zero::zero());
2674
}
@@ -54,18 +102,20 @@ fn det() {
54102
($elem:ty, $shape:expr, $rtol:expr) => {
55103
let a: Array2<$elem> = random($shape);
56104
println!("a = \n{:?}", a);
57-
let det = det_3x3(a.view());
105+
let det = det_naive(a.view());
58106
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
59107
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
60108
assert_rclose!(a.det().unwrap(), det, $rtol);
61109
assert_rclose!(a.det_into().unwrap(), det, $rtol);
62110
}
63111
}
64-
for &shape in &[(3, 3).into_shape(), (3, 3).f()] {
65-
det!(f64, shape, 1e-9);
66-
det!(f32, shape, 1e-4);
67-
det!(c64, shape, 1e-9);
68-
det!(c32, shape, 1e-4);
112+
for rows in 1..5 {
113+
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
114+
det!(f64, shape, 1e-9);
115+
det!(f32, shape, 1e-4);
116+
det!(c64, shape, 1e-9);
117+
det!(c32, shape, 1e-4);
118+
}
69119
}
70120
}
71121

@@ -80,10 +130,18 @@ fn det_nonsquare() {
80130
assert!(a.det_into().is_err());
81131
}
82132
}
83-
for &shape in &[(1, 2).into_shape(), (1, 2).f(), (2, 1).into_shape(), (2, 1).f()] {
84-
det_nonsquare!(f64, shape);
85-
det_nonsquare!(f32, shape);
86-
det_nonsquare!(c64, shape);
87-
det_nonsquare!(c32, shape);
133+
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
134+
// Work around bug in ndarray: https://github.com/bluss/rust-ndarray/issues/361
135+
let shapes = if dims == (1, 0) {
136+
vec![dims.clone().into_shape()]
137+
} else {
138+
vec![dims.clone().into_shape(), dims.clone().f()]
139+
};
140+
for &shape in &shapes {
141+
det_nonsquare!(f64, shape);
142+
det_nonsquare!(f32, shape);
143+
det_nonsquare!(c64, shape);
144+
det_nonsquare!(c32, shape);
145+
}
88146
}
89147
}

0 commit comments

Comments
 (0)