Skip to content

Commit 49e12ca

Browse files
committed
Add wrapper for *{sy,he}{trf,tri,trs}
1 parent ad7624e commit 49e12ca

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

src/lapack_traits/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod opnorm;
44
pub mod qr;
55
pub mod svd;
66
pub mod solve;
7+
pub mod solveh;
78
pub mod cholesky;
89
pub mod eigh;
910
pub mod triangular;
@@ -13,12 +14,15 @@ pub use self::eigh::*;
1314
pub use self::opnorm::*;
1415
pub use self::qr::*;
1516
pub use self::solve::*;
17+
pub use self::solveh::*;
1618
pub use self::svd::*;
1719
pub use self::triangular::*;
1820

1921
use super::error::*;
2022
use super::types::*;
2123

24+
pub type Pivot = Vec<i32>;
25+
2226
pub trait LapackScalar
2327
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ {
2428
}

src/lapack_traits/solve.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ use error::*;
66
use layout::MatrixLayout;
77
use types::*;
88

9-
use super::{Transpose, into_result};
10-
11-
pub type Pivot = Vec<i32>;
9+
use super::{Pivot, Transpose, into_result};
1210

1311
/// Wraps `*getrf`, `*getri`, and `*getrs`
1412
pub trait Solve_: Sized {

src/lapack_traits/solveh.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method.
2+
//!
3+
//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html)
4+
5+
use lapack::c;
6+
7+
use error::*;
8+
use layout::MatrixLayout;
9+
use types::*;
10+
11+
use super::{Pivot, UPLO, into_result};
12+
13+
pub trait Solveh_: Sized {
14+
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
15+
unsafe fn bk(MatrixLayout, UPLO, a: &mut [Self]) -> Result<Pivot>;
16+
/// Wrapper of `*sytri` and `*hetri`
17+
unsafe fn inv(MatrixLayout, UPLO, a: &mut [Self], &Pivot) -> Result<()>;
18+
/// Wrapper of `*sytrs` and `*hetrs`
19+
unsafe fn solve(MatrixLayout, UPLO, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
20+
}
21+
22+
macro_rules! impl_solveh {
23+
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
24+
25+
impl Solveh_ for $scalar {
26+
unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
27+
let (n, _) = l.size();
28+
let mut ipiv = vec![0; n as usize];
29+
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv);
30+
into_result(info, ipiv)
31+
}
32+
33+
unsafe fn inv(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
34+
let (n, _) = l.size();
35+
let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv);
36+
into_result(info, ())
37+
}
38+
39+
unsafe fn solve(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
40+
let (n, _) = l.size();
41+
let nrhs = 1;
42+
let ldb = 1;
43+
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
44+
into_result(info, ())
45+
}
46+
}
47+
48+
}} // impl_solveh!
49+
50+
impl_solveh!(f64, c::dsytrf, c::dsytri, c::dsytrs);
51+
impl_solveh!(f32, c::ssytrf, c::ssytri, c::ssytrs);
52+
impl_solveh!(c64, c::zhetrf, c::zhetri, c::zhetrs);
53+
impl_solveh!(c32, c::chetrf, c::chetri, c::chetrs);

0 commit comments

Comments
 (0)