Skip to content

Commit 9cf8331

Browse files
committed
Move test of least_squares into tests/, and minor cleanup
1 parent 9613cfe commit 9cf8331

File tree

2 files changed

+177
-189
lines changed

2 files changed

+177
-189
lines changed

ndarray-linalg/src/least_squares.rs

Lines changed: 16 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -414,117 +414,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414414

415415
#[cfg(test)]
416416
mod tests {
417-
use super::*;
417+
use crate::{error::LinalgError, *};
418418
use approx::AbsDiffEq;
419-
use ndarray::{ArcArray1, ArcArray2, Array1, Array2, CowArray};
420-
use num_complex::Complex;
421-
422-
//
423-
// Test cases taken from the scipy test suite for the scipy lstsq function
424-
// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425-
//
426-
#[test]
427-
fn scipy_test_simple_exact() {
428-
let a = array![[1., 20.], [-30., 4.]];
429-
let bs = vec![
430-
array![[1., 0.], [0., 1.]],
431-
array![[1.], [0.]],
432-
array![[2., 1.], [-30., 4.]],
433-
];
434-
for b in &bs {
435-
let res = a.least_squares(b).unwrap();
436-
assert_eq!(res.rank, 2);
437-
let b_hat = a.dot(&res.solution);
438-
let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0));
439-
assert!(res
440-
.residual_sum_of_squares
441-
.unwrap()
442-
.abs_diff_eq(&rssq, 1e-12));
443-
assert!(b_hat.abs_diff_eq(&b, 1e-12));
444-
}
445-
}
446-
447-
#[test]
448-
fn scipy_test_simple_overdetermined() {
449-
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
450-
let b: Array1<f64> = array![1., 2., 3.];
451-
let res = a.least_squares(&b).unwrap();
452-
assert_eq!(res.rank, 2);
453-
let b_hat = a.dot(&res.solution);
454-
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
455-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
456-
assert!(res
457-
.solution
458-
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12));
459-
}
460-
461-
#[test]
462-
fn scipy_test_simple_overdetermined_f32() {
463-
let a: Array2<f32> = array![[1., 2.], [4., 5.], [3., 4.]];
464-
let b: Array1<f32> = array![1., 2., 3.];
465-
let res = a.least_squares(&b).unwrap();
466-
assert_eq!(res.rank, 2);
467-
let b_hat = a.dot(&res.solution);
468-
let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum();
469-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-6));
470-
assert!(res
471-
.solution
472-
.abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-6));
473-
}
474-
475-
fn c(re: f64, im: f64) -> Complex<f64> {
476-
Complex::new(re, im)
477-
}
478-
479-
#[test]
480-
fn scipy_test_simple_overdetermined_complex() {
481-
let a: Array2<c64> = array![
482-
[c(1., 2.), c(2., 0.)],
483-
[c(4., 0.), c(5., 0.)],
484-
[c(3., 0.), c(4., 0.)]
485-
];
486-
let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)];
487-
let res = a.least_squares(&b).unwrap();
488-
assert_eq!(res.rank, 2);
489-
let b_hat = a.dot(&res.solution);
490-
let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum();
491-
assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12));
492-
assert!(res.solution.abs_diff_eq(
493-
&array![
494-
c(-0.4831460674157303, 0.258426966292135),
495-
c(0.921348314606741, 0.292134831460674)
496-
],
497-
1e-12
498-
));
499-
}
500-
501-
#[test]
502-
fn scipy_test_simple_underdetermined() {
503-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
504-
let b: Array1<f64> = array![1., 2.];
505-
let res = a.least_squares(&b).unwrap();
506-
assert_eq!(res.rank, 2);
507-
assert!(res.residual_sum_of_squares.is_none());
508-
let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777];
509-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
510-
}
511-
512-
/// This test case tests the underdetermined case for multiple right hand
513-
/// sides. Adapted from scipy lstsq tests.
514-
#[test]
515-
fn scipy_test_simple_underdetermined_nrhs() {
516-
let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]];
517-
let b: Array2<f64> = array![[1., 1.], [2., 2.]];
518-
let res = a.least_squares(&b).unwrap();
519-
assert_eq!(res.rank, 2);
520-
assert!(res.residual_sum_of_squares.is_none());
521-
let expected = array![
522-
[-0.055555555555555, -0.055555555555555],
523-
[0.111111111111111, 0.111111111111111],
524-
[0.277777777777777, 0.277777777777777]
525-
];
526-
assert!(res.solution.abs_diff_eq(&expected, 1e-12));
527-
}
419+
use ndarray::*;
528420

529421
//
530422
// Test that the different lest squares traits work as intended on the
@@ -554,23 +446,23 @@ mod tests {
554446
}
555447

