1
1
//! Least squares
2
2
3
- use crate :: { error:: * , layout:: MatrixLayout } ;
3
+ use crate :: { error:: * , layout:: * } ;
4
4
use cauchy:: * ;
5
- use num_traits:: Zero ;
5
+ use num_traits:: { ToPrimitive , Zero } ;
6
6
7
7
/// Result of LeastSquares
8
8
pub struct LeastSquaresOutput < A : Scalar > {
@@ -14,13 +14,13 @@ pub struct LeastSquaresOutput<A: Scalar> {
14
14
15
15
/// Wraps `*gelsd`
16
16
pub trait LeastSquaresSvdDivideConquer_ : Scalar {
17
- unsafe fn least_squares (
17
+ fn least_squares (
18
18
a_layout : MatrixLayout ,
19
19
a : & mut [ Self ] ,
20
20
b : & mut [ Self ] ,
21
21
) -> Result < LeastSquaresOutput < Self > > ;
22
22
23
- unsafe fn least_squares_nrhs (
23
+ fn least_squares_nrhs (
24
24
a_layout : MatrixLayout ,
25
25
a : & mut [ Self ] ,
26
26
b_layout : MatrixLayout ,
@@ -29,81 +29,129 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
29
29
}
30
30
31
31
macro_rules! impl_least_squares {
32
- ( $scalar: ty, $gelsd: path) => {
32
+ ( @real, $scalar: ty, $gelsd: path) => {
33
+ impl_least_squares!( @body, $scalar, $gelsd, ) ;
34
+ } ;
35
+ ( @complex, $scalar: ty, $gelsd: path) => {
36
+ impl_least_squares!( @body, $scalar, $gelsd, rwork) ;
37
+ } ;
38
+
39
+ ( @body, $scalar: ty, $gelsd: path, $( $rwork: ident) ,* ) => {
33
40
impl LeastSquaresSvdDivideConquer_ for $scalar {
34
- unsafe fn least_squares(
35
- a_layout : MatrixLayout ,
41
+ fn least_squares(
42
+ l : MatrixLayout ,
36
43
a: & mut [ Self ] ,
37
44
b: & mut [ Self ] ,
38
45
) -> Result <LeastSquaresOutput <Self >> {
39
- let ( m, n) = a_layout. size( ) ;
40
- if ( m as usize ) > b. len( ) || ( n as usize ) > b. len( ) {
41
- return Err ( Error :: InvalidShape ) ;
42
- }
43
- let k = :: std:: cmp:: min( m, n) ;
44
- let nrhs = 1 ;
45
- let ldb = match a_layout {
46
- MatrixLayout :: F { .. } => m. max( n) ,
47
- MatrixLayout :: C { .. } => 1 ,
48
- } ;
49
- let rcond: Self :: Real = -1. ;
50
- let mut singular_values: Vec <Self :: Real > = vec![ Self :: Real :: zero( ) ; k as usize ] ;
51
- let mut rank: i32 = 0 ;
52
-
53
- $gelsd(
54
- a_layout. lapacke_layout( ) ,
55
- m,
56
- n,
57
- nrhs,
58
- a,
59
- a_layout. lda( ) ,
60
- b,
61
- ldb,
62
- & mut singular_values,
63
- rcond,
64
- & mut rank,
65
- )
66
- . as_lapack_result( ) ?;
67
-
68
- Ok ( LeastSquaresOutput {
69
- singular_values,
70
- rank,
71
- } )
46
+ let b_layout = l. resized( b. len( ) as i32 , 1 ) ;
47
+ Self :: least_squares_nrhs( l, a, b_layout, b)
72
48
}
73
49
74
- unsafe fn least_squares_nrhs(
50
+ fn least_squares_nrhs(
75
51
a_layout: MatrixLayout ,
76
52
a: & mut [ Self ] ,
77
53
b_layout: MatrixLayout ,
78
54
b: & mut [ Self ] ,
79
55
) -> Result <LeastSquaresOutput <Self >> {
56
+ // Minimize |b - Ax|_2
57
+ //
58
+ // where
59
+ // A : (m, n)
60
+ // b : (max(m, n), nrhs) // `b` has to store `x` on exit
61
+ // x : (n, nrhs)
80
62
let ( m, n) = a_layout. size( ) ;
81
- if ( m as usize ) > b. len( )
82
- || ( n as usize ) > b. len( )
83
- || a_layout. lapacke_layout( ) != b_layout. lapacke_layout( )
84
- {
85
- return Err ( Error :: InvalidShape ) ;
86
- }
87
- let k = :: std:: cmp:: min( m, n) ;
88
- let nrhs = b_layout. size( ) . 1 ;
63
+ let ( m_, nrhs) = b_layout. size( ) ;
64
+ let k = m. min( n) ;
65
+ assert!( m_ >= m) ;
66
+
67
+ // Transpose if a is C-continuous
68
+ let mut a_t = None ;
69
+ let a_layout = match a_layout {
70
+ MatrixLayout :: C { .. } => {
71
+ a_t = Some ( vec![ Self :: zero( ) ; a. len( ) ] ) ;
72
+ transpose( a_layout, a, a_t. as_mut( ) . unwrap( ) )
73
+ }
74
+ MatrixLayout :: F { .. } => a_layout,
75
+ } ;
76
+
77
+ // Transpose if b is C-continuous
78
+ let mut b_t = None ;
79
+ let b_layout = match b_layout {
80
+ MatrixLayout :: C { .. } => {
81
+ b_t = Some ( vec![ Self :: zero( ) ; b. len( ) ] ) ;
82
+ transpose( b_layout, b, b_t. as_mut( ) . unwrap( ) )
83
+ }
84
+ MatrixLayout :: F { .. } => b_layout,
85
+ } ;
86
+
89
87
let rcond: Self :: Real = -1. ;
90
88
let mut singular_values: Vec <Self :: Real > = vec![ Self :: Real :: zero( ) ; k as usize ] ;
91
89
let mut rank: i32 = 0 ;
92
90
93
- $gelsd(
94
- a_layout. lapacke_layout( ) ,
95
- m,
96
- n,
97
- nrhs,
98
- a,
99
- a_layout. lda( ) ,
100
- b,
101
- b_layout. lda( ) ,
102
- & mut singular_values,
103
- rcond,
104
- & mut rank,
105
- )
106
- . as_lapack_result( ) ?;
91
+ // eval work size
92
+ let mut info = 0 ;
93
+ let mut work_size = [ Self :: zero( ) ] ;
94
+ let mut iwork_size = [ 0 ] ;
95
+ $(
96
+ let mut $rwork = [ Self :: Real :: zero( ) ] ;
97
+ ) *
98
+ unsafe {
99
+ $gelsd(
100
+ m,
101
+ n,
102
+ nrhs,
103
+ a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ,
104
+ a_layout. lda( ) ,
105
+ b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ,
106
+ b_layout. lda( ) ,
107
+ & mut singular_values,
108
+ rcond,
109
+ & mut rank,
110
+ & mut work_size,
111
+ -1 ,
112
+ $( & mut $rwork, ) *
113
+ & mut iwork_size,
114
+ & mut info,
115
+ )
116
+ } ;
117
+ info. as_lapack_result( ) ?;
118
+
119
+ // calc
120
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
121
+ let mut work = vec![ Self :: zero( ) ; lwork] ;
122
+ let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
123
+ let mut iwork = vec![ 0 ; liwork] ;
124
+ $(
125
+ let lrwork = $rwork[ 0 ] . to_usize( ) . unwrap( ) ;
126
+ let mut $rwork = vec![ Self :: Real :: zero( ) ; lrwork] ;
127
+ ) *
128
+ unsafe {
129
+ $gelsd(
130
+ m,
131
+ n,
132
+ nrhs,
133
+ a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ,
134
+ a_layout. lda( ) ,
135
+ b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ,
136
+ b_layout. lda( ) ,
137
+ & mut singular_values,
138
+ rcond,
139
+ & mut rank,
140
+ & mut work,
141
+ lwork as i32 ,
142
+ $( & mut $rwork, ) *
143
+ & mut iwork,
144
+ & mut info,
145
+ ) ;
146
+ }
147
+ info. as_lapack_result( ) ?;
148
+
149
+ // Skip a_t -> a transpose because A has been destroyed
150
+ // Re-transpose b
151
+ if let Some ( b_t) = b_t {
152
+ transpose( b_layout, & b_t, b) ;
153
+ }
154
+
107
155
Ok ( LeastSquaresOutput {
108
156
singular_values,
109
157
rank,
@@ -113,7 +161,7 @@ macro_rules! impl_least_squares {
113
161
} ;
114
162
}
115
163
116
- impl_least_squares ! ( f64 , lapacke :: dgelsd) ;
117
- impl_least_squares ! ( f32 , lapacke :: sgelsd) ;
118
- impl_least_squares ! ( c64, lapacke :: zgelsd) ;
119
- impl_least_squares ! ( c32, lapacke :: cgelsd) ;
164
+ impl_least_squares ! ( @real , f64 , lapack :: dgelsd) ;
165
+ impl_least_squares ! ( @real , f32 , lapack :: sgelsd) ;
166
+ impl_least_squares ! ( @complex , c64, lapack :: zgelsd) ;
167
+ impl_least_squares ! ( @complex , c32, lapack :: cgelsd) ;
0 commit comments