diff --git a/src/bin/main.rs b/src/bin/main.rs index 7e3895f9..1b63a595 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -7,7 +7,7 @@ use ndarray_linalg::prelude::*; fn main() { let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); - let (e, vecs) = a.clone().eigh().unwrap(); + let (e, vecs): (Array1<_>, Array2<_>) = a.clone().eigh(UPLO::Upper).unwrap(); println!("eigenvalues = \n{:?}", e); println!("V = \n{:?}", vecs); let av = a.dot(&vecs); diff --git a/src/cholesky.rs b/src/cholesky.rs new file mode 100644 index 00000000..2a61f5eb --- /dev/null +++ b/src/cholesky.rs @@ -0,0 +1,46 @@ + +use ndarray::*; +use num_traits::Zero; + +use super::error::*; +use super::layout::*; +use super::triangular::IntoTriangular; + +use impl2::LapackScalar; +pub use impl2::UPLO; + +pub trait Cholesky { + fn cholesky(self, UPLO) -> Result; +} + +impl Cholesky> for ArrayBase + where A: LapackScalar + Zero, + S: DataMut +{ + fn cholesky(mut self, uplo: UPLO) -> Result> { + A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok(self.into_triangular(uplo)) + } +} + +impl<'a, A, S> Cholesky<&'a mut ArrayBase> for &'a mut ArrayBase + where A: LapackScalar + Zero, + S: DataMut +{ + fn cholesky(mut self, uplo: UPLO) -> Result<&'a mut ArrayBase> { + A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok(self.into_triangular(uplo)) + } +} + +impl<'a, A, Si, So> Cholesky> for &'a ArrayBase + where A: LapackScalar + Copy + Zero, + Si: Data, + So: DataMut + DataOwned +{ + fn cholesky(self, uplo: UPLO) -> Result> { + let mut a = replicate(self); + A::cholesky(a.square_layout()?, uplo, a.as_allocated_mut()?)?; + Ok(a.into_triangular(uplo)) + } +} diff --git a/src/eigh.rs b/src/eigh.rs new file mode 100644 index 00000000..869bca2a --- /dev/null +++ b/src/eigh.rs @@ -0,0 +1,85 @@ + +use ndarray::*; + +use super::error::*; +use super::layout::*; + +use impl2::LapackScalar; +pub use impl2::UPLO; + +pub trait Eigh { + fn eigh(self, UPLO) -> Result<(EigVal, EigVec)>; +} + +impl Eigh, ArrayBase> for ArrayBase + where A: LapackScalar, + S: DataMut, + Se: DataOwned +{ + fn eigh(mut self, uplo: UPLO) -> Result<(ArrayBase, ArrayBase)> { + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok((ArrayBase::from_vec(s), self)) + } +} + +impl<'a, A, S, Se, So> Eigh, ArrayBase> for &'a ArrayBase + where A: LapackScalar + Copy, + S: Data, + Se: DataOwned, + So: DataOwned + DataMut +{ + fn eigh(self, uplo: UPLO) -> Result<(ArrayBase, ArrayBase)> { + let mut a = replicate(self); + let s = A::eigh(true, a.square_layout()?, uplo, a.as_allocated_mut()?)?; + Ok((ArrayBase::from_vec(s), a)) + } +} + +impl<'a, A, S, Se> Eigh, &'a mut ArrayBase> for &'a mut ArrayBase + where A: LapackScalar, + S: DataMut, + Se: DataOwned +{ + fn eigh(mut self, uplo: UPLO) -> Result<(ArrayBase, &'a mut ArrayBase)> { + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok((ArrayBase::from_vec(s), self)) + } +} + +pub trait EigValsh { + fn eigvalsh(self, UPLO) -> Result; +} + +impl EigValsh> for ArrayBase + where A: LapackScalar, + S: DataMut, + Se: DataOwned +{ + fn eigvalsh(mut self, uplo: UPLO) -> Result> { + let s = A::eigh(false, self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok(ArrayBase::from_vec(s)) + } +} + +impl<'a, A, S, Se> EigValsh> for &'a ArrayBase + where A: LapackScalar + Copy, + S: Data, + Se: DataOwned +{ + fn eigvalsh(self, uplo: UPLO) -> Result> { + let mut a = self.to_owned(); + let s = A::eigh(false, a.square_layout()?, uplo, a.as_allocated_mut()?)?; + Ok(ArrayBase::from_vec(s)) + } +} + +impl<'a, A, S, Se> EigValsh> for &'a mut ArrayBase + where A: LapackScalar, + S: DataMut, + Se: DataOwned +{ + fn eigvalsh(mut self, uplo: UPLO) -> Result> { + let s = A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)?; + Ok(ArrayBase::from_vec(s)) + } +} diff --git a/src/hermite.rs b/src/hermite.rs deleted file mode 100644 index c0ecc996..00000000 --- a/src/hermite.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! Define trait for Hermite matrices - -use ndarray::{Ix2, Array, RcArray}; -use lapack::c::Layout; - -use super::matrix::{Matrix, MFloat}; -use super::square::SquareMatrix; -use super::error::LinalgError; -use super::impls::eigh::ImplEigh; -use super::impls::cholesky::ImplCholesky; - -pub trait HMFloat: ImplEigh + ImplCholesky + MFloat {} -impl HMFloat for A {} - -/// Methods for Hermite matrix -pub trait HermiteMatrix: SquareMatrix + Matrix { - /// eigenvalue decomposition - fn eigh(self) -> Result<(Self::Vector, Self), LinalgError>; - /// symmetric square root of Hermite matrix - fn ssqrt(self) -> Result; - /// Cholesky factorization - fn cholesky(self) -> Result; - /// calc determinant using Cholesky factorization - fn deth(self) -> Result; -} - -impl HermiteMatrix for Array { - fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> { - self.check_square()?; - let layout = self.layout()?; - let (rows, cols) = self.size(); - let (w, a) = ImplEigh::eigh(layout, rows, self.into_raw_vec())?; - let ea = Array::from_vec(w); - let va = match layout { - Layout::ColumnMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap().reversed_axes(), - Layout::RowMajor => Array::from_vec(a).into_shape((rows, cols)).unwrap(), - }; - Ok((ea, va)) - } - fn ssqrt(self) -> Result { - let (n, _) = self.size(); - let (e, v) = self.eigh()?; - let mut res = Array::zeros((n, n)); - for i in 0..n { - for j in 0..n { - res[(i, j)] = e[i].sqrt() * v[(j, i)]; - } - } - Ok(v.dot(&res)) - } - fn cholesky(self) -> Result { - self.check_square()?; - let (n, _) = self.size(); - let layout = self.layout()?; - let a = ImplCholesky::cholesky(layout, n, self.into_raw_vec())?; - let mut c = match layout { - Layout::RowMajor => Array::from_vec(a).into_shape((n, n)).unwrap(), - Layout::ColumnMajor => Array::from_vec(a).into_shape((n, n)).unwrap().reversed_axes(), - }; - for ((i, j), val) in c.indexed_iter_mut() { - if i > j { - *val = A::zero(); - } - } - Ok(c) - } - fn deth(self) -> Result { - let (n, _) = self.size(); - let c = self.cholesky()?; - let rt = (0..n).map(|i| c[(i, i)]).fold(A::one(), |det, c| det * c); - Ok(rt * rt) - } -} - -impl HermiteMatrix for RcArray { - fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> { - let (e, v) = self.into_owned().eigh()?; - Ok((e.into_shared(), v.into_shared())) - } - fn ssqrt(self) -> Result { - let s = self.into_owned().ssqrt()?; - Ok(s.into_shared()) - } - fn cholesky(self) -> Result { - let s = self.into_owned().cholesky()?; - Ok(s.into_shared()) - } - fn deth(self) -> Result { - self.into_owned().deth() - } -} diff --git a/src/impl2/cholesky.rs b/src/impl2/cholesky.rs new file mode 100644 index 00000000..9b0c9ecf --- /dev/null +++ b/src/impl2/cholesky.rs @@ -0,0 +1,29 @@ +//! implement Cholesky decomposition + +use lapack::c; + +use types::*; +use error::*; +use layout::Layout; + +use super::{into_result, UPLO}; + +pub trait Cholesky_: Sized { + fn cholesky(Layout, UPLO, a: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_cholesky { + ($scalar:ty, $potrf:path) => { +impl Cholesky_ for $scalar { + fn cholesky(l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let info = $potrf(l.lapacke_layout(), uplo as u8, n, &mut a, n); + into_result(info, ()) + } +} +}} // end macro_rules + +impl_cholesky!(f64, c::dpotrf); +impl_cholesky!(f32, c::spotrf); +impl_cholesky!(c64, c::zpotrf); +impl_cholesky!(c32, c::cpotrf); diff --git a/src/impl2/eigh.rs b/src/impl2/eigh.rs new file mode 100644 index 00000000..a9d6c610 --- /dev/null +++ b/src/impl2/eigh.rs @@ -0,0 +1,31 @@ + +use lapack::c; +use num_traits::Zero; + +use types::*; +use error::*; +use layout::Layout; + +use super::{into_result, UPLO}; + +pub trait Eigh_: AssociatedReal { + fn eigh(calc_eigenvec: bool, Layout, UPLO, a: &mut [Self]) -> Result>; +} + +macro_rules! impl_eigh { + ($scalar:ty, $ev:path) => { +impl Eigh_ for $scalar { + fn eigh(calc_v: bool, l: Layout, uplo: UPLO, mut a: &mut [Self]) -> Result> { + let (n, _) = l.size(); + let jobz = if calc_v { b'V' } else { b'N' }; + let mut w = vec![Self::Real::zero(); n as usize]; + let info = $ev(l.lapacke_layout(), jobz, uplo as u8, n, &mut a, n, &mut w); + into_result(info, w) + } +} +}} // impl_eigh! + +impl_eigh!(f64, c::dsyev); +impl_eigh!(f32, c::ssyev); +impl_eigh!(c64, c::zheev); +impl_eigh!(c32, c::cheev); diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index a05cf59c..91b0bcf0 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -3,16 +3,20 @@ pub mod opnorm; pub mod qr; pub mod svd; pub mod solve; +pub mod cholesky; +pub mod eigh; pub use self::opnorm::*; pub use self::qr::*; pub use self::svd::*; pub use self::solve::*; +pub use self::cholesky::*; +pub use self::eigh::*; use super::error::*; -pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ {} -impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ {} +pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {} +impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {} pub fn into_result(info: i32, val: T) -> Result { if info == 0 { @@ -21,3 +25,10 @@ pub fn into_result(info: i32, val: T) -> Result { Err(LapackError::new(info).into()) } } + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum UPLO { + Upper = b'U', + Lower = b'L', +} diff --git a/src/impls/cholesky.rs b/src/impls/cholesky.rs deleted file mode 100644 index eb1c8b20..00000000 --- a/src/impls/cholesky.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Implements cholesky decomposition - -use lapack::c::*; -use error::LapackError; - -pub trait ImplCholesky: Sized { - fn cholesky(layout: Layout, n: usize, a: Vec) -> Result, LapackError>; -} - -macro_rules! impl_cholesky { - ($scalar:ty, $potrf:path) => { -impl ImplCholesky for $scalar { - fn cholesky(layout: Layout, n: usize, mut a: Vec) -> Result, LapackError> { - let info = $potrf(layout, b'U', n as i32, &mut a, n as i32); - if info == 0 { - Ok(a) - } else { - Err(From::from(info)) - } - } -} -}} // end macro_rules - -impl_cholesky!(f64, dpotrf); -impl_cholesky!(f32, spotrf); diff --git a/src/impls/eigh.rs b/src/impls/eigh.rs deleted file mode 100644 index c099d097..00000000 --- a/src/impls/eigh.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Implement eigenvalue decomposition of Hermite matrix - -use lapack::c::*; -use num_traits::Zero; - -use error::LapackError; - -pub trait ImplEigh: Sized { - fn eigh(layout: Layout, n: usize, a: Vec) -> Result<(Vec, Vec), LapackError>; -} - -macro_rules! impl_eigh { - ($scalar:ty, $syev:path) => { -impl ImplEigh for $scalar { - fn eigh(layout: Layout, n: usize, mut a: Vec) -> Result<(Vec, Vec), LapackError> { - let mut w = vec![Self::zero(); n]; - let info = $syev(layout, b'V', b'U', n as i32, &mut a, n as i32, &mut w); - if info == 0 { - Ok((w, a)) - } else { - Err(From::from(info)) - } - } -} -}} // end macro_rules - -impl_eigh!(f64, dsyev); -impl_eigh!(f32, ssyev); diff --git a/src/impls/mod.rs b/src/impls/mod.rs index 7b7be4f2..7865d1a0 100644 --- a/src/impls/mod.rs +++ b/src/impls/mod.rs @@ -1,5 +1,3 @@ //! Implement trait bindings of LAPACK pub mod outer; -pub mod eigh; pub mod solve; -pub mod cholesky; diff --git a/src/layout.rs b/src/layout.rs index 6c098115..8ea71138 100644 --- a/src/layout.rs +++ b/src/layout.rs @@ -105,3 +105,14 @@ pub fn reconstruct(l: Layout, a: Vec) -> Result> Layout::F((col, row)) => ArrayBase::from_shape_vec((row as usize, col as usize).f(), a)?, }) } + +pub fn replicate(a: &ArrayBase) -> ArrayBase + where A: Copy, + Sv: Data, + So: DataOwned + DataMut, + D: Dimension +{ + let mut b = unsafe { ArrayBase::uninitialized(a.dim()) }; + b.assign(a); + b +} diff --git a/src/lib.rs b/src/lib.rs index 2de86470..df9e17be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,11 +52,12 @@ pub mod qr; pub mod svd; pub mod opnorm; pub mod solve; +pub mod cholesky; +pub mod eigh; pub mod vector; pub mod matrix; pub mod square; -pub mod hermite; pub mod triangular; pub mod util; diff --git a/src/prelude.rs b/src/prelude.rs index e18421f4..e9899e16 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,7 +1,6 @@ pub use vector::Norm; pub use matrix::Matrix; pub use square::SquareMatrix; -pub use hermite::HermiteMatrix; pub use triangular::*; pub use util::*; pub use assert::*; @@ -10,3 +9,5 @@ pub use qr::*; pub use svd::*; pub use opnorm::*; pub use solve::*; +pub use eigh::*; +pub use cholesky::*; diff --git a/src/solve.rs b/src/solve.rs index 3ac60ad3..0bf85e99 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -56,12 +56,13 @@ impl Factorize for ArrayBase } } -impl<'a, A, S> Factorize> for &'a ArrayBase - where A: LapackScalar + Clone, - S: Data +impl<'a, A, Si, So> Factorize for &'a ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataOwned + DataMut { - fn factorize(self) -> Result>> { - let mut a = self.to_owned(); + fn factorize(self) -> Result> { + let mut a: ArrayBase = replicate(self); let ipiv = A::lu(a.layout()?, a.as_allocated_mut()?)?; Ok(Factorized { a: a, ipiv: ipiv }) } @@ -81,11 +82,12 @@ impl Inverse> for ArrayBase } } -impl<'a, A, S> Inverse> for &'a ArrayBase - where A: LapackScalar + Clone, - S: Data +impl<'a, A, Si, So> Inverse> for &'a ArrayBase + where A: LapackScalar + Copy, + Si: Data, + So: DataOwned + DataMut { - fn inv(self) -> Result> { + fn inv(self) -> Result> { let f = self.factorize()?; f.into_inverse() } diff --git a/src/triangular.rs b/src/triangular.rs index 2852ce2a..99bf7220 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -1,6 +1,9 @@ //! Define methods for triangular matrices use ndarray::*; +use num_traits::Zero; +use super::impl2::UPLO; + use super::matrix::{Matrix, MFloat}; use super::square::SquareMatrix; use super::error::LinalgError; @@ -87,7 +90,7 @@ impl SolveTriangular> for RcArray { } } -pub fn drop_upper(mut a: ArrayBase) -> ArrayBase +pub fn drop_upper(mut a: ArrayBase) -> ArrayBase where S: DataMut { for ((i, j), val) in a.indexed_iter_mut() { @@ -98,7 +101,7 @@ pub fn drop_upper(mut a: ArrayBase) -> ArrayBase a } -pub fn drop_lower(mut a: ArrayBase) -> ArrayBase +pub fn drop_lower(mut a: ArrayBase) -> ArrayBase where S: DataMut { for ((i, j), val) in a.indexed_iter_mut() { @@ -108,3 +111,42 @@ pub fn drop_lower(mut a: ArrayBase) -> ArrayBase } a } + +pub trait IntoTriangular { + fn into_triangular(self, UPLO) -> T; +} + +impl<'a, A, S> IntoTriangular<&'a mut ArrayBase> for &'a mut ArrayBase + where A: Zero, + S: DataMut +{ + fn into_triangular(self, uplo: UPLO) -> &'a mut ArrayBase { + match uplo { + UPLO::Upper => { + for ((i, j), val) in self.indexed_iter_mut() { + if i > j { + *val = A::zero(); + } + } + } + UPLO::Lower => { + for ((i, j), val) in self.indexed_iter_mut() { + if i < j { + *val = A::zero(); + } + } + } + } + self + } +} + +impl IntoTriangular> for ArrayBase + where A: Zero, + S: DataMut +{ + fn into_triangular(mut self, uplo: UPLO) -> ArrayBase { + (&mut self).into_triangular(uplo); + self + } +} diff --git a/tests/cholesky.rs b/tests/cholesky.rs index 35f63c6f..b74fec0b 100644 --- a/tests/cholesky.rs +++ b/tests/cholesky.rs @@ -1,30 +1,37 @@ -include!("header.rs"); -macro_rules! impl_test { - ($modname:ident, $clone:ident) => { -mod $modname { - use super::random_hermite; - use ndarray_linalg::prelude::*; - #[test] - fn cholesky() { - let a = random_hermite(3); - println!("a = \n{:?}", a); - let c = a.$clone().cholesky().unwrap(); - println!("c = \n{:?}", c); - println!("cc = \n{:?}", c.t().dot(&c)); - assert_close_l2!(&c.t().dot(&c), &a, 1e-7); - } - #[test] - fn cholesky_t() { - let a = random_hermite(3); - println!("a = \n{:?}", a); - let c = a.$clone().cholesky().unwrap(); - println!("c = \n{:?}", c); - println!("cc = \n{:?}", c.t().dot(&c)); - assert_close_l2!(&c.t().dot(&c), &a, 1e-7); - } +extern crate rand_extra; +extern crate ndarray; +extern crate ndarray_rand; +#[macro_use] +extern crate ndarray_linalg; + +use rand_extra::*; +use ndarray::*; +use ndarray_rand::RandomExt; +use ndarray_linalg::prelude::*; + +pub fn random_hermite(n: usize) -> Array { + let r_dist = RealNormal::new(0., 1.); + let a = Array::::random((n, n), r_dist); + a.dot(&a.t()) } -}} // impl_test -impl_test!(owned, clone); -impl_test!(shared, to_shared); +#[test] +fn cholesky() { + let a = random_hermite(3); + println!("a = \n{:?}", a); + let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap(); + println!("c = \n{:?}", c); + println!("cc = \n{:?}", c.t().dot(&c)); + assert_close_l2!(&c.t().dot(&c), &a, 1e-7); +} + +#[test] +fn cholesky_t() { + let a = random_hermite(3); + println!("a = \n{:?}", a); + let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap(); + println!("c = \n{:?}", c); + println!("cc = \n{:?}", c.t().dot(&c)); + assert_close_l2!(&c.t().dot(&c), &a, 1e-7); +} diff --git a/tests/det.rs b/tests/det.rs deleted file mode 100644 index 672c1221..00000000 --- a/tests/det.rs +++ /dev/null @@ -1,20 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test{ - ($modname:ident, $clone:ident) => { -mod $modname { - use super::random_hermite; - use ndarray_linalg::prelude::*; - #[test] - fn deth() { - let a = random_hermite(3); - let (e, _) = a.$clone().eigh().unwrap(); - let deth = a.$clone().deth().unwrap(); - let det_eig = e.iter().fold(1.0, |x, y| x * y); - assert_rclose!(deth, det_eig, 1.0e-7); - } -} -}} // impl_test - -impl_test!(owned, clone); -impl_test!(shared, to_shared); diff --git a/tests/eigh.rs b/tests/eigh.rs index e4b317ad..292c1705 100644 --- a/tests/eigh.rs +++ b/tests/eigh.rs @@ -1,32 +1,29 @@ -include!("header.rs"); -macro_rules! impl_test { - ($modname:ident, $clone:ident) => { -mod $modname { - use ndarray::prelude::*; - use ndarray_linalg::prelude::*; - #[test] - fn eigen_vector_manual() { - let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); - let (e, vecs) = a.$clone().eigh().unwrap(); - assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7); - for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { - let av = a.dot(&v); - let ev = v.mapv(|x| e[i] * x); - assert_close_l2!(&av, &ev, 1.0e-7); - } - } - #[test] - fn diagonalize() { - let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); - let (e, vecs) = a.$clone().eigh().unwrap(); - let s = vecs.t().dot(&a).dot(&vecs); - for i in 0..3 { - assert_rclose!(e[i], s[(i, i)], 1e-7); - } +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; + +use ndarray::prelude::*; +use ndarray_linalg::prelude::*; + +#[test] +fn eigen_vector_manual() { + let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); + let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Upper).unwrap(); + assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7); + for (i, v) in vecs.axis_iter(Axis(1)).enumerate() { + let av = a.dot(&v); + let ev = v.mapv(|x| e[i] * x); + assert_close_l2!(&av, &ev, 1.0e-7); } } -}} // impl_test -impl_test!(owned, clone); -impl_test!(shared, to_shared); +#[test] +fn diagonalize() { + let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]); + let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Upper).unwrap(); + let s = vecs.t().dot(&a).dot(&vecs); + for i in 0..3 { + assert_rclose!(e[i], s[(i, i)], 1e-7); + } +} diff --git a/tests/ssqrt.rs b/tests/ssqrt.rs deleted file mode 100644 index dcc46e0f..00000000 --- a/tests/ssqrt.rs +++ /dev/null @@ -1,26 +0,0 @@ -include!("header.rs"); - -macro_rules! impl_test{ - ($modname:ident, $clone:ident) => { -mod $modname { - use super::random_hermite; - use ndarray_linalg::prelude::*; - #[test] - fn ssqrt() { - let a = random_hermite(3); - let ar = a.$clone().ssqrt().unwrap(); - assert_close_l2!(&ar.clone().t(), &ar, 1e-7; "not symmetric"); - assert_close_l2!(&ar.dot(&ar), &a, 1e-7; "not sqrt"); - } - #[test] - fn ssqrt_t() { - let a = random_hermite(3).reversed_axes(); - let ar = a.$clone().ssqrt().unwrap(); - assert_close_l2!(&ar.clone().t(), &ar, 1e-7; "not symmetric"); - assert_close_l2!(&ar.dot(&ar), &a, 1e-7; "not sqrt"); - } -} -}} // impl_test - -impl_test!(owned, clone); -impl_test!(shared, to_shared);