556448
#[test]
557-
fn test_least_squares_on_arc() {
449+
fn on_arc() {
558450
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
559451
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
560452
let res = a.least_squares(&b).unwrap();
561453
assert_result(&a, &b, &res);
562454
}
563455

564456
#[test]
565-
fn test_least_squares_on_cow() {
457+
fn on_cow() {
566458
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
567459
let b = CowArray::from(array![1., 2., 3.]);
568460
let res = a.least_squares(&b).unwrap();
569461
assert_result(&a, &b, &res);
570462
}
571463

572464
#[test]
573-
fn test_least_squares_on_view() {
465+
fn on_view() {
574466
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
575467
let b: Array1<f64> = array![1., 2., 3.];
576468
let av = a.view();
@@ -580,7 +472,7 @@ mod tests {
580472
}
581473

582474
#[test]
583-
fn test_least_squares_on_view_mut() {
475+
fn on_view_mut() {
584476
let mut a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
585477
let mut b: Array1<f64> = array![1., 2., 3.];
586478
let av = a.view_mut();
@@ -590,7 +482,7 @@ mod tests {
590482
}
591483

592484
#[test]
593-
fn test_least_squares_into_on_owned() {
485+
fn into_on_owned() {
594486
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
595487
let b: Array1<f64> = array![1., 2., 3.];
596488
let ac = a.clone();
@@ -600,7 +492,7 @@ mod tests {
600492
}
601493

602494
#[test]
603-
fn test_least_squares_into_on_arc() {
495+
fn into_on_arc() {
604496
let a: ArcArray2<f64> = array![[1., 2.], [4., 5.], [3., 4.]].into_shared();
605497
let b: ArcArray1<f64> = array![1., 2., 3.].into_shared();
606498
let a2 = a.clone();
@@ -610,7 +502,7 @@ mod tests {
610502
}
611503

612504
#[test]
613-
fn test_least_squares_into_on_cow() {
505+
fn into_on_cow() {
614506
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
615507
let b = CowArray::from(array![1., 2., 3.]);
616508
let a2 = a.clone();
@@ -620,7 +512,7 @@ mod tests {
620512
}
621513

622514
#[test]
623-
fn test_least_squares_in_place_on_owned() {
515+
fn in_place_on_owned() {
624516
let a = array![[1., 2.], [4., 5.], [3., 4.]];
625517
let b = array![1., 2., 3.];
626518
let mut a2 = a.clone();
@@ -630,7 +522,7 @@ mod tests {
630522
}
631523

632524
#[test]
633-
fn test_least_squares_in_place_on_cow() {
525+
fn in_place_on_cow() {
634526
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
635527
let b = CowArray::from(array![1., 2., 3.]);
636528
let mut a2 = a.clone();
@@ -640,7 +532,7 @@ mod tests {
640532
}
641533

642534
#[test]
643-
fn test_least_squares_in_place_on_mut_view() {
535+
fn in_place_on_mut_view() {
644536
let a = array![[1., 2.], [4., 5.], [3., 4.]];
645537
let b = array![1., 2., 3.];
646538
let mut a2 = a.clone();
@@ -651,95 +543,30 @@ mod tests {
651543
assert_result(&a, &b, &res);
652544
}
653545

654-
//
655-
// Test cases taken from the netlib documentation at
656-
// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657-
//
658-
#[test]
659-
fn netlib_lapack_example_for_dgels_1() {
660-
let a: Array2<f64> = array![
661-
[1., 1., 1.],
662-
[2., 3., 4.],
663-
[3., 5., 2.],
664-
[4., 2., 5.],
665-
[5., 4., 3.]
666-
];
667-
let b: Array1<f64> = array![-10., 12., 14., 16., 18.];
668-
let expected: Array1<f64> = array![2., 1., 1.];
669-
let result = a.least_squares(&b).unwrap();
670-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
671-
672-
let residual = b - a.dot(&result.solution);
673-
let resid_ssq = result.residual_sum_of_squares.unwrap();
674-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
675-
}
676-
677-
#[test]
678-
fn netlib_lapack_example_for_dgels_2() {
679-
let a: Array2<f64> = array![
680-
[1., 1., 1.],
681-
[2., 3., 4.],
682-
[3., 5., 2.],
683-
[4., 2., 5.],
684-
[5., 4., 3.]
685-
];
686-
let b: Array1<f64> = array![-3., 14., 12., 16., 16.];
687-
let expected: Array1<f64> = array![1., 1., 2.];
688-
let result = a.least_squares(&b).unwrap();
689-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
690-
691-
let residual = b - a.dot(&result.solution);
692-
let resid_ssq = result.residual_sum_of_squares.unwrap();
693-
assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12);
694-
}
695-
696-
#[test]
697-
fn netlib_lapack_example_for_dgels_nrhs() {
698-
let a: Array2<f64> = array![
699-
[1., 1., 1.],
700-
[2., 3., 4.],
701-
[3., 5., 2.],
702-
[4., 2., 5.],
703-
[5., 4., 3.]
704-
];
705-
let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]];
706-
let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]];
707-
let result = a.least_squares(&b).unwrap();
708-
assert!(result.solution.abs_diff_eq(&expected, 1e-12));
709-
710-
let residual = &b - &a.dot(&result.solution);
711-
let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0));
712-
assert!(result
713-
.residual_sum_of_squares
714-
.unwrap()
715-
.abs_diff_eq(&residual_ssq, 1e-12));
716-
}
717-
718546
//
719547
// Testing error cases
720548
//
721-
use crate::layout::MatrixLayout;
722549

723550
#[test]
724-
fn test_incompatible_shape_error_on_mismatching_num_rows() {
551+
fn incompatible_shape_error_on_mismatching_num_rows() {
725552
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
726553
let b: Array1<f64> = array![1., 2.];
727554
let res = a.least_squares(&b);
728555
match res {
729-
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
556+
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
730557
_ => panic!("Expected Err()"),
731558
}
732559
}
733560

734561
#[test]
735-
fn test_incompatible_shape_error_on_mismatching_layout() {
562+
fn incompatible_shape_error_on_mismatching_layout() {
736563
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
737564
let b = array![[1.], [2.]].t().to_owned();
738565
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });
739566

740567
let res = a.least_squares(&b);
741568
match res {
742-
Err(LinalgError::Lapack(err)) if matches!(err, lapack::error::Error::InvalidShape) => {}
569+
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
743570
_ => panic!("Expected Err()"),
744571
}
745572
}

0 commit comments

Comments
 (0)