@@ -28,6 +28,28 @@ pub struct Tridiagonal<A: Scalar> {
28
28
pub du : Vec < A > ,
29
29
}
30
30
31
+ impl < A : Scalar > Tridiagonal < A > {
32
+ fn opnorm_one ( & self ) -> A :: Real {
33
+ let n = self . l . len ( ) as usize ;
34
+ let mut col_sum: Vec < A :: Real > = self . d . iter ( ) . map ( |val| val. abs ( ) ) . collect ( ) ;
35
+ for i in 0 ..n - 1 {
36
+ if i < n - 1 {
37
+ col_sum[ i] += self . dl [ i + 1 ] . abs ( ) ;
38
+ }
39
+ if i > 0 {
40
+ col_sum[ i] += self . du [ i - 1 ] . abs ( ) ;
41
+ }
42
+ }
43
+ let mut max = A :: Real :: zero ( ) ;
44
+ for & val in & col_sum {
45
+ if max < val {
46
+ max = val;
47
+ }
48
+ }
49
+ max
50
+ }
51
+ }
52
+
31
53
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
32
54
#[ derive( Clone , PartialEq ) ]
33
55
pub struct LUFactorizedTridiagonal < A : Scalar > {
@@ -41,6 +63,8 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
41
63
pub du2 : Vec < A > ,
42
64
/// The pivot indices that define the permutation matrix `P`.
43
65
pub ipiv : Pivot ,
66
+
67
+ a_opnorm_one : A :: Real ,
44
68
}
45
69
46
70
impl < A : Scalar > Index < ( i32 , i32 ) > for Tridiagonal < A > {
@@ -66,6 +90,14 @@ impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
66
90
}
67
91
}
68
92
93
+ impl < A : Scalar > Index < [ i32 ; 2 ] > for Tridiagonal < A > {
94
+ type Output = A ;
95
+ #[ inline]
96
+ fn index ( & self , [ row, col] : [ i32 ; 2 ] ) -> & A {
97
+ & self [ ( row, col) ]
98
+ }
99
+ }
100
+
69
101
impl < A : Scalar > IndexMut < ( i32 , i32 ) > for Tridiagonal < A > {
70
102
#[ inline]
71
103
fn index_mut ( & mut self , ( row, col) : ( i32 , i32 ) ) -> & mut A {
@@ -88,11 +120,18 @@ impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
88
120
}
89
121
}
90
122
123
+ impl < A : Scalar > IndexMut < [ i32 ; 2 ] > for Tridiagonal < A > {
124
+ #[ inline]
125
+ fn index_mut ( & mut self , [ row, col] : [ i32 ; 2 ] ) -> & mut A {
126
+ & mut self [ ( row, col) ]
127
+ }
128
+ }
129
+
91
130
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
92
131
pub trait Tridiagonal_ : Scalar + Sized {
93
132
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
94
133
/// partial pivoting with row interchanges.
95
- unsafe fn lu_tridiagonal ( a : & mut Tridiagonal < Self > ) -> Result < ( Vec < Self > , Pivot ) > ;
134
+ unsafe fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
96
135
97
136
unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
98
137
@@ -107,19 +146,27 @@ pub trait Tridiagonal_: Scalar + Sized {
107
146
macro_rules! impl_tridiagonal {
108
147
( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
109
148
impl Tridiagonal_ for $scalar {
110
- unsafe fn lu_tridiagonal( a: & mut Tridiagonal <Self >) -> Result <( Vec <Self >, Pivot ) > {
149
+ unsafe fn lu_tridiagonal(
150
+ mut a: Tridiagonal <Self >,
151
+ ) -> Result <LUFactorizedTridiagonal <Self >> {
111
152
let ( n, _) = a. l. size( ) ;
112
153
let mut du2 = vec![ Zero :: zero( ) ; ( n - 2 ) as usize ] ;
113
154
let mut ipiv = vec![ 0 ; n as usize ] ;
155
+ // We have to calc one-norm before LU factorization
156
+ let a_opnorm_one = a. opnorm_one( ) ;
114
157
$gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv)
115
158
. as_lapack_result( ) ?;
116
- Ok ( ( du2, ipiv) )
159
+ Ok ( LUFactorizedTridiagonal {
160
+ a,
161
+ du2,
162
+ ipiv,
163
+ a_opnorm_one,
164
+ } )
117
165
}
118
166
119
167
unsafe fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
120
168
let ( n, _) = lu. a. l. size( ) ;
121
169
let ipiv = & lu. ipiv;
122
- let anorm = lu. anom;
123
170
let mut rcond = Self :: Real :: zero( ) ;
124
171
$gtcon(
125
172
NormType :: One as u8 ,
@@ -129,7 +176,7 @@ macro_rules! impl_tridiagonal {
129
176
& lu. a. du,
130
177
& lu. du2,
131
178
ipiv,
132
- anorm ,
179
+ lu . a_opnorm_one ,
133
180
& mut rcond,
134
181
)
135
182
. as_lapack_result( ) ?;
0 commit comments