Skip to content

Commit 87f0cea

Browse files
committed
Support complex case
1 parent 200b9d4 commit 87f0cea

File tree

1 file changed

+22
-92
lines changed

1 file changed

+22
-92
lines changed

lax/src/least_squares.rs

Lines changed: 22 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
2828
) -> Result<LeastSquaresOutput<Self>>;
2929
}
3030

31-
macro_rules! impl_least_squares_real {
32-
($scalar:ty, $gelsd:path) => {
31+
macro_rules! impl_least_squares {
32+
(@real, $scalar:ty, $gelsd:path) => {
33+
impl_least_squares!(@body, $scalar, $gelsd, );
34+
};
35+
(@complex, $scalar:ty, $gelsd:path) => {
36+
impl_least_squares!(@body, $scalar, $gelsd, rwork);
37+
};
38+
39+
(@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => {
3340
impl LeastSquaresSvdDivideConquer_ for $scalar {
3441
unsafe fn least_squares(
3542
l: MatrixLayout,
@@ -85,6 +92,9 @@ macro_rules! impl_least_squares_real {
8592
let mut info = 0;
8693
let mut work_size = [Self::zero()];
8794
let mut iwork_size = [0];
95+
$(
96+
let mut $rwork = [Self::Real::zero()];
97+
)*
8898
$gelsd(
8999
m,
90100
n,
@@ -98,6 +108,7 @@ macro_rules! impl_least_squares_real {
98108
&mut rank,
99109
&mut work_size,
100110
-1,
111+
$(&mut $rwork,)*
101112
&mut iwork_size,
102113
&mut info,
103114
);
@@ -108,6 +119,10 @@ macro_rules! impl_least_squares_real {
108119
let mut work = vec![Self::zero(); lwork];
109120
let liwork = iwork_size[0].to_usize().unwrap();
110121
let mut iwork = vec![0; liwork];
122+
$(
123+
let lrwork = $rwork[0].to_usize().unwrap();
124+
let mut $rwork = vec![Self::Real::zero(); lrwork];
125+
)*
111126
$gelsd(
112127
m,
113128
n,
@@ -121,6 +136,7 @@ macro_rules! impl_least_squares_real {
121136
&mut rank,
122137
&mut work,
123138
lwork as i32,
139+
$(&mut $rwork,)*
124140
&mut iwork,
125141
&mut info,
126142
);
@@ -141,93 +157,7 @@ macro_rules! impl_least_squares_real {
141157
};
142158
}
143159

144-
impl_least_squares_real!(f64, lapack::dgelsd);
145-
impl_least_squares_real!(f32, lapack::sgelsd);
146-
147-
macro_rules! impl_least_squares {
148-
($scalar:ty, $gelsd:path) => {
149-
impl LeastSquaresSvdDivideConquer_ for $scalar {
150-
unsafe fn least_squares(
151-
a_layout: MatrixLayout,
152-
a: &mut [Self],
153-
b: &mut [Self],
154-
) -> Result<LeastSquaresOutput<Self>> {
155-
let (m, n) = a_layout.size();
156-
if (m as usize) > b.len() || (n as usize) > b.len() {
157-
return Err(Error::InvalidShape);
158-
}
159-
let k = ::std::cmp::min(m, n);
160-
let nrhs = 1;
161-
let ldb = match a_layout {
162-
MatrixLayout::F { .. } => m.max(n),
163-
MatrixLayout::C { .. } => 1,
164-
};
165-
let rcond: Self::Real = -1.;
166-
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
167-
let mut rank: i32 = 0;
168-
169-
$gelsd(
170-
a_layout.lapacke_layout(),
171-
m,
172-
n,
173-
nrhs,
174-
a,
175-
a_layout.lda(),
176-
b,
177-
ldb,
178-
&mut singular_values,
179-
rcond,
180-
&mut rank,
181-
)
182-
.as_lapack_result()?;
183-
184-
Ok(LeastSquaresOutput {
185-
singular_values,
186-
rank,
187-
})
188-
}
189-
190-
unsafe fn least_squares_nrhs(
191-
a_layout: MatrixLayout,
192-
a: &mut [Self],
193-
b_layout: MatrixLayout,
194-
b: &mut [Self],
195-
) -> Result<LeastSquaresOutput<Self>> {
196-
let (m, n) = a_layout.size();
197-
if (m as usize) > b.len()
198-
|| (n as usize) > b.len()
199-
|| a_layout.lapacke_layout() != b_layout.lapacke_layout()
200-
{
201-
return Err(Error::InvalidShape);
202-
}
203-
let k = ::std::cmp::min(m, n);
204-
let nrhs = b_layout.size().1;
205-
let rcond: Self::Real = -1.;
206-
let mut singular_values: Vec<Self::Real> = vec![Self::Real::zero(); k as usize];
207-
let mut rank: i32 = 0;
208-
209-
$gelsd(
210-
a_layout.lapacke_layout(),
211-
m,
212-
n,
213-
nrhs,
214-
a,
215-
a_layout.lda(),
216-
b,
217-
b_layout.lda(),
218-
&mut singular_values,
219-
rcond,
220-
&mut rank,
221-
)
222-
.as_lapack_result()?;
223-
Ok(LeastSquaresOutput {
224-
singular_values,
225-
rank,
226-
})
227-
}
228-
}
229-
};
230-
}
231-
232-
impl_least_squares!(c64, lapacke::zgelsd);
233-
impl_least_squares!(c32, lapacke::cgelsd);
160+
impl_least_squares!(@real, f64, lapack::dgelsd);
161+
impl_least_squares!(@real, f32, lapack::sgelsd);
162+
impl_least_squares!(@complex, c64, lapack::zgelsd);
163+
impl_least_squares!(@complex, c32, lapack::cgelsd);

0 commit comments

Comments
 (0)