|
| 1 | +extern crate ndarray; |
| 2 | +#[macro_use] |
| 3 | +extern crate ndarray_linalg; |
| 4 | +extern crate num_traits; |
| 5 | + |
| 6 | +use ndarray::*; |
| 7 | +use ndarray_linalg::*; |
| 8 | +use num_traits::{One, Zero}; |
| 9 | + |
| 10 | +/// Returns the matrix with the specified `row` and `col` removed. |
| 11 | +fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A> |
| 12 | +where |
| 13 | + A: Scalar, |
| 14 | + S: Data<Elem = A>, |
| 15 | +{ |
| 16 | + let mut select_rows = (0..a.rows()).collect::<Vec<_>>(); |
| 17 | + select_rows.remove(row); |
| 18 | + let mut select_cols = (0..a.cols()).collect::<Vec<_>>(); |
| 19 | + select_cols.remove(col); |
| 20 | + a.select(Axis(0), &select_rows).select( |
| 21 | + Axis(1), |
| 22 | + &select_cols, |
| 23 | + ) |
| 24 | +} |
| 25 | + |
| 26 | +/// Computes the determinant of matrix `a`. |
| 27 | +/// |
| 28 | +/// Note: This implementation is written to be clearly correct so that it's |
| 29 | +/// useful for verification, but it's very inefficient. |
| 30 | +fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A |
| 31 | +where |
| 32 | + A: Scalar, |
| 33 | + S: Data<Elem = A>, |
| 34 | +{ |
| 35 | + assert_eq!(a.rows(), a.cols()); |
| 36 | + match a.cols() { |
| 37 | + 0 => A::one(), |
| 38 | + 1 => a[(0, 0)], |
| 39 | + cols => { |
| 40 | + (0..cols) |
| 41 | + .map(|col| { |
| 42 | + let sign = if col % 2 == 0 { A::one() } else { -A::one() }; |
| 43 | + sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col))) |
| 44 | + }) |
| 45 | + .fold(A::zero(), |sum, subdet| sum + subdet) |
| 46 | + } |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +#[test] |
| 51 | +fn det_empty() { |
| 52 | + macro_rules! det_empty { |
| 53 | + ($elem:ty) => { |
| 54 | + let a: Array2<$elem> = Array2::zeros((0, 0)); |
| 55 | + assert_eq!(a.factorize().unwrap().det().unwrap(), One::one()); |
| 56 | + assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one()); |
| 57 | + assert_eq!(a.det().unwrap(), One::one()); |
| 58 | + assert_eq!(a.det_into().unwrap(), One::one()); |
| 59 | + } |
| 60 | + } |
| 61 | + det_empty!(f64); |
| 62 | + det_empty!(f32); |
| 63 | + det_empty!(c64); |
| 64 | + det_empty!(c32); |
| 65 | +} |
| 66 | + |
| 67 | +#[test] |
| 68 | +fn det_zero() { |
| 69 | + macro_rules! det_zero { |
| 70 | + ($elem:ty) => { |
| 71 | + let a: Array2<$elem> = Array2::zeros((1, 1)); |
| 72 | + assert_eq!(a.det().unwrap(), Zero::zero()); |
| 73 | + assert_eq!(a.det_into().unwrap(), Zero::zero()); |
| 74 | + } |
| 75 | + } |
| 76 | + det_zero!(f64); |
| 77 | + det_zero!(f32); |
| 78 | + det_zero!(c64); |
| 79 | + det_zero!(c32); |
| 80 | +} |
| 81 | + |
| 82 | +#[test] |
| 83 | +fn det_zero_nonsquare() { |
| 84 | + macro_rules! det_zero_nonsquare { |
| 85 | + ($elem:ty, $shape:expr) => { |
| 86 | + let a: Array2<$elem> = Array2::zeros($shape); |
| 87 | + assert!(a.det().is_err()); |
| 88 | + assert!(a.det_into().is_err()); |
| 89 | + } |
| 90 | + } |
| 91 | + for &shape in &[(1, 2).into_shape(), (1, 2).f()] { |
| 92 | + det_zero_nonsquare!(f64, shape); |
| 93 | + det_zero_nonsquare!(f32, shape); |
| 94 | + det_zero_nonsquare!(c64, shape); |
| 95 | + det_zero_nonsquare!(c32, shape); |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +#[test] |
| 100 | +fn det() { |
| 101 | + macro_rules! det { |
| 102 | + ($elem:ty, $shape:expr, $rtol:expr) => { |
| 103 | + let a: Array2<$elem> = random($shape); |
| 104 | + println!("a = \n{:?}", a); |
| 105 | + let det = det_naive(a.view()); |
| 106 | + assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol); |
| 107 | + assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol); |
| 108 | + assert_rclose!(a.det().unwrap(), det, $rtol); |
| 109 | + assert_rclose!(a.det_into().unwrap(), det, $rtol); |
| 110 | + } |
| 111 | + } |
| 112 | + for rows in 1..5 { |
| 113 | + for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] { |
| 114 | + det!(f64, shape, 1e-9); |
| 115 | + det!(f32, shape, 1e-4); |
| 116 | + det!(c64, shape, 1e-9); |
| 117 | + det!(c32, shape, 1e-4); |
| 118 | + } |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +#[test] |
| 123 | +fn det_nonsquare() { |
| 124 | + macro_rules! det_nonsquare { |
| 125 | + ($elem:ty, $shape:expr) => { |
| 126 | + let a: Array2<$elem> = random($shape); |
| 127 | + assert!(a.factorize().unwrap().det().is_err()); |
| 128 | + assert!(a.factorize().unwrap().det_into().is_err()); |
| 129 | + assert!(a.det().is_err()); |
| 130 | + assert!(a.det_into().is_err()); |
| 131 | + } |
| 132 | + } |
| 133 | + for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] { |
| 134 | + for &shape in &[dims.clone().into_shape(), dims.clone().f()] { |
| 135 | + det_nonsquare!(f64, shape); |
| 136 | + det_nonsquare!(f32, shape); |
| 137 | + det_nonsquare!(c64, shape); |
| 138 | + det_nonsquare!(c32, shape); |
| 139 | + } |
| 140 | + } |
| 141 | +} |
0 commit comments