Skip to content

Commit dc1a71c

Browse files
authored
Merge pull request #34 from termoshtt/solve
LU decomposition and inverse matrix
2 parents 403b24c + 3e90e15 commit dc1a71c

File tree

6 files changed

+156
-24
lines changed

6 files changed

+156
-24
lines changed

src/impl2/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
pub mod opnorm;
33
pub mod qr;
44
pub mod svd;
5+
pub mod solve;
56

67
pub use self::opnorm::*;
78
pub use self::qr::*;
89
pub use self::svd::*;
10+
pub use self::solve::*;
911

1012
use super::error::*;
1113

12-
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
13-
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
14+
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ {}
15+
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ {}
1416

1517
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
1618
if info == 0 {

src/impl2/solve.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
use lapack::c;
3+
4+
use types::*;
5+
use error::*;
6+
use layout::Layout;
7+
8+
use super::into_result;
9+
10+
pub type Pivot = Vec<i32>;
11+
12+
#[derive(Debug, Clone, Copy)]
13+
#[repr(u8)]
14+
pub enum Transpose {
15+
No = b'N',
16+
Transpose = b'T',
17+
Hermite = b'C',
18+
}
19+
20+
pub trait Solve_: Sized {
21+
fn lu(Layout, a: &mut [Self]) -> Result<Pivot>;
22+
fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;
23+
fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
24+
}
25+
26+
macro_rules! impl_solve {
27+
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
28+
29+
impl Solve_ for $scalar {
30+
fn lu(l: Layout, a: &mut [Self]) -> Result<Pivot> {
31+
let (row, col) = l.size();
32+
let k = ::std::cmp::min(row, col);
33+
let mut ipiv = vec![0; k as usize];
34+
let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv);
35+
into_result(info, ipiv)
36+
}
37+
38+
fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
39+
let (n, _) = l.size();
40+
let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv);
41+
into_result(info, ())
42+
}
43+
44+
fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
45+
let (n, _) = l.size();
46+
let nrhs = 1;
47+
let ldb = 1;
48+
let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
49+
into_result(info, ())
50+
}
51+
}
52+
53+
}} // impl_solve!
54+
55+
impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs);
56+
impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs);
57+
impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs);
58+
impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs);

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub mod impl2;
5151
pub mod qr;
5252
pub mod svd;
5353
pub mod opnorm;
54+
pub mod solve;
5455

5556
pub mod vector;
5657
pub mod matrix;

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ pub use assert::*;
99
pub use qr::*;
1010
pub use svd::*;
1111
pub use opnorm::*;
12+
pub use solve::*;

