Skip to content

Commit bc7fa0f

Browse files
committed
Restore rcond_tridiagonal
1 parent ba56561 commit bc7fa0f

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

lax/src/tridiagonal.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ pub struct Tridiagonal<A: Scalar> {
2828
pub du: Vec<A>,
2929
}
3030

31+
impl<A: Scalar> Tridiagonal<A> {
32+
fn opnorm_one(&self) -> A::Real {
33+
let n = self.l.len() as usize;
34+
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
35+
for i in 0..n - 1 {
36+
if i < n - 1 {
37+
col_sum[i] += self.dl[i + 1].abs();
38+
}
39+
if i > 0 {
40+
col_sum[i] += self.du[i - 1].abs();
41+
}
42+
}
43+
let mut max = A::Real::zero();
44+
for &val in &col_sum {
45+
if max < val {
46+
max = val;
47+
}
48+
}
49+
max
50+
}
51+
}
52+
3153
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
3254
#[derive(Clone, PartialEq)]
3355
pub struct LUFactorizedTridiagonal<A: Scalar> {
@@ -41,6 +63,8 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
4163
pub du2: Vec<A>,
4264
/// The pivot indices that define the permutation matrix `P`.
4365
pub ipiv: Pivot,
66+
67+
a_opnorm_one: A::Real,
4468
}
4569

4670
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
@@ -66,6 +90,14 @@ impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
6690
}
6791
}
6892

93+
impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
94+
type Output = A;
95+
#[inline]
96+
fn index(&self, [row, col]: [i32; 2]) -> &A {
97+
&self[(row, col)]
98+
}
99+
}
100+
69101
impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
70102
#[inline]
71103
fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
@@ -88,11 +120,18 @@ impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
88120
}
89121
}
90122

123+
impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
124+
#[inline]
125+
fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
126+
&mut self[(row, col)]
127+
}
128+
}
129+
91130
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
92131
pub trait Tridiagonal_: Scalar + Sized {
93132
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
94133
/// partial pivoting with row interchanges.
95-
unsafe fn lu_tridiagonal(a: &mut Tridiagonal<Self>) -> Result<(Vec<Self>, Pivot)>;
134+
unsafe fn lu_tridiagonal(a: Tridiagonal<Self>) -> Result<LUFactorizedTridiagonal<Self>>;
96135

97136
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
98137

@@ -107,19 +146,27 @@ pub trait Tridiagonal_: Scalar + Sized {
107146
macro_rules! impl_tridiagonal {
108147
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
109148
impl Tridiagonal_ for $scalar {
110-
unsafe fn lu_tridiagonal(a: &mut Tridiagonal<Self>) -> Result<(Vec<Self>, Pivot)> {
149+
unsafe fn lu_tridiagonal(
150+
mut a: Tridiagonal<Self>,
151+
) -> Result<LUFactorizedTridiagonal<Self>> {
111152
let (n, _) = a.l.size();
112153
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
113154
let mut ipiv = vec![0; n as usize];
155+
// We have to calc one-norm before LU factorization
156+
let a_opnorm_one = a.opnorm_one();
114157
$gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv)
115158
.as_lapack_result()?;
116-
Ok((du2, ipiv))
159+
Ok(LUFactorizedTridiagonal {
160+
a,
161+
du2,
162+
ipiv,
163+
a_opnorm_one,
164+
})
117165
}
118166

119167
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
120168
let (n, _) = lu.a.l.size();
121169
let ipiv = &lu.ipiv;
122-
let anorm = lu.anom;
123170
let mut rcond = Self::Real::zero();
124171
$gtcon(
125172
NormType::One as u8,
@@ -129,7 +176,7 @@ macro_rules! impl_tridiagonal {
129176
&lu.a.du,
130177
&lu.du2,
131178
ipiv,
132-
anorm,
179+
lu.a_opnorm_one,
133180
&mut rcond,
134181
)
135182
.as_lapack_result()?;

ndarray-linalg/src/tridiagonal.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,8 @@ impl<A> FactorizeTridiagonalInto<A> for Tridiagonal<A>
565565
where
566566
A: Scalar + Lapack,
567567
{
568-
fn factorize_tridiagonal_into(mut self) -> Result<LUFactorizedTridiagonal<A>> {
569-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? };
570-
Ok(LUFactorizedTridiagonal { a: self, du2, ipiv })
568+
fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>> {
569+
Ok(unsafe { A::lu_tridiagonal(self)? })
571570
}
572571
}
573572

@@ -576,9 +575,8 @@ where
576575
A: Scalar + Lapack,
577576
{
578577
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
579-
let mut a = self.clone();
580-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
581-
Ok(LUFactorizedTridiagonal { a, du2, ipiv })
578+
let a = self.clone();
579+
Ok(unsafe { A::lu_tridiagonal(a)? })
582580
}
583581
}
584582

@@ -588,9 +586,8 @@ where
588586
S: Data<Elem = A>,
589587
{
590588
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
591-
let mut a = self.extract_tridiagonal()?;
592-
let (du2, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
593-
Ok(LUFactorizedTridiagonal { a, du2, ipiv })
589+
let a = self.extract_tridiagonal()?;
590+
Ok(unsafe { A::lu_tridiagonal(a)? })
594591
}
595592
}
596593

@@ -680,7 +677,7 @@ where
680677
A: Scalar + Lapack,
681678
{
682679
fn rcond_tridiagonal(&self) -> Result<A::Real> {
683-
unsafe { A::rcond_tridiagonal(&self) }
680+
unsafe { Ok(A::rcond_tridiagonal(&self)?) }
684681
}
685682
}
686683

0 commit comments

Comments
 (0)