From c116ea1addc410dbc4c241fb8d1d6a2ad3c9603f Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 13 Jun 2017 17:36:47 +0900 Subject: [PATCH 01/18] Add impl2/triangular.rs --- src/impl2/mod.rs | 9 +++++++ src/impl2/solve.rs | 10 +------- src/impl2/triangular.rs | 52 +++++++++++++++++++++++++++++++++++++++++ src/layout.rs | 8 +++++++ 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 src/impl2/triangular.rs diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index fd5dba5e..d8682d4e 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -5,6 +5,7 @@ pub mod svd; pub mod solve; pub mod cholesky; pub mod eigh; +pub mod triangular; pub use self::opnorm::*; pub use self::qr::*; @@ -36,3 +37,11 @@ pub enum UPLO { Upper = b'U', Lower = b'L', } + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum Transpose { + No = b'N', + Transpose = b'T', + Hermite = b'C', +} diff --git a/src/impl2/solve.rs b/src/impl2/solve.rs index d024786d..8a2df165 100644 --- a/src/impl2/solve.rs +++ b/src/impl2/solve.rs @@ -5,18 +5,10 @@ use types::*; use error::*; use layout::Layout; -use super::into_result; +use super::{Transpose, into_result}; pub type Pivot = Vec; -#[derive(Debug, Clone, Copy)] -#[repr(u8)] -pub enum Transpose { - No = b'N', - Transpose = b'T', - Hermite = b'C', -} - pub trait Solve_: Sized { fn lu(Layout, a: &mut [Self]) -> Result; fn inv(Layout, a: &mut [Self], &Pivot) -> Result<()>; diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs new file mode 100644 index 00000000..6f5a881d --- /dev/null +++ b/src/impl2/triangular.rs @@ -0,0 +1,52 @@ +//! Implement linear solver and inverse matrix + +use lapack::c; + +use error::*; +use layout::Layout; +use super::{UPLO, Transpose, into_result}; + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum TriangularDiag { + Unit = b'U', + NonUnit = b'N', +} + +pub trait Triangular_: Sized { + fn inv_triangular(l: Layout, UPLO, TriangularDiag, a: &mut [Self]) -> Result<()>; + fn solve_triangular(al: Layout, bl: Layout, UPLO, TriangularDiag, a: &[Self], b: &mut [Self]) -> Result<()>; +} + +impl Triangular_ for f64 { + fn inv_triangular(l: Layout, uplo: UPLO, diag: TriangularDiag, a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let lda = l.lda(); + let info = c::dtrtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); + into_result(info, ()) + } + + fn solve_triangular(al: Layout, + bl: Layout, + uplo: UPLO, + diag: TriangularDiag, + a: &[Self], + mut b: &mut [Self]) + -> Result<()> { + let (n, _) = al.size(); + let lda = al.lda(); + let nrhs = bl.len(); + let ldb = bl.lda(); + let info = c::dtrtrs(al.lapacke_layout(), + uplo as u8, + Transpose::No as u8, + diag as u8, + n, + nrhs, + a, + lda, + &mut b, + ldb); + into_result(info, ()) + } +} diff --git a/src/layout.rs b/src/layout.rs index 8ea71138..49d761ce 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -5,6 +5,7 @@ use lapack::c; use super::error::*; pub type LDA = i32; +pub type LEN = i32; pub type Col = i32; pub type Row = i32; @@ -36,6 +37,13 @@ impl Layout { } } + pub fn len(&self) -> LEN { + match *self { + Layout::C((row, _)) => row, + Layout::F((col, _)) => col, + } + } + pub fn lapacke_layout(&self) -> c::Layout { match *self { Layout::C(_) => c::Layout::RowMajor, From 16904aeba523d0b16525d8691314ff51b27cf0ca Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 13 Jun 2017 19:17:43 +0900 Subject: [PATCH 02/18] impl AllocatedArrayMut for ArrayBase1 to use triangular --- src/impl2/mod.rs | 4 +- src/impl2/triangular.rs | 16 ++---- src/layout.rs | 41 ++++++++++++--- src/triangular.rs | 113 ++++++++-------------------------------- 4 files changed, 62 insertions(+), 112 deletions(-) diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index d8682d4e..efd80e7c 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -13,6 +13,7 @@ pub use self::svd::*; pub use self::solve::*; pub use self::cholesky::*; pub use self::eigh::*; +pub use self::triangular::*; use super::error::*; @@ -21,7 +22,8 @@ trait_alias!(LapackScalar: OperatorNorm_, SVD_, Solve_, Cholesky_, - Eigh_); + Eigh_, + Triangular_); pub fn into_result(info: i32, val: T) -> Result { if info == 0 { diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs index 6f5a881d..807f6c80 100644 --- a/src/impl2/triangular.rs +++ b/src/impl2/triangular.rs @@ -8,31 +8,25 @@ use super::{UPLO, Transpose, into_result}; #[derive(Debug, Clone, Copy)] #[repr(u8)] -pub enum TriangularDiag { +pub enum Diag { Unit = b'U', NonUnit = b'N', } pub trait Triangular_: Sized { - fn inv_triangular(l: Layout, UPLO, TriangularDiag, a: &mut [Self]) -> Result<()>; - fn solve_triangular(al: Layout, bl: Layout, UPLO, TriangularDiag, a: &[Self], b: &mut [Self]) -> Result<()>; + fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>; + fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; } impl Triangular_ for f64 { - fn inv_triangular(l: Layout, uplo: UPLO, diag: TriangularDiag, a: &mut [Self]) -> Result<()> { + fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let lda = l.lda(); let info = c::dtrtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); into_result(info, ()) } - fn solve_triangular(al: Layout, - bl: Layout, - uplo: UPLO, - diag: TriangularDiag, - a: &[Self], - mut b: &mut [Self]) - -> Result<()> { + fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> { let (n, _) = al.size(); let lda = al.lda(); let nrhs = bl.len(); diff --git a/src/layout.rs b/src/layout.rs index 49d761ce..8ee566f8 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -53,20 +53,20 @@ impl Layout { } pub trait AllocatedArray { - type Scalar; + type Elem; fn layout(&self) -> Result; fn square_layout(&self) -> Result; - fn as_allocated(&self) -> Result<&[Self::Scalar]>; + fn as_allocated(&self) -> Result<&[Self::Elem]>; } pub trait AllocatedArrayMut: AllocatedArray { - fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>; + fn as_allocated_mut(&mut self) -> Result<&mut [Self::Elem]>; } impl AllocatedArray for ArrayBase where S: Data { - type Scalar = A; + type Elem = A; fn layout(&self) -> Result { let strides = self.strides(); @@ -91,8 +91,7 @@ impl AllocatedArray for ArrayBase } fn as_allocated(&self) -> Result<&[A]> { - let slice = self.as_slice_memory_order().ok_or(MemoryContError::new())?; - Ok(slice) + Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?) } } @@ -100,11 +99,37 @@ impl AllocatedArrayMut for ArrayBase where S: DataMut { fn as_allocated_mut(&mut self) -> Result<&mut [A]> { - let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?; - Ok(slice) + Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?) } } +impl AllocatedArray for ArrayBase + where S: Data +{ + type Elem = A; + + fn layout(&self) -> Result { + Ok(Layout::F((self.len() as i32, 1))) + } + + fn square_layout(&self) -> Result { + Err(NotSquareError::new(self.len() as i32, 1).into()) + } + + fn as_allocated(&self) -> Result<&[A]> { + Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?) + } +} + +impl AllocatedArrayMut for ArrayBase + where S: DataMut +{ + fn as_allocated_mut(&mut self) -> Result<&mut [A]> { + Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?) + } +} + + pub fn reconstruct(l: Layout, a: Vec) -> Result> where S: DataOwned { diff --git a/src/triangular.rs b/src/triangular.rs index cae6b64a..2d378d62 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -2,114 +2,43 @@ use ndarray::*; use num_traits::Zero; -use super::impl2::UPLO; -use super::matrix::{Matrix, MFloat}; -use super::square::SquareMatrix; -use super::error::LinalgError; -use super::generate::hstack; -use super::impls::solve::ImplSolve; +use super::layout::*; +use super::error::*; +use super::impl2::*; -pub trait SolveTriangular: Matrix + SquareMatrix { +/// solve a triangular system with upper triangular matrix +pub trait SolveTriangular { type Output; - /// solve a triangular system with upper triangular matrix - fn solve_upper(&self, Rhs) -> Result; - /// solve a triangular system with lower triangular matrix - fn solve_lower(&self, Rhs) -> Result; + fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result; } -impl SolveTriangular> for ArrayBase - where A: MFloat, - S1: Data, - S2: DataMut, - ArrayBase: Matrix + SquareMatrix +impl SolveTriangular for ArrayBase + where A: LapackScalar, + S: Data, + V: AllocatedArrayMut { - type Output = ArrayBase; - fn solve_upper(&self, mut b: ArrayBase) -> Result { - let n = self.square_size()?; - let layout = self.layout()?; - let a = self.as_slice_memory_order().unwrap(); - ImplSolve::solve_triangle(layout, - 'U' as u8, - n, - a, - b.as_slice_memory_order_mut().unwrap(), - 1)?; - Ok(b) - } - fn solve_lower(&self, mut b: ArrayBase) -> Result { - let n = self.square_size()?; - let layout = self.layout()?; - let a = self.as_slice_memory_order().unwrap(); - ImplSolve::solve_triangle(layout, - 'L' as u8, - n, - a, - b.as_slice_memory_order_mut().unwrap(), - 1)?; - Ok(b) - } -} - -impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase> for ArrayBase - where A: MFloat, - S1: Data, - S2: Data, - ArrayBase: Matrix + SquareMatrix -{ - type Output = Array; - fn solve_upper(&self, bs: &ArrayBase) -> Result { - let mut xs = Vec::new(); - for b in bs.axis_iter(Axis(1)) { - let x = self.solve_upper(b.to_owned())?; - xs.push(x); - } - hstack(&xs).map_err(|e| e.into()) - } - fn solve_lower(&self, bs: &ArrayBase) -> Result { - let mut xs = Vec::new(); - for b in bs.axis_iter(Axis(1)) { - let x = self.solve_lower(b.to_owned())?; - xs.push(x); - } - hstack(&xs).map_err(|e| e.into()) - } -} + type Output = V; -impl SolveTriangular> for RcArray { - type Output = RcArray; - fn solve_upper(&self, b: RcArray) -> Result { - // XXX unnecessary clone - let x = self.to_owned().solve_upper(&b)?; - Ok(x.into_shared()) - } - fn solve_lower(&self, b: RcArray) -> Result { - // XXX unnecessary clone - let x = self.to_owned().solve_lower(&b)?; - Ok(x.into_shared()) + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: V) -> Result { + let la = self.layout()?; + let lb = b.layout()?; + let a_ = self.as_allocated()?; + A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; + Ok(b) } } -pub fn drop_upper(mut a: ArrayBase) -> ArrayBase +pub fn drop_upper(a: ArrayBase) -> ArrayBase where S: DataMut { - for ((i, j), val) in a.indexed_iter_mut() { - if i < j { - *val = A::zero(); - } - } - a + a.into_triangular(UPLO::Lower) } -pub fn drop_lower(mut a: ArrayBase) -> ArrayBase +pub fn drop_lower(a: ArrayBase) -> ArrayBase where S: DataMut { - for ((i, j), val) in a.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - a + a.into_triangular(UPLO::Upper) } pub trait IntoTriangular { From e13ca645a7f79db5d8ea6e9ec62ab6dc9e5b8f61 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 14 Jun 2017 18:27:54 +0900 Subject: [PATCH 03/18] Introduce Field, RealField --- src/impl2/triangular.rs | 26 ++++++++++++++------------ src/prelude.rs | 1 + src/triangular.rs | 2 ++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs index 807f6c80..b83a4f05 100644 --- a/src/impl2/triangular.rs +++ b/src/impl2/triangular.rs @@ -3,6 +3,7 @@ use lapack::c; use error::*; +use types::*; use layout::Layout; use super::{UPLO, Transpose, into_result}; @@ -18,11 +19,14 @@ pub trait Triangular_: Sized { fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>; } -impl Triangular_ for f64 { +macro_rules! impl_triangular { + ($scalar:ty, $trtri:path, $trtrs:path) => { + +impl Triangular_ for $scalar { fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> { let (n, _) = l.size(); let lda = l.lda(); - let info = c::dtrtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); + let info = $trtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda); into_result(info, ()) } @@ -31,16 +35,14 @@ impl Triangular_ for f64 { let lda = al.lda(); let nrhs = bl.len(); let ldb = bl.lda(); - let info = c::dtrtrs(al.lapacke_layout(), - uplo as u8, - Transpose::No as u8, - diag as u8, - n, - nrhs, - a, - lda, - &mut b, - ldb); + let info = $trtrs(al.lapacke_layout(), uplo as u8, Transpose::No as u8, diag as u8, n, nrhs, a, lda, &mut b, ldb); into_result(info, ()) } } + +}} // impl_triangular! + +impl_triangular!(f64, c::dtrtri, c::dtrtrs); +impl_triangular!(f32, c::strtri, c::strtrs); +impl_triangular!(c64, c::ztrtri, c::ztrtrs); +impl_triangular!(c32, c::ctrtri, c::ctrtrs); diff --git a/src/prelude.rs b/src/prelude.rs index f34cb921..32623834 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,3 +12,4 @@ pub use opnorm::*; pub use solve::*; pub use eigh::*; pub use cholesky::*; +pub use impl2::LapackScalar; diff --git a/src/triangular.rs b/src/triangular.rs index 2d378d62..813cc1d9 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -7,6 +7,8 @@ use super::layout::*; use super::error::*; use super::impl2::*; +pub use super::impl2::Diag; + /// solve a triangular system with upper triangular matrix pub trait SolveTriangular { type Output; From 0e75875207db6ddcaff4a8aff2d138385501b26b Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 14 Jun 2017 18:28:22 +0900 Subject: [PATCH 04/18] Remove old test --- tests/triangular.rs | 133 ++++++-------------------------------------- 1 file changed, 17 insertions(+), 116 deletions(-) diff --git a/tests/triangular.rs b/tests/triangular.rs index 11f16a70..219acfce 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -1,120 +1,21 @@ -include!("header.rs"); -macro_rules! impl_test { - ($modname:ident, $random:path) => { -mod $modname { - use ndarray::prelude::*; - use ndarray_linalg::prelude::*; - use ndarray_rand::RandomExt; - use rand_extra::*; - #[test] - fn solve_upper() { - let r_dist = RealNormal::new(0.0, 1.0); - let a = drop_lower($random((3, 3), r_dist)); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_upper(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; - #[test] - fn solve_upper_t() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_lower($random((3, 3), r_dist).reversed_axes()); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_upper(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } +use ndarray::*; +use ndarray_linalg::prelude::*; - #[test] - fn solve_lower() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_upper($random((3, 3), r_dist)); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_lower(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - - #[test] - fn solve_lower_t() { - let r_dist = RealNormal::new(0., 1.); - let a = drop_upper($random((3, 3), r_dist).reversed_axes()); - println!("a = \n{:?}", &a); - let b = $random(3, r_dist); - println!("b = \n{:?}", &b); - let x = a.solve_lower(b.clone()).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } +fn test1d(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) + where A: Field + Absolute, + Sa: Data, + Sb: DataMut + DataClone, + Tol: RealField +{ + println!("a = {:?}", &a); + println!("b = {:?}", &b); + let ans = b.clone(); + let x = a.solve_triangular(uplo, Diag::NonUnit, b).unwrap(); + let b_ = a.dot(&x); + assert_close_l2!(&b_, &ans, tol); } -}} // impl_test_opnorm - -impl_test!(owned, Array::random); -impl_test!(shared, RcArray::random); - -macro_rules! impl_test_2d { - ($modname:ident, $drop:path, $solve:ident) => { -mod $modname { - use super::random_owned; - use ndarray_linalg::prelude::*; - #[test] - fn solve_tt() { - let a = $drop(random_owned(3, 3, true)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, true); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_tf() { - let a = $drop(random_owned(3, 3, true)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, false); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_ft() { - let a = $drop(random_owned(3, 3, false)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, true); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } - #[test] - fn solve_ff() { - let a = $drop(random_owned(3, 3, false)); - println!("a = \n{:?}", &a); - let b = random_owned(3, 2, false); - println!("b = \n{:?}", &b); - let x = a.$solve(&b).unwrap(); - println!("x = \n{:?}", &x); - println!("Ax = \n{:?}", a.dot(&x)); - assert_close_l2!(&a.dot(&x), &b, 1e-7); - } -} -}} // impl_test_2d - -impl_test_2d!(lower2d, drop_upper, solve_lower); -impl_test_2d!(upper2d, drop_lower, solve_upper); From d246487fcc8b3102890e36e23ee5a13eaf0256c9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 14 Jun 2017 18:55:11 +0900 Subject: [PATCH 05/18] Remove impls --- src/impls/mod.rs | 2 -- src/impls/solve.rs | 87 ---------------------------------------------- src/lib.rs | 1 - 3 files changed, 90 deletions(-) delete mode 100644 src/impls/mod.rs delete mode 100644 src/impls/solve.rs diff --git a/src/impls/mod.rs b/src/impls/mod.rs deleted file mode 100644 index ce926498..00000000 --- a/src/impls/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Implement trait bindings of LAPACK -pub mod solve; diff --git a/src/impls/solve.rs b/src/impls/solve.rs deleted file mode 100644 index 364bc468..00000000 --- a/src/impls/solve.rs +++ /dev/null @@ -1,87 +0,0 @@ -//! Implement linear solver and inverse matrix - -use lapack::c::*; -use std::cmp::min; - -use error::LapackError; - -pub trait ImplSolve: Sized { - /// execute LU decomposition - fn lu(layout: Layout, m: usize, n: usize, a: Vec) -> Result<(Vec, Vec), LapackError>; - /// calc inverse matrix with LU factorized matrix - fn inv(layout: Layout, size: usize, a: Vec, ipiv: &Vec) -> Result, LapackError>; - /// solve linear problem with LU factorized matrix - fn solve(layout: Layout, - size: usize, - a: &Vec, - ipiv: &Vec, - b: Vec) - -> Result, LapackError>; - /// solve triangular linear problem - fn solve_triangle<'a, 'b>(layout: Layout, - uplo: u8, - size: usize, - a: &'a [Self], - b: &'b mut [Self], - nrhs: i32) - -> Result<&'b mut [Self], LapackError>; -} - -macro_rules! impl_solve { - ($scalar:ty, $getrf:path, $getri:path, $getrs:path, $trtrs:path) => { -impl ImplSolve for $scalar { - fn lu(layout: Layout, m: usize, n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { - let m = m as i32; - let n = n as i32; - let k = min(m, n); - let lda = match layout { - Layout::ColumnMajor => m, - Layout::RowMajor => n, - }; - let mut ipiv = vec![0; k as usize]; - let info = $getrf(layout, m, n, &mut a, lda, &mut ipiv); - if info == 0 { - Ok((ipiv, a)) - } else { - Err(From::from(info)) - } - } - fn inv(layout: Layout, size: usize, mut a: Vec, ipiv: &Vec) -> Result, LapackError> { - let n = size as i32; - let lda = n; - let info = $getri(layout, n, &mut a, lda, &ipiv); - if info == 0 { - Ok(a) - } else { - Err(From::from(info)) - } - } - fn solve(layout: Layout, size: usize, a: &Vec, ipiv: &Vec, mut b: Vec) -> Result, LapackError> { - let n = size as i32; - let lda = n; - let info = $getrs(layout, 'N' as u8, n, 1, a, lda, ipiv, &mut b, n); - if info == 0 { - Ok(b) - } else { - Err(From::from(info)) - } - } - fn solve_triangle<'a, 'b>(layout: Layout, uplo: u8, size: usize, a: &'a [Self], mut b: &'b mut [Self], nrhs: i32) -> Result<&'b mut [Self], LapackError> { - let n = size as i32; - let lda = n; - let ldb = match layout { - Layout::ColumnMajor => n, - Layout::RowMajor => 1, - }; - let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, nrhs, a, lda, &mut b, ldb); - if info == 0 { - Ok(b) - } else { - Err(From::from(info)) - } - } -} -}} // end macro_rules - -impl_solve!(f64, dgetrf, dgetri, dgetrs, dtrtrs); -impl_solve!(f32, sgetrf, sgetri, sgetrs, strtrs); diff --git a/src/lib.rs b/src/lib.rs index f654d8a9..72f998bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,7 +47,6 @@ extern crate derive_new; pub mod types; pub mod error; pub mod layout; -pub mod impls; pub mod impl2; pub mod qr; From 76f42370675cb6f5378e3d59f01de8158b2e69b1 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 15 Jun 2017 02:58:27 +0900 Subject: [PATCH 06/18] Drop matrix and square --- src/lib.rs | 3 -- src/matrix.rs | 124 -------------------------------------------- src/prelude.rs | 8 ++- src/square.rs | 55 -------------------- tests/lu.rs | 32 ------------ tests/permutate.rs | 49 ----------------- tests/triangular.rs | 5 ++ 7 files changed, 8 insertions(+), 268 deletions(-) delete mode 100644 src/matrix.rs delete mode 100644 src/square.rs delete mode 100644 tests/lu.rs delete mode 100644 tests/permutate.rs diff --git a/src/lib.rs b/src/lib.rs index 72f998bc..547e5e6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,9 +55,6 @@ pub mod opnorm; pub mod solve; pub mod cholesky; pub mod eigh; - -pub mod matrix; -pub mod square; pub mod triangular; pub mod generate; diff --git a/src/matrix.rs b/src/matrix.rs deleted file mode 100644 index 82a8c665..00000000 --- a/src/matrix.rs +++ /dev/null @@ -1,124 +0,0 @@ -//! Define trait for general matrix - -use std::cmp::min; -use ndarray::*; -use ndarray::DataMut; -use lapack::c::Layout; - -use super::error::{LinalgError, StrideError}; -use super::impls::solve::ImplSolve; - -pub trait MFloat: ImplSolve + NdFloat {} -impl MFloat for A {} - -/// Methods for general matrices -pub trait Matrix: Sized { - type Scalar; - type Vector; - type Permutator; - /// number of (rows, columns) - fn size(&self) -> (usize, usize); - /// Layout (C/Fortran) of matrix - fn layout(&self) -> Result; - /// LU decomposition - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError>; - /// permutate matrix (inplace) - fn permutate(&mut self, p: &Self::Permutator); - /// permutate matrix (outplace) - fn permutated(mut self, p: &Self::Permutator) -> Self { - self.permutate(p); - self - } -} - -fn check_layout(strides: &[Ixs]) -> Result { - if min(strides[0], strides[1]) != 1 { - return Err(StrideError { - s0: strides[0], - s1: strides[1], - });; - } - if strides[0] < strides[1] { - Ok(Layout::ColumnMajor) - } else { - Ok(Layout::RowMajor) - } -} - -fn permutate(mut a: &mut ArrayBase, ipiv: &Vec) - where S: DataMut -{ - let m = a.cols(); - for (i, j_) in ipiv.iter().enumerate().rev() { - let j = (j_ - 1) as usize; - if i == j { - continue; - } - for k in 0..m { - a.swap((i, k), (j, k)); - } - } -} - -impl Matrix for Array { - type Scalar = A; - type Vector = Array; - type Permutator = Vec; - - fn size(&self) -> (usize, usize) { - (self.rows(), self.cols()) - } - fn layout(&self) -> Result { - check_layout(self.strides()) - } - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { - let (n, m) = self.size(); - let k = min(n, m); - let (p, l) = ImplSolve::lu(self.layout()?, n, m, self.clone().into_raw_vec())?; - let mut a = match self.layout()? { - Layout::ColumnMajor => Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes(), - Layout::RowMajor => Array::from_vec(l).into_shape((n, m)).unwrap(), - }; - let mut lm = Array::zeros((n, k)); - for ((i, j), val) in lm.indexed_iter_mut() { - if i > j { - *val = a[(i, j)]; - } else if i == j { - *val = A::one(); - } - } - for ((i, j), val) in a.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - let am = if n > k { - a.slice(s![0..k as isize, ..]).to_owned() - } else { - a - }; - Ok((p, lm, am)) - } - fn permutate(&mut self, ipiv: &Self::Permutator) { - permutate(self, ipiv); - } -} - -impl Matrix for RcArray { - type Scalar = A; - type Vector = RcArray; - type Permutator = Vec; - fn size(&self) -> (usize, usize) { - (self.rows(), self.cols()) - } - fn layout(&self) -> Result { - check_layout(self.strides()) - } - fn lu(self) -> Result<(Self::Permutator, Self, Self), LinalgError> { - let (p, l, u) = self.into_owned().lu()?; - Ok((p, l.into_shared(), u.into_shared())) - } - fn permutate(&mut self, ipiv: &Self::Permutator) { - permutate(self, ipiv); - } -} diff --git a/src/prelude.rs b/src/prelude.rs index 32623834..b551c85e 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,7 +1,4 @@ -pub use matrix::Matrix; -pub use square::SquareMatrix; -pub use triangular::*; -pub use norm::*; + pub use types::*; pub use generate::*; pub use assert::*; @@ -12,4 +9,5 @@ pub use opnorm::*; pub use solve::*; pub use eigh::*; pub use cholesky::*; -pub use impl2::LapackScalar; +pub use triangular::*; +pub use norm::*; diff --git a/src/square.rs b/src/square.rs deleted file mode 100644 index 78a12bda..00000000 --- a/src/square.rs +++ /dev/null @@ -1,55 +0,0 @@ -//! Define trait for Hermite matrices - -use ndarray::{Ix2, Array, RcArray, ArrayBase, Data}; - -use super::matrix::{Matrix, MFloat}; -use super::error::{LinalgError, NotSquareError}; - -/// Methods for square matrices -/// -/// This trait defines method for square matrices, -/// but does not assure that the matrix is square. -/// If not square, `NotSquareError` will be thrown. -pub trait SquareMatrix: Matrix { - /// trace of matrix - fn trace(&self) -> Result; - #[doc(hidden)] - fn check_square(&self) -> Result<(), NotSquareError> { - let (rows, cols) = self.size(); - if rows == cols { - Ok(()) - } else { - Err(NotSquareError { - rows: rows as i32, - cols: cols as i32, - }) - } - } - /// test matrix is square and return its size - fn square_size(&self) -> Result { - self.check_square()?; - let (n, _) = self.size(); - Ok(n) - } -} - -fn trace(a: &ArrayBase) -> A - where S: Data -{ - let n = a.rows(); - (0..n).fold(A::zero(), |sum, i| sum + a[(i, i)]) -} - -impl SquareMatrix for Array { - fn trace(&self) -> Result { - self.check_square()?; - Ok(trace(self)) - } -} - -impl SquareMatrix for RcArray { - fn trace(&self) -> Result { - self.check_square()?; - Ok(trace(self)) - } -} diff --git a/tests/lu.rs b/tests/lu.rs deleted file mode 100644 index e62c47cd..00000000 --- a/tests/lu.rs +++ /dev/null @@ -1,32 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test { - ($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => { -#[test] -fn $funcname() { - use ndarray_linalg::prelude::*; - let a = $random($n, $m, $t); - let ans = a.clone(); - let (p, l, u) = a.lu().unwrap(); - println!("P = \n{:?}", &p); - println!("L = \n{:?}", &l); - println!("U = \n{:?}", &u); - println!("LU = \n{:?}", l.dot(&u)); - assert_close_l2!(&l.dot(&u).permutated(&p), &ans, 1e-7); -} -}} // impl_test - -macro_rules! impl_test_lu { - ($modname:ident, $random:path) => { -mod $modname { - impl_test!(lu_square, $random, 3, 3, false); - impl_test!(lu_square_t, $random, 3, 3, true); - impl_test!(lu_3x4, $random, 3, 4, false); - impl_test!(lu_3x4_t, $random, 3, 4, true); - impl_test!(lu_4x3, $random, 4, 3, false); - impl_test!(lu_4x3_t, $random, 4, 3, true); -} -}} // impl_test_lu - -impl_test_lu!(owned, super::random_owned); -impl_test_lu!(shared, super::random_shared); diff --git a/tests/permutate.rs b/tests/permutate.rs deleted file mode 100644 index d6d64e00..00000000 --- a/tests/permutate.rs +++ /dev/null @@ -1,49 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test { - ($testname:ident, $permutate:expr, $input:expr, $answer:expr) => { -#[test] -fn $testname() { - use ndarray_linalg::prelude::*; - let a = $input; - println!("a= \n{:?}", &a); - let p = $permutate; // replace 1-2 - let pa = a.permutated(&p); - println!("permutated = \n{:?}", &pa); - assert_close_l2!(&pa, &$answer, 1e-7); -} -}} // impl_test - -macro_rules! impl_test_permuate { - ($modname:ident, $array:path) => { -mod $modname { - use ndarray; - impl_test!(permutate, - vec![2, 2, 3], - $array(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]), - $array(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]])); - impl_test!(permutate_t, - vec![2, 2, 3], - $array(&[[1., 4., 7.], [2., 5., 8.], [3., 6., 9.]]).reversed_axes(), - $array(&[[4., 5., 6.], [1., 2., 3.], [7., 8., 9.]])); - impl_test!(permutate_3x4, - vec![1, 3, 3], - $array(&[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]]), - $array(&[[1., 4., 7., 10.], [3., 6., 9., 12.], [2., 5., 8., 11.]])); - impl_test!(permutate_3x4_t, - vec![1, 3, 3], - $array(&[[1., 5., 9.], [2., 6., 10.], [3., 7., 11.], [4., 8., 12.]]).reversed_axes(), - $array(&[[1., 2., 3., 4.], [9., 10., 11., 12.], [5., 6., 7., 8.]])); - impl_test!(permutate_4x3, - vec![4, 2, 3, 4], - $array(&[[1., 5., 9.], [2., 6., 10.], [3., 7., 11.], [4., 8., 12.]]), - $array(&[[4., 8., 12.], [2., 6., 10.], [3., 7., 11.], [1., 5., 9.]])); - impl_test!(permutate_4x3_t, - vec![4, 2, 3, 4], - $array(&[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]]).reversed_axes(), - $array(&[[10., 11., 12.], [4., 5., 6.], [7., 8., 9.], [1., 2., 3.]])); -} -}} // impl_test_permuate - -impl_test_permuate!(owned, ndarray::arr2); -impl_test_permuate!(shared, ndarray::rcarr2); diff --git a/tests/triangular.rs b/tests/triangular.rs index 219acfce..ea48f8b1 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -19,3 +19,8 @@ fn test1d(uplo: UPLO, a: ArrayBase, b: ArrayBase Date: Thu, 15 Jun 2017 03:47:25 +0900 Subject: [PATCH 07/18] Restore trace --- src/lib.rs | 1 + src/prelude.rs | 15 ++++++++------- src/trace.rs | 23 +++++++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 src/trace.rs diff --git a/src/lib.rs b/src/lib.rs index 547e5e6a..981475dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,5 +60,6 @@ pub mod triangular; pub mod generate; pub mod assert; pub mod norm; +pub mod trace; pub mod prelude; diff --git a/src/prelude.rs b/src/prelude.rs index b551c85e..41a8e689 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,13 +1,14 @@ -pub use types::*; -pub use generate::*; pub use assert::*; +pub use generate::*; +pub use types::*; -pub use qr::*; -pub use svd::*; +pub use cholesky::*; +pub use eigh::*; +pub use norm::*; pub use opnorm::*; +pub use qr::*; pub use solve::*; -pub use eigh::*; -pub use cholesky::*; +pub use svd::*; +pub use trace::*; pub use triangular::*; -pub use norm::*; diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 00000000..86995161 --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,23 @@ + +use ndarray::*; + +use super::types::*; +use super::error::*; +use super::layout::*; + +pub trait Trace { + type Output; + fn trace(&self) -> Result; +} + +impl Trace for ArrayBase + where A: Field, + S: Data +{ + type Output = A; + + fn trace(&self) -> Result { + let (n, _) = self.square_layout()?.size(); + Ok((0..n as usize).map(|i| self[(i, i)]).sum()) + } +} From 3abe98f69f539b058a7f8b86cd48df6e0d915e84 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 15 Jun 2017 03:47:56 +0900 Subject: [PATCH 08/18] Update impl of SolveTriangular --- src/generate.rs | 14 ++++++++++-- src/triangular.rs | 58 ++++++++++++++++++++++++++++++++++++----------- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/generate.rs b/src/generate.rs index 54211c9b..d1d22ef0 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -19,8 +19,18 @@ pub fn conjugate(a: &ArrayBase) -> ArrayBase a } +/// Random vector +pub fn random_vector(n: usize) -> ArrayBase + where A: RandNormal, + S: DataOwned +{ + let mut rng = thread_rng(); + let v: Vec = (0..n).map(|_| A::randn(&mut rng)).collect(); + ArrayBase::from_vec(v) +} + /// Random matrix -pub fn random(n: usize, m: usize) -> ArrayBase +pub fn random_matrix(n: usize, m: usize) -> ArrayBase where A: RandNormal, S: DataOwned { @@ -34,7 +44,7 @@ pub fn random_square(n: usize) -> ArrayBase where A: RandNormal, S: DataOwned { - random(n, n) + random_matrix(n, n) } /// Random Hermite matrix diff --git a/src/triangular.rs b/src/triangular.rs index 813cc1d9..8a1bb22c 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -15,14 +15,31 @@ pub trait SolveTriangular { fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result; } -impl SolveTriangular for ArrayBase +impl SolveTriangular> for ArrayBase where A: LapackScalar, - S: Data, - V: AllocatedArrayMut + Si: Data, + So: DataMut, + D: Dimension, + ArrayBase: AllocatedArrayMut { - type Output = V; + type Output = ArrayBase; - fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: V) -> Result { + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase) -> Result { + self.solve_triangular(uplo, diag, &mut b)?; + Ok(b) + } +} + +impl<'a, A, Si, So, D> SolveTriangular<&'a mut ArrayBase> for ArrayBase + where A: LapackScalar, + Si: Data, + So: DataMut, + D: Dimension, + ArrayBase: AllocatedArrayMut +{ + type Output = &'a mut ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: &'a mut ArrayBase) -> Result { let la = self.layout()?; let lb = b.layout()?; let a_ = self.as_allocated()?; @@ -31,16 +48,19 @@ impl SolveTriangular for ArrayBase } } -pub fn drop_upper(a: ArrayBase) -> ArrayBase - where S: DataMut +impl<'a, A, Si, So, D> SolveTriangular<&'a ArrayBase> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned, + D: Dimension, + ArrayBase: AllocatedArrayMut { - a.into_triangular(UPLO::Lower) -} + type Output = ArrayBase; -pub fn drop_lower(a: ArrayBase) -> ArrayBase - where S: DataMut -{ - a.into_triangular(UPLO::Upper) + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { + let b = replicate(b); + self.solve_triangular(uplo, diag, b) + } } pub trait IntoTriangular { @@ -81,3 +101,15 @@ impl IntoTriangular> for ArrayBase self } } + +pub fn drop_upper(a: ArrayBase) -> ArrayBase + where S: DataMut +{ + a.into_triangular(UPLO::Lower) +} + +pub fn drop_lower(a: ArrayBase) -> ArrayBase + where S: DataMut +{ + a.into_triangular(UPLO::Upper) +} From f98ef76ee7bec7fe9695709a828b32e8d4311921 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Mon, 19 Jun 2017 15:49:13 +0900 Subject: [PATCH 09/18] Debuging solve_triangular (not solved) --- src/impl2/triangular.rs | 11 +++- src/layout.rs | 4 +- tests/triangular.rs | 122 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 128 insertions(+), 9 deletions(-) diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs index b83a4f05..49ab92bc 100644 --- a/src/impl2/triangular.rs +++ b/src/impl2/triangular.rs @@ -34,7 +34,16 @@ impl Triangular_ for $scalar { let (n, _) = al.size(); let lda = al.lda(); let nrhs = bl.len(); - let ldb = bl.lda(); + let ldb = match al { + Layout::C(_) => bl.len() as i32, + Layout::F(_) => bl.lda() as i32, + }; + println!("al = {:?}", al); + println!("bl = {:?}", bl); + println!("n = {}", n); + println!("lda = {}", lda); + println!("nrhs = {}", nrhs); + println!("ldb = {}", ldb); let info = $trtrs(al.lapacke_layout(), uplo as u8, Transpose::No as u8, diag as u8, n, nrhs, a, lda, &mut b, ldb); into_result(info, ()) } diff --git a/src/layout.rs b/src/layout.rs index 8ee566f8..9d326e45 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -109,11 +109,11 @@ impl AllocatedArray for ArrayBase type Elem = A; fn layout(&self) -> Result { - Ok(Layout::F((self.len() as i32, 1))) + Ok(Layout::F((1, self.len() as i32))) } fn square_layout(&self) -> Result { - Err(NotSquareError::new(self.len() as i32, 1).into()) + Err(NotSquareError::new(1, self.len() as i32).into()) } fn as_allocated(&self) -> Result<&[A]> { diff --git a/tests/triangular.rs b/tests/triangular.rs index ea48f8b1..827b6f96 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -9,18 +9,128 @@ use ndarray_linalg::prelude::*; fn test1d(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) where A: Field + Absolute, Sa: Data, - Sb: DataMut + DataClone, + Sb: DataMut + DataOwned, Tol: RealField { println!("a = {:?}", &a); println!("b = {:?}", &b); - let ans = b.clone(); - let x = a.solve_triangular(uplo, Diag::NonUnit, b).unwrap(); + let x = a.solve_triangular(uplo, Diag::NonUnit, &b).unwrap(); + println!("x = {:?}", &x); let b_ = a.dot(&x); - assert_close_l2!(&b_, &ans, tol); + println!("Ax = {:?}", &b_); + assert_close_l2!(&b_, &b, tol); +} + +fn test2d(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) + where A: Field + Absolute, + Sa: Data, + Sb: DataMut + DataOwned, + Tol: RealField +{ + println!("a = {:?}", &a); + println!("b = {:?}", &b); + let x = a.solve_triangular(uplo, Diag::NonUnit, &b).unwrap(); + println!("x = {:?}", &x); + let b_ = a.dot(&x); + println!("Ax = {:?}", &b_); + println!("A^Tx = {:?}", a.t().dot(&x)); + println!("Ax^T = {:?}", a.dot(&x.t())); + println!("(Ax^T)^T = {:?}", a.dot(&x.t()).t()); + assert_close_l2!(&b_, &b, tol); +} + +#[test] +fn triangular_1d_upper() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + test1d(UPLO::Upper, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_1d_lower() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + test1d(UPLO::Lower, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_1d_lower_t() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + test1d(UPLO::Upper, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_1d_upper_t() { + let n = 3; + let b: Array1 = random_vector(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + test1d(UPLO::Lower, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_upper() { + let n = 3; + let b: Array2 = random_square(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_lower() { + let n = 3; + let b: Array2 = random_square(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_lower_t() { + let n = 3; + let b: Array2 = random_square(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + test2d(UPLO::Upper, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_upper_t() { + let n = 3; + let b: Array2 = random_square(n); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + test2d(UPLO::Lower, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_upper_bt() { + let n = 3; + let b: Array2 = random_square(n).reversed_axes(); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_lower_bt() { + let n = 3; + let b: Array2 = random_square(n).reversed_axes(); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b.clone(), 1e-7); +} + +#[test] +fn triangular_2d_lower_t_bt() { + let n = 3; + let b: Array2 = random_square(n).reversed_axes(); + let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + test2d(UPLO::Upper, a, b.clone(), 1e-7); } #[test] -fn triangular_rand() { - let a = random_square(n); +fn triangular_2d_upper_t_bt() { + let n = 3; + let b: Array2 = random_square(n).reversed_axes(); + let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + test2d(UPLO::Lower, a, b.clone(), 1e-7); } From 22b3a4bd6b4c4c097a96d5bbc02d211199d4bf6c Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 20 Jun 2017 16:10:07 +0900 Subject: [PATCH 10/18] Fix test not to use clone --- tests/triangular.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/triangular.rs b/tests/triangular.rs index 827b6f96..499211b9 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -24,19 +24,20 @@ fn test1d(uplo: UPLO, a: ArrayBase, b: ArrayBase(uplo: UPLO, a: ArrayBase, b: ArrayBase, tol: Tol) where A: Field + Absolute, Sa: Data, - Sb: DataMut + DataOwned, + Sb: DataMut + DataOwned + DataClone, Tol: RealField { println!("a = {:?}", &a); println!("b = {:?}", &b); - let x = a.solve_triangular(uplo, Diag::NonUnit, &b).unwrap(); + let ans = b.clone(); + let x = a.solve_triangular(uplo, Diag::NonUnit, b).unwrap(); println!("x = {:?}", &x); let b_ = a.dot(&x); println!("Ax = {:?}", &b_); println!("A^Tx = {:?}", a.t().dot(&x)); println!("Ax^T = {:?}", a.dot(&x.t())); println!("(Ax^T)^T = {:?}", a.dot(&x.t()).t()); - assert_close_l2!(&b_, &b, tol); + assert_close_l2!(&b_, &ans, tol); } #[test] @@ -44,7 +45,7 @@ fn triangular_1d_upper() { let n = 3; let b: Array1 = random_vector(n); let a: Array2 = random_square(n).into_triangular(UPLO::Upper); - test1d(UPLO::Upper, a, b.clone(), 1e-7); + test1d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -52,7 +53,7 @@ fn triangular_1d_lower() { let n = 3; let b: Array1 = random_vector(n); let a: Array2 = random_square(n).into_triangular(UPLO::Lower); - test1d(UPLO::Lower, a, b.clone(), 1e-7); + test1d(UPLO::Lower, a, b, 1e-7); } #[test] @@ -60,7 +61,7 @@ fn triangular_1d_lower_t() { let n = 3; let b: Array1 = random_vector(n); let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); - test1d(UPLO::Upper, a, b.clone(), 1e-7); + test1d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -68,7 +69,7 @@ fn triangular_1d_upper_t() { let n = 3; let b: Array1 = random_vector(n); let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); - test1d(UPLO::Lower, a, b.clone(), 1e-7); + test1d(UPLO::Lower, a, b, 1e-7); } #[test] @@ -76,7 +77,7 @@ fn triangular_2d_upper() { let n = 3; let b: Array2 = random_square(n); let a: Array2 = random_square(n).into_triangular(UPLO::Upper); - test2d(UPLO::Upper, a, b.clone(), 1e-7); + test2d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -84,7 +85,7 @@ fn triangular_2d_lower() { let n = 3; let b: Array2 = random_square(n); let a: Array2 = random_square(n).into_triangular(UPLO::Lower); - test2d(UPLO::Lower, a, b.clone(), 1e-7); + test2d(UPLO::Lower, a, b, 1e-7); } #[test] @@ -92,7 +93,7 @@ fn triangular_2d_lower_t() { let n = 3; let b: Array2 = random_square(n); let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); - test2d(UPLO::Upper, a, b.clone(), 1e-7); + test2d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -100,7 +101,7 @@ fn triangular_2d_upper_t() { let n = 3; let b: Array2 = random_square(n); let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); - test2d(UPLO::Lower, a, b.clone(), 1e-7); + test2d(UPLO::Lower, a, b, 1e-7); } #[test] @@ -108,7 +109,7 @@ fn triangular_2d_upper_bt() { let n = 3; let b: Array2 = random_square(n).reversed_axes(); let a: Array2 = random_square(n).into_triangular(UPLO::Upper); - test2d(UPLO::Upper, a, b.clone(), 1e-7); + test2d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -116,7 +117,7 @@ fn triangular_2d_lower_bt() { let n = 3; let b: Array2 = random_square(n).reversed_axes(); let a: Array2 = random_square(n).into_triangular(UPLO::Lower); - test2d(UPLO::Lower, a, b.clone(), 1e-7); + test2d(UPLO::Lower, a, b, 1e-7); } #[test] @@ -124,7 +125,7 @@ fn triangular_2d_lower_t_bt() { let n = 3; let b: Array2 = random_square(n).reversed_axes(); let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); - test2d(UPLO::Upper, a, b.clone(), 1e-7); + test2d(UPLO::Upper, a, b, 1e-7); } #[test] @@ -132,5 +133,5 @@ fn triangular_2d_upper_t_bt() { let n = 3; let b: Array2 = random_square(n).reversed_axes(); let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); - test2d(UPLO::Lower, a, b.clone(), 1e-7); + test2d(UPLO::Lower, a, b, 1e-7); } From af0bc7448334262571394ddc32044f12d1b15bdd Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 20 Jun 2017 16:13:17 +0900 Subject: [PATCH 11/18] Add utlities for Layout --- src/layout.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index 9d326e45..c17ca5d1 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -50,6 +50,37 @@ impl Layout { Layout::F(_) => c::Layout::ColumnMajor, } } + + pub fn same_order(&self, other: &Layout) -> bool { + match *self { + Layout::C(_) => { + match *other { + Layout::C(_) => true, + Layout::F(_) => false, + } + } + Layout::F(_) => { + match *other { + Layout::C(_) => false, + Layout::F(_) => true, + } + } + } + } + + pub fn as_shape(&self) -> Shape { + match *self { + Layout::C((row, col)) => (row as usize, col as usize).into_shape(), + Layout::F((col, row)) => (row as usize, col as usize).f().into_shape(), + } + } + + pub fn t(&self) -> Self { + match *self { + Layout::C((row, col)) => Layout::F((col, row)), + Layout::F((col, row)) => Layout::C((row, col)), + } + } } pub trait AllocatedArray { @@ -133,10 +164,14 @@ impl AllocatedArrayMut for ArrayBase pub fn reconstruct(l: Layout, a: Vec) -> Result> where S: DataOwned { - Ok(match l { - Layout::C((row, col)) => ArrayBase::from_shape_vec((row as usize, col as usize), a)?, - Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?, - }) + Ok(ArrayBase::from_shape_vec(l.as_shape(), a)?) +} + +pub fn uninitialized(l: Layout) -> ArrayBase + where A: Copy, + S: DataOwned +{ + unsafe { ArrayBase::uninitialized(l.as_shape()) } } pub fn replicate(a: &ArrayBase) -> ArrayBase @@ -149,3 +184,13 @@ pub fn replicate(a: &ArrayBase) -> ArrayBase b.assign(a); b } + +pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayBase + where A: Copy, + Si: Data, + So: DataOwned + DataMut +{ + let mut b = uninitialized(l); + b.assign(a); + b +} From 2984b2402a90bad36af830773e37c19df03fe6a5 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Tue, 20 Jun 2017 16:19:06 +0900 Subject: [PATCH 12/18] Simplify same_order --- src/layout.rs | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index c17ca5d1..638c1fc2 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -52,20 +52,7 @@ impl Layout { } pub fn same_order(&self, other: &Layout) -> bool { - match *self { - Layout::C(_) => { - match *other { - Layout::C(_) => true, - Layout::F(_) => false, - } - } - Layout::F(_) => { - match *other { - Layout::C(_) => false, - Layout::F(_) => true, - } - } - } + self.lapacke_layout() == other.lapacke_layout() } pub fn as_shape(&self) -> Shape { From 2040369fc2860145142426b9005614be9501c596 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 13:52:21 +0900 Subject: [PATCH 13/18] Drop impl AllocatedArray for Array1 --- src/layout.rs | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/layout.rs b/src/layout.rs index 638c1fc2..08196710 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -121,31 +121,7 @@ impl AllocatedArrayMut for ArrayBase } } -impl AllocatedArray for ArrayBase - where S: Data -{ - type Elem = A; - - fn layout(&self) -> Result { - Ok(Layout::F((1, self.len() as i32))) - } - - fn square_layout(&self) -> Result { - Err(NotSquareError::new(1, self.len() as i32).into()) - } - - fn as_allocated(&self) -> Result<&[A]> { - Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?) - } -} -impl AllocatedArrayMut for ArrayBase - where S: DataMut -{ - fn as_allocated_mut(&mut self) -> Result<&mut [A]> { - Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?) - } -} pub fn reconstruct(l: Layout, a: Vec) -> Result> From 4b36f972e58722a85dcd75c68e2a8bdaf692886a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 13:53:28 +0900 Subject: [PATCH 14/18] Add data_transpose --- src/layout.rs | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/layout.rs b/src/layout.rs index 08196710..861f2725 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -62,7 +62,7 @@ impl Layout { } } - pub fn t(&self) -> Self { + pub fn toggle_order(&self) -> Self { match *self { Layout::C((row, col)) => Layout::F((col, row)), Layout::F((col, row)) => Layout::C((row, col)), @@ -121,8 +121,26 @@ impl AllocatedArrayMut for ArrayBase } } +pub fn into_col_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((n, 1)).unwrap() +} +pub fn into_row_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((1, n)).unwrap() +} +pub fn into_vec(a: ArrayBase) -> ArrayBase + where S: Data +{ + let n = a.len(); + a.into_shape((n)).unwrap() +} pub fn reconstruct(l: Layout, a: Vec) -> Result> where S: DataOwned @@ -157,3 +175,13 @@ pub fn clone_with_layout(l: Layout, a: &ArrayBase) -> ArrayB b.assign(a); b } + +pub fn data_transpose(a: &mut ArrayBase) -> Result<&mut ArrayBase> + where A: Copy, + S: DataOwned + DataMut +{ + let l = a.layout()?.toggle_order(); + let new = clone_with_layout(l, a); + ::std::mem::replace(a, new); + Ok(a) +} From da0d5d1c7ccc892c82679df785d2d7b1841d34ee Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 13:53:53 +0900 Subject: [PATCH 15/18] Revise impl SolveTriangular --- src/triangular.rs | 67 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/src/triangular.rs b/src/triangular.rs index 8a1bb22c..7c8002d4 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -15,49 +15,74 @@ pub trait SolveTriangular { fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result; } -impl SolveTriangular> for ArrayBase - where A: LapackScalar, +impl SolveTriangular> for ArrayBase + where A: LapackScalar + Copy, Si: Data, - So: DataMut, - D: Dimension, - ArrayBase: AllocatedArrayMut + So: DataMut + DataOwned { - type Output = ArrayBase; + type Output = ArrayBase; - fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase) -> Result { + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: ArrayBase) -> Result { self.solve_triangular(uplo, diag, &mut b)?; Ok(b) } } -impl<'a, A, Si, So, D> SolveTriangular<&'a mut ArrayBase> for ArrayBase - where A: LapackScalar, +impl<'a, A, Si, So> SolveTriangular<&'a mut ArrayBase> for ArrayBase + where A: LapackScalar + Copy, Si: Data, - So: DataMut, - D: Dimension, - ArrayBase: AllocatedArrayMut + So: DataMut + DataOwned { - type Output = &'a mut ArrayBase; + type Output = &'a mut ArrayBase; - fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: &'a mut ArrayBase) -> Result { + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: &'a mut ArrayBase) -> Result { let la = self.layout()?; - let lb = b.layout()?; let a_ = self.as_allocated()?; + let lb = b.layout()?; + if !la.same_order(&lb) { + data_transpose(b)?; + } + let lb = b.layout()?; A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; Ok(b) } } -impl<'a, A, Si, So, D> SolveTriangular<&'a ArrayBase> for ArrayBase +impl<'a, A, Si, So> SolveTriangular<&'a ArrayBase> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned +{ + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { + let b = replicate(b); + self.solve_triangular(uplo, diag, b) + } +} + +impl SolveTriangular> for ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataMut + DataOwned +{ + type Output = ArrayBase; + + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: ArrayBase) -> Result { + let b = into_col_vec(b); + let b = self.solve_triangular(uplo, diag, b)?; + Ok(into_vec(b)) + } +} + +impl<'a, A, Si, So> SolveTriangular<&'a ArrayBase> for ArrayBase where A: LapackScalar + Copy, Si: Data, - So: DataMut + DataOwned, - D: Dimension, - ArrayBase: AllocatedArrayMut + So: DataMut + DataOwned { - type Output = ArrayBase; + type Output = ArrayBase; - fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { + fn solve_triangular(&self, uplo: UPLO, diag: Diag, b: &'a ArrayBase) -> Result { let b = replicate(b); self.solve_triangular(uplo, diag, b) } From 0d58102caeb9b6bbf851f592cbc13c71903554a7 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 14:03:45 +0900 Subject: [PATCH 16/18] Update test for non-square b --- tests/triangular.rs | 43 ++++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/tests/triangular.rs b/tests/triangular.rs index 499211b9..9c9893c7 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -34,9 +34,6 @@ fn test2d(uplo: UPLO, a: ArrayBase, b: ArrayBase = random_square(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_lower() { - let n = 3; - let b: Array2 = random_square(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower); test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_lower_t() { - let n = 3; - let b: Array2 = random_square(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_upper_t() { - let n = 3; - let b: Array2 = random_square(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + let b: Array2 = random_matrix(3, 4); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_upper_bt() { - let n = 3; - let b: Array2 = random_square(n).reversed_axes(); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_lower_bt() { - let n = 3; - let b: Array2 = random_square(n).reversed_axes(); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower); test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_lower_t_bt() { - let n = 3; - let b: Array2 = random_square(n).reversed_axes(); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_upper_t_bt() { - let n = 3; - let b: Array2 = random_square(n).reversed_axes(); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + let b: Array2 = random_matrix(4, 3).reversed_axes(); + let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); test2d(UPLO::Lower, a, b, 1e-7); } From f50a6b82ebd6f4d71b8f4c4497a55c9d54070e5d Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 15:49:17 +0900 Subject: [PATCH 17/18] Revise solve_triangular algorithm --- src/impl2/triangular.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/impl2/triangular.rs b/src/impl2/triangular.rs index 49ab92bc..ed37ba8e 100644 --- a/src/impl2/triangular.rs +++ b/src/impl2/triangular.rs @@ -33,11 +33,8 @@ impl Triangular_ for $scalar { fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> { let (n, _) = al.size(); let lda = al.lda(); - let nrhs = bl.len(); - let ldb = match al { - Layout::C(_) => bl.len() as i32, - Layout::F(_) => bl.lda() as i32, - }; + let (_, nrhs) = bl.size(); + let ldb = bl.lda(); println!("al = {:?}", al); println!("bl = {:?}", bl); println!("n = {}", n); From e53a328c613e144a0521523779ad57589a947ac3 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Wed, 21 Jun 2017 16:09:50 +0900 Subject: [PATCH 18/18] Revise layout func --- src/layout.rs | 14 +++++++------- src/prelude.rs | 1 + tests/layout.rs | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 tests/layout.rs diff --git a/src/layout.rs b/src/layout.rs index 861f2725..93783853 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -9,7 +9,7 @@ pub type LEN = i32; pub type Col = i32; pub type Row = i32; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Layout { C((Row, LDA)), F((Col, LDA)), @@ -87,15 +87,15 @@ impl AllocatedArray for ArrayBase type Elem = A; fn layout(&self) -> Result { + let shape = self.shape(); let strides = self.strides(); - if ::std::cmp::min(strides[0], strides[1]) != 1 { - return Err(StrideError::new(strides[0], strides[1]).into()); + if shape[0] == strides[1] as usize { + return Ok(Layout::F((self.cols() as i32, self.rows() as i32))); } - if strides[0] > strides[1] { - Ok(Layout::C((self.rows() as i32, self.cols() as i32))) - } else { - Ok(Layout::F((self.cols() as i32, self.rows() as i32))) + if shape[1] == strides[0] as usize { + return Ok(Layout::C((self.rows() as i32, self.cols() as i32))); } + Err(StrideError::new(strides[0], strides[1]).into()) } fn square_layout(&self) -> Result { diff --git a/src/prelude.rs b/src/prelude.rs index 41a8e689..9f1910d4 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -2,6 +2,7 @@ pub use assert::*; pub use generate::*; pub use types::*; +pub use layout::*; pub use cholesky::*; pub use eigh::*; diff --git a/tests/layout.rs b/tests/layout.rs new file mode 100644 index 00000000..f0ff8d91 --- /dev/null +++ b/tests/layout.rs @@ -0,0 +1,35 @@ + +extern crate ndarray; +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::prelude::*; +use ndarray_linalg::layout::Layout; + +#[test] +fn layout_c_3x1() { + let a: Array2 = Array::zeros((3, 1)); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::C((3, 1))); +} + +#[test] +fn layout_f_3x1() { + let a: Array2 = Array::zeros((3, 1).f()); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::F((1, 3))); +} + +#[test] +fn layout_c_3x2() { + let a: Array2 = Array::zeros((3, 2)); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::C((3, 2))); +} + +#[test] +fn layout_f_3x2() { + let a: Array2 = Array::zeros((3, 2).f()); + println!("a = {:?}", &a); + assert_eq!(a.layout().unwrap(), Layout::F((2, 3))); +}