diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs
index 89a38823..2c6fcbcc 100644
--- a/src/impl2/mod.rs
+++ b/src/impl2/mod.rs
@@ -1,6 +1,21 @@
pub mod opnorm;
+pub mod qr;
+pub mod svd;
+
pub use self::opnorm::*;
+pub use self::qr::*;
+pub use self::svd::*;
+
+use super::error::*;
+
+pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {}
+impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {}
-pub trait LapackScalar: OperatorNorm_ {}
-impl LapackScalar for A where A: OperatorNorm_ {}
+pub fn into_result(info: i32, val: T) -> Result {
+ if info == 0 {
+ Ok(val)
+ } else {
+ Err(LapackError::new(info).into())
+ }
+}
diff --git a/src/impl2/opnorm.rs b/src/impl2/opnorm.rs
index e2687ebf..b1efd9f7 100644
--- a/src/impl2/opnorm.rs
+++ b/src/impl2/opnorm.rs
@@ -4,7 +4,7 @@ use lapack::c;
use lapack::c::Layout::ColumnMajor as cm;
use types::*;
-use layout::*;
+use layout::Layout;
#[repr(u8)]
pub enum NormType {
diff --git a/src/impl2/qr.rs b/src/impl2/qr.rs
new file mode 100644
index 00000000..714135d0
--- /dev/null
+++ b/src/impl2/qr.rs
@@ -0,0 +1,49 @@
+//! Implement QR decomposition
+
+use std::cmp::min;
+use num_traits::Zero;
+use lapack::c;
+
+use types::*;
+use error::*;
+use layout::Layout;
+
+use super::into_result;
+
+pub trait QR_: Sized {
+ fn householder(Layout, a: &mut [Self]) -> Result>;
+ fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>;
+ fn qr(Layout, a: &mut [Self]) -> Result>;
+}
+
+macro_rules! impl_qr {
+ ($scalar:ty, $qrf:path, $gqr:path) => {
+impl QR_ for $scalar {
+ fn householder(l: Layout, mut a: &mut [Self]) -> Result> {
+ let (row, col) = l.size();
+ let k = min(row, col);
+ let mut tau = vec![Self::zero(); k as usize];
+ let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau);
+ into_result(info, tau)
+ }
+
+ fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> {
+ let (row, col) = l.size();
+ let k = min(row, col);
+ let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau);
+ into_result(info, ())
+ }
+
+ fn qr(l: Layout, mut a: &mut [Self]) -> Result> {
+ let tau = Self::householder(l, a)?;
+ let r = Vec::from(&*a);
+ Self::q(l, a, &tau)?;
+ Ok(r)
+ }
+}
+}} // endmacro
+
+impl_qr!(f64, c::dgeqrf, c::dorgqr);
+impl_qr!(f32, c::sgeqrf, c::sorgqr);
+impl_qr!(c64, c::zgeqrf, c::zungqr);
+impl_qr!(c32, c::cgeqrf, c::cungqr);
diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs
new file mode 100644
index 00000000..1151f8c8
--- /dev/null
+++ b/src/impl2/svd.rs
@@ -0,0 +1,64 @@
+//! Implement Operator norms for matrices
+
+use lapack::c;
+use num_traits::Zero;
+
+use types::*;
+use error::*;
+use layout::Layout;
+
+use super::into_result;
+
+#[repr(u8)]
+enum FlagSVD {
+ All = b'A',
+ // OverWrite = b'O',
+ // Separately = b'S',
+ No = b'N',
+}
+
+pub struct SVDOutput {
+ pub s: Vec,
+ pub u: Option>,
+ pub vt: Option>,
+}
+
+pub trait SVD_: AssociatedReal {
+ fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>;
+}
+
+macro_rules! impl_svd {
+ ($scalar:ty, $gesvd:path) => {
+
+impl SVD_ for $scalar {
+ fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> {
+ let (m, n) = l.size();
+ let k = ::std::cmp::min(n, m);
+ let lda = l.lda();
+ let (ju, ldu, mut u) = if calc_u {
+ (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize])
+ } else {
+ (FlagSVD::No, 0, Vec::new())
+ };
+ let (jvt, ldvt, mut vt) = if calc_vt {
+ (FlagSVD::All, n, vec![Self::zero(); (n*n) as usize])
+ } else {
+ (FlagSVD::No, 0, Vec::new())
+ };
+ let mut s = vec![Self::Real::zero(); k as usize];
+ let mut superb = vec![Self::Real::zero(); (k-2) as usize];
+ let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
+ into_result(info, SVDOutput {
+ s: s,
+ u: if ldu > 0 { Some(u) } else { None },
+ vt: if ldvt > 0 { Some(vt) } else { None },
+ })
+ }
+}
+
+}} // impl_svd!
+
+impl_svd!(f64, c::dgesvd);
+impl_svd!(f32, c::sgesvd);
+impl_svd!(c64, c::zgesvd);
+impl_svd!(c32, c::cgesvd);
diff --git a/src/layout.rs b/src/layout.rs
index 66e56c19..6c098115 100644
--- a/src/layout.rs
+++ b/src/layout.rs
@@ -1,5 +1,6 @@
use ndarray::*;
+use lapack::c;
use super::error::*;
@@ -7,6 +8,7 @@ pub type LDA = i32;
pub type Col = i32;
pub type Row = i32;
+#[derive(Debug, Clone, Copy)]
pub enum Layout {
C((Row, LDA)),
F((Col, LDA)),
@@ -19,6 +21,27 @@ impl Layout {
Layout::F((col, lda)) => (lda, col),
}
}
+
+ pub fn resized(&self, row: Row, col: Col) -> Layout {
+ match *self {
+ Layout::C(_) => Layout::C((row, col)),
+ Layout::F(_) => Layout::F((col, row)),
+ }
+ }
+
+ pub fn lda(&self) -> LDA {
+ match *self {
+ Layout::C((_, lda)) => lda,
+ Layout::F((_, lda)) => lda,
+ }
+ }
+
+ pub fn lapacke_layout(&self) -> c::Layout {
+ match *self {
+ Layout::C(_) => c::Layout::RowMajor,
+ Layout::F(_) => c::Layout::ColumnMajor,
+ }
+ }
}
pub trait AllocatedArray {
@@ -28,6 +51,10 @@ pub trait AllocatedArray {
fn as_allocated(&self) -> Result<&[Self::Scalar]>;
}
+pub trait AllocatedArrayMut: AllocatedArray {
+ fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>;
+}
+
impl AllocatedArray for ArrayBase
where S: Data
{
@@ -60,3 +87,21 @@ impl AllocatedArray for ArrayBase
Ok(slice)
}
}
+
+impl AllocatedArrayMut for ArrayBase
+ where S: DataMut
+{
+ fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
+ let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?;
+ Ok(slice)
+ }
+}
+
+pub fn reconstruct(l: Layout, a: Vec) -> Result>
+ where S: DataOwned
+{
+ Ok(match l {
+ Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?,
+ Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?,
+ })
+}
diff --git a/src/lib.rs b/src/lib.rs
index 1d4f537e..bda5683e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -48,7 +48,9 @@ pub mod layout;
pub mod impls;
pub mod impl2;
-pub mod traits;
+pub mod qr;
+pub mod svd;
+pub mod opnorm;
pub mod vector;
pub mod matrix;
diff --git a/src/matrix.rs b/src/matrix.rs
index ea6c5f79..b135b6f7 100644
--- a/src/matrix.rs
+++ b/src/matrix.rs
@@ -6,12 +6,11 @@ use ndarray::DataMut;
use lapack::c::Layout;
use super::error::{LinalgError, StrideError};
-use super::impls::qr::ImplQR;
use super::impls::svd::ImplSVD;
use super::impls::solve::ImplSolve;
-pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {}
-impl MFloat for A {}
+pub trait MFloat: ImplSVD + ImplSolve + NdFloat {}
+impl MFloat for A {}
/// Methods for general matrices
pub trait Matrix: Sized {
@@ -22,10 +21,6 @@ pub trait Matrix: Sized {
fn size(&self) -> (usize, usize);
/// Layout (C/Fortran) of matrix
fn layout(&self) -> Result;
- /// singular-value decomposition (SVD)
- fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>;
- /// QR decomposition
- fn qr(self) -> Result<(Self, Self), LinalgError>;
/// LU decomposition
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>;
/// permutate matrix (inplace)
@@ -77,49 +72,6 @@ impl Matrix for Array {
fn layout(&self) -> Result {
check_layout(self.strides())
}
- fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
- let (n, m) = self.size();
- let layout = self.layout()?;
- let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?;
- let sv = Array::from_vec(s);
- let ua = Array::from_vec(u).into_shape((n, n)).unwrap();
- let va = Array::from_vec(vt).into_shape((m, m)).unwrap();
- match layout {
- Layout::RowMajor => Ok((ua, sv, va)),
- Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())),
- }
- }
- fn qr(self) -> Result<(Self, Self), LinalgError> {
- let (n, m) = self.size();
- let strides = self.strides();
- let k = min(n, m);
- let layout = self.layout()?;
- let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?;
- let (qa, ra) = if strides[0] < strides[1] {
- (Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(),
- Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes())
- } else {
- (Array::from_vec(q).into_shape((n, m)).unwrap(), Array::from_vec(r).into_shape((n, m)).unwrap())
- };
- let qm = if m > k {
- let (qsl, _) = qa.view().split_at(Axis(1), k);
- qsl.to_owned()
- } else {
- qa
- };
- let mut rm = if n > k {
- let (rsl, _) = ra.view().split_at(Axis(0), k);
- rsl.to_owned()
- } else {
- ra
- };
- for ((i, j), val) in rm.indexed_iter_mut() {
- if i > j {
- *val = A::zero();
- }
- }
- Ok((qm, rm))
- }
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
let (n, m) = self.size();
let k = min(n, m);
@@ -163,14 +115,6 @@ impl Matrix for RcArray {
fn layout(&self) -> Result {
check_layout(self.strides())
}
- fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> {
- let (u, s, v) = self.into_owned().svd()?;
- Ok((u.into_shared(), s.into_shared(), v.into_shared()))
- }
- fn qr(self) -> Result<(Self, Self), LinalgError> {
- let (q, r) = self.into_owned().qr()?;
- Ok((q.into_shared(), r.into_shared()))
- }
fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> {
let (p, l, u) = self.into_owned().lu()?;
Ok((p, l.into_shared(), u.into_shared()))
diff --git a/src/traits.rs b/src/opnorm.rs
similarity index 96%
rename from src/traits.rs
rename to src/opnorm.rs
index 5dada9c6..e2996d22 100644
--- a/src/traits.rs
+++ b/src/opnorm.rs
@@ -1,13 +1,13 @@
-pub use impl2::LapackScalar;
-pub use impl2::NormType;
-
use ndarray::*;
use super::types::*;
use super::error::*;
use super::layout::*;
+pub use impl2::NormType;
+use impl2::LapackScalar;
+
pub trait OperationNorm {
type Output;
fn opnorm(&self, t: NormType) -> Self::Output;
diff --git a/src/prelude.rs b/src/prelude.rs
index 9fb296b2..06e15551 100644
--- a/src/prelude.rs
+++ b/src/prelude.rs
@@ -5,4 +5,7 @@ pub use hermite::HermiteMatrix;
pub use triangular::*;
pub use util::*;
pub use assert::*;
-pub use traits::*;
+
+pub use qr::*;
+pub use svd::*;
+pub use opnorm::*;
diff --git a/src/qr.rs b/src/qr.rs
new file mode 100644
index 00000000..032e1f72
--- /dev/null
+++ b/src/qr.rs
@@ -0,0 +1,83 @@
+
+use num_traits::Zero;
+use ndarray::*;
+
+use super::error::*;
+use super::layout::*;
+
+use impl2::LapackScalar;
+
+pub trait QR {
+ fn qr(self) -> Result<(Q, R)>;
+}
+
+impl QR, ArrayBase> for ArrayBase
+ where A: LapackScalar + Copy + Zero,
+ S: DataMut,
+ Sq: DataOwned + DataMut,
+ Sr: DataOwned + DataMut
+{
+ fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> {
+ (&mut self).qr()
+ }
+}
+
+fn take_slice(a: &ArrayBase, n: usize, m: usize) -> ArrayBase
+ where A: Copy,
+ S1: Data,
+ S2: DataMut + DataOwned
+{
+ let av = a.slice(s![..n as isize, ..m as isize]);
+ let mut a = unsafe { ArrayBase::uninitialized((n, m)) };
+ a.assign(&av);
+ a
+}
+
+fn take_slice_upper(a: &ArrayBase, n: usize, m: usize) -> ArrayBase
+ where A: Copy + Zero,
+ S1: Data,
+ S2: DataMut + DataOwned
+{
+ let av = a.slice(s![..n as isize, ..m as isize]);
+ let mut a = unsafe { ArrayBase::uninitialized((n, m)) };
+ for ((i, j), val) in a.indexed_iter_mut() {
+ *val = if i <= j { av[(i, j)] } else { A::zero() };
+ }
+ a
+}
+
+impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a mut ArrayBase
+ where A: LapackScalar + Copy + Zero,
+ S: DataMut,
+ Sq: DataOwned + DataMut,
+ Sr: DataOwned + DataMut
+{
+ fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> {
+ let n = self.rows();
+ let m = self.cols();
+ let k = ::std::cmp::min(n, m);
+ let l = self.layout()?;
+ let r = A::qr(l, self.as_allocated_mut()?)?;
+ let r: Array2<_> = reconstruct(l, r)?;
+ let q = self;
+ Ok((take_slice(q, n, k), take_slice_upper(&r, k, m)))
+ }
+}
+
+impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a ArrayBase
+ where A: LapackScalar + Copy + Zero,
+ S: Data,
+ Sq: DataOwned + DataMut,
+ Sr: DataOwned + DataMut
+{
+ fn qr(self) -> Result<(ArrayBase, ArrayBase)> {
+ let n = self.rows();
+ let m = self.cols();
+ let k = ::std::cmp::min(n, m);
+ let l = self.layout()?;
+ let mut q = self.to_owned();
+ let r = A::qr(l, q.as_allocated_mut()?)?;
+ let r: Array2<_> = reconstruct(l, r)?;
+ Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m)))
+ }
+}
diff --git a/src/svd.rs b/src/svd.rs
new file mode 100644
index 00000000..1aa1488c
--- /dev/null
+++ b/src/svd.rs
@@ -0,0 +1,63 @@
+
+use ndarray::*;
+
+use super::error::*;
+use super::layout::*;
+use impl2::LapackScalar;
+
+pub trait SVD {
+ fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option, S, Option)>;
+}
+
+impl SVD, ArrayBase, ArrayBase> for ArrayBase
+ where A: LapackScalar,
+ S: DataMut,
+ Su: DataOwned,
+ Svt: DataOwned,
+ Ss: DataOwned
+{
+ fn svd(mut self,
+ calc_u: bool,
+ calc_vt: bool)
+ -> Result<(Option>, ArrayBase, Option>)> {
+ (&mut self).svd(calc_u, calc_vt)
+ }
+}
+
+impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase> for &'a ArrayBase
+ where A: LapackScalar + Clone,
+ S: Data,
+ Su: DataOwned,
+ Svt: DataOwned,
+ Ss: DataOwned
+{
+ fn svd(self,
+ calc_u: bool,
+ calc_vt: bool)
+ -> Result<(Option>, ArrayBase, Option>)> {
+ let a = self.to_owned();
+ a.svd(calc_u, calc_vt)
+ }
+}
+
+impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase>
+ for &'a mut ArrayBase
+ where A: LapackScalar,
+ S: DataMut,
+ Su: DataOwned,
+ Svt: DataOwned,
+ Ss: DataOwned
+{
+ fn svd(mut self,
+ calc_u: bool,
+ calc_vt: bool)
+ -> Result<(Option>, ArrayBase, Option>)> {
+ let l = self.layout()?;
+ let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?;
+ let (n, m) = l.size();
+ let u = svd_res.u.map(|u| reconstruct(l.resized(n, n), u).unwrap());
+ let vt = svd_res.vt.map(|vt| reconstruct(l.resized(m, m), vt).unwrap());
+ let s = ArrayBase::from_vec(svd_res.s);
+ Ok((u, s, vt))
+ }
+}
diff --git a/tests/qr.rs b/tests/qr.rs
index c298f30b..3232782f 100644
--- a/tests/qr.rs
+++ b/tests/qr.rs
@@ -10,7 +10,7 @@ fn $funcname() {
let a = $random($n, $m, $t);
let ans = a.clone();
println!("a = \n{:?}", &a);
- let (q, r) = a.qr().unwrap();
+ let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap();
println!("q = \n{:?}", &q);
println!("r = \n{:?}", &r);
assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7);
diff --git a/tests/svd.rs b/tests/svd.rs
index 1a4d8b2c..ccd16048 100644
--- a/tests/svd.rs
+++ b/tests/svd.rs
@@ -9,11 +9,13 @@ fn $funcname() {
use ndarray_linalg::prelude::*;
let a = $random($n, $m, $t);
let answer = a.clone();
- println!("a = \n{}", &a);
- let (u, s, vt) = a.svd().unwrap();
- println!("u = \n{}", &u);
- println!("s = \n{}", &s);
- println!("v = \n{}", &vt);
+ println!("a = \n{:?}", &a);
+ let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
+ let u: Array2<_> = u.unwrap();
+ let vt: Array2<_> = vt.unwrap();
+ println!("u = \n{:?}", &u);
+ println!("s = \n{:?}", &s);
+ println!("v = \n{:?}", &vt);
let mut sm = Array::zeros(($n, $m));
for i in 0..min($n, $m) {
sm[(i, i)] = s[i];