Skip to content

Commit 4d0d8c3

Browse files
authored
Merge pull request #227 from rust-ndarray/least-square-test
Revise tests for least-square problems
2 parents 9613cfe + f6a9c2a commit 4d0d8c3

File tree

4 files changed

+356
-194
lines changed

4 files changed

+356
-194
lines changed

lax/src/least_squares.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ macro_rules! impl_least_squares {
4242
}
4343
let k = ::std::cmp::min(m, n);
4444
let nrhs = 1;
45+
let ldb = match a_layout {
46+
MatrixLayout::F { .. } => m.max(n),
47+
MatrixLayout::C { .. } => 1,
48+
};
4549
let rcond: Self::Real = -1.;
4650
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
4751
let mut rank: i32 = 0;
@@ -54,9 +58,7 @@ macro_rules! impl_least_squares {
5458
a,
5559
a_layout.lda(),
5660
b,
57-
// this is the 'leading dimension of b', in the case where
58-
// b is a single vector, this is 1
59-
nrhs,
61+
ldb,
6062
&mut singular_values,
6163
rcond,
6264
&mut rank,

ndarray-linalg/src/least_squares.rs

Lines changed: 21 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
//! // `a` and `b` have been moved, no longer valid
6161
//! ```
6262
63-
use ndarray::{s, Array, Array1, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2};
63+
use ndarray::*;
6464

6565
use crate::error::*;
6666
use crate::lapack::least_squares::*;
@@ -352,7 +352,10 @@ where
352352
// we need a new rhs b/c it will be overwritten with the solution
353353
// for which we need `n` entries
354354
let k = rhs.shape()[1];
355-
let mut new_rhs = Array2::<E>::zeros((n, k));
355+
let mut new_rhs = match self.layout()? {
356+
MatrixLayout::C { .. } => Array2::<E>::zeros((n, k)),
357+
MatrixLayout::F { .. } => Array2::<E>::zeros((n, k).f()),
358+
};
356359
new_rhs.slice_mut(s![0..m, ..]).assign(rhs);
357360
compute_least_squares_nrhs(self, &mut new_rhs)
358361
} else {
@@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414417

415418
#[cfg(test)]
416419
mod tests {
417-
use super::*;
420+
use crate::{error::LinalgError, *};
418421
use approx::AbsDiffEq;
419-
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
420-
use num_complex::Complex;
421-
422-
//
423-
// Test cases taken from the scipy test suite for the scipy lstsq function
424-
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425-
//
426-
#[test]
427-
fn scipy_test_simple_exact() {
428-
let a = array![[1., 20.], [-30., 4.]];
429-
let bs = vec![
430-
array![[1., 0.], [0., 1.]],
431-
array![[1.], [0.]],
432-
array![[2., 1.], [-30., 4.]],
433-
];
434-
for b in &bs {
435-
let res = a.least_squares(b).unwrap();
436-
assert_eq!(res.rank, 2);
437-
let b_hat = a.dot(&res.solution);
438-
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
439-
assert!(res
440-
.residual_sum_of_squares
441-
.unwrap()
442-
.abs_diff_eq(&rssq, 1e-12));
443-
assert!(b_hat.abs_diff_eq(&b, 1e-12));
444-
}
445-
}
446-
447-
#[test]
448-
fn scipy_test_simple_overdetermined() {
449-
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
450-
let b: Array1<f64> = array![1., 2., 3.];
451-
let res = a.least_squares(&b).unwrap();
452-
assert_eq!(res.rank, 2);
453-
let b_hat = a.dot(&res.solution);
454-
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
455-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
456-
assert!(res
457-
.solution
458-
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
459-
}
460-
461-
#[test]
462-
fn scipy_test_simple_overdetermined_f32() {
463-
let a: Array2<f32> = array![[1., 2.], [4., 5.], [3., 4.]];
464-
let b: Array1<f32> = array![1., 2., 3.];
465-
let res = a.least_squares(&b).unwrap();
466-
assert_eq!(res.rank, 2);
467-
let b_hat = a.dot(&res.solution);
468-
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
469-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
470-
assert!(res
471-
.solution
472-
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
473-
}
474-
475-
fn c(re: f64, im: f64) -> Complex<f64> {
476-
Complex::new(re, im)
477-
}
478-
479-
#[test]
480-
fn scipy_test_simple_overdetermined_complex() {
481-
let a: Array2<c64> = array![
482-
[c(1., 2.), c(2., 0.)],
483-
[c(4., 0.), c(5., 0.)],
484-
[c(3., 0.), c(4., 0.)]
485-
];
486-
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
487-
let res = a.least_squares(&b).unwrap();
488-
assert_eq!(res.rank, 2);
489-
let b_hat = a.dot(&res.solution);
490-
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
491-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
492-
assert!(res.solution.abs_diff_eq(
493-
&array![
494-
c(-0.4831460674157303, 0.258426966292135),
495-
c(0.921348314606741, 0.292134831460674)
496-
],
497-
1e-12
498-
));
499-
}
500-
501-
#[test]
502-
fn scipy_test_simple_underdetermined() {
503-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
504-
let b: Array1<f64> = array![1., 2.];
505-
let res = a.least_squares(&b).unwrap();
506-
assert_eq!(res.rank, 2);
507-
assert!(res.residual_sum_of_squares.is_none());
508-
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
509-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
510-
}
511-
512-
/// This test case tests the underdetermined case for multiple right hand
513-
/// sides. Adapted from scipy lstsq tests.
514-
#[test]
515-
fn scipy_test_simple_underdetermined_nrhs() {
516-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
517-
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
518-
let res = a.least_squares(&b).unwrap();
519-
assert_eq!(res.rank, 2);
520-
assert!(res.residual_sum_of_squares.is_none());
521-
let expected = array![
522-
[-0.055555555555555, -0.055555555555555],
523-
[0.111111111111111, 0.111111111111111],
524-
[0.277777777777777, 0.277777777777777]
525-
];
526-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
527-
}
422+
use ndarray::*;
528423

529424
//
530425
// Test that the different lest squares traits work as intended on the
@@ -554,23 +449,23 @@ mod tests {
554449
}
555450

556451
#[test]
557-
fn test_least_squares_on_arc() {
452+
fn on_arc() {
558453
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
559454
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
560455
let res = a.least_squares(&b).unwrap();
561456
assert_result(&a, &b, &res);
562457
}
563458

564459
#[test]
565-
fn test_least_squares_on_cow() {
460+
fn on_cow() {
566461
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
567462
let b = CowArray::from(array![1., 2., 3.]);
568463
let res = a.least_squares(&b).unwrap();
569464
assert_result(&a, &b, &res);
570465
}
571466

572467
#[test]
573-
fn test_least_squares_on_view() {
468+
fn on_view() {
574469
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
575470
let b: Array1<f64> = array![1., 2., 3.];
576471
let av = a.view();
@@ -580,7 +475,7 @@ mod tests {
580475
}
581476

582477
#[test]
583-
fn test_least_squares_on_view_mut() {
478+
fn on_view_mut() {
584479
let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
585480
let mut b: Array1<f64> = array![1., 2., 3.];
586481
let av = a.view_mut();
@@ -590,7 +485,7 @@ mod tests {
590485
}
591486

592487
#[test]
593-
fn test_least_squares_into_on_owned() {
488+
fn into_on_owned() {
594489
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
595490
let b: Array1<f64> = array![1., 2., 3.];
596491
let ac = a.clone();
@@ -600,7 +495,7 @@ mod tests {
600495
}
601496

602497
#[test]
603-
fn test_least_squares_into_on_arc() {
498+
fn into_on_arc() {
604499
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
605500
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
606501
let a2 = a.clone();
@@ -610,7 +505,7 @@ mod tests {
610505
}
611506

612507
#[test]
613-
fn test_least_squares_into_on_cow() {
508+
fn into_on_cow() {
614509
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
615510
let b = CowArray::from(array![1., 2., 3.]);
616511
let a2 = a.clone();
@@ -620,7 +515,7 @@ mod tests {
620515
}
621516

622517
#[test]
623-
fn test_least_squares_in_place_on_owned() {
518+
fn in_place_on_owned() {
624519
let a = array![[1., 2.], [4., 5.], [3., 4.]];
625520
let b = array![1., 2., 3.];
626521
let mut a2 = a.clone();
@@ -630,7 +525,7 @@ mod tests {
630525
}
631526

632527
#[test]
633-
fn test_least_squares_in_place_on_cow() {
528+
fn in_place_on_cow() {
634529
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
635530
let b = CowArray::from(array![1., 2., 3.]);
636531
let mut a2 = a.clone();
@@ -640,7 +535,7 @@ mod tests {
640535
}
641536

642537
#[test]
643-
fn test_least_squares_in_place_on_mut_view() {
538+
fn in_place_on_mut_view() {
644539
let a = array![[1., 2.], [4., 5.], [3., 4.]];
645540
let b = array![1., 2., 3.];
646541
let mut a2 = a.clone();
@@ -651,95 +546,30 @@ mod tests {
651546
assert_result(&a, &b, &res);
652547
}
653548

654-
//
655-
// Test cases taken from the netlib documentation at
656-
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657-
//
658-
#[test]
659-
fn netlib_lapack_example_for_dgels_1() {
660-
let a: Array2<f64> = array![
661-
[1., 1., 1.],
662-
[2., 3., 4.],
663-
[3., 5., 2.],
664-
[4., 2., 5.],
665-
[5., 4., 3.]
666-
];
667-
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
668-
let expected: Array1<f64> = array![2., 1., 1.];
669-
let result = a.least_squares(&b).unwrap();
670-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
671-
672-
let residual = b - a.dot(&result.solution);
673-
let resid_ssq = result.residual_sum_of_squares.unwrap();
674-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
675-
}
676-
677-
#[test]
678-
fn netlib_lapack_example_for_dgels_2() {
679-
let a: Array2<f64> = array![
680-
[1., 1., 1.],
681-
[2., 3., 4.],
682-
[3., 5., 2.],
683-
[4., 2., 5.],
684-
[5., 4., 3.]
685-
];
686-
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
687-
let expected: Array1<f64> = array![1., 1., 2.];
688-
let result = a.least_squares(&b).unwrap();
689-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
690-
691-
let residual = b - a.dot(&result.solution);
692-
let resid_ssq = result.residual_sum_of_squares.unwrap();
693-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
694-
}
695-
696-
#[test]
697-
fn netlib_lapack_example_for_dgels_nrhs() {
698-
let a: Array2<f64> = array![
699-
[1., 1., 1.],
700-
[2., 3., 4.],
701-
[3., 5., 2.],
702-
[4., 2., 5.],
703-
[5., 4., 3.]
704-
];
705-
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
706-
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
707-
let result = a.least_squares(&b).unwrap();
708-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
709-
710-
let residual = &b - &a.dot(&result.solution);
711-
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
712-
assert!(result
713-
.residual_sum_of_squares
714-
.unwrap()
715-
.abs_diff_eq(&residual_ssq, 1e-12));
716-
}
717-
718549
//
719550
// Testing error cases
720551
//
721-
use crate::layout::MatrixLayout;
722552

723553
#[test]
724-
fn test_incompatible_shape_error_on_mismatching_num_rows() {
554+
fn incompatible_shape_error_on_mismatching_num_rows() {
725555
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
726556
let b: Array1<f64> = array![1., 2.];
727557
let res = a.least_squares(&b);
728558
match res {
729-
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
559+
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
730560
_ => panic!("Expected Err()"),
731561
}
732562
}
733563

734564
#[test]
735-
fn test_incompatible_shape_error_on_mismatching_layout() {
565+
fn incompatible_shape_error_on_mismatching_layout() {
736566
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
737567
let b = array![[1.], [2.]].t().to_owned();
738568
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });
739569

740570
let res = a.least_squares(&b);
741571
match res {
742-
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
572+
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
743573
_ => panic!("Expected Err()"),
744574
}
745575
}

0 commit comments

Comments
 (0)