Skip to content

Commit 373a18a

Browse files
committed
Impl tridiagonal by LAPACK
1 parent e9c3481 commit 373a18a

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

lax/src/tridiagonal.rs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,13 @@ pub trait Tridiagonal_: Scalar + Sized {
143143
}
144144

145145
macro_rules! impl_tridiagonal {
146-
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
146+
(@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
147+
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
148+
};
149+
(@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
150+
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
151+
};
152+
(@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
147153
impl Tridiagonal_ for $scalar {
148154
unsafe fn lu_tridiagonal(
149155
mut a: Tridiagonal<Self>,
@@ -153,8 +159,11 @@ macro_rules! impl_tridiagonal {
153159
let mut ipiv = vec![0; n as usize];
154160
// We have to calc one-norm before LU factorization
155161
let a_opnorm_one = a.opnorm_one();
156-
$gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv)
157-
.as_lapack_result()?;
162+
let mut info = 0;
163+
$gttrf(
164+
n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,
165+
);
166+
info.as_lapack_result()?;
158167
Ok(LUFactorizedTridiagonal {
159168
a,
160169
du2,
@@ -166,7 +175,12 @@ macro_rules! impl_tridiagonal {
166175
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
167176
let (n, _) = lu.a.l.size();
168177
let ipiv = &lu.ipiv;
178+
let mut work = vec![Self::zero(); 2 * n as usize];
179+
$(
180+
let mut $iwork = vec![0; n as usize];
181+
)*
169182
let mut rcond = Self::Real::zero();
183+
let mut info = 0;
170184
$gtcon(
171185
NormType::One as u8,
172186
n,
@@ -177,8 +191,11 @@ macro_rules! impl_tridiagonal {
177191
ipiv,
178192
lu.a_opnorm_one,
179193
&mut rcond,
180-
)
181-
.as_lapack_result()?;
194+
&mut work,
195+
$(&mut $iwork,)*
196+
&mut info,
197+
);
198+
info.as_lapack_result()?;
182199
Ok(rcond)
183200
}
184201

@@ -192,27 +209,18 @@ macro_rules! impl_tridiagonal {
192209
let (_, nrhs) = bl.size();
193210
let ipiv = &lu.ipiv;
194211
let ldb = bl.lda();
212+
let mut info = 0;
195213
$gttrs(
196-
lu.a.l.lapacke_layout(),
197-
t as u8,
198-
n,
199-
nrhs,
200-
&lu.a.dl,
201-
&lu.a.d,
202-
&lu.a.du,
203-
&lu.du2,
204-
ipiv,
205-
b,
206-
ldb,
207-
)
208-
.as_lapack_result()?;
214+
t as u8, n, nrhs, &lu.a.dl, &lu.a.d, &lu.a.du, &lu.du2, ipiv, b, ldb, &mut info,
215+
);
216+
info.as_lapack_result()?;
209217
Ok(())
210218
}
211219
}
212220
};
213221
} // impl_tridiagonal!
214222

215-
impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs);
216-
impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs);
217-
impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs);
218-
impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs);
223+
impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs);
224+
impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs);
225+
impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs);
226+
impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs);

0 commit comments

Comments
 (0)