@@ -28,8 +28,15 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
28
28
) -> Result < LeastSquaresOutput < Self > > ;
29
29
}
30
30
31
- macro_rules! impl_least_squares_real {
32
- ( $scalar: ty, $gelsd: path) => {
31
+ macro_rules! impl_least_squares {
32
+ ( @real, $scalar: ty, $gelsd: path) => {
33
+ impl_least_squares!( @body, $scalar, $gelsd, ) ;
34
+ } ;
35
+ ( @complex, $scalar: ty, $gelsd: path) => {
36
+ impl_least_squares!( @body, $scalar, $gelsd, rwork) ;
37
+ } ;
38
+
39
+ ( @body, $scalar: ty, $gelsd: path, $( $rwork: ident) ,* ) => {
33
40
impl LeastSquaresSvdDivideConquer_ for $scalar {
34
41
unsafe fn least_squares(
35
42
l: MatrixLayout ,
@@ -85,6 +92,9 @@ macro_rules! impl_least_squares_real {
85
92
let mut info = 0 ;
86
93
let mut work_size = [ Self :: zero( ) ] ;
87
94
let mut iwork_size = [ 0 ] ;
95
+ $(
96
+ let mut $rwork = [ Self :: Real :: zero( ) ] ;
97
+ ) *
88
98
$gelsd(
89
99
m,
90
100
n,
@@ -98,6 +108,7 @@ macro_rules! impl_least_squares_real {
98
108
& mut rank,
99
109
& mut work_size,
100
110
-1 ,
111
+ $( & mut $rwork, ) *
101
112
& mut iwork_size,
102
113
& mut info,
103
114
) ;
@@ -108,6 +119,10 @@ macro_rules! impl_least_squares_real {
108
119
let mut work = vec![ Self :: zero( ) ; lwork] ;
109
120
let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
110
121
let mut iwork = vec![ 0 ; liwork] ;
122
+ $(
123
+ let lrwork = $rwork[ 0 ] . to_usize( ) . unwrap( ) ;
124
+ let mut $rwork = vec![ Self :: Real :: zero( ) ; lrwork] ;
125
+ ) *
111
126
$gelsd(
112
127
m,
113
128
n,
@@ -121,6 +136,7 @@ macro_rules! impl_least_squares_real {
121
136
& mut rank,
122
137
& mut work,
123
138
lwork as i32 ,
139
+ $( & mut $rwork, ) *
124
140
& mut iwork,
125
141
& mut info,
126
142
) ;
@@ -141,93 +157,7 @@ macro_rules! impl_least_squares_real {
141
157
} ;
142
158
}
143
159
144
- impl_least_squares_real ! ( f64 , lapack:: dgelsd) ;
145
- impl_least_squares_real ! ( f32 , lapack:: sgelsd) ;
146
-
147
- macro_rules! impl_least_squares {
148
- ( $scalar: ty, $gelsd: path) => {
149
- impl LeastSquaresSvdDivideConquer_ for $scalar {
150
- unsafe fn least_squares(
151
- a_layout: MatrixLayout ,
152
- a: & mut [ Self ] ,
153
- b: & mut [ Self ] ,
154
- ) -> Result <LeastSquaresOutput <Self >> {
155
- let ( m, n) = a_layout. size( ) ;
156
- if ( m as usize ) > b. len( ) || ( n as usize ) > b. len( ) {
157
- return Err ( Error :: InvalidShape ) ;
158
- }
159
- let k = :: std:: cmp:: min( m, n) ;
160
- let nrhs = 1 ;
161
- let ldb = match a_layout {
162
- MatrixLayout :: F { .. } => m. max( n) ,
163
- MatrixLayout :: C { .. } => 1 ,
164
- } ;
165
- let rcond: Self :: Real = -1. ;
166
- let mut singular_values: Vec <Self :: Real > = vec![ Self :: Real :: zero( ) ; k as usize ] ;
167
- let mut rank: i32 = 0 ;
168
-
169
- $gelsd(
170
- a_layout. lapacke_layout( ) ,
171
- m,
172
- n,
173
- nrhs,
174
- a,
175
- a_layout. lda( ) ,
176
- b,
177
- ldb,
178
- & mut singular_values,
179
- rcond,
180
- & mut rank,
181
- )
182
- . as_lapack_result( ) ?;
183
-
184
- Ok ( LeastSquaresOutput {
185
- singular_values,
186
- rank,
187
- } )
188
- }
189
-
190
- unsafe fn least_squares_nrhs(
191
- a_layout: MatrixLayout ,
192
- a: & mut [ Self ] ,
193
- b_layout: MatrixLayout ,
194
- b: & mut [ Self ] ,
195
- ) -> Result <LeastSquaresOutput <Self >> {
196
- let ( m, n) = a_layout. size( ) ;
197
- if ( m as usize ) > b. len( )
198
- || ( n as usize ) > b. len( )
199
- || a_layout. lapacke_layout( ) != b_layout. lapacke_layout( )
200
- {
201
- return Err ( Error :: InvalidShape ) ;
202
- }
203
- let k = :: std:: cmp:: min( m, n) ;
204
- let nrhs = b_layout. size( ) . 1 ;
205
- let rcond: Self :: Real = -1. ;
206
- let mut singular_values: Vec <Self :: Real > = vec![ Self :: Real :: zero( ) ; k as usize ] ;
207
- let mut rank: i32 = 0 ;
208
-
209
- $gelsd(
210
- a_layout. lapacke_layout( ) ,
211
- m,
212
- n,
213
- nrhs,
214
- a,
215
- a_layout. lda( ) ,
216
- b,
217
- b_layout. lda( ) ,
218
- & mut singular_values,
219
- rcond,
220
- & mut rank,
221
- )
222
- . as_lapack_result( ) ?;
223
- Ok ( LeastSquaresOutput {
224
- singular_values,
225
- rank,
226
- } )
227
- }
228
- }
229
- } ;
230
- }
231
-
232
- impl_least_squares ! ( c64, lapacke:: zgelsd) ;
233
- impl_least_squares ! ( c32, lapacke:: cgelsd) ;
160
+ impl_least_squares ! ( @real, f64 , lapack:: dgelsd) ;
161
+ impl_least_squares ! ( @real, f32 , lapack:: sgelsd) ;
162
+ impl_least_squares ! ( @complex, c64, lapack:: zgelsd) ;
163
+ impl_least_squares ! ( @complex, c32, lapack:: cgelsd) ;
0 commit comments