src/solve.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
use ndarray::*;
3+
use super::layout::*;
4+
use super::error::*;
5+
use super::impl2::*;
6+
7+
pub use impl2::{Pivot, Transpose};
8+
9+
pub struct Factorized<S: Data> {
10+
pub a: ArrayBase<S, Ix2>,
11+
pub ipiv: Pivot,
12+
}
13+
14+
impl<A, S> Factorized<S>
15+
where A: LapackScalar,
16+
S: Data<Elem = A>
17+
{
18+
pub fn solve<Sb>(&self, t: Transpose, mut rhs: ArrayBase<Sb, Ix1>) -> Result<ArrayBase<Sb, Ix1>>
19+
where Sb: DataMut<Elem = A>
20+
{
21+
A::solve(self.a.square_layout()?,
22+
t,
23+
self.a.as_allocated()?,
24+
&self.ipiv,
25+
rhs.as_slice_mut().unwrap())?;
26+
Ok(rhs)
27+
}
28+
}
29+
30+
impl<A, S> Factorized<S>
31+
where A: LapackScalar,
32+
S: DataMut<Elem = A>
33+
{
34+
pub fn into_inverse(mut self) -> Result<ArrayBase<S, Ix2>> {
35+
A::inv(self.a.square_layout()?,
36+
self.a.as_allocated_mut()?,
37+
&self.ipiv)?;
38+
Ok(self.a)
39+
}
40+
}
41+
42+
pub trait Factorize<S: Data> {
43+
fn factorize(self) -> Result<Factorized<S>>;
44+
}
45+
46+
impl<A, S> Factorize<S> for ArrayBase<S, Ix2>
47+
where A: LapackScalar,
48+
S: DataMut<Elem = A>
49+
{
50+
fn factorize(mut self) -> Result<Factorized<S>> {
51+
let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
52+
Ok(Factorized {
53+
a: self,
54+
ipiv: ipiv,
55+
})
56+
}
57+
}
58+
59+
impl<'a, A, S> Factorize<OwnedRepr<A>> for &'a ArrayBase<S, Ix2>
60+
where A: LapackScalar + Clone,
61+
S: Data<Elem = A>
62+
{
63+
fn factorize(self) -> Result<Factorized<OwnedRepr<A>>> {
64+
let mut a = self.to_owned();
65+
let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
66+
Ok(Factorized { a: a, ipiv: ipiv })
67+
}
68+
}
69+
70+
pub trait Inverse<Inv> {
71+
fn inv(self) -> Result<Inv>;
72+
}
73+
74+
impl<A, S> Inverse<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
75+
where A: LapackScalar,
76+
S: DataMut<Elem = A>
77+
{
78+
fn inv(self) -> Result<ArrayBase<S, Ix2>> {
79+
let f = self.factorize()?;
80+
f.into_inverse()
81+
}
82+
}
83+
84+
impl<'a, A, S> Inverse<Array2<A>> for &'a ArrayBase<S, Ix2>
85+
where A: LapackScalar + Clone,
86+
S: Data<Elem = A>
87+
{
88+
fn inv(self) -> Result<Array2<A>> {
89+
let f = self.factorize()?;
90+
f.into_inverse()
91+
}
92+
}

src/square.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
11
//! Define trait for Hermite matrices
22
33
use ndarray::{Ix2, Array, RcArray, ArrayBase, Data};
4-
use lapack::c::Layout;
54

65
use super::matrix::{Matrix, MFloat};
76
use super::error::{LinalgError, NotSquareError};
8-
use super::impls::solve::ImplSolve;
97

108
/// Methods for square matrices
119
///
1210
/// This trait defines method for square matrices,
1311
/// but does not assure that the matrix is square.
1412
/// If not square, `NotSquareError` will be thrown.
1513
pub trait SquareMatrix: Matrix {
16-
// fn eig(self) -> (Self::Vector, Self);
17-
/// inverse matrix
18-
fn inv(self) -> Result<Self, LinalgError>;
1914
/// trace of matrix
2015
fn trace(&self) -> Result<Self::Scalar, LinalgError>;
2116
#[doc(hidden)]
@@ -46,30 +41,13 @@ fn trace<A: MFloat, S>(a: &ArrayBase<S, Ix2>) -> A
4641
}
4742

4843
impl<A: MFloat> SquareMatrix for Array<A, Ix2> {
49-
fn inv(self) -> Result<Self, LinalgError> {
50-
self.check_square()?;
51-
let (n, _) = self.size();
52-
let layout = self.layout()?;
53-
let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?;
54-
let a = ImplSolve::inv(layout, n, a, &ipiv)?;
55-
let m = Array::from_vec(a).into_shape((n, n)).unwrap();
56-
match layout {
57-
Layout::RowMajor => Ok(m),
58-
Layout::ColumnMajor => Ok(m.reversed_axes()),
59-
}
60-
}
6144
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
6245
self.check_square()?;
6346
Ok(trace(self))
6447
}
6548
}
6649

6750
impl<A: MFloat> SquareMatrix for RcArray<A, Ix2> {
68-
fn inv(self) -> Result<Self, LinalgError> {
69-
// XXX unnecessary clone (should use into_owned())
70-
let i = self.to_owned().inv()?;
71-
Ok(i.into_shared())
72-
}
7351
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
7452
self.check_square()?;
7553
Ok(trace(self))

0 commit comments

Comments
 (0)