Skip to content

LU decomposition and inverse matrix #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/impl2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
pub mod opnorm;
pub mod qr;
pub mod svd;
pub mod solve;

pub use self::opnorm::*;
pub use self::qr::*;
pub use self::svd::*;
pub use self::solve::*;

use super::error::*;

pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ {}
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ {}

pub fn into_result<T>(info: i32, val: T) -> Result<T> {
if info == 0 {
Expand Down
58 changes: 58 additions & 0 deletions src/impl2/solve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

use lapack::c;

use types::*;
use error::*;
use layout::Layout;

use super::into_result;

pub type Pivot = Vec<i32>;

#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum Transpose {
No = b'N',
Transpose = b'T',
Hermite = b'C',
}

pub trait Solve_: Sized {
fn lu(Layout, a: &mut [Self]) -> Result<Pivot>;
fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;
fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solve {
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {

impl Solve_ for $scalar {
fn lu(l: Layout, a: &mut [Self]) -> Result<Pivot> {
let (row, col) = l.size();
let k = ::std::cmp::min(row, col);
let mut ipiv = vec![0; k as usize];
let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv);
into_result(info, ipiv)
}

fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv);
into_result(info, ())
}

fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
let nrhs = 1;
let ldb = 1;
let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
into_result(info, ())
}
}

}} // impl_solve!

impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs);
impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs);
impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs);
impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs);
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub mod impl2;
pub mod qr;
pub mod svd;
pub mod opnorm;
pub mod solve;

pub mod vector;
pub mod matrix;
Expand Down
1 change: 1 addition & 0 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pub use assert::*;
pub use qr::*;
pub use svd::*;
pub use opnorm::*;
pub use solve::*;
92 changes: 92 additions & 0 deletions src/solve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

use ndarray::*;
use super::layout::*;
use super::error::*;
use super::impl2::*;

pub use impl2::{Pivot, Transpose};

pub struct Factorized<S: Data> {
pub a: ArrayBase<S, Ix2>,
pub ipiv: Pivot,
}

impl<A, S> Factorized<S>
where A: LapackScalar,
S: Data<Elem = A>
{
pub fn solve<Sb>(&self, t: Transpose, mut rhs: ArrayBase<Sb, Ix1>) -> Result<ArrayBase<Sb, Ix1>>
where Sb: DataMut<Elem = A>
{
A::solve(self.a.square_layout()?,
t,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap())?;
Ok(rhs)
}
}

impl<A, S> Factorized<S>
where A: LapackScalar,
S: DataMut<Elem = A>
{
pub fn into_inverse(mut self) -> Result<ArrayBase<S, Ix2>> {
A::inv(self.a.square_layout()?,
self.a.as_allocated_mut()?,
&self.ipiv)?;
Ok(self.a)
}
}

pub trait Factorize<S: Data> {
fn factorize(self) -> Result<Factorized<S>>;
}

impl<A, S> Factorize<S> for ArrayBase<S, Ix2>
where A: LapackScalar,
S: DataMut<Elem = A>
{
fn factorize(mut self) -> Result<Factorized<S>> {
let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
Ok(Factorized {
a: self,
ipiv: ipiv,
})
}
}

impl<'a, A, S> Factorize<OwnedRepr<A>> for &'a ArrayBase<S, Ix2>
where A: LapackScalar + Clone,
S: Data<Elem = A>
{
fn factorize(self) -> Result<Factorized<OwnedRepr<A>>> {
let mut a = self.to_owned();
let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
Ok(Factorized { a: a, ipiv: ipiv })
}
}

pub trait Inverse<Inv> {
fn inv(self) -> Result<Inv>;
}

impl<A, S> Inverse<ArrayBase<S, Ix2>> for ArrayBase<S, Ix2>
where A: LapackScalar,
S: DataMut<Elem = A>
{
fn inv(self) -> Result<ArrayBase<S, Ix2>> {
let f = self.factorize()?;
f.into_inverse()
}
}

impl<'a, A, S> Inverse<Array2<A>> for &'a ArrayBase<S, Ix2>
where A: LapackScalar + Clone,
S: Data<Elem = A>
{
fn inv(self) -> Result<Array2<A>> {
let f = self.factorize()?;
f.into_inverse()
}
}
22 changes: 0 additions & 22 deletions src/square.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
//! Define trait for Hermite matrices

use ndarray::{Ix2, Array, RcArray, ArrayBase, Data};
use lapack::c::Layout;

use super::matrix::{Matrix, MFloat};
use super::error::{LinalgError, NotSquareError};
use super::impls::solve::ImplSolve;

/// Methods for square matrices
///
/// This trait defines method for square matrices,
/// but does not assure that the matrix is square.
/// If not square, `NotSquareError` will be thrown.
pub trait SquareMatrix: Matrix {
// fn eig(self) -> (Self::Vector, Self);
/// inverse matrix
fn inv(self) -> Result<Self, LinalgError>;
/// trace of matrix
fn trace(&self) -> Result<Self::Scalar, LinalgError>;
#[doc(hidden)]
Expand Down Expand Up @@ -46,30 +41,13 @@ fn trace<A: MFloat, S>(a: &ArrayBase<S, Ix2>) -> A
}

impl<A: MFloat> SquareMatrix for Array<A, Ix2> {
fn inv(self) -> Result<Self, LinalgError> {
self.check_square()?;
let (n, _) = self.size();
let layout = self.layout()?;
let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?;
let a = ImplSolve::inv(layout, n, a, &ipiv)?;
let m = Array::from_vec(a).into_shape((n, n)).unwrap();
match layout {
Layout::RowMajor => Ok(m),
Layout::ColumnMajor => Ok(m.reversed_axes()),
}
}
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
self.check_square()?;
Ok(trace(self))
}
}

impl<A: MFloat> SquareMatrix for RcArray<A, Ix2> {
fn inv(self) -> Result<Self, LinalgError> {
// XXX unnecessary clone (should use into_owned())
let i = self.to_owned().inv()?;
Ok(i.into_shared())
}
fn trace(&self) -> Result<Self::Scalar, LinalgError> {
self.check_square()?;
Ok(trace(self))
Expand Down