2
2
//! for tridiagonal matrix
3
3
4
4
use super :: * ;
5
- use crate :: { error:: * , layout:: MatrixLayout } ;
5
+ use crate :: { error:: * , layout:: * } ;
6
6
use cauchy:: * ;
7
7
use num_traits:: Zero ;
8
8
use std:: ops:: { Index , IndexMut } ;
@@ -130,11 +130,11 @@ impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
130
130
pub trait Tridiagonal_ : Scalar + Sized {
131
131
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
132
132
/// partial pivoting with row interchanges.
133
- unsafe fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
133
+ fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
134
134
135
- unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
135
+ fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
136
136
137
- unsafe fn solve_tridiagonal (
137
+ fn solve_tridiagonal (
138
138
lu : & LUFactorizedTridiagonal < Self > ,
139
139
bl : MatrixLayout ,
140
140
t : Transpose ,
@@ -143,18 +143,23 @@ pub trait Tridiagonal_: Scalar + Sized {
143
143
}
144
144
145
145
macro_rules! impl_tridiagonal {
146
- ( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
146
+ ( @real, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
147
+ impl_tridiagonal!( @body, $scalar, $gttrf, $gtcon, $gttrs, iwork) ;
148
+ } ;
149
+ ( @complex, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
150
+ impl_tridiagonal!( @body, $scalar, $gttrf, $gtcon, $gttrs, ) ;
151
+ } ;
152
+ ( @body, $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path, $( $iwork: ident) * ) => {
147
153
impl Tridiagonal_ for $scalar {
148
- unsafe fn lu_tridiagonal(
149
- mut a: Tridiagonal <Self >,
150
- ) -> Result <LUFactorizedTridiagonal <Self >> {
154
+ fn lu_tridiagonal( mut a: Tridiagonal <Self >) -> Result <LUFactorizedTridiagonal <Self >> {
151
155
let ( n, _) = a. l. size( ) ;
152
156
let mut du2 = vec![ Zero :: zero( ) ; ( n - 2 ) as usize ] ;
153
157
let mut ipiv = vec![ 0 ; n as usize ] ;
154
158
// We have to calc one-norm before LU factorization
155
159
let a_opnorm_one = a. opnorm_one( ) ;
156
- $gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv)
157
- . as_lapack_result( ) ?;
160
+ let mut info = 0 ;
161
+ unsafe { $gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv, & mut info, ) } ;
162
+ info. as_lapack_result( ) ?;
158
163
Ok ( LUFactorizedTridiagonal {
159
164
a,
160
165
du2,
@@ -163,56 +168,80 @@ macro_rules! impl_tridiagonal {
163
168
} )
164
169
}
165
170
166
- unsafe fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
171
+ fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
167
172
let ( n, _) = lu. a. l. size( ) ;
168
173
let ipiv = & lu. ipiv;
174
+ let mut work = vec![ Self :: zero( ) ; 2 * n as usize ] ;
175
+ $(
176
+ let mut $iwork = vec![ 0 ; n as usize ] ;
177
+ ) *
169
178
let mut rcond = Self :: Real :: zero( ) ;
170
- $gtcon(
171
- NormType :: One as u8 ,
172
- n,
173
- & lu. a. dl,
174
- & lu. a. d,
175
- & lu. a. du,
176
- & lu. du2,
177
- ipiv,
178
- lu. a_opnorm_one,
179
- & mut rcond,
180
- )
181
- . as_lapack_result( ) ?;
179
+ let mut info = 0 ;
180
+ unsafe {
181
+ $gtcon(
182
+ NormType :: One as u8 ,
183
+ n,
184
+ & lu. a. dl,
185
+ & lu. a. d,
186
+ & lu. a. du,
187
+ & lu. du2,
188
+ ipiv,
189
+ lu. a_opnorm_one,
190
+ & mut rcond,
191
+ & mut work,
192
+ $( & mut $iwork, ) *
193
+ & mut info,
194
+ ) ;
195
+ }
196
+ info. as_lapack_result( ) ?;
182
197
Ok ( rcond)
183
198
}
184
199
185
- unsafe fn solve_tridiagonal(
200
+ fn solve_tridiagonal(
186
201
lu: & LUFactorizedTridiagonal <Self >,
187
- bl : MatrixLayout ,
202
+ b_layout : MatrixLayout ,
188
203
t: Transpose ,
189
204
b: & mut [ Self ] ,
190
205
) -> Result <( ) > {
191
206
let ( n, _) = lu. a. l. size( ) ;
192
- let ( _, nrhs) = bl. size( ) ;
193
207
let ipiv = & lu. ipiv;
194
- let ldb = bl. lda( ) ;
195
- $gttrs(
196
- lu. a. l. lapacke_layout( ) ,
197
- t as u8 ,
198
- n,
199
- nrhs,
200
- & lu. a. dl,
201
- & lu. a. d,
202
- & lu. a. du,
203
- & lu. du2,
204
- ipiv,
205
- b,
206
- ldb,
207
- )
208
- . as_lapack_result( ) ?;
208
+ // Transpose if b is C-continuous
209
+ let mut b_t = None ;
210
+ let b_layout = match b_layout {
211
+ MatrixLayout :: C { .. } => {
212
+ b_t = Some ( vec![ Self :: zero( ) ; b. len( ) ] ) ;
213
+ transpose( b_layout, b, b_t. as_mut( ) . unwrap( ) )
214
+ }
215
+ MatrixLayout :: F { .. } => b_layout,
216
+ } ;
217
+ let ( ldb, nrhs) = b_layout. size( ) ;
218
+ let mut info = 0 ;
219
+ unsafe {
220
+ $gttrs(
221
+ t as u8 ,
222
+ n,
223
+ nrhs,
224
+ & lu. a. dl,
225
+ & lu. a. d,
226
+ & lu. a. du,
227
+ & lu. du2,
228
+ ipiv,
229
+ b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ,
230
+ ldb,
231
+ & mut info,
232
+ ) ;
233
+ }
234
+ info. as_lapack_result( ) ?;
235
+ if let Some ( b_t) = b_t {
236
+ transpose( b_layout, & b_t, b) ;
237
+ }
209
238
Ok ( ( ) )
210
239
}
211
240
}
212
241
} ;
213
242
} // impl_tridiagonal!
214
243
215
- impl_tridiagonal ! ( f64 , lapacke :: dgttrf, lapacke :: dgtcon, lapacke :: dgttrs) ;
216
- impl_tridiagonal ! ( f32 , lapacke :: sgttrf, lapacke :: sgtcon, lapacke :: sgttrs) ;
217
- impl_tridiagonal ! ( c64, lapacke :: zgttrf, lapacke :: zgtcon, lapacke :: zgttrs) ;
218
- impl_tridiagonal ! ( c32, lapacke :: cgttrf, lapacke :: cgtcon, lapacke :: cgttrs) ;
244
+ impl_tridiagonal ! ( @real , f64 , lapack :: dgttrf, lapack :: dgtcon, lapack :: dgttrs) ;
245
+ impl_tridiagonal ! ( @real , f32 , lapack :: sgttrf, lapack :: sgtcon, lapack :: sgttrs) ;
246
+ impl_tridiagonal ! ( @complex , c64, lapack :: zgttrf, lapack :: zgtcon, lapack :: zgttrs) ;
247
+ impl_tridiagonal ! ( @complex , c32, lapack :: cgttrf, lapack :: cgtcon, lapack :: cgttrs) ;
0 commit comments