@@ -414,117 +414,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414
414
415
415
#[ cfg( test) ]
416
416
mod tests {
417
- use super :: * ;
417
+ use crate :: { error :: LinalgError , * } ;
418
418
use approx:: AbsDiffEq ;
419
- use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
420
- use num_complex:: Complex ;
421
-
422
- //
423
- // Test cases taken from the scipy test suite for the scipy lstsq function
424
- // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425
- //
426
- #[ test]
427
- fn scipy_test_simple_exact ( ) {
428
- let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
429
- let bs = vec ! [
430
- array![ [ 1. , 0. ] , [ 0. , 1. ] ] ,
431
- array![ [ 1. ] , [ 0. ] ] ,
432
- array![ [ 2. , 1. ] , [ -30. , 4. ] ] ,
433
- ] ;
434
- for b in & bs {
435
- let res = a. least_squares ( b) . unwrap ( ) ;
436
- assert_eq ! ( res. rank, 2 ) ;
437
- let b_hat = a. dot ( & res. solution ) ;
438
- let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
439
- assert ! ( res
440
- . residual_sum_of_squares
441
- . unwrap( )
442
- . abs_diff_eq( & rssq, 1e-12 ) ) ;
443
- assert ! ( b_hat. abs_diff_eq( & b, 1e-12 ) ) ;
444
- }
445
- }
446
-
447
- #[ test]
448
- fn scipy_test_simple_overdetermined ( ) {
449
- let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
450
- let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
451
- let res = a. least_squares ( & b) . unwrap ( ) ;
452
- assert_eq ! ( res. rank, 2 ) ;
453
- let b_hat = a. dot ( & res. solution ) ;
454
- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
455
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
456
- assert ! ( res
457
- . solution
458
- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
459
- }
460
-
461
- #[ test]
462
- fn scipy_test_simple_overdetermined_f32 ( ) {
463
- let a: Array2 < f32 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
464
- let b: Array1 < f32 > = array ! [ 1. , 2. , 3. ] ;
465
- let res = a. least_squares ( & b) . unwrap ( ) ;
466
- assert_eq ! ( res. rank, 2 ) ;
467
- let b_hat = a. dot ( & res. solution ) ;
468
- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
469
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
470
- assert ! ( res
471
- . solution
472
- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
473
- }
474
-
475
- fn c ( re : f64 , im : f64 ) -> Complex < f64 > {
476
- Complex :: new ( re, im)
477
- }
478
-
479
- #[ test]
480
- fn scipy_test_simple_overdetermined_complex ( ) {
481
- let a: Array2 < c64 > = array ! [
482
- [ c( 1. , 2. ) , c( 2. , 0. ) ] ,
483
- [ c( 4. , 0. ) , c( 5. , 0. ) ] ,
484
- [ c( 3. , 0. ) , c( 4. , 0. ) ]
485
- ] ;
486
- let b: Array1 < c64 > = array ! [ c( 1. , 0. ) , c( 2. , 4. ) , c( 3. , 0. ) ] ;
487
- let res = a. least_squares ( & b) . unwrap ( ) ;
488
- assert_eq ! ( res. rank, 2 ) ;
489
- let b_hat = a. dot ( & res. solution ) ;
490
- let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
491
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
492
- assert ! ( res. solution. abs_diff_eq(
493
- & array![
494
- c( -0.4831460674157303 , 0.258426966292135 ) ,
495
- c( 0.921348314606741 , 0.292134831460674 )
496
- ] ,
497
- 1e-12
498
- ) ) ;
499
- }
500
-
501
- #[ test]
502
- fn scipy_test_simple_underdetermined ( ) {
503
- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
504
- let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
505
- let res = a. least_squares ( & b) . unwrap ( ) ;
506
- assert_eq ! ( res. rank, 2 ) ;
507
- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
508
- let expected = array ! [ -0.055555555555555 , 0.111111111111111 , 0.277777777777777 ] ;
509
- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
510
- }
511
-
512
- /// This test case tests the underdetermined case for multiple right hand
513
- /// sides. Adapted from scipy lstsq tests.
514
- #[ test]
515
- fn scipy_test_simple_underdetermined_nrhs ( ) {
516
- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
517
- let b: Array2 < f64 > = array ! [ [ 1. , 1. ] , [ 2. , 2. ] ] ;
518
- let res = a. least_squares ( & b) . unwrap ( ) ;
519
- assert_eq ! ( res. rank, 2 ) ;
520
- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
521
- let expected = array ! [
522
- [ -0.055555555555555 , -0.055555555555555 ] ,
523
- [ 0.111111111111111 , 0.111111111111111 ] ,
524
- [ 0.277777777777777 , 0.277777777777777 ]
525
- ] ;
526
- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
527
- }
419
+ use ndarray:: * ;
528
420
529
421
//
530
422
// Test that the different lest squares traits work as intended on the
@@ -554,23 +446,23 @@ mod tests {
554
446
}
555
447
556
448
#[ test]
557
- fn test_least_squares_on_arc ( ) {
449
+ fn on_arc ( ) {
558
450
let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
559
451
let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
560
452
let res = a. least_squares ( & b) . unwrap ( ) ;
561
453
assert_result ( & a, & b, & res) ;
562
454
}
563
455
564
456
#[ test]
565
- fn test_least_squares_on_cow ( ) {
457
+ fn on_cow ( ) {
566
458
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
567
459
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
568
460
let res = a. least_squares ( & b) . unwrap ( ) ;
569
461
assert_result ( & a, & b, & res) ;
570
462
}
571
463
572
464
#[ test]
573
- fn test_least_squares_on_view ( ) {
465
+ fn on_view ( ) {
574
466
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
575
467
let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
576
468
let av = a. view ( ) ;
@@ -580,7 +472,7 @@ mod tests {
580
472
}
581
473
582
474
#[ test]
583
- fn test_least_squares_on_view_mut ( ) {
475
+ fn on_view_mut ( ) {
584
476
let mut a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
585
477
let mut b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
586
478
let av = a. view_mut ( ) ;
@@ -590,7 +482,7 @@ mod tests {
590
482
}
591
483
592
484
#[ test]
593
- fn test_least_squares_into_on_owned ( ) {
485
+ fn into_on_owned ( ) {
594
486
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
595
487
let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
596
488
let ac = a. clone ( ) ;
@@ -600,7 +492,7 @@ mod tests {
600
492
}
601
493
602
494
#[ test]
603
- fn test_least_squares_into_on_arc ( ) {
495
+ fn into_on_arc ( ) {
604
496
let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
605
497
let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
606
498
let a2 = a. clone ( ) ;
@@ -610,7 +502,7 @@ mod tests {
610
502
}
611
503
612
504
#[ test]
613
- fn test_least_squares_into_on_cow ( ) {
505
+ fn into_on_cow ( ) {
614
506
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
615
507
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
616
508
let a2 = a. clone ( ) ;
@@ -620,7 +512,7 @@ mod tests {
620
512
}
621
513
622
514
#[ test]
623
- fn test_least_squares_in_place_on_owned ( ) {
515
+ fn in_place_on_owned ( ) {
624
516
let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
625
517
let b = array ! [ 1. , 2. , 3. ] ;
626
518
let mut a2 = a. clone ( ) ;
@@ -630,7 +522,7 @@ mod tests {
630
522
}
631
523
632
524
#[ test]
633
- fn test_least_squares_in_place_on_cow ( ) {
525
+ fn in_place_on_cow ( ) {
634
526
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
635
527
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
636
528
let mut a2 = a. clone ( ) ;
@@ -640,7 +532,7 @@ mod tests {
640
532
}
641
533
642
534
#[ test]
643
- fn test_least_squares_in_place_on_mut_view ( ) {
535
+ fn in_place_on_mut_view ( ) {
644
536
let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
645
537
let b = array ! [ 1. , 2. , 3. ] ;
646
538
let mut a2 = a. clone ( ) ;
@@ -651,95 +543,30 @@ mod tests {
651
543
assert_result ( & a, & b, & res) ;
652
544
}
653
545
654
- //
655
- // Test cases taken from the netlib documentation at
656
- // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657
- //
658
- #[ test]
659
- fn netlib_lapack_example_for_dgels_1 ( ) {
660
- let a: Array2 < f64 > = array ! [
661
- [ 1. , 1. , 1. ] ,
662
- [ 2. , 3. , 4. ] ,
663
- [ 3. , 5. , 2. ] ,
664
- [ 4. , 2. , 5. ] ,
665
- [ 5. , 4. , 3. ]
666
- ] ;
667
- let b: Array1 < f64 > = array ! [ -10. , 12. , 14. , 16. , 18. ] ;
668
- let expected: Array1 < f64 > = array ! [ 2. , 1. , 1. ] ;
669
- let result = a. least_squares ( & b) . unwrap ( ) ;
670
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
671
-
672
- let residual = b - a. dot ( & result. solution ) ;
673
- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
674
- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
675
- }
676
-
677
- #[ test]
678
- fn netlib_lapack_example_for_dgels_2 ( ) {
679
- let a: Array2 < f64 > = array ! [
680
- [ 1. , 1. , 1. ] ,
681
- [ 2. , 3. , 4. ] ,
682
- [ 3. , 5. , 2. ] ,
683
- [ 4. , 2. , 5. ] ,
684
- [ 5. , 4. , 3. ]
685
- ] ;
686
- let b: Array1 < f64 > = array ! [ -3. , 14. , 12. , 16. , 16. ] ;
687
- let expected: Array1 < f64 > = array ! [ 1. , 1. , 2. ] ;
688
- let result = a. least_squares ( & b) . unwrap ( ) ;
689
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
690
-
691
- let residual = b - a. dot ( & result. solution ) ;
692
- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
693
- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
694
- }
695
-
696
- #[ test]
697
- fn netlib_lapack_example_for_dgels_nrhs ( ) {
698
- let a: Array2 < f64 > = array ! [
699
- [ 1. , 1. , 1. ] ,
700
- [ 2. , 3. , 4. ] ,
701
- [ 3. , 5. , 2. ] ,
702
- [ 4. , 2. , 5. ] ,
703
- [ 5. , 4. , 3. ]
704
- ] ;
705
- let b: Array2 < f64 > = array ! [ [ -10. , -3. ] , [ 12. , 14. ] , [ 14. , 12. ] , [ 16. , 16. ] , [ 18. , 16. ] ] ;
706
- let expected: Array2 < f64 > = array ! [ [ 2. , 1. ] , [ 1. , 1. ] , [ 1. , 2. ] ] ;
707
- let result = a. least_squares ( & b) . unwrap ( ) ;
708
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
709
-
710
- let residual = & b - & a. dot ( & result. solution ) ;
711
- let residual_ssq = residual. mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
712
- assert ! ( result
713
- . residual_sum_of_squares
714
- . unwrap( )
715
- . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
716
- }
717
-
718
546
//
719
547
// Testing error cases
720
548
//
721
- use crate :: layout:: MatrixLayout ;
722
549
723
550
#[ test]
724
- fn test_incompatible_shape_error_on_mismatching_num_rows ( ) {
551
+ fn incompatible_shape_error_on_mismatching_num_rows ( ) {
725
552
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
726
553
let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
727
554
let res = a. least_squares ( & b) ;
728
555
match res {
729
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
556
+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
730
557
_ => panic ! ( "Expected Err()" ) ,
731
558
}
732
559
}
733
560
734
561
#[ test]
735
- fn test_incompatible_shape_error_on_mismatching_layout ( ) {
562
+ fn incompatible_shape_error_on_mismatching_layout ( ) {
736
563
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
737
564
let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
738
565
assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F { col: 2 , lda: 1 } ) ;
739
566
740
567
let res = a. least_squares ( & b) ;
741
568
match res {
742
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
569
+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
743
570
_ => panic ! ( "Expected Err()" ) ,
744
571
}
745
572
}
0 commit comments