@@ -5,137 +5,21 @@ extern crate num_traits;
5
5
6
6
use ndarray:: * ;
7
7
use ndarray_linalg:: * ;
8
- use num_traits:: { One , Zero } ;
9
-
10
- /// Returns the matrix with the specified `row` and `col` removed.
11
- fn matrix_minor < A , S > ( a : ArrayBase < S , Ix2 > , ( row, col) : ( usize , usize ) ) -> Array2 < A >
12
- where
13
- A : Scalar ,
14
- S : Data < Elem = A > ,
15
- {
16
- let mut select_rows = ( 0 ..a. rows ( ) ) . collect :: < Vec < _ > > ( ) ;
17
- select_rows. remove ( row) ;
18
- let mut select_cols = ( 0 ..a. cols ( ) ) . collect :: < Vec < _ > > ( ) ;
19
- select_cols. remove ( col) ;
20
- a. select ( Axis ( 0 ) , & select_rows) . select (
21
- Axis ( 1 ) ,
22
- & select_cols,
23
- )
24
- }
25
-
26
- /// Computes the determinant of matrix `a`.
27
- ///
28
- /// Note: This implementation is written to be clearly correct so that it's
29
- /// useful for verification, but it's very inefficient.
30
- fn det_naive < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
31
- where
32
- A : Scalar ,
33
- S : Data < Elem = A > ,
34
- {
35
- assert_eq ! ( a. rows( ) , a. cols( ) ) ;
36
- match a. cols ( ) {
37
- 0 => A :: one ( ) ,
38
- 1 => a[ ( 0 , 0 ) ] ,
39
- cols => {
40
- ( 0 ..cols)
41
- . map ( |col| {
42
- let sign = if col % 2 == 0 { A :: one ( ) } else { -A :: one ( ) } ;
43
- sign * a[ ( 0 , col) ] * det_naive ( matrix_minor ( a. view ( ) , ( 0 , col) ) )
44
- } )
45
- . fold ( A :: zero ( ) , |sum, subdet| sum + subdet)
46
- }
47
- }
48
- }
49
-
50
- #[ test]
51
- fn det_empty ( ) {
52
- macro_rules! det_empty {
53
- ( $elem: ty) => {
54
- let a: Array2 <$elem> = Array2 :: zeros( ( 0 , 0 ) ) ;
55
- assert_eq!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , One :: one( ) ) ;
56
- assert_eq!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , One :: one( ) ) ;
57
- assert_eq!( a. det( ) . unwrap( ) , One :: one( ) ) ;
58
- assert_eq!( a. det_into( ) . unwrap( ) , One :: one( ) ) ;
59
- }
60
- }
61
- det_empty ! ( f64 ) ;
62
- det_empty ! ( f32 ) ;
63
- det_empty ! ( c64) ;
64
- det_empty ! ( c32) ;
65
- }
66
-
67
- #[ test]
68
- fn det_zero ( ) {
69
- macro_rules! det_zero {
70
- ( $elem: ty) => {
71
- let a: Array2 <$elem> = Array2 :: zeros( ( 1 , 1 ) ) ;
72
- assert_eq!( a. det( ) . unwrap( ) , Zero :: zero( ) ) ;
73
- assert_eq!( a. det_into( ) . unwrap( ) , Zero :: zero( ) ) ;
74
- }
75
- }
76
- det_zero ! ( f64 ) ;
77
- det_zero ! ( f32 ) ;
78
- det_zero ! ( c64) ;
79
- det_zero ! ( c32) ;
80
- }
81
-
82
- #[ test]
83
- fn det_zero_nonsquare ( ) {
84
- macro_rules! det_zero_nonsquare {
85
- ( $elem: ty, $shape: expr) => {
86
- let a: Array2 <$elem> = Array2 :: zeros( $shape) ;
87
- assert!( a. det( ) . is_err( ) ) ;
88
- assert!( a. det_into( ) . is_err( ) ) ;
89
- }
90
- }
91
- for & shape in & [ ( 1 , 2 ) . into_shape ( ) , ( 1 , 2 ) . f ( ) ] {
92
- det_zero_nonsquare ! ( f64 , shape) ;
93
- det_zero_nonsquare ! ( f32 , shape) ;
94
- det_zero_nonsquare ! ( c64, shape) ;
95
- det_zero_nonsquare ! ( c32, shape) ;
96
- }
97
- }
98
8
99
9
#[ test]
100
- fn det ( ) {
101
- macro_rules! det {
102
- ( $elem: ty, $shape: expr, $rtol: expr) => {
103
- let a: Array2 <$elem> = random( $shape) ;
104
- println!( "a = \n {:?}" , a) ;
105
- let det = det_naive( a. view( ) ) ;
106
- assert_rclose!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , det, $rtol) ;
107
- assert_rclose!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , det, $rtol) ;
108
- assert_rclose!( a. det( ) . unwrap( ) , det, $rtol) ;
109
- assert_rclose!( a. det_into( ) . unwrap( ) , det, $rtol) ;
110
- }
111
- }
112
- for rows in 1 ..5 {
113
- for & shape in & [ ( rows, rows) . into_shape ( ) , ( rows, rows) . f ( ) ] {
114
- det ! ( f64 , shape, 1e-9 ) ;
115
- det ! ( f32 , shape, 1e-4 ) ;
116
- det ! ( c64, shape, 1e-9 ) ;
117
- det ! ( c32, shape, 1e-4 ) ;
118
- }
119
- }
10
+ fn solve_random ( ) {
11
+ let a: Array2 < f64 > = random ( ( 3 , 3 ) ) ;
12
+ let x: Array1 < f64 > = random ( 3 ) ;
13
+ let b = a. dot ( & x) ;
14
+ let y = a. solve_into ( b) . unwrap ( ) ;
15
+ assert_close_l2 ! ( & x, & y, 1e-7 ) ;
120
16
}
121
17
122
18
#[ test]
123
- fn det_nonsquare ( ) {
124
- macro_rules! det_nonsquare {
125
- ( $elem: ty, $shape: expr) => {
126
- let a: Array2 <$elem> = random( $shape) ;
127
- assert!( a. factorize( ) . unwrap( ) . det( ) . is_err( ) ) ;
128
- assert!( a. factorize( ) . unwrap( ) . det_into( ) . is_err( ) ) ;
129
- assert!( a. det( ) . is_err( ) ) ;
130
- assert!( a. det_into( ) . is_err( ) ) ;
131
- }
132
- }
133
- for & dims in & [ ( 1 , 0 ) , ( 1 , 2 ) , ( 2 , 1 ) , ( 2 , 3 ) ] {
134
- for & shape in & [ dims. clone ( ) . into_shape ( ) , dims. clone ( ) . f ( ) ] {
135
- det_nonsquare ! ( f64 , shape) ;
136
- det_nonsquare ! ( f32 , shape) ;
137
- det_nonsquare ! ( c64, shape) ;
138
- det_nonsquare ! ( c32, shape) ;
139
- }
140
- }
19
+ fn solve_random_t ( ) {
20
+ let a: Array2 < f64 > = random ( ( 3 , 3 ) . f ( ) ) ;
21
+ let x: Array1 < f64 > = random ( 3 ) ;
22
+ let b = a. dot ( & x) ;
23
+ let y = a. solve_into ( b) . unwrap ( ) ;
24
+ assert_close_l2 ! ( & x, & y, 1e-7 ) ;
141
25
}
0 commit comments