Skip to content

Commit 278a7bf

Browse files
authored
Merge pull request #67 from termoshtt/solveh
Solve Hermite/Symmetric linear problems
2 parents ad7624e + 4638183 commit 278a7bf

File tree

6 files changed

+248
-4
lines changed

6 files changed

+248
-4
lines changed

examples/solveh.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
extern crate ndarray;
3+
extern crate ndarray_linalg;
4+
5+
use ndarray::*;
6+
use ndarray_linalg::*;
7+
8+
// Solve `Ax=b` for Hermite matrix A
9+
fn solve() -> Result<(), error::LinalgError> {
10+
let a: Array2<c64> = random_hermite(3); // complex Hermite positive definite matrix
11+
let b: Array1<c64> = random(3);
12+
println!("b = {:?}", &b);
13+
let x = a.solveh(&b)?;
14+
println!("Ax = {:?}", a.dot(&x));;
15+
Ok(())
16+
}
17+
18+
// Solve `Ax=b` for many b with fixed A
19+
fn factorize() -> Result<(), error::LinalgError> {
20+
let a: Array2<f64> = random_hpd(3);
21+
let f = a.factorizeh_into()?;
22+
// once factorized, you can use it several times:
23+
for _ in 0..10 {
24+
let b: Array1<f64> = random(3);
25+
let _x = f.solveh_into(b)?;
26+
}
27+
Ok(())
28+
}
29+
30+
fn main() {
31+
solve().unwrap();
32+
factorize().unwrap();
33+
}

src/lapack_traits/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod opnorm;
44
pub mod qr;
55
pub mod svd;
66
pub mod solve;
7+
pub mod solveh;
78
pub mod cholesky;
89
pub mod eigh;
910
pub mod triangular;
@@ -13,14 +14,18 @@ pub use self::eigh::*;
1314
pub use self::opnorm::*;
1415
pub use self::qr::*;
1516
pub use self::solve::*;
17+
pub use self::solveh::*;
1618
pub use self::svd::*;
1719
pub use self::triangular::*;
1820

1921
use super::error::*;
2022
use super::types::*;
2123

24+
pub type Pivot = Vec<i32>;
25+
2226
pub trait LapackScalar
23-
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ {
27+
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_
28+
{
2429
}
2530

2631
impl LapackScalar for f32 {}

src/lapack_traits/solve.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ use error::*;
66
use layout::MatrixLayout;
77
use types::*;
88

9-
use super::{Transpose, into_result};
10-
11-
pub type Pivot = Vec<i32>;
9+
use super::{Pivot, Transpose, into_result};
1210

1311
/// Wraps `*getrf`, `*getri`, and `*getrs`
1412
pub trait Solve_: Sized {

src/lapack_traits/solveh.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method.
2+
//!
3+
//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html)
4+
5+
use lapack::c;
6+
7+
use error::*;
8+
use layout::MatrixLayout;
9+
use types::*;
10+
11+
use super::{Pivot, UPLO, into_result};
12+
13+
pub trait Solveh_: Sized {
14+
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
15+
unsafe fn bk(MatrixLayout, UPLO, a: &mut [Self]) -> Result<Pivot>;
16+
/// Wrapper of `*sytri` and `*hetri`
17+
unsafe fn invh(MatrixLayout, UPLO, a: &mut [Self], &Pivot) -> Result<()>;
18+
/// Wrapper of `*sytrs` and `*hetrs`
19+
unsafe fn solveh(MatrixLayout, UPLO, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
20+
}
21+
22+
macro_rules! impl_solveh {
23+
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
24+
25+
impl Solveh_ for $scalar {
26+
unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
27+
let (n, _) = l.size();
28+
let mut ipiv = vec![0; n as usize];
29+
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv);
30+
into_result(info, ipiv)
31+
}
32+
33+
unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
34+
let (n, _) = l.size();
35+
let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv);
36+
into_result(info, ())
37+
}
38+
39+
unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
40+
let (n, _) = l.size();
41+
let nrhs = 1;
42+
let ldb = 1;
43+
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
44+
into_result(info, ())
45+
}
46+
}
47+
48+
}} // impl_solveh!
49+
50+
impl_solveh!(f64, c::dsytrf, c::dsytri, c::dsytrs);
51+
impl_solveh!(f32, c::ssytrf, c::ssytri, c::ssytrs);
52+
impl_solveh!(c64, c::zhetrf, c::zhetri, c::zhetrs);
53+
impl_solveh!(c32, c::chetrf, c::chetri, c::chetrs);

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub mod operator;
4242
pub mod opnorm;
4343
pub mod qr;
4444
pub mod solve;
45+
pub mod solveh;
4546
pub mod svd;
4647
pub mod trace;
4748
pub mod triangular;
@@ -59,6 +60,7 @@ pub use operator::*;
5960
pub use opnorm::*;
6061
pub use qr::*;
6162
pub use solve::*;
63+
pub use solveh::*;
6264
pub use svd::*;
6365
pub use trace::*;
6466
pub use triangular::*;

