60
60
//! // `a` and `b` have been moved, no longer valid
61
61
//! ```
62
62
63
- use ndarray:: { s , Array , Array1 , Array2 , ArrayBase , Axis , Data , DataMut , Dimension , Ix0 , Ix1 , Ix2 } ;
63
+ use ndarray:: * ;
64
64
65
65
use crate :: error:: * ;
66
66
use crate :: lapack:: least_squares:: * ;
@@ -352,7 +352,10 @@ where
352
352
// we need a new rhs b/c it will be overwritten with the solution
353
353
// for which we need `n` entries
354
354
let k = rhs. shape ( ) [ 1 ] ;
355
- let mut new_rhs = Array2 :: < E > :: zeros ( ( n, k) ) ;
355
+ let mut new_rhs = match self . layout ( ) ? {
356
+ MatrixLayout :: C { .. } => Array2 :: < E > :: zeros ( ( n, k) ) ,
357
+ MatrixLayout :: F { .. } => Array2 :: < E > :: zeros ( ( n, k) . f ( ) ) ,
358
+ } ;
356
359
new_rhs. slice_mut ( s ! [ 0 ..m, ..] ) . assign ( rhs) ;
357
360
compute_least_squares_nrhs ( self , & mut new_rhs)
358
361
} else {
@@ -414,117 +417,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414
417
415
418
#[ cfg( test) ]
416
419
mod tests {
417
- use super :: * ;
420
+ use crate :: { error :: LinalgError , * } ;
418
421
use approx:: AbsDiffEq ;
419
- use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
420
- use num_complex:: Complex ;
421
-
422
- //
423
- // Test cases taken from the scipy test suite for the scipy lstsq function
424
- // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425
- //
426
- #[ test]
427
- fn scipy_test_simple_exact ( ) {
428
- let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
429
- let bs = vec ! [
430
- array![ [ 1. , 0. ] , [ 0. , 1. ] ] ,
431
- array![ [ 1. ] , [ 0. ] ] ,
432
- array![ [ 2. , 1. ] , [ -30. , 4. ] ] ,
433
- ] ;
434
- for b in & bs {
435
- let res = a. least_squares ( b) . unwrap ( ) ;
436
- assert_eq ! ( res. rank, 2 ) ;
437
- let b_hat = a. dot ( & res. solution ) ;
438
- let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
439
- assert ! ( res
440
- . residual_sum_of_squares
441
- . unwrap( )
442
- . abs_diff_eq( & rssq, 1e-12 ) ) ;
443
- assert ! ( b_hat. abs_diff_eq( & b, 1e-12 ) ) ;
444
- }
445
- }
446
-
447
- #[ test]
448
- fn scipy_test_simple_overdetermined ( ) {
449
- let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
450
- let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
451
- let res = a. least_squares ( & b) . unwrap ( ) ;
452
- assert_eq ! ( res. rank, 2 ) ;
453
- let b_hat = a. dot ( & res. solution ) ;
454
- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
455
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
456
- assert ! ( res
457
- . solution
458
- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
459
- }
460
-
461
- #[ test]
462
- fn scipy_test_simple_overdetermined_f32 ( ) {
463
- let a: Array2 < f32 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
464
- let b: Array1 < f32 > = array ! [ 1. , 2. , 3. ] ;
465
- let res = a. least_squares ( & b) . unwrap ( ) ;
466
- assert_eq ! ( res. rank, 2 ) ;
467
- let b_hat = a. dot ( & res. solution ) ;
468
- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
469
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
470
- assert ! ( res
471
- . solution
472
- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
473
- }
474
-
475
- fn c ( re : f64 , im : f64 ) -> Complex < f64 > {
476
- Complex :: new ( re, im)
477
- }
478
-
479
- #[ test]
480
- fn scipy_test_simple_overdetermined_complex ( ) {
481
- let a: Array2 < c64 > = array ! [
482
- [ c( 1. , 2. ) , c( 2. , 0. ) ] ,
483
- [ c( 4. , 0. ) , c( 5. , 0. ) ] ,
484
- [ c( 3. , 0. ) , c( 4. , 0. ) ]
485
- ] ;
486
- let b: Array1 < c64 > = array ! [ c( 1. , 0. ) , c( 2. , 4. ) , c( 3. , 0. ) ] ;
487
- let res = a. least_squares ( & b) . unwrap ( ) ;
488
- assert_eq ! ( res. rank, 2 ) ;
489
- let b_hat = a. dot ( & res. solution ) ;
490
- let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
491
- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
492
- assert ! ( res. solution. abs_diff_eq(
493
- & array![
494
- c( -0.4831460674157303 , 0.258426966292135 ) ,
495
- c( 0.921348314606741 , 0.292134831460674 )
496
- ] ,
497
- 1e-12
498
- ) ) ;
499
- }
500
-
501
- #[ test]
502
- fn scipy_test_simple_underdetermined ( ) {
503
- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
504
- let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
505
- let res = a. least_squares ( & b) . unwrap ( ) ;
506
- assert_eq ! ( res. rank, 2 ) ;
507
- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
508
- let expected = array ! [ -0.055555555555555 , 0.111111111111111 , 0.277777777777777 ] ;
509
- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
510
- }
511
-
512
- /// This test case tests the underdetermined case for multiple right hand
513
- /// sides. Adapted from scipy lstsq tests.
514
- #[ test]
515
- fn scipy_test_simple_underdetermined_nrhs ( ) {
516
- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
517
- let b: Array2 < f64 > = array ! [ [ 1. , 1. ] , [ 2. , 2. ] ] ;
518
- let res = a. least_squares ( & b) . unwrap ( ) ;
519
- assert_eq ! ( res. rank, 2 ) ;
520
- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
521
- let expected = array ! [
522
- [ -0.055555555555555 , -0.055555555555555 ] ,
523
- [ 0.111111111111111 , 0.111111111111111 ] ,
524
- [ 0.277777777777777 , 0.277777777777777 ]
525
- ] ;
526
- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
527
- }
422
+ use ndarray:: * ;
528
423
529
424
//
530
425
// Test that the different lest squares traits work as intended on the
@@ -554,23 +449,23 @@ mod tests {
554
449
}
555
450
556
451
#[ test]
557
- fn test_least_squares_on_arc ( ) {
452
+ fn on_arc ( ) {
558
453
let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
559
454
let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
560
455
let res = a. least_squares ( & b) . unwrap ( ) ;
561
456
assert_result ( & a, & b, & res) ;
562
457
}
563
458
564
459
#[ test]
565
- fn test_least_squares_on_cow ( ) {
460
+ fn on_cow ( ) {
566
461
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
567
462
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
568
463
let res = a. least_squares ( & b) . unwrap ( ) ;
569
464
assert_result ( & a, & b, & res) ;
570
465
}
571
466
572
467
#[ test]
573
- fn test_least_squares_on_view ( ) {
468
+ fn on_view ( ) {
574
469
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
575
470
let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
576
471
let av = a. view ( ) ;
@@ -580,7 +475,7 @@ mod tests {
580
475
}
581
476
582
477
#[ test]
583
- fn test_least_squares_on_view_mut ( ) {
478
+ fn on_view_mut ( ) {
584
479
let mut a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
585
480
let mut b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
586
481
let av = a. view_mut ( ) ;
@@ -590,7 +485,7 @@ mod tests {
590
485
}
591
486
592
487
#[ test]
593
- fn test_least_squares_into_on_owned ( ) {
488
+ fn into_on_owned ( ) {
594
489
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
595
490
let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
596
491
let ac = a. clone ( ) ;
@@ -600,7 +495,7 @@ mod tests {
600
495
}
601
496
602
497
#[ test]
603
- fn test_least_squares_into_on_arc ( ) {
498
+ fn into_on_arc ( ) {
604
499
let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
605
500
let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
606
501
let a2 = a. clone ( ) ;
@@ -610,7 +505,7 @@ mod tests {
610
505
}
611
506
612
507
#[ test]
613
- fn test_least_squares_into_on_cow ( ) {
508
+ fn into_on_cow ( ) {
614
509
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
615
510
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
616
511
let a2 = a. clone ( ) ;
@@ -620,7 +515,7 @@ mod tests {
620
515
}
621
516
622
517
#[ test]
623
- fn test_least_squares_in_place_on_owned ( ) {
518
+ fn in_place_on_owned ( ) {
624
519
let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
625
520
let b = array ! [ 1. , 2. , 3. ] ;
626
521
let mut a2 = a. clone ( ) ;
@@ -630,7 +525,7 @@ mod tests {
630
525
}
631
526
632
527
#[ test]
633
- fn test_least_squares_in_place_on_cow ( ) {
528
+ fn in_place_on_cow ( ) {
634
529
let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
635
530
let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
636
531
let mut a2 = a. clone ( ) ;
@@ -640,7 +535,7 @@ mod tests {
640
535
}
641
536
642
537
#[ test]
643
- fn test_least_squares_in_place_on_mut_view ( ) {
538
+ fn in_place_on_mut_view ( ) {
644
539
let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
645
540
let b = array ! [ 1. , 2. , 3. ] ;
646
541
let mut a2 = a. clone ( ) ;
@@ -651,95 +546,30 @@ mod tests {
651
546
assert_result ( & a, & b, & res) ;
652
547
}
653
548
654
- //
655
- // Test cases taken from the netlib documentation at
656
- // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657
- //
658
- #[ test]
659
- fn netlib_lapack_example_for_dgels_1 ( ) {
660
- let a: Array2 < f64 > = array ! [
661
- [ 1. , 1. , 1. ] ,
662
- [ 2. , 3. , 4. ] ,
663
- [ 3. , 5. , 2. ] ,
664
- [ 4. , 2. , 5. ] ,
665
- [ 5. , 4. , 3. ]
666
- ] ;
667
- let b: Array1 < f64 > = array ! [ -10. , 12. , 14. , 16. , 18. ] ;
668
- let expected: Array1 < f64 > = array ! [ 2. , 1. , 1. ] ;
669
- let result = a. least_squares ( & b) . unwrap ( ) ;
670
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
671
-
672
- let residual = b - a. dot ( & result. solution ) ;
673
- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
674
- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
675
- }
676
-
677
- #[ test]
678
- fn netlib_lapack_example_for_dgels_2 ( ) {
679
- let a: Array2 < f64 > = array ! [
680
- [ 1. , 1. , 1. ] ,
681
- [ 2. , 3. , 4. ] ,
682
- [ 3. , 5. , 2. ] ,
683
- [ 4. , 2. , 5. ] ,
684
- [ 5. , 4. , 3. ]
685
- ] ;
686
- let b: Array1 < f64 > = array ! [ -3. , 14. , 12. , 16. , 16. ] ;
687
- let expected: Array1 < f64 > = array ! [ 1. , 1. , 2. ] ;
688
- let result = a. least_squares ( & b) . unwrap ( ) ;
689
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
690
-
691
- let residual = b - a. dot ( & result. solution ) ;
692
- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
693
- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
694
- }
695
-
696
- #[ test]
697
- fn netlib_lapack_example_for_dgels_nrhs ( ) {
698
- let a: Array2 < f64 > = array ! [
699
- [ 1. , 1. , 1. ] ,
700
- [ 2. , 3. , 4. ] ,
701
- [ 3. , 5. , 2. ] ,
702
- [ 4. , 2. , 5. ] ,
703
- [ 5. , 4. , 3. ]
704
- ] ;
705
- let b: Array2 < f64 > = array ! [ [ -10. , -3. ] , [ 12. , 14. ] , [ 14. , 12. ] , [ 16. , 16. ] , [ 18. , 16. ] ] ;
706
- let expected: Array2 < f64 > = array ! [ [ 2. , 1. ] , [ 1. , 1. ] , [ 1. , 2. ] ] ;
707
- let result = a. least_squares ( & b) . unwrap ( ) ;
708
- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
709
-
710
- let residual = & b - & a. dot ( & result. solution ) ;
711
- let residual_ssq = residual. mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
712
- assert ! ( result
713
- . residual_sum_of_squares
714
- . unwrap( )
715
- . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
716
- }
717
-
718
549
//
719
550
// Testing error cases
720
551
//
721
- use crate :: layout:: MatrixLayout ;
722
552
723
553
#[ test]
724
- fn test_incompatible_shape_error_on_mismatching_num_rows ( ) {
554
+ fn incompatible_shape_error_on_mismatching_num_rows ( ) {
725
555
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
726
556
let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
727
557
let res = a. least_squares ( & b) ;
728
558
match res {
729
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
559
+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
730
560
_ => panic ! ( "Expected Err()" ) ,
731
561
}
732
562
}
733
563
734
564
#[ test]
735
- fn test_incompatible_shape_error_on_mismatching_layout ( ) {
565
+ fn incompatible_shape_error_on_mismatching_layout ( ) {
736
566
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
737
567
let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
738
568
assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F { col: 2 , lda: 1 } ) ;
739
569
740
570
let res = a. least_squares ( & b) ;
741
571
match res {
742
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
572
+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
743
573
_ => panic ! ( "Expected Err()" ) ,
744
574
}
745
575
}
0 commit comments