|
| 1 | +/// Solve least square problem `|b - Ax|` |
1 | 2 | use approx::AbsDiffEq;
|
2 | 3 | use ndarray::*;
|
3 | 4 | use ndarray_linalg::*;
|
4 |
| -use num_complex::Complex; |
5 | 5 |
|
6 |
| -fn c(re: f64, im: f64) -> Complex<f64> { |
7 |
| - Complex::new(re, im) |
8 |
| -} |
9 |
| - |
10 |
| -// |
11 |
| -// Test cases taken from the scipy test suite for the scipy lstsq function |
12 |
| -// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/basic.py |
13 |
| -// |
| 6 | +/// A is square. `x = A^{-1} b`, `|b - Ax| = 0` |
14 | 7 | #[test]
|
15 | 8 | fn least_squares_exact() {
|
16 |
| - let a = array![[1., 20.], [-30., 4.]]; |
17 |
| - let bs = vec![ |
18 |
| - array![[1., 0.], [0., 1.]], |
19 |
| - array![[1.], [0.]], |
20 |
| - array![[2., 1.], [-30., 4.]], |
21 |
| - ]; |
22 |
| - for b in &bs { |
23 |
| - let res = a.least_squares(b).unwrap(); |
24 |
| - assert_eq!(res.rank, 2); |
25 |
| - let b_hat = a.dot(&res.solution); |
26 |
| - let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0)); |
27 |
| - assert!(res |
28 |
| - .residual_sum_of_squares |
29 |
| - .unwrap() |
30 |
| - .abs_diff_eq(&rssq, 1e-12)); |
31 |
| - assert!(b_hat.abs_diff_eq(&b, 1e-12)); |
32 |
| - } |
33 |
| -} |
| 9 | + let a: Array2<f64> = random((3, 3)); |
| 10 | + let b: Array1<f64> = random(3); |
| 11 | + let result = a.least_squares(&b).unwrap(); |
| 12 | + // unpack result |
| 13 | + let x = result.solution; |
| 14 | + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; |
34 | 15 |
|
35 |
| -#[test] |
36 |
| -fn least_squares_overdetermined() { |
37 |
| - let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]]; |
38 |
| - let b: Array1<f64> = array![1., 2., 3.]; |
39 |
| - let res = a.least_squares(&b).unwrap(); |
40 |
| - assert_eq!(res.rank, 2); |
41 |
| - let b_hat = a.dot(&res.solution); |
42 |
| - let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); |
43 |
| - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); |
44 |
| - assert!(res |
45 |
| - .solution |
46 |
| - .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12)); |
47 |
| -} |
| 16 | + // must be full-rank |
| 17 | + assert_eq!(result.rank, 3); |
48 | 18 |
|
49 |
| -#[test] |
50 |
| -fn least_squares_overdetermined_complex() { |
51 |
| - let a: Array2<c64> = array![ |
52 |
| - [c(1., 2.), c(2., 0.)], |
53 |
| - [c(4., 0.), c(5., 0.)], |
54 |
| - [c(3., 0.), c(4., 0.)] |
55 |
| - ]; |
56 |
| - let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)]; |
57 |
| - let res = a.least_squares(&b).unwrap(); |
58 |
| - assert_eq!(res.rank, 2); |
59 |
| - let b_hat = a.dot(&res.solution); |
60 |
| - let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum(); |
61 |
| - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); |
62 |
| - assert!(res.solution.abs_diff_eq( |
63 |
| - &array![ |
64 |
| - c(-0.4831460674157303, 0.258426966292135), |
65 |
| - c(0.921348314606741, 0.292134831460674) |
66 |
| - ], |
67 |
| - 1e-12 |
68 |
| - )); |
69 |
| -} |
| 19 | + // |b - Ax| == 0 |
| 20 | + assert!(residual_l2_square < 1.0e-7); |
70 | 21 |
|
71 |
| -#[test] |
72 |
| -fn least_squares_underdetermined() { |
73 |
| - let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]]; |
74 |
| - let b: Array1<f64> = array![1., 2.]; |
75 |
| - let res = a.least_squares(&b).unwrap(); |
76 |
| - assert_eq!(res.rank, 2); |
77 |
| - assert!(res.residual_sum_of_squares.is_none()); |
78 |
| - let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777]; |
79 |
| - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); |
| 22 | + // b == Ax |
| 23 | + let ax = a.dot(&x); |
| 24 | + assert_close_l2!(&b, &ax, 1.0e-7); |
80 | 25 | }
|
81 | 26 |
|
82 |
| -/// This test case tests the underdetermined case for multiple right hand |
83 |
| -/// sides. Adapted from scipy lstsq tests. |
| 27 | +/// #column < #row case. |
| 28 | +/// Linear problem is overdetermined, `|b - Ax| > 0`. |
84 | 29 | #[test]
|
85 |
| -fn least_squares_underdetermined_nrhs() { |
86 |
| - let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]]; |
87 |
| - let b: Array2<f64> = array![[1., 1.], [2., 2.]]; |
88 |
| - let res = a.least_squares(&b).unwrap(); |
89 |
| - assert_eq!(res.rank, 2); |
90 |
| - assert!(res.residual_sum_of_squares.is_none()); |
91 |
| - let expected = array![ |
92 |
| - [-0.055555555555555, -0.055555555555555], |
93 |
| - [0.111111111111111, 0.111111111111111], |
94 |
| - [0.277777777777777, 0.277777777777777] |
95 |
| - ]; |
96 |
| - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); |
97 |
| -} |
98 |
| - |
99 |
| -// |
100 |
| -// Test cases taken from the netlib documentation at |
101 |
| -// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code |
102 |
| -// |
103 |
| -#[test] |
104 |
| -fn netlib_lapack_example_for_dgels_1() { |
105 |
| - let a: Array2<f64> = array![ |
106 |
| - [1., 1., 1.], |
107 |
| - [2., 3., 4.], |
108 |
| - [3., 5., 2.], |
109 |
| - [4., 2., 5.], |
110 |
| - [5., 4., 3.] |
111 |
| - ]; |
112 |
| - let b: Array1<f64> = array![-10., 12., 14., 16., 18.]; |
113 |
| - let expected: Array1<f64> = array![2., 1., 1.]; |
| 30 | +fn least_squares_overdetermined() { |
| 31 | + let a: Array2<f64> = random((4, 3)); |
| 32 | + let b: Array1<f64> = random(4); |
114 | 33 | let result = a.least_squares(&b).unwrap();
|
115 |
| - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 34 | + // unpack result |
| 35 | + let x = result.solution; |
| 36 | + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; |
116 | 37 |
|
117 |
| - let residual = b - a.dot(&result.solution); |
118 |
| - let resid_ssq = result.residual_sum_of_squares.unwrap(); |
119 |
| - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); |
120 |
| -} |
| 38 | + // Must be full-rank |
| 39 | + assert_eq!(result.rank, 3); |
121 | 40 |
|
122 |
| -#[test] |
123 |
| -fn netlib_lapack_example_for_dgels_2() { |
124 |
| - let a: Array2<f64> = array![ |
125 |
| - [1., 1., 1.], |
126 |
| - [2., 3., 4.], |
127 |
| - [3., 5., 2.], |
128 |
| - [4., 2., 5.], |
129 |
| - [5., 4., 3.] |
130 |
| - ]; |
131 |
| - let b: Array1<f64> = array![-3., 14., 12., 16., 16.]; |
132 |
| - let expected: Array1<f64> = array![1., 1., 2.]; |
133 |
| - let result = a.least_squares(&b).unwrap(); |
134 |
| - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 41 | + // eval `residual = b - Ax` |
| 42 | + let residual = &b - &a.dot(&x); |
| 43 | + assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12)); |
135 | 44 |
|
136 |
| - let residual = b - a.dot(&result.solution); |
137 |
| - let resid_ssq = result.residual_sum_of_squares.unwrap(); |
138 |
| - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); |
| 45 | + // `|residual| < |b|` |
| 46 | + assert!(residual.norm_l2() < b.norm_l2()); |
139 | 47 | }
|
140 | 48 |
|
| 49 | +/// #column > #row case. |
| 50 | +/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique |
141 | 51 | #[test]
|
142 |
| -fn netlib_lapack_example_for_dgels_nrhs() { |
143 |
| - let a: Array2<f64> = array![ |
144 |
| - [1., 1., 1.], |
145 |
| - [2., 3., 4.], |
146 |
| - [3., 5., 2.], |
147 |
| - [4., 2., 5.], |
148 |
| - [5., 4., 3.] |
149 |
| - ]; |
150 |
| - let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]]; |
151 |
| - let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]]; |
| 52 | +fn least_squares_underdetermined() { |
| 53 | + let a: Array2<f64> = random((3, 4)); |
| 54 | + let b: Array1<f64> = random(3); |
152 | 55 | let result = a.least_squares(&b).unwrap();
|
153 |
| - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 56 | + assert_eq!(result.rank, 3); |
| 57 | + assert!(result.residual_sum_of_squares.is_none()); |
154 | 58 |
|
155 |
| - let residual = &b - &a.dot(&result.solution); |
156 |
| - let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0)); |
157 |
| - assert!(result |
158 |
| - .residual_sum_of_squares |
159 |
| - .unwrap() |
160 |
| - .abs_diff_eq(&residual_ssq, 1e-12)); |
| 59 | + // b == Ax |
| 60 | + let x = result.solution; |
| 61 | + let ax = a.dot(&x); |
| 62 | + assert_close_l2!(&b, &ax, 1.0e-7); |
161 | 63 | }
|
0 commit comments