@@ -5,22 +5,70 @@ extern crate num_traits;
5
5
6
6
use ndarray:: * ;
7
7
use ndarray_linalg:: * ;
8
- use num_traits:: Zero ;
8
+ use num_traits:: { One , Zero } ;
9
9
10
- fn det_3x3 < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
10
+ /// Returns the matrix with the specified `row` and `col` removed.
11
+ fn matrix_minor < A , S > ( a : ArrayBase < S , Ix2 > , ( row, col) : ( usize , usize ) ) -> Array2 < A >
11
12
where
12
13
A : Scalar ,
13
14
S : Data < Elem = A > ,
14
15
{
15
- a[ ( 0 , 0 ) ] * a[ ( 1 , 1 ) ] * a[ ( 2 , 2 ) ] + a[ ( 0 , 1 ) ] * a[ ( 1 , 2 ) ] * a[ ( 2 , 0 ) ] + a[ ( 0 , 2 ) ] * a[ ( 1 , 0 ) ] * a[ ( 2 , 1 ) ] -
16
- a[ ( 0 , 2 ) ] * a[ ( 1 , 1 ) ] * a[ ( 2 , 0 ) ] - a[ ( 0 , 1 ) ] * a[ ( 1 , 0 ) ] * a[ ( 2 , 2 ) ] - a[ ( 0 , 0 ) ] * a[ ( 1 , 2 ) ] * a[ ( 2 , 1 ) ]
16
+ let mut select_rows = ( 0 ..a. rows ( ) ) . collect :: < Vec < _ > > ( ) ;
17
+ select_rows. remove ( row) ;
18
+ let mut select_cols = ( 0 ..a. cols ( ) ) . collect :: < Vec < _ > > ( ) ;
19
+ select_cols. remove ( col) ;
20
+ a. select ( Axis ( 0 ) , & select_rows) . select (
21
+ Axis ( 1 ) ,
22
+ & select_cols,
23
+ )
24
+ }
25
+
26
+ /// Computes the determinant of matrix `a`.
27
+ ///
28
+ /// Note: This implementation is written to be clearly correct so that it's
29
+ /// useful for verification, but it's very inefficient.
30
+ fn det_naive < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
31
+ where
32
+ A : Scalar ,
33
+ S : Data < Elem = A > ,
34
+ {
35
+ assert_eq ! ( a. rows( ) , a. cols( ) ) ;
36
+ match a. cols ( ) {
37
+ 0 => A :: one ( ) ,
38
+ 1 => a[ ( 0 , 0 ) ] ,
39
+ cols => {
40
+ ( 0 ..cols)
41
+ . map ( |col| {
42
+ let sign = if col % 2 == 0 { A :: one ( ) } else { -A :: one ( ) } ;
43
+ sign * a[ ( 0 , col) ] * det_naive ( matrix_minor ( a. view ( ) , ( 0 , col) ) )
44
+ } )
45
+ . fold ( A :: zero ( ) , |sum, subdet| sum + subdet)
46
+ }
47
+ }
48
+ }
49
+
50
+ #[ test]
51
+ fn det_empty ( ) {
52
+ macro_rules! det_empty {
53
+ ( $elem: ty) => {
54
+ let a: Array2 <$elem> = Array2 :: zeros( ( 0 , 0 ) ) ;
55
+ assert_eq!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , One :: one( ) ) ;
56
+ assert_eq!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , One :: one( ) ) ;
57
+ assert_eq!( a. det( ) . unwrap( ) , One :: one( ) ) ;
58
+ assert_eq!( a. det_into( ) . unwrap( ) , One :: one( ) ) ;
59
+ }
60
+ }
61
+ det_empty ! ( f64 ) ;
62
+ det_empty ! ( f32 ) ;
63
+ det_empty ! ( c64) ;
64
+ det_empty ! ( c32) ;
17
65
}
18
66
19
67
#[ test]
20
68
fn det_zero ( ) {
21
69
macro_rules! det_zero {
22
70
( $elem: ty) => {
23
- let a: Array2 <$elem> = array! [ [ Zero :: zero ( ) ] ] ;
71
+ let a: Array2 <$elem> = Array2 :: zeros ( ( 1 , 1 ) ) ;
24
72
assert_eq!( a. det( ) . unwrap( ) , Zero :: zero( ) ) ;
25
73
assert_eq!( a. det_into( ) . unwrap( ) , Zero :: zero( ) ) ;
26
74
}
@@ -54,18 +102,20 @@ fn det() {
54
102
( $elem: ty, $shape: expr, $rtol: expr) => {
55
103
let a: Array2 <$elem> = random( $shape) ;
56
104
println!( "a = \n {:?}" , a) ;
57
- let det = det_3x3 ( a. view( ) ) ;
105
+ let det = det_naive ( a. view( ) ) ;
58
106
assert_rclose!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , det, $rtol) ;
59
107
assert_rclose!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , det, $rtol) ;
60
108
assert_rclose!( a. det( ) . unwrap( ) , det, $rtol) ;
61
109
assert_rclose!( a. det_into( ) . unwrap( ) , det, $rtol) ;
62
110
}
63
111
}
64
- for & shape in & [ ( 3 , 3 ) . into_shape ( ) , ( 3 , 3 ) . f ( ) ] {
65
- det ! ( f64 , shape, 1e-9 ) ;
66
- det ! ( f32 , shape, 1e-4 ) ;
67
- det ! ( c64, shape, 1e-9 ) ;
68
- det ! ( c32, shape, 1e-4 ) ;
112
+ for rows in 1 ..5 {
113
+ for & shape in & [ ( rows, rows) . into_shape ( ) , ( rows, rows) . f ( ) ] {
114
+ det ! ( f64 , shape, 1e-9 ) ;
115
+ det ! ( f32 , shape, 1e-4 ) ;
116
+ det ! ( c64, shape, 1e-9 ) ;
117
+ det ! ( c32, shape, 1e-4 ) ;
118
+ }
69
119
}
70
120
}
71
121
@@ -80,10 +130,18 @@ fn det_nonsquare() {
80
130
assert!( a. det_into( ) . is_err( ) ) ;
81
131
}
82
132
}
83
- for & shape in & [ ( 1 , 2 ) . into_shape ( ) , ( 1 , 2 ) . f ( ) , ( 2 , 1 ) . into_shape ( ) , ( 2 , 1 ) . f ( ) ] {
84
- det_nonsquare ! ( f64 , shape) ;
85
- det_nonsquare ! ( f32 , shape) ;
86
- det_nonsquare ! ( c64, shape) ;
87
- det_nonsquare ! ( c32, shape) ;
133
+ for & dims in & [ ( 1 , 0 ) , ( 1 , 2 ) , ( 2 , 1 ) , ( 2 , 3 ) ] {
134
+ // Work around bug in ndarray: https://github.com/bluss/rust-ndarray/issues/361
135
+ let shapes = if dims == ( 1 , 0 ) {
136
+ vec ! [ dims. clone( ) . into_shape( ) ]
137
+ } else {
138
+ vec ! [ dims. clone( ) . into_shape( ) , dims. clone( ) . f( ) ]
139
+ } ;
140
+ for & shape in & shapes {
141
+ det_nonsquare ! ( f64 , shape) ;
142
+ det_nonsquare ! ( f32 , shape) ;
143
+ det_nonsquare ! ( c64, shape) ;
144
+ det_nonsquare ! ( c32, shape) ;
145
+ }
88
146
}
89
147
}
0 commit comments