Skip to content

Commit f86d104

Browse files
authored
Merge pull request #235 from rust-ndarray/lapack-tridiagonal
Impl tridiagonal by LAPACK
2 parents e9c3481 + 2a6154f commit f86d104

File tree

1 file changed

+74
-45
lines changed

1 file changed

+74
-45
lines changed

lax/src/tridiagonal.rs

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//! for tridiagonal matrix
33
44
use super::*;
5-
use crate::{error::*, layout::MatrixLayout};
5+
use crate::{error::*, layout::*};
66
use cauchy::*;
77
use num_traits::Zero;
88
use std::ops::{Index, IndexMut};
@@ -130,11 +130,11 @@ impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
130130
pub trait Tridiagonal_: Scalar + Sized {
131131
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
132132
/// partial pivoting with row interchanges.
133-
unsafe fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
133+
fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
134134

135-
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
135+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
136136

137-
unsafe fn solve_tridiagonal(
137+
fn solve_tridiagonal(
138138
lu: &LUFactorizedTridiagonal<Self>,
139139
bl: MatrixLayout,
140140
t: Transpose,
@@ -143,18 +143,23 @@ pub trait Tridiagonal_: Scalar + Sized {
143143
}
144144

145145
macro_rules! impl_tridiagonal {
146-
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
146+
(@real, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
147+
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, iwork);
148+
};
149+
(@complex, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
150+
impl_tridiagonal!(@body, $scalar, $gttrf, $gtcon, $gttrs, );
151+
};
152+
(@body, $scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path, $($iwork:ident)*) => {
147153
impl Tridiagonal_ for $scalar {
148-
unsafe fn lu_tridiagonal(
149-
mut a: Tridiagonal<Self>,
150-
) -> Result<LUFactorizedTridiagonal<Self>> {
154+
fn lu_tridiagonal(mut a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>> {
151155
let (n, _) = a.l.size();
152156
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
153157
let mut ipiv = vec![0; n as usize];
154158
// We have to calc one-norm before LU factorization
155159
let a_opnorm_one = a.opnorm_one();
156-
$gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv)
157-
.as_lapack_result()?;
160+
let mut info = 0;
161+
unsafe { $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv, &mut info,) };
162+
info.as_lapack_result()?;
158163
Ok(LUFactorizedTridiagonal {
159164
a,
160165
du2,
@@ -163,56 +168,80 @@ macro_rules! impl_tridiagonal {
163168
})
164169
}
165170

166-
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
171+
fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
167172
let (n, _) = lu.a.l.size();
168173
let ipiv = &lu.ipiv;
174+
let mut work = vec![Self::zero(); 2 * n as usize];
175+
$(
176+
let mut $iwork = vec![0; n as usize];
177+
)*
169178
let mut rcond = Self::Real::zero();
170-
$gtcon(
171-
NormType::One as u8,
172-
n,
173-
&lu.a.dl,
174-
&lu.a.d,
175-
&lu.a.du,
176-
&lu.du2,
177-
ipiv,
178-
lu.a_opnorm_one,
179-
&mut rcond,
180-
)
181-
.as_lapack_result()?;
179+
let mut info = 0;
180+
unsafe {
181+
$gtcon(
182+
NormType::One as u8,
183+
n,
184+
&lu.a.dl,
185+
&lu.a.d,
186+
&lu.a.du,
187+
&lu.du2,
188+
ipiv,
189+
lu.a_opnorm_one,
190+
&mut rcond,
191+
&mut work,
192+
$(&mut $iwork,)*
193+
&mut info,
194+
);
195+
}
196+
info.as_lapack_result()?;
182197
Ok(rcond)
183198
}
184199

185-
unsafe fn solve_tridiagonal(
200+
fn solve_tridiagonal(
186201
lu: &LUFactorizedTridiagonal<Self>,
187-
bl: MatrixLayout,
202+
b_layout: MatrixLayout,
188203
t: Transpose,
189204
b: &mut [Self],
190205
) -> Result<()> {
191206
let (n, _) = lu.a.l.size();
192-
let (_, nrhs) = bl.size();
193207
let ipiv = &lu.ipiv;
194-
let ldb = bl.lda();
195-
$gttrs(
196-
lu.a.l.lapacke_layout(),
197-
t as u8,
198-
n,
199-
nrhs,
200-
&lu.a.dl,
201-
&lu.a.d,
202-
&lu.a.du,
203-
&lu.du2,
204-
ipiv,
205-
b,
206-
ldb,
207-
)
208-
.as_lapack_result()?;
208+
// Transpose if b is C-continuous
209+
let mut b_t = None;
210+
let b_layout = match b_layout {
211+
MatrixLayout::C { .. } => {
212+
b_t = Some(vec![Self::zero(); b.len()]);
213+
transpose(b_layout, b, b_t.as_mut().unwrap())
214+
}
215+
MatrixLayout::F { .. } => b_layout,
216+
};
217+
let (ldb, nrhs) = b_layout.size();
218+
let mut info = 0;
219+
unsafe {
220+
$gttrs(
221+
t as u8,
222+
n,
223+
nrhs,
224+
&lu.a.dl,
225+
&lu.a.d,
226+
&lu.a.du,
227+
&lu.du2,
228+
ipiv,
229+
b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b),
230+
ldb,
231+
&mut info,
232+
);
233+
}
234+
info.as_lapack_result()?;
235+
if let Some(b_t) = b_t {
236+
transpose(b_layout, &b_t, b);
237+
}
209238
Ok(())
210239
}
211240
}
212241
};
213242
} // impl_tridiagonal!
214243

215-
impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs);
216-
impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs);
217-
impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs);
218-
impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs);
244+
impl_tridiagonal!(@real, f64, lapack::dgttrf, lapack::dgtcon, lapack::dgttrs);
245+
impl_tridiagonal!(@real, f32, lapack::sgttrf, lapack::sgtcon, lapack::sgttrs);
246+
impl_tridiagonal!(@complex, c64, lapack::zgttrf, lapack::zgtcon, lapack::zgttrs);
247+
impl_tridiagonal!(@complex, c32, lapack::cgttrf, lapack::cgtcon, lapack::cgttrs);

0 commit comments

Comments
 (0)