Skip to content

Commit a8a07fb

Browse files
committed
Test for F-continuous and complex cases
1 parent e51f966 commit a8a07fb

File tree

1 file changed

+82
-16
lines changed

1 file changed

+82
-16
lines changed

ndarray-linalg/tests/least_squares.rs

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ use ndarray::*;
44
use ndarray_linalg::*;
55

66
/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
7-
#[test]
8-
fn least_squares_exact() {
9-
let a: Array2<f64> = random((3, 3));
10-
let b: Array1<f64> = random(3);
7+
fn test_exact<T: Scalar + Lapack>(a: Array2<T>) {
8+
let b: Array1<T> = random(3);
119
let result = a.least_squares(&b).unwrap();
1210
// unpack result
1311
let x = result.solution;
@@ -17,19 +15,43 @@ fn least_squares_exact() {
1715
assert_eq!(result.rank, 3);
1816

1917
// |b - Ax| == 0
20-
assert!(residual_l2_square < 1.0e-7);
18+
assert!(residual_l2_square < T::real(1.0e-4));
2119

2220
// b == Ax
2321
let ax = a.dot(&x);
24-
assert_close_l2!(&b, &ax, 1.0e-7);
22+
assert_close_l2!(&b, &ax, T::real(1.0e-4));
2523
}
2624

25+
macro_rules! impl_exact {
26+
($scalar:ty) => {
27+
paste::item! {
28+
#[test]
29+
fn [<least_squares_ $scalar _exact>]() {
30+
let a: Array2<f64> = random((3, 3));
31+
test_exact(a)
32+
}
33+
34+
#[test]
35+
fn [<least_squares_ $scalar _exact_t>]() {
36+
let a: Array2<f64> = random((3, 3).f());
37+
test_exact(a)
38+
}
39+
}
40+
};
41+
}
42+
43+
impl_exact!(f32);
44+
impl_exact!(f64);
45+
impl_exact!(c32);
46+
impl_exact!(c64);
47+
2748
/// #column < #row case.
2849
/// Linear problem is overdetermined, `|b - Ax| > 0`.
29-
#[test]
30-
fn least_squares_overdetermined() {
31-
let a: Array2<f64> = random((4, 3));
32-
let b: Array1<f64> = random(4);
50+
fn test_overdetermined<T: Scalar + Lapack>(a: Array2<T>)
51+
where
52+
T::Real: AbsDiffEq<Epsilon = T::Real>,
53+
{
54+
let b: Array1<T> = random(4);
3355
let result = a.least_squares(&b).unwrap();
3456
// unpack result
3557
let x = result.solution;
@@ -40,24 +62,68 @@ fn least_squares_overdetermined() {
4062

4163
// eval `residual = b - Ax`
4264
let residual = &b - &a.dot(&x);
43-
assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12));
65+
assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4)));
4466

4567
// `|residual| < |b|`
4668
assert!(residual.norm_l2() < b.norm_l2());
4769
}
4870

71+
macro_rules! impl_overdetermined {
72+
($scalar:ty) => {
73+
paste::item! {
74+
#[test]
75+
fn [<least_squares_ $scalar _overdetermined>]() {
76+
let a: Array2<f64> = random((4, 3));
77+
test_overdetermined(a)
78+
}
79+
80+
#[test]
81+
fn [<least_squares_ $scalar _overdetermined_t>]() {
82+
let a: Array2<f64> = random((4, 3).f());
83+
test_overdetermined(a)
84+
}
85+
}
86+
};
87+
}
88+
89+
impl_overdetermined!(f32);
90+
impl_overdetermined!(f64);
91+
impl_overdetermined!(c32);
92+
impl_overdetermined!(c64);
93+
4994
/// #column > #row case.
5095
/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
51-
#[test]
52-
fn least_squares_underdetermined() {
53-
let a: Array2<f64> = random((3, 4));
54-
let b: Array1<f64> = random(3);
96+
fn test_underdetermined<T: Scalar + Lapack>(a: Array2<T>) {
97+
let b: Array1<T> = random(3);
5598
let result = a.least_squares(&b).unwrap();
5699
assert_eq!(result.rank, 3);
57100
assert!(result.residual_sum_of_squares.is_none());
58101

59102
// b == Ax
60103
let x = result.solution;
61104
let ax = a.dot(&x);
62-
assert_close_l2!(&b, &ax, 1.0e-7);
105+
assert_close_l2!(&b, &ax, T::real(1.0e-4));
106+
}
107+
108+
macro_rules! impl_underdetermined {
109+
($scalar:ty) => {
110+
paste::item! {
111+
#[test]
112+
fn [<least_squares_ $scalar _underdetermined>]() {
113+
let a: Array2<f64> = random((3, 4));
114+
test_underdetermined(a)
115+
}
116+
117+
#[test]
118+
fn [<least_squares_ $scalar _underdetermined_t>]() {
119+
let a: Array2<f64> = random((3, 4).f());
120+
test_underdetermined(a)
121+
}
122+
}
123+
};
63124
}
125+
126+
impl_underdetermined!(f32);
127+
impl_underdetermined!(f64);
128+
impl_underdetermined!(c32);
129+
impl_underdetermined!(c64);

0 commit comments

Comments
 (0)