src/solveh.rs

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//! Solve Hermite/Symmetric linear problems
2+
3+
use ndarray::*;
4+
5+
use super::convert::*;
6+
use super::error::*;
7+
use super::layout::*;
8+
use super::types::*;
9+
10+
pub use lapack_traits::{Pivot, UPLO};
11+
12+
pub trait SolveH<A: Scalar> {
13+
fn solveh<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
14+
let mut a = replicate(a);
15+
self.solveh_mut(&mut a)?;
16+
Ok(a)
17+
}
18+
fn solveh_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
19+
self.solveh_mut(&mut a)?;
20+
Ok(a)
21+
}
22+
fn solveh_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
23+
}
24+
25+
pub struct FactorizedH<S: Data> {
26+
pub a: ArrayBase<S, Ix2>,
27+
pub ipiv: Pivot,
28+
}
29+
30+
impl<A, S> SolveH<A> for FactorizedH<S>
31+
where
32+
A: Scalar,
33+
S: Data<Elem = A>,
34+
{
35+
fn solveh_mut<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
36+
where
37+
Sb: DataMut<Elem = A>,
38+
{
39+
unsafe {
40+
A::solveh(
41+
self.a.square_layout()?,
42+
UPLO::Upper,
43+
self.a.as_allocated()?,
44+
&self.ipiv,
45+
rhs.as_slice_mut().unwrap(),
46+
)?
47+
};
48+
Ok(rhs)
49+
}
50+
}
51+
52+
impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
53+
where
54+
A: Scalar,
55+
S: Data<Elem = A>,
56+
{
57+
fn solveh_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
58+
where
59+
Sb: DataMut<Elem = A>,
60+
{
61+
let f = self.factorizeh()?;
62+
f.solveh_mut(rhs)
63+
}
64+
}
65+
66+
67+
impl<A, S> FactorizedH<S>
68+
where
69+
A: Scalar,
70+
S: DataMut<Elem = A>,
71+
{
72+
pub fn into_inverseh(mut self) -> Result<ArrayBase<S, Ix2>> {
73+
unsafe {
74+
A::invh(
75+
self.a.square_layout()?,
76+
UPLO::Upper,
77+
self.a.as_allocated_mut()?,
78+
&self.ipiv,
79+
)?
80+
};
81+
Ok(self.a)
82+
}
83+
}
84+
85+
pub trait FactorizeH<S: Data> {
86+
fn factorizeh(&self) -> Result<FactorizedH<S>>;
87+
}
88+
89+
pub trait FactorizeHInto<S: Data> {
90+
fn factorizeh_into(self) -> Result<FactorizedH<S>>;
91+
}
92+
93+
impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
94+
where
95+
A: Scalar,
96+
S: DataMut<Elem = A>,
97+
{
98+
fn factorizeh_into(mut self) -> Result<FactorizedH<S>> {
99+
let ipiv = unsafe { A::bk(self.layout()?, UPLO::Upper, self.as_allocated_mut()?)? };
100+
Ok(FactorizedH {
101+
a: self,
102+
ipiv: ipiv,
103+
})
104+
}
105+
}
106+
107+
impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
108+
where
109+
A: Scalar,
110+
Si: Data<Elem = A>,
111+
{
112+
fn factorizeh(&self) -> Result<FactorizedH<OwnedRepr<A>>> {
113+
let mut a: Array2<A> = replicate(self);
114+
let ipiv = unsafe { A::bk(a.layout()?, UPLO::Upper, a.as_allocated_mut()?)? };
115+
Ok(FactorizedH { a: a, ipiv: ipiv })
116+
}
117+
}
118+
119+
pub trait InverseH {
120+
type Output;
121+
fn invh(&self) -> Result<Self::Output>;
122+
}
123+
124+
pub trait InverseHInto {
125+
type Output;
126+
fn invh_into(self) -> Result<Self::Output>;
127+
}
128+
129+
impl<A, S> InverseHInto for ArrayBase<S, Ix2>
130+
where
131+
A: Scalar,
132+
S: DataMut<Elem = A>,
133+
{
134+
type Output = Self;
135+
136+
fn invh_into(self) -> Result<Self::Output> {
137+
let f = self.factorizeh_into()?;
138+
f.into_inverseh()
139+
}
140+
}
141+
142+
impl<A, Si> InverseH for ArrayBase<Si, Ix2>
143+
where
144+
A: Scalar,
145+
Si: Data<Elem = A>,
146+
{
147+
type Output = Array2<A>;
148+
149+
fn invh(&self) -> Result<Self::Output> {
150+
let f = self.factorizeh()?;
151+
f.into_inverseh()
152+
}
153+
}

0 commit comments

Comments
 (0)