@@ -4,10 +4,8 @@ use ndarray::*;
4
4
use ndarray_linalg:: * ;
5
5
6
6
/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
7
- #[ test]
8
- fn least_squares_exact ( ) {
9
- let a: Array2 < f64 > = random ( ( 3 , 3 ) ) ;
10
- let b: Array1 < f64 > = random ( 3 ) ;
7
+ fn test_exact < T : Scalar + Lapack > ( a : Array2 < T > ) {
8
+ let b: Array1 < T > = random ( 3 ) ;
11
9
let result = a. least_squares ( & b) . unwrap ( ) ;
12
10
// unpack result
13
11
let x = result. solution ;
@@ -17,19 +15,43 @@ fn least_squares_exact() {
17
15
assert_eq ! ( result. rank, 3 ) ;
18
16
19
17
// |b - Ax| == 0
20
- assert ! ( residual_l2_square < 1.0e-7 ) ;
18
+ assert ! ( residual_l2_square < T :: real ( 1.0e-4 ) ) ;
21
19
22
20
// b == Ax
23
21
let ax = a. dot ( & x) ;
24
- assert_close_l2 ! ( & b, & ax, 1.0e-7 ) ;
22
+ assert_close_l2 ! ( & b, & ax, T :: real ( 1.0e-4 ) ) ;
25
23
}
26
24
25
+ macro_rules! impl_exact {
26
+ ( $scalar: ty) => {
27
+ paste:: item! {
28
+ #[ test]
29
+ fn [ <least_squares_ $scalar _exact>] ( ) {
30
+ let a: Array2 <f64 > = random( ( 3 , 3 ) ) ;
31
+ test_exact( a)
32
+ }
33
+
34
+ #[ test]
35
+ fn [ <least_squares_ $scalar _exact_t>] ( ) {
36
+ let a: Array2 <f64 > = random( ( 3 , 3 ) . f( ) ) ;
37
+ test_exact( a)
38
+ }
39
+ }
40
+ } ;
41
+ }
42
+
43
+ impl_exact ! ( f32 ) ;
44
+ impl_exact ! ( f64 ) ;
45
+ impl_exact ! ( c32) ;
46
+ impl_exact ! ( c64) ;
47
+
27
48
/// #column < #row case.
28
49
/// Linear problem is overdetermined, `|b - Ax| > 0`.
29
- #[ test]
30
- fn least_squares_overdetermined ( ) {
31
- let a: Array2 < f64 > = random ( ( 4 , 3 ) ) ;
32
- let b: Array1 < f64 > = random ( 4 ) ;
50
+ fn test_overdetermined < T : Scalar + Lapack > ( a : Array2 < T > )
51
+ where
52
+ T :: Real : AbsDiffEq < Epsilon = T :: Real > ,
53
+ {
54
+ let b: Array1 < T > = random ( 4 ) ;
33
55
let result = a. least_squares ( & b) . unwrap ( ) ;
34
56
// unpack result
35
57
let x = result. solution ;
@@ -40,24 +62,68 @@ fn least_squares_overdetermined() {
40
62
41
63
// eval `residual = b - Ax`
42
64
let residual = & b - & a. dot ( & x) ;
43
- assert ! ( residual_l2_square. abs_diff_eq( & residual. norm_l2( ) . powi( 2 ) , 1e-12 ) ) ;
65
+ assert ! ( residual_l2_square. abs_diff_eq( & residual. norm_l2( ) . powi( 2 ) , T :: real ( 1.0e-4 ) ) ) ;
44
66
45
67
// `|residual| < |b|`
46
68
assert ! ( residual. norm_l2( ) < b. norm_l2( ) ) ;
47
69
}
48
70
71
+ macro_rules! impl_overdetermined {
72
+ ( $scalar: ty) => {
73
+ paste:: item! {
74
+ #[ test]
75
+ fn [ <least_squares_ $scalar _overdetermined>] ( ) {
76
+ let a: Array2 <f64 > = random( ( 4 , 3 ) ) ;
77
+ test_overdetermined( a)
78
+ }
79
+
80
+ #[ test]
81
+ fn [ <least_squares_ $scalar _overdetermined_t>] ( ) {
82
+ let a: Array2 <f64 > = random( ( 4 , 3 ) . f( ) ) ;
83
+ test_overdetermined( a)
84
+ }
85
+ }
86
+ } ;
87
+ }
88
+
89
+ impl_overdetermined ! ( f32 ) ;
90
+ impl_overdetermined ! ( f64 ) ;
91
+ impl_overdetermined ! ( c32) ;
92
+ impl_overdetermined ! ( c64) ;
93
+
49
94
/// #column > #row case.
50
95
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
51
- #[ test]
52
- fn least_squares_underdetermined ( ) {
53
- let a: Array2 < f64 > = random ( ( 3 , 4 ) ) ;
54
- let b: Array1 < f64 > = random ( 3 ) ;
96
+ fn test_underdetermined < T : Scalar + Lapack > ( a : Array2 < T > ) {
97
+ let b: Array1 < T > = random ( 3 ) ;
55
98
let result = a. least_squares ( & b) . unwrap ( ) ;
56
99
assert_eq ! ( result. rank, 3 ) ;
57
100
assert ! ( result. residual_sum_of_squares. is_none( ) ) ;
58
101
59
102
// b == Ax
60
103
let x = result. solution ;
61
104
let ax = a. dot ( & x) ;
62
- assert_close_l2 ! ( & b, & ax, 1.0e-7 ) ;
105
+ assert_close_l2 ! ( & b, & ax, T :: real( 1.0e-4 ) ) ;
106
+ }
107
+
108
+ macro_rules! impl_underdetermined {
109
+ ( $scalar: ty) => {
110
+ paste:: item! {
111
+ #[ test]
112
+ fn [ <least_squares_ $scalar _underdetermined>] ( ) {
113
+ let a: Array2 <f64 > = random( ( 3 , 4 ) ) ;
114
+ test_underdetermined( a)
115
+ }
116
+
117
+ #[ test]
118
+ fn [ <least_squares_ $scalar _underdetermined_t>] ( ) {
119
+ let a: Array2 <f64 > = random( ( 3 , 4 ) . f( ) ) ;
120
+ test_underdetermined( a)
121
+ }
122
+ }
123
+ } ;
63
124
}
125
+
126
+ impl_underdetermined ! ( f32 ) ;
127
+ impl_underdetermined ! ( f64 ) ;
128
+ impl_underdetermined ! ( c32) ;
129
+ impl_underdetermined ! ( c64) ;
0 commit comments