diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 89a38823..2c6fcbcc 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -1,6 +1,21 @@ pub mod opnorm; +pub mod qr; +pub mod svd; + pub use self::opnorm::*; +pub use self::qr::*; +pub use self::svd::*; + +use super::error::*; + +pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ {} +impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ {} -pub trait LapackScalar: OperatorNorm_ {} -impl LapackScalar for A where A: OperatorNorm_ {} +pub fn into_result(info: i32, val: T) -> Result { + if info == 0 { + Ok(val) + } else { + Err(LapackError::new(info).into()) + } +} diff --git a/src/impl2/opnorm.rs b/src/impl2/opnorm.rs index e2687ebf..b1efd9f7 100644 --- a/src/impl2/opnorm.rs +++ b/src/impl2/opnorm.rs @@ -4,7 +4,7 @@ use lapack::c; use lapack::c::Layout::ColumnMajor as cm; use types::*; -use layout::*; +use layout::Layout; #[repr(u8)] pub enum NormType { diff --git a/src/impl2/qr.rs b/src/impl2/qr.rs new file mode 100644 index 00000000..714135d0 --- /dev/null +++ b/src/impl2/qr.rs @@ -0,0 +1,49 @@ +//! Implement QR decomposition + +use std::cmp::min; +use num_traits::Zero; +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +use super::into_result; + +pub trait QR_: Sized { + fn householder(Layout, a: &mut [Self]) -> Result>; + fn q(Layout, a: &mut [Self], tau: &[Self]) -> Result<()>; + fn qr(Layout, a: &mut [Self]) -> Result>; +} + +macro_rules! impl_qr { + ($scalar:ty, $qrf:path, $gqr:path) => { +impl QR_ for $scalar { + fn householder(l: Layout, mut a: &mut [Self]) -> Result> { + let (row, col) = l.size(); + let k = min(row, col); + let mut tau = vec![Self::zero(); k as usize]; + let info = $qrf(l.lapacke_layout(), row, col, &mut a, l.lda(), &mut tau); + into_result(info, tau) + } + + fn q(l: Layout, mut a: &mut [Self], tau: &[Self]) -> Result<()> { + let (row, col) = l.size(); + let k = min(row, col); + let info = $gqr(l.lapacke_layout(), row, k, k, &mut a, l.lda(), &tau); + into_result(info, ()) + } + + fn qr(l: Layout, mut a: &mut [Self]) -> Result> { + let tau = Self::householder(l, a)?; + let r = Vec::from(&*a); + Self::q(l, a, &tau)?; + Ok(r) + } +} +}} // endmacro + +impl_qr!(f64, c::dgeqrf, c::dorgqr); +impl_qr!(f32, c::sgeqrf, c::sorgqr); +impl_qr!(c64, c::zgeqrf, c::zungqr); +impl_qr!(c32, c::cgeqrf, c::cungqr); diff --git a/src/impl2/svd.rs b/src/impl2/svd.rs new file mode 100644 index 00000000..1151f8c8 --- /dev/null +++ b/src/impl2/svd.rs @@ -0,0 +1,64 @@ +//! Implement Operator norms for matrices + +use lapack::c; +use num_traits::Zero; + +use types::*; +use error::*; +use layout::Layout; + +use super::into_result; + +#[repr(u8)] +enum FlagSVD { + All = b'A', + // OverWrite = b'O', + // Separately = b'S', + No = b'N', +} + +pub struct SVDOutput { + pub s: Vec, + pub u: Option>, + pub vt: Option>, +} + +pub trait SVD_: AssociatedReal { + fn svd(Layout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; +} + +macro_rules! impl_svd { + ($scalar:ty, $gesvd:path) => { + +impl SVD_ for $scalar { + fn svd(l: Layout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { + let (m, n) = l.size(); + let k = ::std::cmp::min(n, m); + let lda = l.lda(); + let (ju, ldu, mut u) = if calc_u { + (FlagSVD::All, m, vec![Self::zero(); (m*m) as usize]) + } else { + (FlagSVD::No, 0, Vec::new()) + }; + let (jvt, ldvt, mut vt) = if calc_vt { + (FlagSVD::All, n, vec![Self::zero(); (n*n) as usize]) + } else { + (FlagSVD::No, 0, Vec::new()) + }; + let mut s = vec![Self::Real::zero(); k as usize]; + let mut superb = vec![Self::Real::zero(); (k-2) as usize]; + let info = $gesvd(l.lapacke_layout(), ju as u8, jvt as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb); + into_result(info, SVDOutput { + s: s, + u: if ldu > 0 { Some(u) } else { None }, + vt: if ldvt > 0 { Some(vt) } else { None }, + }) + } +} + +}} // impl_svd! + +impl_svd!(f64, c::dgesvd); +impl_svd!(f32, c::sgesvd); +impl_svd!(c64, c::zgesvd); +impl_svd!(c32, c::cgesvd); diff --git a/src/layout.rs b/src/layout.rs index 66e56c19..6c098115 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -1,5 +1,6 @@ use ndarray::*; +use lapack::c; use super::error::*; @@ -7,6 +8,7 @@ pub type LDA = i32; pub type Col = i32; pub type Row = i32; +#[derive(Debug, Clone, Copy)] pub enum Layout { C((Row, LDA)), F((Col, LDA)), @@ -19,6 +21,27 @@ impl Layout { Layout::F((col, lda)) => (lda, col), } } + + pub fn resized(&self, row: Row, col: Col) -> Layout { + match *self { + Layout::C(_) => Layout::C((row, col)), + Layout::F(_) => Layout::F((col, row)), + } + } + + pub fn lda(&self) -> LDA { + match *self { + Layout::C((_, lda)) => lda, + Layout::F((_, lda)) => lda, + } + } + + pub fn lapacke_layout(&self) -> c::Layout { + match *self { + Layout::C(_) => c::Layout::RowMajor, + Layout::F(_) => c::Layout::ColumnMajor, + } + } } pub trait AllocatedArray { @@ -28,6 +51,10 @@ pub trait AllocatedArray { fn as_allocated(&self) -> Result<&[Self::Scalar]>; } +pub trait AllocatedArrayMut: AllocatedArray { + fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>; +} + impl AllocatedArray for ArrayBase where S: Data { @@ -60,3 +87,21 @@ impl AllocatedArray for ArrayBase Ok(slice) } } + +impl AllocatedArrayMut for ArrayBase + where S: DataMut +{ + fn as_allocated_mut(&mut self) -> Result<&mut [A]> { + let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?; + Ok(slice) + } +} + +pub fn reconstruct(l: Layout, a: Vec) -> Result> + where S: DataOwned +{ + Ok(match l { + Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?, + Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?, + }) +} diff --git a/src/lib.rs b/src/lib.rs index 1d4f537e..bda5683e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,9 @@ pub mod layout; pub mod impls; pub mod impl2; -pub mod traits; +pub mod qr; +pub mod svd; +pub mod opnorm; pub mod vector; pub mod matrix; diff --git a/src/matrix.rs b/src/matrix.rs index ea6c5f79..b135b6f7 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -6,12 +6,11 @@ use ndarray::DataMut; use lapack::c::Layout; use super::error::{LinalgError, StrideError}; -use super::impls::qr::ImplQR; use super::impls::svd::ImplSVD; use super::impls::solve::ImplSolve; -pub trait MFloat: ImplQR + ImplSVD + ImplSolve + NdFloat {} -impl MFloat for A {} +pub trait MFloat: ImplSVD + ImplSolve + NdFloat {} +impl MFloat for A {} /// Methods for general matrices pub trait Matrix: Sized { @@ -22,10 +21,6 @@ pub trait Matrix: Sized { fn size(&self) -> (usize, usize); /// Layout (C/Fortran) of matrix fn layout(&self) -> Result; - /// singular-value decomposition (SVD) - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError>; - /// QR decomposition - fn qr(self) -> Result<(Self, Self), LinalgError>; /// LU decomposition fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; /// permutate matrix (inplace) @@ -77,49 +72,6 @@ impl Matrix for Array { fn layout(&self) -> Result { check_layout(self.strides()) } - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { - let (n, m) = self.size(); - let layout = self.layout()?; - let (u, s, vt) = ImplSVD::svd(layout, m, n, self.clone().into_raw_vec())?; - let sv = Array::from_vec(s); - let ua = Array::from_vec(u).into_shape((n, n)).unwrap(); - let va = Array::from_vec(vt).into_shape((m, m)).unwrap(); - match layout { - Layout::RowMajor => Ok((ua, sv, va)), - Layout::ColumnMajor => Ok((ua.reversed_axes(), sv, va.reversed_axes())), - } - } - fn qr(self) -> Result<(Self, Self), LinalgError> { - let (n, m) = self.size(); - let strides = self.strides(); - let k = min(n, m); - let layout = self.layout()?; - let (q, r) = ImplQR::qr(layout, m, n, self.clone().into_raw_vec())?; - let (qa, ra) = if strides[0] < strides[1] { - (Array::from_vec(q).into_shape((m, n)).unwrap().reversed_axes(), - Array::from_vec(r).into_shape((m, n)).unwrap().reversed_axes()) - } else { - (Array::from_vec(q).into_shape((n, m)).unwrap(), Array::from_vec(r).into_shape((n, m)).unwrap()) - }; - let qm = if m > k { - let (qsl, _) = qa.view().split_at(Axis(1), k); - qsl.to_owned() - } else { - qa - }; - let mut rm = if n > k { - let (rsl, _) = ra.view().split_at(Axis(0), k); - rsl.to_owned() - } else { - ra - }; - for ((i, j), val) in rm.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - Ok((qm, rm)) - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (n, m) = self.size(); let k = min(n, m); @@ -163,14 +115,6 @@ impl Matrix for RcArray { fn layout(&self) -> Result { check_layout(self.strides()) } - fn svd(self) -> Result<(Self, Self::Vector, Self), LinalgError> { - let (u, s, v) = self.into_owned().svd()?; - Ok((u.into_shared(), s.into_shared(), v.into_shared())) - } - fn qr(self) -> Result<(Self, Self), LinalgError> { - let (q, r) = self.into_owned().qr()?; - Ok((q.into_shared(), r.into_shared())) - } fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { let (p, l, u) = self.into_owned().lu()?; Ok((p, l.into_shared(), u.into_shared())) diff --git a/src/traits.rs b/src/opnorm.rs similarity index 96% rename from src/traits.rs rename to src/opnorm.rs index 5dada9c6..e2996d22 100644 --- a/src/traits.rs +++ b/src/opnorm.rs @@ -1,13 +1,13 @@ -pub use impl2::LapackScalar; -pub use impl2::NormType; - use ndarray::*; use super::types::*; use super::error::*; use super::layout::*; +pub use impl2::NormType; +use impl2::LapackScalar; + pub trait OperationNorm { type Output; fn opnorm(&self, t: NormType) -> Self::Output; diff --git a/src/prelude.rs b/src/prelude.rs index 9fb296b2..06e15551 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,4 +5,7 @@ pub use hermite::HermiteMatrix; pub use triangular::*; pub use util::*; pub use assert::*; -pub use traits::*; + +pub use qr::*; +pub use svd::*; +pub use opnorm::*; diff --git a/src/qr.rs b/src/qr.rs new file mode 100644 index 00000000..032e1f72 --- /dev/null +++ b/src/qr.rs @@ -0,0 +1,83 @@ + +use num_traits::Zero; +use ndarray::*; + +use super::error::*; +use super::layout::*; + +use impl2::LapackScalar; + +pub trait QR { + fn qr(self) -> Result<(Q, R)>; +} + +impl QR, ArrayBase> for ArrayBase + where A: LapackScalar + Copy + Zero, + S: DataMut, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut +{ + fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { + (&mut self).qr() + } +} + +fn take_slice(a: &ArrayBase, n: usize, m: usize) -> ArrayBase + where A: Copy, + S1: Data, + S2: DataMut + DataOwned +{ + let av = a.slice(s![..n as isize, ..m as isize]); + let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; + a.assign(&av); + a +} + +fn take_slice_upper(a: &ArrayBase, n: usize, m: usize) -> ArrayBase + where A: Copy + Zero, + S1: Data, + S2: DataMut + DataOwned +{ + let av = a.slice(s![..n as isize, ..m as isize]); + let mut a = unsafe { ArrayBase::uninitialized((n, m)) }; + for ((i, j), val) in a.indexed_iter_mut() { + *val = if i <= j { av[(i, j)] } else { A::zero() }; + } + a +} + +impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a mut ArrayBase + where A: LapackScalar + Copy + Zero, + S: DataMut, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut +{ + fn qr(mut self) -> Result<(ArrayBase, ArrayBase)> { + let n = self.rows(); + let m = self.cols(); + let k = ::std::cmp::min(n, m); + let l = self.layout()?; + let r = A::qr(l, self.as_allocated_mut()?)?; + let r: Array2<_> = reconstruct(l, r)?; + let q = self; + Ok((take_slice(q, n, k), take_slice_upper(&r, k, m))) + } +} + +impl<'a, A, S, Sq, Sr> QR, ArrayBase> for &'a ArrayBase + where A: LapackScalar + Copy + Zero, + S: Data, + Sq: DataOwned + DataMut, + Sr: DataOwned + DataMut +{ + fn qr(self) -> Result<(ArrayBase, ArrayBase)> { + let n = self.rows(); + let m = self.cols(); + let k = ::std::cmp::min(n, m); + let l = self.layout()?; + let mut q = self.to_owned(); + let r = A::qr(l, q.as_allocated_mut()?)?; + let r: Array2<_> = reconstruct(l, r)?; + Ok((take_slice(&q, n, k), take_slice_upper(&r, k, m))) + } +} diff --git a/src/svd.rs b/src/svd.rs new file mode 100644 index 00000000..1aa1488c --- /dev/null +++ b/src/svd.rs @@ -0,0 +1,63 @@ + +use ndarray::*; + +use super::error::*; +use super::layout::*; +use impl2::LapackScalar; + +pub trait SVD { + fn svd(self, calc_u: bool, calc_vt: bool) -> Result<(Option, S, Option)>; +} + +impl SVD, ArrayBase, ArrayBase> for ArrayBase + where A: LapackScalar, + S: DataMut, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned +{ + fn svd(mut self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + (&mut self).svd(calc_u, calc_vt) + } +} + +impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase> for &'a ArrayBase + where A: LapackScalar + Clone, + S: Data, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned +{ + fn svd(self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + let a = self.to_owned(); + a.svd(calc_u, calc_vt) + } +} + +impl<'a, A, S, Su, Svt, Ss> SVD, ArrayBase, ArrayBase> + for &'a mut ArrayBase + where A: LapackScalar, + S: DataMut, + Su: DataOwned, + Svt: DataOwned, + Ss: DataOwned +{ + fn svd(mut self, + calc_u: bool, + calc_vt: bool) + -> Result<(Option>, ArrayBase, Option>)> { + let l = self.layout()?; + let svd_res = A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)?; + let (n, m) = l.size(); + let u = svd_res.u.map(|u| reconstruct(l.resized(n, n), u).unwrap()); + let vt = svd_res.vt.map(|vt| reconstruct(l.resized(m, m), vt).unwrap()); + let s = ArrayBase::from_vec(svd_res.s); + Ok((u, s, vt)) + } +} diff --git a/tests/qr.rs b/tests/qr.rs index c298f30b..3232782f 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -10,7 +10,7 @@ fn $funcname() { let a = $random($n, $m, $t); let ans = a.clone(); println!("a = \n{:?}", &a); - let (q, r) = a.qr().unwrap(); + let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap(); println!("q = \n{:?}", &q); println!("r = \n{:?}", &r); assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7); diff --git a/tests/svd.rs b/tests/svd.rs index 1a4d8b2c..ccd16048 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -9,11 +9,13 @@ fn $funcname() { use ndarray_linalg::prelude::*; let a = $random($n, $m, $t); let answer = a.clone(); - println!("a = \n{}", &a); - let (u, s, vt) = a.svd().unwrap(); - println!("u = \n{}", &u); - println!("s = \n{}", &s); - println!("v = \n{}", &vt); + println!("a = \n{:?}", &a); + let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap(); + let u: Array2<_> = u.unwrap(); + let vt: Array2<_> = vt.unwrap(); + println!("u = \n{:?}", &u); + println!("s = \n{:?}", &s); + println!("v = \n{:?}", &vt); let mut sm = Array::zeros(($n, $m)); for i in 0..min($n, $m) { sm[(i, i)] = s[i];