diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 2c6fcbcc..a05cf59c 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -2,15 +2,17 @@ pub mod opnorm; pub mod qr; pub mod svd; +pub mod solve; pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; +pub use self::solve::*; use super::error::*; -pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} -impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} +pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ {} +impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ {} pub fn into_result(info: i32, val: T) -> Result { if info == 0 { diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs new file mode 100644 index 00000000..d024786d --- /dev/null +++ b/src/impl2/solve.rs @@ -0,0 +1,58 @@ + +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +use super::into_result; + +pub type Pivot = Vec; + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} + +pub trait Solve_: Sized { + fn lu(Layout, a: &mut [Self]) -> Result; + fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; + fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solve { + ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => { + +impl Solve_ for $scalar { + fn lu(l: Layout, a: &mut [Self]) -> Result { + let (row, col) = l.size(); + let k = ::std::cmp::min(row, col); + let mut ipiv = vec![0; k as usize]; + let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv); + into_result(info, ipiv) + } + + fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + let (n, _) = l.size(); + let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv); + into_result(info, ()) + } + + fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let nrhs = 1; + let ldb = 1; + let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); + into_result(info, ()) + } +} + +}} // impl_solve! + +impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs); +impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs); +impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs); +impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs); diff --git a/src/lib.rs b/src/lib.rs index bda5683e..2de86470 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ pub mod impl2; pub mod qr; pub mod svd; pub mod opnorm; +pub mod solve; pub mod vector; pub mod matrix; diff --git a/src/prelude.rs b/src/prelude.rs index 06e15551..e18421f4 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -9,3 +9,4 @@ pub use assert::*; pub use qr::*; pub use svd::*; pub use opnorm::*; +pub use solve::*; diff --git a/src/solve.rs b/src/solve.rs new file mode 100644 index 00000000..3ac60ad3 --- /dev/null +++ b/src/solve.rs @@ -0,0 +1,92 @@ + +use ndarray::*; +use super::layout::*; +use super::error::*; +use super::impl2::*; + +pub use impl2::{Pivot, Transpose}; + +pub struct Factorized { + pub a: ArrayBase, + pub ipiv: Pivot, +} + +impl Factorized + where A: LapackScalar, + S: Data +{ + pub fn solve(&self, t: Transpose, mut rhs: ArrayBase) -> Result> + where Sb: DataMut + { + A::solve(self.a.square_layout()?, + t, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap())?; + Ok(rhs) + } +} + +impl Factorized + where A: LapackScalar, + S: DataMut +{ + pub fn into_inverse(mut self) -> Result> { + A::inv(self.a.square_layout()?, + self.a.as_allocated_mut()?, + &self.ipiv)?; + Ok(self.a) + } +} + +pub trait Factorize { + fn factorize(self) -> Result>; +} + +impl Factorize for ArrayBase + where A: LapackScalar, + S: DataMut +{ + fn factorize(mut self) -> Result> { + let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?; + Ok(Factorized { + a: self, + ipiv: ipiv, + }) + } +} + +impl<'a, A, S> Factorize> for &'a ArrayBase + where A: LapackScalar + Clone, + S: Data +{ + fn factorize(self) -> Result>> { + let mut a = self.to_owned(); + let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; + Ok(Factorized { a: a, ipiv: ipiv }) + } +} + +pub trait Inverse { + fn inv(self) -> Result; +} + +impl Inverse> for ArrayBase + where A: LapackScalar, + S: DataMut +{ + fn inv(self) -> Result> { + let f = self.factorize()?; + f.into_inverse() + } +} + +impl<'a, A, S> Inverse> for &'a ArrayBase + where A: LapackScalar + Clone, + S: Data +{ + fn inv(self) -> Result> { + let f = self.factorize()?; + f.into_inverse() + } +} diff --git a/src/square.rs b/src/square.rs index 709abd9a..78a12bda 100644 --- a/src/square.rs +++ b/src/square.rs @@ -1,11 +1,9 @@ //! Define trait for Hermite matrices use ndarray::{Ix2, Array, RcArray, ArrayBase, Data}; -use lapack::c::Layout; use super::matrix::{Matrix, MFloat}; use super::error::{LinalgError, NotSquareError}; -use super::impls::solve::ImplSolve; /// Methods for square matrices /// @@ -13,9 +11,6 @@ use super::impls::solve::ImplSolve; /// but does not assure that the matrix is square. /// If not square, `NotSquareError` will be thrown. pub trait SquareMatrix: Matrix { - // fn eig(self) -> (Self::Vector, Self); - /// inverse matrix - fn inv(self) -> Result; /// trace of matrix fn trace(&self) -> Result; #[doc(hidden)] @@ -46,18 +41,6 @@ fn trace(a: &ArrayBase) -> A } impl SquareMatrix for Array { - fn inv(self) -> Result { - self.check_square()?; - let (n, _) = self.size(); - let layout = self.layout()?; - let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?; - let a = ImplSolve::inv(layout, n, a, &ipiv)?; - let m = Array::from_vec(a).into_shape((n, n)).unwrap(); - match layout { - Layout::RowMajor => Ok(m), - Layout::ColumnMajor => Ok(m.reversed_axes()), - } - } fn trace(&self) -> Result { self.check_square()?; Ok(trace(self)) @@ -65,11 +48,6 @@ impl SquareMatrix for Array { } impl SquareMatrix for RcArray { - fn inv(self) -> Result { - // XXX unnecessary clone (should use into_owned()) - let i = self.to_owned().inv()?; - Ok(i.into_shared()) - } fn trace(&self) -> Result { self.check_square()?; Ok(trace(self))