Skip to content

Commit e51f966

Browse files
committed
Rewrite test cases
1 parent 9cf8331 commit e51f966

File tree

1 file changed

+41
-139
lines changed

1 file changed

+41
-139
lines changed

ndarray-linalg/tests/least_squares.rs

Lines changed: 41 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,63 @@
1+
/// Solve least square problem `|b - Ax|`
12
use approx::AbsDiffEq;
23
use ndarray::*;
34
use ndarray_linalg::*;
4-
use num_complex::Complex;
55

6-
fn c(re: f64, im: f64) -> Complex<f64> {
7-
Complex::new(re, im)
8-
}
9-
10-
//
11-
// Test cases taken from the scipy test suite for the scipy lstsq function
12-
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/basic.py
13-
//
6+
/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
147
#[test]
158
fn least_squares_exact() {
16-
let a = array![[1., 20.], [-30., 4.]];
17-
let bs = vec![
18-
array![[1., 0.], [0., 1.]],
19-
array![[1.], [0.]],
20-
array![[2., 1.], [-30., 4.]],
21-
];
22-
for b in &bs {
23-
let res = a.least_squares(b).unwrap();
24-
assert_eq!(res.rank, 2);
25-
let b_hat = a.dot(&res.solution);
26-
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
27-
assert!(res
28-
.residual_sum_of_squares
29-
.unwrap()
30-
.abs_diff_eq(&rssq, 1e-12));
31-
assert!(b_hat.abs_diff_eq(&b, 1e-12));
32-
}
33-
}
9+
let a: Array2<f64> = random((3, 3));
10+
let b: Array1<f64> = random(3);
11+
let result = a.least_squares(&b).unwrap();
12+
// unpack result
13+
let x = result.solution;
14+
let residual_l2_square = result.residual_sum_of_squares.unwrap()[()];
3415

35-
#[test]
36-
fn least_squares_overdetermined() {
37-
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
38-
let b: Array1<f64> = array![1., 2., 3.];
39-
let res = a.least_squares(&b).unwrap();
40-
assert_eq!(res.rank, 2);
41-
let b_hat = a.dot(&res.solution);
42-
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
43-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
44-
assert!(res
45-
.solution
46-
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
47-
}
16+
// must be full-rank
17+
assert_eq!(result.rank, 3);
4818

49-
#[test]
50-
fn least_squares_overdetermined_complex() {
51-
let a: Array2<c64> = array![
52-
[c(1., 2.), c(2., 0.)],
53-
[c(4., 0.), c(5., 0.)],
54-
[c(3., 0.), c(4., 0.)]
55-
];
56-
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
57-
let res = a.least_squares(&b).unwrap();
58-
assert_eq!(res.rank, 2);
59-
let b_hat = a.dot(&res.solution);
60-
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
61-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
62-
assert!(res.solution.abs_diff_eq(
63-
&array![
64-
c(-0.4831460674157303, 0.258426966292135),
65-
c(0.921348314606741, 0.292134831460674)
66-
],
67-
1e-12
68-
));
69-
}
19+
// |b - Ax| == 0
20+
assert!(residual_l2_square < 1.0e-7);
7021

71-
#[test]
72-
fn least_squares_underdetermined() {
73-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
74-
let b: Array1<f64> = array![1., 2.];
75-
let res = a.least_squares(&b).unwrap();
76-
assert_eq!(res.rank, 2);
77-
assert!(res.residual_sum_of_squares.is_none());
78-
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
79-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
22+
// b == Ax
23+
let ax = a.dot(&x);
24+
assert_close_l2!(&b, &ax, 1.0e-7);
8025
}
8126

82-
/// This test case tests the underdetermined case for multiple right hand
83-
/// sides. Adapted from scipy lstsq tests.
27+
/// #column < #row case.
28+
/// Linear problem is overdetermined, `|b - Ax| > 0`.
8429
#[test]
85-
fn least_squares_underdetermined_nrhs() {
86-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
87-
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
88-
let res = a.least_squares(&b).unwrap();
89-
assert_eq!(res.rank, 2);
90-
assert!(res.residual_sum_of_squares.is_none());
91-
let expected = array![
92-
[-0.055555555555555, -0.055555555555555],
93-
[0.111111111111111, 0.111111111111111],
94-
[0.277777777777777, 0.277777777777777]
95-
];
96-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
97-
}
98-
99-
//
100-
// Test cases taken from the netlib documentation at
101-
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
102-
//
103-
#[test]
104-
fn netlib_lapack_example_for_dgels_1() {
105-
let a: Array2<f64> = array![
106-
[1., 1., 1.],
107-
[2., 3., 4.],
108-
[3., 5., 2.],
109-
[4., 2., 5.],
110-
[5., 4., 3.]
111-
];
112-
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
113-
let expected: Array1<f64> = array![2., 1., 1.];
30+
fn least_squares_overdetermined() {
31+
let a: Array2<f64> = random((4, 3));
32+
let b: Array1<f64> = random(4);
11433
let result = a.least_squares(&b).unwrap();
115-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
34+
// unpack result
35+
let x = result.solution;
36+
let residual_l2_square = result.residual_sum_of_squares.unwrap()[()];
11637

117-
let residual = b - a.dot(&result.solution);
118-
let resid_ssq = result.residual_sum_of_squares.unwrap();
119-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
120-
}
38+
// Must be full-rank
39+
assert_eq!(result.rank, 3);
12140

122-
#[test]
123-
fn netlib_lapack_example_for_dgels_2() {
124-
let a: Array2<f64> = array![
125-
[1., 1., 1.],
126-
[2., 3., 4.],
127-
[3., 5., 2.],
128-
[4., 2., 5.],
129-
[5., 4., 3.]
130-
];
131-
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
132-
let expected: Array1<f64> = array![1., 1., 2.];
133-
let result = a.least_squares(&b).unwrap();
134-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
41+
// eval `residual = b - Ax`
42+
let residual = &b - &a.dot(&x);
43+
assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12));
13544

136-
let residual = b - a.dot(&result.solution);
137-
let resid_ssq = result.residual_sum_of_squares.unwrap();
138-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
45+
// `|residual| < |b|`
46+
assert!(residual.norm_l2() < b.norm_l2());
13947
}
14048

49+
/// #column > #row case.
50+
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
14151
#[test]
142-
fn netlib_lapack_example_for_dgels_nrhs() {
143-
let a: Array2<f64> = array![
144-
[1., 1., 1.],
145-
[2., 3., 4.],
146-
[3., 5., 2.],
147-
[4., 2., 5.],
148-
[5., 4., 3.]
149-
];
150-
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
151-
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
52+
fn least_squares_underdetermined() {
53+
let a: Array2<f64> = random((3, 4));
54+
let b: Array1<f64> = random(3);
15255
let result = a.least_squares(&b).unwrap();
153-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
56+
assert_eq!(result.rank, 3);
57+
assert!(result.residual_sum_of_squares.is_none());
15458

155-
let residual = &b - &a.dot(&result.solution);
156-
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
157-
assert!(result
158-
.residual_sum_of_squares
159-
.unwrap()
160-
.abs_diff_eq(&residual_ssq, 1e-12));
59+
// b == Ax
60+
let x = result.solution;
61+
let ax = a.dot(&x);
62+
assert_close_l2!(&b, &ax, 1.0e-7);
16163
}

0 commit comments

Comments
 (0)