diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs
index 2c6fcbcc..a05cf59c 100644
--- a/src/impl2/mod.rs
+++ b/src/impl2/mod.rs
@@ -2,15 +2,17 @@
pub mod opnorm;
pub mod qr;
pub mod svd;
+pub mod solve;
pub use self::opnorm::*;
pub use self::qr::*;
pub use self::svd::*;
+pub use self::solve::*;
use super::error::*;
-pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
-impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
+pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ {}
+impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ {}
pub fn into_result(info: i32, val: T) -> Result {
if info == 0 {
diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs
new file mode 100644
index 00000000..d024786d
--- /dev/null
+++ b/src/impl2/solve.rs
@@ -0,0 +1,58 @@
+
+use lapack::c;
+
+use types::*;
+use error::*;
+use layout::Layout;
+
+use super::into_result;
+
+pub type Pivot = Vec;
+
+#[derive(Debug, Clone, Copy)]
+#[repr(u8)]
+pub enum Transpose {
+ No = b'N',
+ Transpose = b'T',
+ Hermite = b'C',
+}
+
+pub trait Solve_: Sized {
+ fn lu(Layout, a: &mut [Self]) -> Result;
+ fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>;
+ fn solve(Layout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
+}
+
+macro_rules! impl_solve {
+ ($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
+
+impl Solve_ for $scalar {
+ fn lu(l: Layout, a: &mut [Self]) -> Result {
+ let (row, col) = l.size();
+ let k = ::std::cmp::min(row, col);
+ let mut ipiv = vec![0; k as usize];
+ let info = $getrf(l.lapacke_layout(), row, col, a, l.lda(), &mut ipiv);
+ into_result(info, ipiv)
+ }
+
+ fn inv(l: Layout, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
+ let (n, _) = l.size();
+ let info = $getri(l.lapacke_layout(), n, a, l.lda(), ipiv);
+ into_result(info, ())
+ }
+
+ fn solve(l: Layout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
+ let (n, _) = l.size();
+ let nrhs = 1;
+ let ldb = 1;
+ let info = $getrs(l.lapacke_layout(), t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
+ into_result(info, ())
+ }
+}
+
+}} // impl_solve!
+
+impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs);
+impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs);
+impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs);
+impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs);
diff --git a/src/lib.rs b/src/lib.rs
index bda5683e..2de86470 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -51,6 +51,7 @@ pub mod impl2;
pub mod qr;
pub mod svd;
pub mod opnorm;
+pub mod solve;
pub mod vector;
pub mod matrix;
diff --git a/src/prelude.rs b/src/prelude.rs
index 06e15551..e18421f4 100644
--- a/src/prelude.rs
+++ b/src/prelude.rs
@@ -9,3 +9,4 @@ pub use assert::*;
pub use qr::*;
pub use svd::*;
pub use opnorm::*;
+pub use solve::*;
diff --git a/src/solve.rs b/src/solve.rs
new file mode 100644
index 00000000..3ac60ad3
--- /dev/null
+++ b/src/solve.rs
@@ -0,0 +1,92 @@
+
+use ndarray::*;
+use super::layout::*;
+use super::error::*;
+use super::impl2::*;
+
+pub use impl2::{Pivot, Transpose};
+
+pub struct Factorized {
+ pub a: ArrayBase,
+ pub ipiv: Pivot,
+}
+
+impl Factorized
+ where A: LapackScalar,
+ S: Data
+{
+ pub fn solve(&self, t: Transpose, mut rhs: ArrayBase) -> Result>
+ where Sb: DataMut
+ {
+ A::solve(self.a.square_layout()?,
+ t,
+ self.a.as_allocated()?,
+ &self.ipiv,
+ rhs.as_slice_mut().unwrap())?;
+ Ok(rhs)
+ }
+}
+
+impl Factorized
+ where A: LapackScalar,
+ S: DataMut
+{
+ pub fn into_inverse(mut self) -> Result> {
+ A::inv(self.a.square_layout()?,
+ self.a.as_allocated_mut()?,
+ &self.ipiv)?;
+ Ok(self.a)
+ }
+}
+
+pub trait Factorize {
+ fn factorize(self) -> Result>;
+}
+
+impl Factorize for ArrayBase
+ where A: LapackScalar,
+ S: DataMut
+{
+ fn factorize(mut self) -> Result> {
+ let ipiv = A::lu(self.layout()?, self.as_allocated_mut()?)?;
+ Ok(Factorized {
+ a: self,
+ ipiv: ipiv,
+ })
+ }
+}
+
+impl<'a, A, S> Factorize> for &'a ArrayBase
+ where A: LapackScalar + Clone,
+ S: Data
+{
+ fn factorize(self) -> Result>> {
+ let mut a = self.to_owned();
+ let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?;
+ Ok(Factorized { a: a, ipiv: ipiv })
+ }
+}
+
+pub trait Inverse {
+ fn inv(self) -> Result;
+}
+
+impl Inverse> for ArrayBase
+ where A: LapackScalar,
+ S: DataMut
+{
+ fn inv(self) -> Result> {
+ let f = self.factorize()?;
+ f.into_inverse()
+ }
+}
+
+impl<'a, A, S> Inverse> for &'a ArrayBase
+ where A: LapackScalar + Clone,
+ S: Data
+{
+ fn inv(self) -> Result> {
+ let f = self.factorize()?;
+ f.into_inverse()
+ }
+}
diff --git a/src/square.rs b/src/square.rs
index 709abd9a..78a12bda 100644
--- a/src/square.rs
+++ b/src/square.rs
@@ -1,11 +1,9 @@
//! Define trait for Hermite matrices
use ndarray::{Ix2, Array, RcArray, ArrayBase, Data};
-use lapack::c::Layout;
use super::matrix::{Matrix, MFloat};
use super::error::{LinalgError, NotSquareError};
-use super::impls::solve::ImplSolve;
/// Methods for square matrices
///
@@ -13,9 +11,6 @@ use super::impls::solve::ImplSolve;
/// but does not assure that the matrix is square.
/// If not square, `NotSquareError` will be thrown.
pub trait SquareMatrix: Matrix {
- // fn eig(self) -> (Self::Vector, Self);
- /// inverse matrix
- fn inv(self) -> Result;
/// trace of matrix
fn trace(&self) -> Result;
#[doc(hidden)]
@@ -46,18 +41,6 @@ fn trace(a: &ArrayBase) -> A
}
impl SquareMatrix for Array {
- fn inv(self) -> Result {
- self.check_square()?;
- let (n, _) = self.size();
- let layout = self.layout()?;
- let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?;
- let a = ImplSolve::inv(layout, n, a, &ipiv)?;
- let m = Array::from_vec(a).into_shape((n, n)).unwrap();
- match layout {
- Layout::RowMajor => Ok(m),
- Layout::ColumnMajor => Ok(m.reversed_axes()),
- }
- }
fn trace(&self) -> Result {
self.check_square()?;
Ok(trace(self))
@@ -65,11 +48,6 @@ impl SquareMatrix for Array {
}
impl SquareMatrix for RcArray {
- fn inv(self) -> Result {
- // XXX unnecessary clone (should use into_owned())
- let i = self.to_owned().inv()?;
- Ok(i.into_shared())
- }
fn trace(&self) -> Result {
self.check_square()?;
Ok(trace(self))