diff --git a/src/generate.rs b/src/generate.rs index 54211c9b..d1d22ef0 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -19,8 +19,18 @@ pub fn conjugate(a: &ArrayBase) -> ArrayBase a } +/// Random vector +pub fn random_vector(n: usize) -> ArrayBase + where A: RandNormal, + S: DataOwned +{ + let mut rng = thread_rng(); + let v: Vec = (0..n).map(|_| A::randn(&mut rng)).collect(); + ArrayBase::from_vec(v) +} + /// Random matrix -pub fn random(n: usize, m: usize) -> ArrayBase +pub fn random_matrix(n: usize, m: usize) -> ArrayBase where A: RandNormal, S: DataOwned { @@ -34,7 +44,7 @@ pub fn random_square(n: usize) -> ArrayBase where A: RandNormal, S: DataOwned { - random(n, n) + random_matrix(n, n) } /// Random Hermite matrix diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index fd5dba5e..efd80e7c 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -5,6 +5,7 @@ pub mod svd; pub mod solve; pub mod cholesky; pub mod eigh; +pub mod triangular; pub use self::opnorm::*; pub use self::qr::*; @@ -12,6 +13,7 @@ pub use self::svd::*; pub use self::solve::*; pub use self::cholesky::*; pub use self::eigh::*; +pub use self::triangular::*; use super::error::*; @@ -20,7 +22,8 @@ trait_alias!(LapackScalar: OperatorNorm_, SVD_, Solve_, Cholesky_, - Eigh_); + Eigh_, + Triangular_); pub fn into_result(info: i32, val: T) -> Result { if info == 0 { @@ -36,3 +39,11 @@ pub enum UPLO { Upper = b'U', Lower = b'L', } + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs index d024786d..8a2df165 100644 --- a/src/impl2/solve.rs +++ b/src/impl2/solve.rs @@ -5,18 +5,10 @@ use types::*; use error::*; use layout::Layout; -use super::into_result; +use super::{Transpose, into_result}; pub type Pivot = Vec; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - pub trait Solve_: Sized { fn lu(Layout, a: &mut [Self]) -> Result; fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs new file mode 100644 index 00000000..ed37ba8e --- /dev/null +++ b/src/impl2/triangular.rs @@ -0,0 +1,54 @@ +//! Implement linear solver and inverse matrix + +use lapack::c; + +use error::*; +use types::*; +use layout::Layout; +use super::{UPLO, Transpose, into_result}; + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Diag { + Unit = b'U', + NonUnit = b'N', +} + +pub trait Triangular_: Sized { + fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>; + fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_triangular { + ($scalar:ty, $trtri:path, $trtrs:path) => { + +impl Triangular_ for $scalar { + fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let lda = l.lda(); + let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); + into_result(info, ()) + } + + fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> { + let (n, _) = al.size(); + let lda = al.lda(); + let (_, nrhs) = bl.size(); + let ldb = bl.lda(); + println!("al = {:?}", al); + println!("bl = {:?}", bl); + println!("n = {}", n); + println!("lda = {}", lda); + println!("nrhs = {}", nrhs); + println!("ldb = {}", ldb); + let info = $trtrs(al.lapacke_layout(), uplo as u8, Transpose::No as u8, diag as u8, n, nrhs, a, lda, &mut b, ldb); + into_result(info, ()) + } +} + +}} // impl_triangular! + +impl_triangular!(f64, c::dtrtri, c::dtrtrs); +impl_triangular!(f32, c::strtri, c::strtrs); +impl_triangular!(c64, c::ztrtri, c::ztrtrs); +impl_triangular!(c32, c::ctrtri, c::ctrtrs); diff --git a/src/impls/mod.rs b/src/impls/mod.rs deleted file mode 100644 index ce926498..00000000 --- a/src/impls/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Implement trait bindings of LAPACK -pub mod solve; diff --git a/src/impls/solve.rs b/src/impls/solve.rs deleted file mode 100644 index 364bc468..00000000 --- a/src/impls/solve.rs +++ /dev/null @@ -1,87 +0,0 @@ -//! Implement linear solver and inverse matrix - -use lapack::c::*; -use std::cmp::min; - -use error::LapackError; - -pub trait ImplSolve: Sized { - /// execute LU decomposition - fn lu(layout: Layout, m: usize, n: usize, a: Vec) -> Result<(Vec, Vec), LapackError>; - /// calc inverse matrix with LU factorized matrix - fn inv(layout: Layout, size: usize, a: Vec, ipiv: &Vec) -> Result, LapackError>; - /// solve linear problem with LU factorized matrix - fn solve(layout: Layout, - size: usize, - a: &Vec, - ipiv: &Vec, - b: Vec) - -> Result, LapackError>; - /// solve triangular linear problem - fn solve_triangle<'a, 'b>(layout: Layout, - uplo: u8, - size: usize, - a: &'a [Self], - b: &'b mut [Self], - nrhs: i32) - -> Result<&'b mut [Self], LapackError>; -} - -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path, $trtrs:path) => { -impl ImplSolve for $scalar { - fn lu(layout: Layout, m: usize, n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { - let m = m as i32; - let n = n as i32; - let k = min(m, n); - let lda = match layout { - Layout::ColumnMajor => m, - Layout::RowMajor => n, - }; - let mut ipiv = vec![0; k as usize]; - let info = $getrf(layout, m, n, &mut a, lda, &mut ipiv); - if info == 0 { - Ok((ipiv, a)) - } else { - Err(From::from(info)) - } - } - fn inv(layout: Layout, size: usize, mut a: Vec, ipiv: &Vec) -> Result, LapackError> { - let n = size as i32; - let lda = n; - let info = $getri(layout, n, &mut a, lda, &ipiv); - if info == 0 { - Ok(a) - } else { - Err(From::from(info)) - } - } - fn solve(layout: Layout, size: usize, a: &Vec, ipiv: &Vec, mut b: Vec) -> Result, LapackError> { - let n = size as i32; - let lda = n; - let info = $getrs(layout, 'N' as u8, n, 1, a, lda, ipiv, &mut b, n); - if info == 0 { - Ok(b) - } else { - Err(From::from(info)) - } - } - fn solve_triangle<'a, 'b>(layout: Layout, uplo: u8, size: usize, a: &'a [Self], mut b: &'b mut [Self], nrhs: i32) -> Result<&'b mut [Self], LapackError> { - let n = size as i32; - let lda = n; - let ldb = match layout { - Layout::ColumnMajor => n, - Layout::RowMajor => 1, - }; - let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, nrhs, a, lda, &mut b, ldb); - if info == 0 { - Ok(b) - } else { - Err(From::from(info)) - } - } -} -}} // end macro_rules - -impl_solve!(f64, dgetrf, dgetri, dgetrs, dtrtrs); -impl_solve!(f32, sgetrf, sgetri, sgetrs, strtrs); diff --git a/src/layout.rs b/src/layout.rs index 8ea71138..93783853 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -5,10 +5,11 @@ use lapack::c; use super::error::*; pub type LDA = i32; +pub type LEN = i32; pub type Col = i32; pub type Row = i32; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Layout { C((Row, LDA)), F((Col, LDA)), @@ -36,40 +37,65 @@ impl Layout { } } + pub fn len(&self) -> LEN { + match *self { + Layout::C((row, _)) => row, + Layout::F((col, _)) => col, + } + } + pub fn lapacke_layout(&self) -> c::Layout { match *self { Layout::C(_) => c::Layout::RowMajor, Layout::F(_) => c::Layout::ColumnMajor, } } + + pub fn same_order(&self, other: &Layout) -> bool { + self.lapacke_layout() == other.lapacke_layout() + } + + pub fn as_shape(&self) -> Shape { + match *self { + Layout::C((row, col)) => (row as usize, col as usize).into_shape(), + Layout::F((col, row)) => (row as usize, col as usize).f().into_shape(), + } + } + + pub fn toggle_order(&self) -> Self { + match *self { + Layout::C((row, col)) => Layout::F((col, row)), + Layout::F((col, row)) => Layout::C((row, col)), + } + } } pub trait AllocatedArray { - type Scalar; + type Elem; fn layout(&self) -> Result; fn square_layout(&self) -> Result; - fn as_allocated(&self) -> Result<&[Self::Scalar]>; + fn as_allocated(&self) -> Result<&[Self::Elem]>; } pub trait AllocatedArrayMut: AllocatedArray { - fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>; + fn as_allocated_mut(&mut self) -> Result<&mut [Self::Elem]>; } impl AllocatedArray for ArrayBase where S: Data { - type Scalar = A; + type Elem = A; fn layout(&self) -> Result { + let shape = self.shape(); let strides = self.strides(); - if ::std::cmp::min(strides[0], strides[1]) != 1 { - return Err(StrideError::new(strides[0], strides[1]).into()); + if shape[0] == strides[1] as usize { + return Ok(Layout::F((self.cols() as i32, self.rows() as i32))); } - if strides[0] > strides[1] { - Ok(Layout::C((self.rows() as i32, self.cols() as i32))) - } else { - Ok(Layout::F((self.cols() as i32, self.rows() as i32))) + if shape[1] == strides[0] as usize { + return Ok(Layout::C((self.rows() as i32, self.cols() as i32))); } + Err(StrideError::new(strides[0], strides[1]).into()) } fn square_layout(&self) -> Result { @@ -83,8 +109,7 @@ impl AllocatedArray for ArrayBase } fn as_allocated(&self) -> Result<&[A]> { - let slice = self.as_slice_memory_order().ok_or(MemoryContError::new())?; - Ok(slice) + Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?) } } @@ -92,18 +117,42 @@ impl AllocatedArrayMut for ArrayBase where S: DataMut { fn as_allocated_mut(&mut self) -> Result<&mut [A]> { - let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?; - Ok(slice) + Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?) } } +pub fn into_col_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((n, 1)).unwrap() +} + +pub fn into_row_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((1, n)).unwrap() +} + +pub fn into_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((n)).unwrap() +} + pub fn reconstruct(l: Layout, a: Vec) -> Result> where S: DataOwned { - Ok(match l { - Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?, - Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?, - }) + Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) +} + +pub fn uninitialized(l: Layout) -> ArrayBase + where A: Copy, + S: DataOwned +{ + unsafe { ArrayBase::uninitialized(l.as_shape()) } } pub fn replicate(a: &ArrayBase) -> ArrayBase @@ -116,3 +165,23 @@ pub fn replicate(a: &ArrayBase) -> ArrayBase b.assign(a); b } + +pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayBase + where A: Copy, + Si: Data, + So: DataOwned + DataMut +{ + let mut b = uninitialized(l); + b.assign(a); + b +} + +pub fn data_transpose(a: &mut ArrayBase) -> Result<&mut ArrayBase> + where A: Copy, + S: DataOwned + DataMut +{ + let l = a.layout()?.toggle_order(); + let new = clone_with_layout(l, a); + ::std::mem::replace(a, new); + Ok(a) +} diff --git a/src/lib.rs b/src/lib.rs index f654d8a9..981475dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,6 @@ extern crate derive_new; pub mod types; pub mod error; pub mod layout; -pub mod impls; pub mod impl2; pub mod qr; @@ -56,13 +55,11 @@ pub mod opnorm; pub mod solve; pub mod cholesky; pub mod eigh; - -pub mod matrix; -pub mod square; pub mod triangular; pub mod generate; pub mod assert; pub mod norm; +pub mod trace; pub mod prelude; diff --git a/src/matrix.rs b/src/matrix.rs deleted file mode 100644 index 82a8c665..00000000 --- a/src/matrix.rs +++ /dev/null @@ -1,124 +0,0 @@ -//! Define trait for general matrix - -use std::cmp::min; -use ndarray::*; -use ndarray::DataMut; -use lapack::c::Layout; - -use super::error::{LinalgError, StrideError}; -use super::impls::solve::ImplSolve; - -pub trait MFloat: ImplSolve + NdFloat {} -impl MFloat for A {} - -/// Methods for general matrices -pub trait Matrix: Sized { - type Scalar; - type Vector; - type Permutator; - /// number of (rows, columns) - fn size(&self) -> (usize, usize); - /// Layout (C/Fortran) of matrix - fn layout(&self) -> Result; - /// LU decomposition - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; - /// permutate matrix (inplace) - fn permutate(&mut self, p: &Self::Permutator); - /// permutate matrix (outplace) - fn permutated(mut self, p: &Self::Permutator) -> Self { - self.permutate(p); - self - } -} - -fn check_layout(strides: &[Ixs]) -> Result { - if min(strides[0], strides[1]) != 1 { - return Err(StrideError { - s0: strides[0], - s1: strides[1], - });; - } - if strides[0] < strides[1] { - Ok(Layout::ColumnMajor) - } else { - Ok(Layout::RowMajor) - } -} - -fn permutate(mut a: &mut ArrayBase, ipiv: &Vec) - where S: DataMut -{ - let m = a.cols(); - for (i, j_) in ipiv.iter().enumerate().rev() { - let j = (j_ - 1) as usize; - if i == j { - continue; - } - for k in 0..m { - a.swap((i, k), (j, k)); - } - } -} - -impl Matrix for Array { - type Scalar = A; - type Vector = Array; - type Permutator = Vec; - - fn size(&self) -> (usize, usize) { - (self.rows(), self.cols()) - } - fn layout(&self) -> Result { - check_layout(self.strides()) - } - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { - let (n, m) = self.size(); - let k = min(n, m); - let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; - let mut a = match self.layout()? { - Layout::ColumnMajor => Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes(), - Layout::RowMajor => Array::from_vec(l).into_shape((n, m)).unwrap(), - }; - let mut lm = Array::zeros((n, k)); - for ((i, j), val) in lm.indexed_iter_mut() { - if i > j { - *val = a[(i, j)]; - } else if i == j { - *val = A::one(); - } - } - for ((i, j), val) in a.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - let am = if n > k { - a.slice(s![0..k as isize, ..]).to_owned() - } else { - a - }; - Ok((p, lm, am)) - } - fn permutate(&mut self, ipiv: &Self::Permutator) { - permutate(self, ipiv); - } -} - -impl Matrix for RcArray { - type Scalar = A; - type Vector = RcArray; - type Permutator = Vec; - fn size(&self) -> (usize, usize) { - (self.rows(), self.cols()) - } - fn layout(&self) -> Result { - check_layout(self.strides()) - } - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { - let (p, l, u) = self.into_owned().lu()?; - Ok((p, l.into_shared(), u.into_shared())) - } - fn permutate(&mut self, ipiv: &Self::Permutator) { - permutate(self, ipiv); - } -} diff --git a/src/prelude.rs b/src/prelude.rs index f34cb921..9f1910d4 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,14 +1,15 @@ -pub use matrix::Matrix; -pub use square::SquareMatrix; -pub use triangular::*; -pub use norm::*; -pub use types::*; -pub use generate::*; + pub use assert::*; +pub use generate::*; +pub use types::*; +pub use layout::*; -pub use qr::*; -pub use svd::*; +pub use cholesky::*; +pub use eigh::*; +pub use norm::*; pub use opnorm::*; +pub use qr::*; pub use solve::*; -pub use eigh::*; -pub use cholesky::*; +pub use svd::*; +pub use trace::*; +pub use triangular::*; diff --git a/src/square.rs b/src/square.rs deleted file mode 100644 index 78a12bda..00000000 --- a/src/square.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! Define trait for Hermite matrices - -use ndarray::{Ix2, Array, RcArray, ArrayBase, Data}; - -use super::matrix::{Matrix, MFloat}; -use super::error::{LinalgError, NotSquareError}; - -/// Methods for square matrices -/// -/// This trait defines method for square matrices, -/// but does not assure that the matrix is square. -/// If not square, `NotSquareError` will be thrown. -pub trait SquareMatrix: Matrix { - /// trace of matrix - fn trace(&self) -> Result; - #[doc(hidden)] - fn check_square(&self) -> Result<(), NotSquareError> { - let (rows, cols) = self.size(); - if rows == cols { - Ok(()) - } else { - Err(NotSquareError { - rows: rows as i32, - cols: cols as i32, - }) - } - } - /// test matrix is square and return its size - fn square_size(&self) -> Result { - self.check_square()?; - let (n, _) = self.size(); - Ok(n) - } -} - -fn trace(a: &ArrayBase) -> A - where S: Data -{ - let n = a.rows(); - (0..n).fold(A::zero(), |sum, i| sum + a[(i, i)]) -} - -impl SquareMatrix for Array { - fn trace(&self) -> Result { - self.check_square()?; - Ok(trace(self)) - } -} - -impl SquareMatrix for RcArray { - fn trace(&self) -> Result { - self.check_square()?; - Ok(trace(self)) - } -} diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 00000000..86995161 --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,23 @@ + +use ndarray::*; + +use super::types::*; +use super::error::*; +use super::layout::*; + +pub trait Trace { + type Output; + fn trace(&self) -> Result; +} + +impl Trace for ArrayBase + where A: Field, + S: Data +{ + type Output = A; + + fn trace(&self) -> Result { + let (n, _) = self.square_layout()?.size(); + Ok((0..n as usize).map(|i| self[(i, i)]).sum()) + } +} diff --git a/src/triangular.rs b/src/triangular.rs index cae6b64a..7c8002d4 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -2,114 +2,90 @@ use ndarray::*; use num_traits::Zero; -use super::impl2::UPLO; -use super::matrix::{Matrix, MFloat}; -use super::square::SquareMatrix; -use super::error::LinalgError; -use super::generate::hstack; -use super::impls::solve::ImplSolve; +use super::layout::*; +use super::error::*; +use super::impl2::*; -pub trait SolveTriangular: Matrix + SquareMatrix { +pub use super::impl2::Diag; + +/// solve a triangular system with upper triangular matrix +pub trait SolveTriangular { type Output; - /// solve a triangular system with upper triangular matrix - fn solve_upper(&self, Rhs) -> Result; - /// solve a triangular system with lower triangular matrix - fn solve_lower(&self, Rhs) -> Result; + fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result; } -impl SolveTriangular> for ArrayBase - where A: MFloat, - S1: Data, - S2: DataMut, - ArrayBase: Matrix + SquareMatrix +impl SolveTriangular> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned { - type Output = ArrayBase; - fn solve_upper(&self, mut b: ArrayBase) -> Result { - let n = self.square_size()?; - let layout = self.layout()?; - let a = self.as_slice_memory_order().unwrap(); - ImplSolve::solve_triangle(layout, - 'U' as u8, - n, - a, - b.as_slice_memory_order_mut().unwrap(), - 1)?; - Ok(b) - } - fn solve_lower(&self, mut b: ArrayBase) -> Result { - let n = self.square_size()?; - let layout = self.layout()?; - let a = self.as_slice_memory_order().unwrap(); - ImplSolve::solve_triangle(layout, - 'L' as u8, - n, - a, - b.as_slice_memory_order_mut().unwrap(), - 1)?; + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase) -> Result { + self.solve_triangular(uplo, diag, &mut b)?; Ok(b) } } -impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase> for ArrayBase - where A: MFloat, - S1: Data, - S2: Data, - ArrayBase: Matrix + SquareMatrix +impl<'a, A, Si, So> SolveTriangular<&'a mut ArrayBase> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned { - type Output = Array; - fn solve_upper(&self, bs: &ArrayBase) -> Result { - let mut xs = Vec::new(); - for b in bs.axis_iter(Axis(1)) { - let x = self.solve_upper(b.to_owned())?; - xs.push(x); - } - hstack(&xs).map_err(|e| e.into()) - } - fn solve_lower(&self, bs: &ArrayBase) -> Result { - let mut xs = Vec::new(); - for b in bs.axis_iter(Axis(1)) { - let x = self.solve_lower(b.to_owned())?; - xs.push(x); + type Output = &'a mut ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: &'a mut ArrayBase) -> Result { + let la = self.layout()?; + let a_ = self.as_allocated()?; + let lb = b.layout()?; + if !la.same_order(&lb) { + data_transpose(b)?; } - hstack(&xs).map_err(|e| e.into()) + let lb = b.layout()?; + A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; + Ok(b) } } -impl SolveTriangular> for RcArray { - type Output = RcArray; - fn solve_upper(&self, b: RcArray) -> Result { - // XXX unnecessary clone - let x = self.to_owned().solve_upper(&b)?; - Ok(x.into_shared()) - } - fn solve_lower(&self, b: RcArray) -> Result { - // XXX unnecessary clone - let x = self.to_owned().solve_lower(&b)?; - Ok(x.into_shared()) +impl<'a, A, Si, So> SolveTriangular<&'a ArrayBase> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned +{ + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { + let b = replicate(b); + self.solve_triangular(uplo, diag, b) } } -pub fn drop_upper(mut a: ArrayBase) -> ArrayBase - where S: DataMut +impl SolveTriangular> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned { - for ((i, j), val) in a.indexed_iter_mut() { - if i < j { - *val = A::zero(); - } + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: ArrayBase) -> Result { + let b = into_col_vec(b); + let b = self.solve_triangular(uplo, diag, b)?; + Ok(into_vec(b)) } - a } -pub fn drop_lower(mut a: ArrayBase) -> ArrayBase - where S: DataMut +impl<'a, A, Si, So> SolveTriangular<&'a ArrayBase> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned { - for ((i, j), val) in a.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { + let b = replicate(b); + self.solve_triangular(uplo, diag, b) } - a } pub trait IntoTriangular { @@ -150,3 +126,15 @@ impl IntoTriangular> for ArrayBase self } } + +pub fn drop_upper(a: ArrayBase) -> ArrayBase + where S: DataMut +{ + a.into_triangular(UPLO::Lower) +} + +pub fn drop_lower(a: ArrayBase) -> ArrayBase + where S: DataMut +{ + a.into_triangular(UPLO::Upper) +} diff --git a/tests/layout.rs b/tests/layout.rs new file mode 100644 index 00000000..f0ff8d91 --- /dev/null +++ b/tests/layout.rs @@ -0,0 +1,35 @@ + +extern crate ndarray; +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::prelude::*; +use ndarray_linalg::layout::Layout; + +#[test] +fn layout_c_3x1() { + let a: Array2 = Array::zeros((3, 1)); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::C((3, 1))); +} + +#[test] +fn layout_f_3x1() { + let a: Array2 = Array::zeros((3, 1).f()); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::F((1, 3))); +} + +#[test] +fn layout_c_3x2() { + let a: Array2 = Array::zeros((3, 2)); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::C((3, 2))); +} + +#[test] +fn layout_f_3x2() { + let a: Array2 = Array::zeros((3, 2).f()); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::F((2, 3))); +} diff --git a/tests/lu.rs b/tests/lu.rs deleted file mode 100644 index e62c47cd..00000000 --- a/tests/lu.rs +++ /dev/null @@ -1,32 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test { - ($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => { -#[test] -fn $funcname() { - use ndarray_linalg::prelude::*; - let a = $random($n, $m, $t); - let ans = a.clone(); - let (p, l, u) = a.lu().unwrap(); - println!("P = \n{:?}", &p); - println!("L = \n{:?}", &l); - println!("U = \n{:?}", &u); - println!("LU = \n{:?}", l.dot(&u)); - assert_close_l2!(&l.dot(&u).permutated(&p), &ans, 1e-7); -} -}} // impl_test - -macro_rules! impl_test_lu { - ($modname:ident, $random:path) => { -mod $modname { - impl_test!(lu_square, $random, 3, 3, false); - impl_test!(lu_square_t, $random, 3, 3, true); - impl_test!(lu_3x4, $random, 3, 4, false); - impl_test!(lu_3x4_t, $random, 3, 4, true); - impl_test!(lu_4x3, $random, 4, 3, false); - impl_test!(lu_4x3_t, $random, 4, 3, true); -} -}} // impl_test_lu - -impl_test_lu!(owned, super::random_owned); -impl_test_lu!(shared, super::random_shared); diff --git a/tests/permutate.rs b/tests/permutate.rs deleted file mode 100644 index d6d64e00..00000000 --- a/tests/permutate.rs +++ /dev/null @@ -1,49 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test { - ($testname:ident, $permutate:expr, $input:expr, $answer:expr) => { -#[test] -fn $testname() { - use ndarray_linalg::prelude::*; - let a = $input; - println!("a= \n{:?}", &a); - let p = $permutate; // replace 1-2 - let pa = a.permutated(&p); - println!("permutated = \n{:?}", &pa); - assert_close_l2!(&pa, &$answer, 1e-7); -} -}} // impl_test - -macro_rules! impl_test_permuate { - ($modname:ident, $array:path) => { -mod $modname { - use ndarray; - impl_test!(permutate, - vec![2, 2, 3], - $array(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), - $array(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]])); - impl_test!(permutate_t, - vec![2, 2, 3], - $array(&[[1., 4., 7.], [2., 5., 8.], [3., 6., 9.]]).reversed_axes(), - $array(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]])); - impl_test!(permutate_3x4, - vec![1, 3, 3], - $array(&[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]]), - $array(&[[1., 4., 7., 10.], [3., 6., 9., 12.], [2., 5., 8., 11.]])); - impl_test!(permutate_3x4_t, - vec![1, 3, 3], - $array(&[[1., 5., 9.], [2., 6., 10.], [3., 7., 11.], [4., 8., 12.]]).reversed_axes(), - $array(&[[1., 2., 3., 4.], [9., 10., 11., 12.], [5., 6., 7., 8.]])); - impl_test!(permutate_4x3, - vec![4, 2, 3, 4], - $array(&[[1., 5., 9.], [2., 6., 10.], [3., 7., 11.], [4., 8., 12.]]), - $array(&[[4., 8., 12.], [2., 6., 10.], [3., 7., 11.], [1., 5., 9.]])); - impl_test!(permutate_4x3_t, - vec![4, 2, 3, 4], - $array(&[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]]).reversed_axes(), - $array(&[[10., 11., 12.], [4., 5., 6.], [7., 8., 9.], [1., 2., 3.]])); -} -}} // impl_test_permuate - -impl_test_permuate!(owned, ndarray::arr2); -impl_test_permuate!(shared, ndarray::rcarr2); diff --git a/tests/triangular.rs b/tests/triangular.rs index 11f16a70..9c9893c7 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -1,120 +1,126 @@ -include!("header.rs"); - -macro_rules! impl_test { - ($modname:ident, $random:path) => { -mod $modname { - use ndarray::prelude::*; - use ndarray_linalg::prelude::*; - use ndarray_rand::RandomExt; - use rand_extra::*; - #[test] - fn solve_upper() { - let r_dist = RealNormal::new(0.0, 1.0); - let a = drop_lower($random((3, 3), r_dist)); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_upper(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - - #[test] - fn solve_upper_t() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_lower($random((3, 3), r_dist).reversed_axes()); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_upper(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - - #[test] - fn solve_lower() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_upper($random((3, 3), r_dist)); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_lower(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - - #[test] - fn solve_lower_t() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_upper($random((3, 3), r_dist).reversed_axes()); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_lower(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } + +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::prelude::*; + +fn test1d(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) + where A: Field + Absolute, + Sa: Data, + Sb: DataMut + DataOwned, + Tol: RealField +{ + println!("a = {:?}", &a); + println!("b = {:?}", &b); + let x = a.solve_triangular(uplo, Diag::NonUnit, &b).unwrap(); + println!("x = {:?}", &x); + let b_ = a.dot(&x); + println!("Ax = {:?}", &b_); + assert_close_l2!(&b_, &b, tol); +} + +fn test2d(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) + where A: Field + Absolute, + Sa: Data, + Sb: DataMut + DataOwned + DataClone, + Tol: RealField +{ + println!("a = {:?}", &a); + println!("b = {:?}", &b); + let ans = b.clone(); + let x = a.solve_triangular(uplo, Diag::NonUnit, b).unwrap(); + println!("x = {:?}", &x); + let b_ = a.dot(&x); + println!("Ax = {:?}", &b_); + assert_close_l2!(&b_, &ans, tol); +} + +#[test] +fn triangular_1d_upper() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + test1d(UPLO::Upper, a, b, 1e-7); } -}} // impl_test_opnorm - -impl_test!(owned, Array::random); -impl_test!(shared, RcArray::random); - -macro_rules! impl_test_2d { - ($modname:ident, $drop:path, $solve:ident) => { -mod $modname { - use super::random_owned; - use ndarray_linalg::prelude::*; - #[test] - fn solve_tt() { - let a = $drop(random_owned(3, 3, true)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, true); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_tf() { - let a = $drop(random_owned(3, 3, true)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, false); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_ft() { - let a = $drop(random_owned(3, 3, false)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, true); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_ff() { - let a = $drop(random_owned(3, 3, false)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, false); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } + +#[test] +fn triangular_1d_lower() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + test1d(UPLO::Lower, a, b, 1e-7); +} + +#[test] +fn triangular_1d_lower_t() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + test1d(UPLO::Upper, a, b, 1e-7); +} + +#[test] +fn triangular_1d_upper_t() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + test1d(UPLO::Lower, a, b, 1e-7); +} + +#[test] +fn triangular_2d_upper() { + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b, 1e-7); } -}} // impl_test_2d -impl_test_2d!(lower2d, drop_upper, solve_lower); -impl_test_2d!(upper2d, drop_lower, solve_upper); +#[test] +fn triangular_2d_lower() { + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b, 1e-7); +} + +#[test] +fn triangular_2d_lower_t() { + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); + test2d(UPLO::Upper, a, b, 1e-7); +} + +#[test] +fn triangular_2d_upper_t() { + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); + test2d(UPLO::Lower, a, b, 1e-7); +} + +#[test] +fn triangular_2d_upper_bt() { + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b, 1e-7); +} + +#[test] +fn triangular_2d_lower_bt() { + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b, 1e-7); +} + +#[test] +fn triangular_2d_lower_t_bt() { + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); + test2d(UPLO::Upper, a, b, 1e-7); +} + +#[test] +fn triangular_2d_upper_t_bt() { + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); + test2d(UPLO::Lower, a, b, 1e-7); +}