From dd495ff53914a605a856e03a9418d0c990ca26a9 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 26 Apr 2019 01:45:39 +0900 Subject: [PATCH 01/10] Use cauchy 0.2 --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 4e5b92c2..4aef8017 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ serde-1 = ["ndarray/serde-1", "num-complex/serde"] lapacke = "0.2" num-traits = "0.2" rand = "0.6" +cauchy = "0.2" [dependencies.num-complex] version = "0.2" From 32e56e08ccbc83f7db34917ecc1bd9f3423a3456 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 25 Apr 2019 03:29:16 +0900 Subject: [PATCH 02/10] Rename trait LapackScalar -> Lapack --- src/lapack_traits/mod.rs | 10 +++++----- src/types.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lapack_traits/mod.rs b/src/lapack_traits/mod.rs index 693cfc3d..a29c4de8 100644 --- a/src/lapack_traits/mod.rs +++ b/src/lapack_traits/mod.rs @@ -23,12 +23,12 @@ use super::types::*; pub type Pivot = Vec; -pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {} +pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {} -impl LapackScalar for f32 {} -impl LapackScalar for f64 {} -impl LapackScalar for c32 {} -impl LapackScalar for c64 {} +impl Lapack for f32 {} +impl Lapack for f64 {} +impl Lapack for c32 {} +impl Lapack for c64 {} pub fn into_result(return_code: i32, val: T) -> Result { if return_code == 0 { diff --git a/src/types.rs b/src/types.rs index 197a9ab8..bd6c52d9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -9,7 +9,7 @@ use std::fmt::Debug; use std::iter::Sum; use std::ops::Neg; -use super::lapack_traits::LapackScalar; +use super::lapack_traits::Lapack; pub use num_complex::Complex32 as c32; pub use num_complex::Complex64 as c64; @@ -27,7 +27,7 @@ pub use num_complex::Complex64 as c64; /// - [`randn`](trait.RandNormal.html#tymethod.randn) /// pub trait Scalar: - LapackScalar + Lapack + LinalgScalar + AssociatedReal + AssociatedComplex From 5be2ed28818100dd4ac1ac377f29aff0a4180ad6 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Thu, 25 Apr 2019 04:18:30 +0900 Subject: [PATCH 03/10] Use cauchy::Scalar --- src/assert.rs | 6 +- src/cholesky.rs | 42 ++--- src/convert.rs | 4 +- src/eigh.rs | 18 +-- src/generate.rs | 12 +- src/lapack_traits/eigh.rs | 2 +- src/lapack_traits/opnorm.rs | 2 +- src/lapack_traits/solve.rs | 2 +- src/lapack_traits/svd.rs | 4 +- src/lib.rs | 4 +- src/norm.rs | 8 +- src/operator.rs | 4 +- src/opnorm.rs | 4 +- src/qr.rs | 10 +- src/solve.rs | 44 ++--- src/solveh.rs | 60 +++---- src/svd.rs | 6 +- src/triangular.rs | 12 +- src/types.rs | 309 +----------------------------------- 19 files changed, 119 insertions(+), 434 deletions(-) diff --git a/src/assert.rs b/src/assert.rs index 63896a11..27a6ece4 100644 --- a/src/assert.rs +++ b/src/assert.rs @@ -32,7 +32,7 @@ pub fn close_max( atol: A::Real, ) -> Result where - A: Scalar, + A: Scalar + Lapack, S1: Data, S2: Data, D: Dimension, @@ -52,7 +52,7 @@ pub fn close_l1( rtol: A::Real, ) -> Result where - A: Scalar, + A: Scalar + Lapack, S1: Data, S2: Data, D: Dimension, @@ -72,7 +72,7 @@ pub fn close_l2( rtol: A::Real, ) -> Result where - A: Scalar, + A: Scalar + Lapack, S1: Data, S2: Data, D: Dimension, diff --git a/src/cholesky.rs b/src/cholesky.rs index bf7820ce..49cc467a 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -66,7 +66,7 @@ pub struct CholeskyFactorized { impl CholeskyFactorized where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { /// Returns `L` from the Cholesky decomposition `A = L * L^H`. @@ -96,10 +96,10 @@ where impl DeterminantC for CholeskyFactorized where - A: Absolute, + A: Scalar + Lapack, S: Data, { - type Output = ::Real; + type Output = ::Real; fn detc(&self) -> Self::Output { self.ln_detc().exp() @@ -109,17 +109,17 @@ where self.factor .diag() .iter() - .map(|elem| elem.abs_sqr().ln()) + .map(|elem| elem.square().ln()) .sum::() } } impl DeterminantCInto for CholeskyFactorized where - A: Absolute, + A: Scalar + Lapack, S: Data, { - type Output = ::Real; + type Output = ::Real; fn detc_into(self) -> Self::Output { self.detc() @@ -132,7 +132,7 @@ where impl InverseC for CholeskyFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -148,7 +148,7 @@ where impl InverseCInto for CholeskyFactorized where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = ArrayBase; @@ -163,7 +163,7 @@ where impl SolveC for CholeskyFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -226,7 +226,7 @@ pub trait CholeskyInplace { impl Cholesky for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -239,7 +239,7 @@ where impl CholeskyInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = Self; @@ -252,7 +252,7 @@ where impl CholeskyInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn cholesky_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> { @@ -289,7 +289,7 @@ pub trait FactorizeCInto { impl FactorizeCInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn factorizec_into(self, uplo: UPLO) -> Result> { @@ -302,7 +302,7 @@ where impl FactorizeC> for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, { fn factorizec(&self, uplo: UPLO) -> Result>> { @@ -343,7 +343,7 @@ pub trait SolveC { impl SolveC for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solvec_inplace<'a, Sb>(&self, b: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -372,7 +372,7 @@ pub trait InverseCInto { impl InverseC for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -384,7 +384,7 @@ where impl InverseCInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = Self; @@ -430,10 +430,10 @@ pub trait DeterminantCInto { impl DeterminantC for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { - type Output = Result<::Real>; + type Output = Result<::Real>; fn detc(&self) -> Self::Output { Ok(self.ln_detc()?.exp()) @@ -446,10 +446,10 @@ where impl DeterminantCInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { - type Output = Result<::Real>; + type Output = Result<::Real>; fn detc_into(self) -> Self::Output { Ok(self.ln_detc_into()?.exp()) diff --git a/src/convert.rs b/src/convert.rs index 1465aa4e..5f761f53 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -5,7 +5,7 @@ use ndarray::*; use super::error::*; use super::lapack_traits::UPLO; use super::layout::*; -use super::types::Conjugate; +use super::types::*; pub fn into_col(a: ArrayBase) -> ArrayBase where @@ -107,7 +107,7 @@ where /// ***Panics*** if `a` is not square. pub(crate) fn triangular_fill_hermitian(a: &mut ArrayBase, uplo: UPLO) where - A: Conjugate, + A: Scalar + Lapack, S: DataMut, { assert!(a.is_square()); diff --git a/src/eigh.rs b/src/eigh.rs index 2ab8b438..a3ce912f 100644 --- a/src/eigh.rs +++ b/src/eigh.rs @@ -30,7 +30,7 @@ pub trait EighInto: Sized { impl EighInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type EigVal = Array1; @@ -43,7 +43,7 @@ where impl Eigh for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type EigVal = Array1; @@ -57,7 +57,7 @@ where impl EighInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type EigVal = Array1; @@ -88,7 +88,7 @@ pub trait EigValshInplace { impl EigValshInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type EigVal = Array1; @@ -100,7 +100,7 @@ where impl EigValsh for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type EigVal = Array1; @@ -113,7 +113,7 @@ where impl EigValshInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type EigVal = Array1; @@ -132,7 +132,7 @@ pub trait SymmetricSqrt { impl SymmetricSqrt for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -151,14 +151,14 @@ pub trait SymmetricSqrtInto { impl SymmetricSqrtInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut + DataOwned, { type Output = Array2; fn ssqrt_into(self, uplo: UPLO) -> Result { let (e, v) = self.eigh_into(uplo)?; - let e_sqrt = Array1::from_iter(e.iter().map(|r| AssociatedReal::inject(r.sqrt()))); + let e_sqrt = Array1::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt()))); let ev = e_sqrt.into_diagonal().op(&v.t()); Ok(v.op(&ev)) } diff --git a/src/generate.rs b/src/generate.rs index 2f5afe15..bf177386 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -11,13 +11,13 @@ use super::types::*; /// Hermite conjugate matrix pub fn conjugate(a: &ArrayBase) -> ArrayBase where - A: Conjugate, + A: Scalar + Lapack Si: Data, So: DataOwned + DataMut, { let mut a = replicate(&a.t()); for val in a.iter_mut() { - *val = Conjugate::conj(*val); + *val = Scalar::conj(*val); } a } @@ -37,14 +37,14 @@ where /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where - A: RandNormal + Conjugate + Add, + A: RandNormal + Scalar + Add, S: DataOwned + DataMut, { let mut a = random((n, n)); for i in 0..n { - a[(i, i)] = a[(i, i)] + Conjugate::conj(a[(i, i)]); + a[(i, i)] = a[(i, i)] + Scalar::conj(a[(i, i)]); for j in (i + 1)..n { - a[(i, j)] = Conjugate::conj(a[(j, i)]) + a[(i, j)] = Scalar::conj(a[(j, i)]) } } a @@ -56,7 +56,7 @@ where /// pub fn random_hpd(n: usize) -> ArrayBase where - A: RandNormal + Conjugate + LinalgScalar, + A: RandNormal + Scalar + LinalgScalar, S: DataOwned + DataMut, { let a: Array2 = random((n, n)); diff --git a/src/lapack_traits/eigh.rs b/src/lapack_traits/eigh.rs index 7f0beb1f..610564d2 100644 --- a/src/lapack_traits/eigh.rs +++ b/src/lapack_traits/eigh.rs @@ -10,7 +10,7 @@ use crate::types::*; use super::{into_result, UPLO}; /// Wraps `*syev` for real and `*heev` for complex -pub trait Eigh_: AssociatedReal { +pub trait Eigh_: Scalar { unsafe fn eigh(calc_eigenvec: bool, l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result>; } diff --git a/src/lapack_traits/opnorm.rs b/src/lapack_traits/opnorm.rs index 08e1c663..0cdc9640 100644 --- a/src/lapack_traits/opnorm.rs +++ b/src/lapack_traits/opnorm.rs @@ -8,7 +8,7 @@ use crate::types::*; pub use super::NormType; -pub trait OperatorNorm_: AssociatedReal { +pub trait OperatorNorm_: Scalar { unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real; } diff --git a/src/lapack_traits/solve.rs b/src/lapack_traits/solve.rs index 150ab91f..d71836af 100644 --- a/src/lapack_traits/solve.rs +++ b/src/lapack_traits/solve.rs @@ -11,7 +11,7 @@ use super::NormType; use super::{into_result, Pivot, Transpose}; /// Wraps `*getrf`, `*getri`, and `*getrs` -pub trait Solve_: AssociatedReal + Sized { +pub trait Solve_: Scalar + Sized { /// Computes the LU factorization of a general `m x n` matrix `a` using /// partial pivoting with row interchanges. /// diff --git a/src/lapack_traits/svd.rs b/src/lapack_traits/svd.rs index 9c7d9926..8766ce52 100644 --- a/src/lapack_traits/svd.rs +++ b/src/lapack_traits/svd.rs @@ -18,7 +18,7 @@ enum FlagSVD { } /// Result of SVD -pub struct SVDOutput { +pub struct SVDOutput { /// diagonal values pub s: Vec, /// Unitary matrix for destination space @@ -28,7 +28,7 @@ pub struct SVDOutput { } /// Wraps `*gesvd` -pub trait SVD_: AssociatedReal { +pub trait SVD_: Scalar { unsafe fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; } diff --git a/src/lib.rs b/src/lib.rs index 0d8eb73a..3f252a00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,7 @@ pub mod convert; pub mod diagonal; pub mod eigh; pub mod error; -pub mod generate; +// pub mod generate; pub mod lapack_traits; pub mod layout; pub mod norm; @@ -42,7 +42,7 @@ pub use cholesky::*; pub use convert::*; pub use diagonal::*; pub use eigh::*; -pub use generate::*; +// pub use generate::*; pub use layout::*; pub use norm::*; pub use operator::*; diff --git a/src/norm.rs b/src/norm.rs index 15aed934..55259409 100644 --- a/src/norm.rs +++ b/src/norm.rs @@ -24,7 +24,7 @@ pub trait Norm { impl Norm for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, D: Dimension, { @@ -33,7 +33,7 @@ where self.iter().map(|x| x.abs()).sum() } fn norm_l2(&self) -> Self::Output { - self.iter().map(|x| x.abs_sqr()).sum::().sqrt() + self.iter().map(|x| x.square()).sum::().sqrt() } fn norm_max(&self) -> Self::Output { self.iter().fold(A::Real::zero(), |f, &val| { @@ -55,14 +55,14 @@ pub enum NormalizeAxis { /// normalize in L2 norm pub fn normalize(mut m: ArrayBase, axis: NormalizeAxis) -> (ArrayBase, Vec) where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { let mut ms = Vec::new(); for mut v in m.axis_iter_mut(Axis(axis as usize)) { let n = v.norm(); ms.push(n); - v.map_inplace(|x| *x = x.div_real(n)) + v.map_inplace(|x| *x = *x / A::from_real(n)) } (m, ms) } diff --git a/src/operator.rs b/src/operator.rs index 26b28666..2693d7bc 100644 --- a/src/operator.rs +++ b/src/operator.rs @@ -30,7 +30,7 @@ where impl Operator for T where - A: Scalar, + A: Scalar + Lapack, S: Data, D: Dimension, T: linalg::Dot, Output = Array>, @@ -50,7 +50,7 @@ where impl OperatorMulti for T where - A: Scalar, + A: Scalar + Lapack, S: DataMut, D: Dimension + RemoveAxis, for<'a> T: OperatorInplace, D::Smaller>, diff --git a/src/opnorm.rs b/src/opnorm.rs index 57371538..b00aaa9f 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -13,7 +13,7 @@ pub use crate::lapack_traits::NormType; /// [Wikipedia article on operator norm](https://en.wikipedia.org/wiki/Operator_norm) pub trait OperationNorm { /// the value of norm - type Output: RealScalar; + type Output: Scalar; fn opnorm(&self, t: NormType) -> Result; @@ -35,7 +35,7 @@ pub trait OperationNorm { impl OperationNorm for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = A::Real; diff --git a/src/qr.rs b/src/qr.rs index cf470f3b..fa3a749d 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -54,7 +54,7 @@ pub trait QRSquareInplace: Sized { impl QRSquareInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type R = Array2; @@ -69,7 +69,7 @@ where impl QRSquareInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type R = Array2; @@ -82,7 +82,7 @@ where impl QRSquare for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Q = Array2; @@ -96,7 +96,7 @@ where impl QRInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Q = Array2; @@ -116,7 +116,7 @@ where impl QR for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Q = Array2; diff --git a/src/solve.rs b/src/solve.rs index 0de9ff60..06cf2ed2 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -147,7 +147,7 @@ pub struct LUFactorized { impl Solve for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -199,7 +199,7 @@ where impl Solve for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solve_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -241,7 +241,7 @@ pub trait FactorizeInto { impl FactorizeInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn factorize_into(mut self) -> Result> { @@ -252,7 +252,7 @@ where impl Factorize> for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, { fn factorize(&self) -> Result>> { @@ -278,7 +278,7 @@ pub trait InverseInto { impl InverseInto for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = ArrayBase; @@ -291,7 +291,7 @@ where impl Inverse for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -307,7 +307,7 @@ where impl InverseInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = Self; @@ -320,7 +320,7 @@ where impl Inverse for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, { type Output = Array2; @@ -336,7 +336,7 @@ pub trait Determinant { /// Computes the determinant of the matrix. fn det(&self) -> Result { let (sign, ln_det) = self.sln_det()?; - Ok(sign.mul_real(ln_det.exp())) + Ok(sign * A::from_real(ln_det.exp())) } /// Computes the `(sign, natural_log)` of the determinant of the matrix. @@ -361,7 +361,7 @@ pub trait DeterminantInto: Sized { /// Computes the determinant of the matrix. fn det_into(self) -> Result { let (sign, ln_det) = self.sln_det_into()?; - Ok(sign.mul_real(ln_det.exp())) + Ok(sign * A::from_real(ln_det.exp())) } /// Computes the `(sign, natural_log)` of the determinant of the matrix. @@ -383,7 +383,7 @@ pub trait DeterminantInto: Sized { fn lu_sln_det<'a, A, P, U>(ipiv_iter: P, u_diag_iter: U) -> (A, A::Real) where - A: Scalar, + A: Scalar + Lapack, P: Iterator, U: Iterator, { @@ -400,14 +400,14 @@ where }; let (upper_sign, ln_det) = u_diag_iter.fold((A::one(), A::Real::zero()), |(upper_sign, ln_det), &elem| { let abs_elem: A::Real = elem.abs(); - (upper_sign * elem.div_real(abs_elem), ln_det + abs_elem.ln()) + (upper_sign * elem / A::from_real(abs_elem), ln_det + abs_elem.ln()) }); (pivot_sign * upper_sign, ln_det) } impl Determinant for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn sln_det(&self) -> Result<(A, A::Real)> { @@ -418,7 +418,7 @@ where impl DeterminantInto for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn sln_det_into(self) -> Result<(A, A::Real)> { @@ -429,7 +429,7 @@ where impl Determinant for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn sln_det(&self) -> Result<(A, A::Real)> { @@ -438,7 +438,7 @@ where Ok(fac) => fac.sln_det(), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // The determinant is zero. - Ok((A::zero(), A::Real::neg_infinity())) + Ok((A::zero(), A::Real::NEG_INFINITY)) } Err(err) => Err(err), } @@ -447,7 +447,7 @@ where impl DeterminantInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn sln_det_into(self) -> Result<(A, A::Real)> { @@ -456,7 +456,7 @@ where Ok(fac) => fac.sln_det_into(), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // The determinant is zero. - Ok((A::zero(), A::Real::neg_infinity())) + Ok((A::zero(), A::Real::NEG_INFINITY)) } Err(err) => Err(err), } @@ -493,7 +493,7 @@ pub trait ReciprocalConditionNumInto { impl ReciprocalConditionNum for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn rcond(&self) -> Result { @@ -503,7 +503,7 @@ where impl ReciprocalConditionNumInto for LUFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn rcond_into(self) -> Result { @@ -513,7 +513,7 @@ where impl ReciprocalConditionNum for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn rcond(&self) -> Result { @@ -523,7 +523,7 @@ where impl ReciprocalConditionNumInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn rcond_into(self) -> Result { diff --git a/src/solveh.rs b/src/solveh.rs index 8dfeec74..189e7056 100644 --- a/src/solveh.rs +++ b/src/solveh.rs @@ -100,7 +100,7 @@ pub struct BKFactorized { impl SolveH for BKFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -122,7 +122,7 @@ where impl SolveH for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { fn solveh_inplace<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> @@ -152,7 +152,7 @@ pub trait FactorizeHInto { impl FactorizeHInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { fn factorizeh_into(mut self) -> Result> { @@ -163,7 +163,7 @@ where impl FactorizeH> for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, { fn factorizeh(&self) -> Result>> { @@ -189,7 +189,7 @@ pub trait InverseHInto { impl InverseHInto for BKFactorized where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = ArrayBase; @@ -210,7 +210,7 @@ where impl InverseH for BKFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Output = Array2; @@ -226,7 +226,7 @@ where impl InverseHInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Output = Self; @@ -239,7 +239,7 @@ where impl InverseH for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, { type Output = Array2; @@ -256,7 +256,7 @@ pub trait DeterminantH { type Elem: Scalar; /// Computes the determinant of the Hermitian (or real symmetric) matrix. - fn deth(&self) -> Result<::Real>; + fn deth(&self) -> Result<::Real>; /// Computes the `(sign, natural_log)` of the determinant of the Hermitian /// (or real symmetric) matrix. @@ -271,12 +271,7 @@ pub trait DeterminantH { /// This method is more robust than `.deth()` to very small or very large /// determinants since it returns the natural logarithm of the determinant /// rather than the determinant itself. - fn sln_deth( - &self, - ) -> Result<( - ::Real, - ::Real, - )>; + fn sln_deth(&self) -> Result<(::Real, ::Real)>; } /// An interface for calculating determinants of Hermitian (or real symmetric) matrices. @@ -285,7 +280,7 @@ pub trait DeterminantHInto { type Elem: Scalar; /// Computes the determinant of the Hermitian (or real symmetric) matrix. - fn deth_into(self) -> Result<::Real>; + fn deth_into(self) -> Result<::Real>; /// Computes the `(sign, natural_log)` of the determinant of the Hermitian /// (or real symmetric) matrix. @@ -300,12 +295,7 @@ pub trait DeterminantHInto { /// This method is more robust than `.deth_into()` to very small or very /// large determinants since it returns the natural logarithm of the /// determinant rather than the determinant itself. - fn sln_deth_into( - self, - ) -> Result<( - ::Real, - ::Real, - )>; + fn sln_deth_into(self) -> Result<(::Real, ::Real)>; } /// Returns the sign and natural log of the determinant. @@ -313,7 +303,7 @@ fn bk_sln_det(uplo: UPLO, ipiv_iter: P, a: &ArrayBase) -> (A::R where P: Iterator, S: Data, - A: Scalar, + A: Scalar + Lapack, { let mut sign = A::Real::one(); let mut ln_det = A::Real::zero(); @@ -322,20 +312,20 @@ where debug_assert!(k < a.rows() && k < a.cols()); if ipiv_k > 0 { // 1x1 block at k, must be real. - let elem = unsafe { a.uget((k, k)) }.real(); - debug_assert_eq!(elem.imag(), Zero::zero()); + let elem = unsafe { a.uget((k, k)) }.re(); + debug_assert_eq!(elem.im(), Zero::zero()); sign = sign * elem.signum(); ln_det = ln_det + elem.abs().ln(); } else { // 2x2 block at k..k+2. // Upper left diagonal elem, must be real. - let upper_diag = unsafe { a.uget((k, k)) }.real(); - debug_assert_eq!(upper_diag.imag(), Zero::zero()); + let upper_diag = unsafe { a.uget((k, k)) }.re(); + debug_assert_eq!(upper_diag.im(), Zero::zero()); // Lower right diagonal elem, must be real. - let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.real(); - debug_assert_eq!(lower_diag.imag(), Zero::zero()); + let lower_diag = unsafe { a.uget((k + 1, k + 1)) }.re(); + debug_assert_eq!(lower_diag.im(), Zero::zero()); // Off-diagonal elements, can be complex. let off_diag = match uplo { @@ -344,7 +334,7 @@ where }; // Determinant of 2x2 block. - let block_det = upper_diag * lower_diag - off_diag.abs_sqr(); + let block_det = upper_diag * lower_diag - off_diag.square(); sign = sign * block_det.signum(); ln_det = ln_det + block_det.abs().ln(); @@ -357,7 +347,7 @@ where impl BKFactorized where - A: Scalar, + A: Scalar + Lapack, S: Data, { /// Computes the determinant of the factorized Hermitian (or real @@ -411,7 +401,7 @@ where impl DeterminantH for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type Elem = A; @@ -426,7 +416,7 @@ where Ok(fac) => Ok(fac.sln_deth()), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // Determinant is zero. - Ok((A::Real::zero(), A::Real::neg_infinity())) + Ok((A::Real::zero(), A::Real::NEG_INFINITY)) } Err(err) => Err(err), } @@ -435,7 +425,7 @@ where impl DeterminantHInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type Elem = A; @@ -450,7 +440,7 @@ where Ok(fac) => Ok(fac.sln_deth_into()), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // Determinant is zero. - Ok((A::Real::zero(), A::Real::neg_infinity())) + Ok((A::Real::zero(), A::Real::NEG_INFINITY)) } Err(err) => Err(err), } diff --git a/src/svd.rs b/src/svd.rs index cba3a0ea..bbe0804d 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -35,7 +35,7 @@ pub trait SVDInplace { impl SVDInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type U = Array2; @@ -49,7 +49,7 @@ where impl SVD for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: Data, { type U = Array2; @@ -64,7 +64,7 @@ where impl SVDInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, S: DataMut, { type U = Array2; diff --git a/src/triangular.rs b/src/triangular.rs index 1cacd70e..581d3160 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -14,7 +14,7 @@ pub use super::lapack_traits::Diag; /// solve a triangular system with upper triangular matrix pub trait SolveTriangular where - A: Scalar, + A: Scalar + Lapack, S: Data, D: Dimension, { @@ -46,7 +46,7 @@ where impl SolveTriangularInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, So: DataMut + DataOwned, { @@ -58,7 +58,7 @@ where impl SolveTriangularInplace for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, So: DataMut + DataOwned, { @@ -82,7 +82,7 @@ where impl SolveTriangular for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, So: DataMut + DataOwned, { @@ -94,7 +94,7 @@ where impl SolveTriangularInto for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, So: DataMut + DataOwned, { @@ -107,7 +107,7 @@ where impl SolveTriangular for ArrayBase where - A: Scalar, + A: Scalar + Lapack, Si: Data, So: DataMut + DataOwned, { diff --git a/src/types.rs b/src/types.rs index bd6c52d9..e1705a6c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,312 +1,7 @@ //! Basic types and their methods for linear algebra -use ndarray::LinalgScalar; -use num_complex::Complex; -use num_traits::*; -use rand::distributions::*; -use rand::Rng; -use std::fmt::Debug; -use std::iter::Sum; -use std::ops::Neg; - -use super::lapack_traits::Lapack; +pub use super::lapack_traits::Lapack; +pub use cauchy::{RealScalar, Scalar}; pub use num_complex::Complex32 as c32; pub use num_complex::Complex64 as c64; - -/// General Scalar trait. This generalizes complex and real number. -/// -/// You can use the following operations with `A: Scalar`: -/// -/// - [`abs`](trait.Absolute.html#method.abs) -/// - [`abs_sqr`](trait.Absolute.html#tymethod.abs_sqr) -/// - [`sqrt`](trait.SquareRoot.html#tymethod.sqrt) -/// - [`exp`](trait.Exponential.html#tymethod.exp) -/// - [`ln`](trait.NaturalLogarithm.html#tymethod.ln) -/// - [`conj`](trait.Conjugate.html#tymethod.conj) -/// - [`randn`](trait.RandNormal.html#tymethod.randn) -/// -pub trait Scalar: - Lapack - + LinalgScalar - + AssociatedReal - + AssociatedComplex - + Absolute - + SquareRoot - + Exponential - + NaturalLogarithm - + Conjugate - + RandNormal - + Neg - + Debug -{ - fn from_f64(a: f64) -> Self; -} - -impl Scalar for f32 { - fn from_f64(f: f64) -> Self { - f as f32 - } -} - -impl Scalar for f64 { - fn from_f64(f: f64) -> Self { - f - } -} - -impl Scalar for c32 { - fn from_f64(f: f64) -> Self { - Self::new(f as f32, 0.0) - } -} - -impl Scalar for c64 { - fn from_f64(f: f64) -> Self { - Self::new(f, 0.0) - } -} - -pub trait RealScalar: Scalar + Float + Sum {} -impl RealScalar for f32 {} -impl RealScalar for f64 {} - -/// Convert `f64` into `Scalar` -/// -/// ```rust -/// use ndarray_linalg::*; -/// fn mult(a: A) -> A { -/// // a * 2.0 // Error! -/// a * into_scalar(2.0) -/// } -/// ``` -pub fn into_scalar(f: f64) -> T { - T::from_f64(f) -} - -/// Define associating real float type -pub trait AssociatedReal: Sized { - type Real: RealScalar; - fn inject(x: Self::Real) -> Self; - /// Returns the real part of `self`. - fn real(self) -> Self::Real; - /// Returns the imaginary part of `self`. - fn imag(self) -> Self::Real; - fn add_real(self, re: Self::Real) -> Self; - fn sub_real(self, re: Self::Real) -> Self; - fn mul_real(self, re: Self::Real) -> Self; - fn div_real(self, re: Self::Real) -> Self; -} - -/// Define associating complex type -pub trait AssociatedComplex: Sized { - type Complex; - fn inject(self) -> Self::Complex; - fn add_complex(self, c: Self::Complex) -> Self::Complex; - fn sub_complex(self, c: Self::Complex) -> Self::Complex; - fn mul_complex(self, c: Self::Complex) -> Self::Complex; -} - -/// Define `abs()` more generally -pub trait Absolute: AssociatedReal { - fn abs_sqr(&self) -> Self::Real; - fn abs(&self) -> Self::Real { - self.abs_sqr().sqrt() - } -} - -/// Define `sqrt()` more generally -pub trait SquareRoot { - fn sqrt(&self) -> Self; -} - -/// Define `exp()` more generally -pub trait Exponential { - fn exp(&self) -> Self; -} - -/// Define `ln()` more generally -pub trait NaturalLogarithm { - fn ln(&self) -> Self; -} - -/// Complex conjugate value -pub trait Conjugate: Copy { - fn conj(self) -> Self; -} - -/// Scalars which can be initialized from Gaussian random number -pub trait RandNormal { - fn randn(rng: &mut R) -> Self; -} - -macro_rules! impl_traits { - ($real:ty, $complex:ty) => { - impl AssociatedReal for $real { - type Real = $real; - fn inject(r: Self::Real) -> Self { - r - } - fn real(self) -> Self::Real { - self - } - fn imag(self) -> Self::Real { - 0. - } - fn add_real(self, r: Self::Real) -> Self { - self + r - } - fn sub_real(self, r: Self::Real) -> Self { - self - r - } - fn mul_real(self, r: Self::Real) -> Self { - self * r - } - fn div_real(self, r: Self::Real) -> Self { - self / r - } - } - - impl AssociatedReal for $complex { - type Real = $real; - fn inject(r: Self::Real) -> Self { - Self::new(r, 0.0) - } - fn real(self) -> Self::Real { - self.re - } - fn imag(self) -> Self::Real { - self.im - } - fn add_real(self, r: Self::Real) -> Self { - self + r - } - fn sub_real(self, r: Self::Real) -> Self { - self - r - } - fn mul_real(self, r: Self::Real) -> Self { - self * r - } - fn div_real(self, r: Self::Real) -> Self { - self / r - } - } - - impl AssociatedComplex for $real { - type Complex = $complex; - fn inject(self) -> Self::Complex { - Self::Complex::new(self, 0.0) - } - fn add_complex(self, c: Self::Complex) -> Self::Complex { - self + c - } - fn sub_complex(self, c: Self::Complex) -> Self::Complex { - self - c - } - fn mul_complex(self, c: Self::Complex) -> Self::Complex { - self * c - } - } - - impl AssociatedComplex for $complex { - type Complex = $complex; - fn inject(self) -> Self::Complex { - self - } - fn add_complex(self, c: Self::Complex) -> Self::Complex { - self + c - } - fn sub_complex(self, c: Self::Complex) -> Self::Complex { - self - c - } - fn mul_complex(self, c: Self::Complex) -> Self::Complex { - self * c - } - } - - impl Absolute for $real { - fn abs_sqr(&self) -> Self::Real { - *self * *self - } - fn abs(&self) -> Self::Real { - Float::abs(*self) - } - } - - impl Absolute for $complex { - fn abs_sqr(&self) -> Self::Real { - self.norm_sqr() - } - fn abs(&self) -> Self::Real { - self.norm() - } - } - - impl SquareRoot for $real { - fn sqrt(&self) -> Self { - Float::sqrt(*self) - } - } - - impl SquareRoot for $complex { - fn sqrt(&self) -> Self { - Complex::sqrt(self) - } - } - - impl Exponential for $real { - fn exp(&self) -> Self { - Float::exp(*self) - } - } - - impl Exponential for $complex { - fn exp(&self) -> Self { - Complex::exp(self) - } - } - - impl NaturalLogarithm for $real { - fn ln(&self) -> Self { - Float::ln(*self) - } - } - - impl NaturalLogarithm for $complex { - fn ln(&self) -> Self { - Complex::ln(self) - } - } - - impl Conjugate for $real { - fn conj(self) -> Self { - self - } - } - - impl Conjugate for $complex { - fn conj(self) -> Self { - Complex::conj(&self) - } - } - - impl RandNormal for $real { - fn randn(rng: &mut R) -> Self { - let dist = Normal::new(0., 1.); - dist.sample(rng) as $real - } - } - - impl RandNormal for $complex { - fn randn(rng: &mut R) -> Self { - let dist = Normal::new(0., 1.); - let re = dist.sample(rng) as $real; - let im = dist.sample(rng) as $real; - Self::new(re, im) - } - } - }; -} // impl_traits! - -impl_traits!(f64, c64); -impl_traits!(f32, c32); From 1417f39f149db71b4640a110717260f033c2c33a Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 26 Apr 2019 01:14:58 +0900 Subject: [PATCH 04/10] Revise cauchy support --- src/generate.rs | 31 ++++++++++++++++--------------- src/lib.rs | 4 ++-- src/solve.rs | 4 ++-- src/solveh.rs | 4 ++-- src/types.rs | 2 +- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/generate.rs b/src/generate.rs index bf177386..1c8c4e08 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -1,8 +1,7 @@ //! Generator functions for matrices use ndarray::*; -use rand::*; -use std::ops::*; +use rand::{distributions::Standard, prelude::*}; use super::convert::*; use super::error::*; @@ -11,13 +10,13 @@ use super::types::*; /// Hermite conjugate matrix pub fn conjugate(a: &ArrayBase) -> ArrayBase where - A: Scalar + Lapack + A: Scalar, Si: Data, So: DataOwned + DataMut, { - let mut a = replicate(&a.t()); + let mut a: ArrayBase = replicate(&a.t()); for val in a.iter_mut() { - *val = Scalar::conj(*val); + *val = val.conj(); } a } @@ -25,26 +24,27 @@ where /// Generate random array pub fn random(sh: Sh) -> ArrayBase where - A: RandNormal, S: DataOwned, D: Dimension, Sh: ShapeBuilder, + Standard: Distribution, { let mut rng = thread_rng(); - ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng)) + ArrayBase::from_shape_fn(sh, |_| rng.sample(Standard)) } /// Random Hermite matrix pub fn random_hermite(n: usize) -> ArrayBase where - A: RandNormal + Scalar + Add, + A: Scalar, S: DataOwned + DataMut, + Standard: Distribution, { - let mut a = random((n, n)); + let mut a: ArrayBase = random((n, n)); for i in 0..n { - a[(i, i)] = a[(i, i)] + Scalar::conj(a[(i, i)]); + a[(i, i)] = a[(i, i)] + a[(i, i)].conj(); for j in (i + 1)..n { - a[(i, j)] = Scalar::conj(a[(j, i)]) + a[(i, j)] = a[(j, i)].conj(); } } a @@ -56,8 +56,9 @@ where /// pub fn random_hpd(n: usize) -> ArrayBase where - A: RandNormal + Scalar + LinalgScalar, + A: Scalar, S: DataOwned + DataMut, + Standard: Distribution, { let a: Array2 = random((n, n)); let ah: Array2 = conjugate(&a); @@ -67,7 +68,7 @@ where /// construct matrix from diag pub fn from_diag(d: &[A]) -> Array2 where - A: LinalgScalar, + A: Scalar, { let n = d.len(); let mut e = Array::zeros((n, n)); @@ -80,7 +81,7 @@ where /// stack vectors into matrix horizontally pub fn hstack(xs: &[ArrayBase]) -> Result> where - A: LinalgScalar, + A: Scalar, S: Data, { let views: Vec<_> = xs @@ -96,7 +97,7 @@ where /// stack vectors into matrix vertically pub fn vstack(xs: &[ArrayBase]) -> Result> where - A: LinalgScalar, + A: Scalar, S: Data, { let views: Vec<_> = xs diff --git a/src/lib.rs b/src/lib.rs index 3f252a00..0d8eb73a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,7 @@ pub mod convert; pub mod diagonal; pub mod eigh; pub mod error; -// pub mod generate; +pub mod generate; pub mod lapack_traits; pub mod layout; pub mod norm; @@ -42,7 +42,7 @@ pub use cholesky::*; pub use convert::*; pub use diagonal::*; pub use eigh::*; -// pub use generate::*; +pub use generate::*; pub use layout::*; pub use norm::*; pub use operator::*; diff --git a/src/solve.rs b/src/solve.rs index 06cf2ed2..6082d529 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -438,7 +438,7 @@ where Ok(fac) => fac.sln_det(), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // The determinant is zero. - Ok((A::zero(), A::Real::NEG_INFINITY)) + Ok((A::zero(), A::Real::neg_infinity())) } Err(err) => Err(err), } @@ -456,7 +456,7 @@ where Ok(fac) => fac.sln_det_into(), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // The determinant is zero. - Ok((A::zero(), A::Real::NEG_INFINITY)) + Ok((A::zero(), A::Real::neg_infinity())) } Err(err) => Err(err), } diff --git a/src/solveh.rs b/src/solveh.rs index 189e7056..86e369b0 100644 --- a/src/solveh.rs +++ b/src/solveh.rs @@ -416,7 +416,7 @@ where Ok(fac) => Ok(fac.sln_deth()), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // Determinant is zero. - Ok((A::Real::zero(), A::Real::NEG_INFINITY)) + Ok((A::Real::zero(), A::Real::neg_infinity())) } Err(err) => Err(err), } @@ -440,7 +440,7 @@ where Ok(fac) => Ok(fac.sln_deth_into()), Err(LinalgError::Lapack { return_code }) if return_code > 0 => { // Determinant is zero. - Ok((A::Real::zero(), A::Real::NEG_INFINITY)) + Ok((A::Real::zero(), A::Real::neg_infinity())) } Err(err) => Err(err), } diff --git a/src/types.rs b/src/types.rs index e1705a6c..8025ac42 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,7 @@ //! Basic types and their methods for linear algebra pub use super::lapack_traits::Lapack; -pub use cauchy::{RealScalar, Scalar}; +pub use cauchy::Scalar; pub use num_complex::Complex32 as c32; pub use num_complex::Complex64 as c64; From 622000df5cf39ebb9a7d23715206420d3a2f8b0b Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 26 Apr 2019 01:21:17 +0900 Subject: [PATCH 05/10] Use rand feature of num-complex --- Cargo.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4aef8017..b2eed749 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,12 +23,13 @@ serde-1 = ["ndarray/serde-1", "num-complex/serde"] [dependencies] lapacke = "0.2" num-traits = "0.2" -rand = "0.6" +rand = "0.5" cauchy = "0.2" [dependencies.num-complex] -version = "0.2" +version = "0.2.1" default-features = false +features = ["rand"] [dependencies.ndarray] version = "0.12" From be2c794a0e4bf9374fed59afb4fd26e0e70bbe21 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 26 Apr 2019 01:45:46 +0900 Subject: [PATCH 06/10] Support cauchy in tests --- tests/deth.rs | 5 +---- tests/triangular.rs | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/deth.rs b/tests/deth.rs index 1b6a9002..30155864 100644 --- a/tests/deth.rs +++ b/tests/deth.rs @@ -71,10 +71,7 @@ fn deth() { // Compute determinant from eigenvalues. let (sign, ln_det) = a.eigvalsh(UPLO::Upper).unwrap().iter().fold( - ( - <$elem as AssociatedReal>::Real::one(), - <$elem as AssociatedReal>::Real::zero(), - ), + (<$elem as Scalar>::Real::one(), <$elem as Scalar>::Real::zero()), |(sign, ln_det), eigval| (sign * eigval.signum(), ln_det + eigval.abs().ln()), ); let det = sign * ln_det.exp(); diff --git a/tests/triangular.rs b/tests/triangular.rs index 0f563517..21f12c71 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -3,7 +3,7 @@ use ndarray_linalg::*; fn test1d(uplo: UPLO, a: &ArrayBase, b: &ArrayBase, tol: A::Real) where - A: Scalar, + A: Scalar + Lapack, Sa: Data, Sb: DataMut + DataOwned, { @@ -18,7 +18,7 @@ where fn test2d(uplo: UPLO, a: &ArrayBase, b: &ArrayBase, tol: A::Real) where - A: Scalar, + A: Scalar + Lapack, Sa: Data, Sb: DataMut + DataOwned + DataClone, { From 9cfe4fb1e9bd4d0f1f7e94d99f6b077db56d3ac8 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 26 Apr 2019 04:22:06 +0900 Subject: [PATCH 07/10] Rename lapack_traits -> lapack --- src/cholesky.rs | 2 +- src/convert.rs | 2 +- src/{lapack_traits => lapack}/cholesky.rs | 0 src/{lapack_traits => lapack}/eigh.rs | 0 src/{lapack_traits => lapack}/mod.rs | 0 src/{lapack_traits => lapack}/opnorm.rs | 0 src/{lapack_traits => lapack}/qr.rs | 0 src/{lapack_traits => lapack}/solve.rs | 0 src/{lapack_traits => lapack}/solveh.rs | 0 src/{lapack_traits => lapack}/svd.rs | 0 src/{lapack_traits => lapack}/triangular.rs | 0 src/lib.rs | 2 +- src/opnorm.rs | 2 +- src/qr.rs | 2 +- src/solve.rs | 2 +- src/solveh.rs | 2 +- src/triangular.rs | 4 ++-- src/types.rs | 2 +- 18 files changed, 10 insertions(+), 10 deletions(-) rename src/{lapack_traits => lapack}/cholesky.rs (100%) rename src/{lapack_traits => lapack}/eigh.rs (100%) rename src/{lapack_traits => lapack}/mod.rs (100%) rename src/{lapack_traits => lapack}/opnorm.rs (100%) rename src/{lapack_traits => lapack}/qr.rs (100%) rename src/{lapack_traits => lapack}/solve.rs (100%) rename src/{lapack_traits => lapack}/solveh.rs (100%) rename src/{lapack_traits => lapack}/svd.rs (100%) rename src/{lapack_traits => lapack}/triangular.rs (100%) diff --git a/src/cholesky.rs b/src/cholesky.rs index 49cc467a..efc135bc 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -52,7 +52,7 @@ use crate::layout::*; use crate::triangular::IntoTriangular; use crate::types::*; -pub use crate::lapack_traits::UPLO; +pub use crate::lapack::UPLO; /// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix pub struct CholeskyFactorized { diff --git a/src/convert.rs b/src/convert.rs index 5f761f53..1562650f 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -3,7 +3,7 @@ use ndarray::*; use super::error::*; -use super::lapack_traits::UPLO; +use super::lapack::UPLO; use super::layout::*; use super::types::*; diff --git a/src/lapack_traits/cholesky.rs b/src/lapack/cholesky.rs similarity index 100% rename from src/lapack_traits/cholesky.rs rename to src/lapack/cholesky.rs diff --git a/src/lapack_traits/eigh.rs b/src/lapack/eigh.rs similarity index 100% rename from src/lapack_traits/eigh.rs rename to src/lapack/eigh.rs diff --git a/src/lapack_traits/mod.rs b/src/lapack/mod.rs similarity index 100% rename from src/lapack_traits/mod.rs rename to src/lapack/mod.rs diff --git a/src/lapack_traits/opnorm.rs b/src/lapack/opnorm.rs similarity index 100% rename from src/lapack_traits/opnorm.rs rename to src/lapack/opnorm.rs diff --git a/src/lapack_traits/qr.rs b/src/lapack/qr.rs similarity index 100% rename from src/lapack_traits/qr.rs rename to src/lapack/qr.rs diff --git a/src/lapack_traits/solve.rs b/src/lapack/solve.rs similarity index 100% rename from src/lapack_traits/solve.rs rename to src/lapack/solve.rs diff --git a/src/lapack_traits/solveh.rs b/src/lapack/solveh.rs similarity index 100% rename from src/lapack_traits/solveh.rs rename to src/lapack/solveh.rs diff --git a/src/lapack_traits/svd.rs b/src/lapack/svd.rs similarity index 100% rename from src/lapack_traits/svd.rs rename to src/lapack/svd.rs diff --git a/src/lapack_traits/triangular.rs b/src/lapack/triangular.rs similarity index 100% rename from src/lapack_traits/triangular.rs rename to src/lapack/triangular.rs diff --git a/src/lib.rs b/src/lib.rs index 0d8eb73a..b22b32df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ pub mod diagonal; pub mod eigh; pub mod error; pub mod generate; -pub mod lapack_traits; +pub mod lapack; pub mod layout; pub mod norm; pub mod operator; diff --git a/src/opnorm.rs b/src/opnorm.rs index b00aaa9f..37d38115 100644 --- a/src/opnorm.rs +++ b/src/opnorm.rs @@ -6,7 +6,7 @@ use crate::error::*; use crate::layout::*; use crate::types::*; -pub use crate::lapack_traits::NormType; +pub use crate::lapack::NormType; /// Operator norm using `*lange` LAPACK routines /// diff --git a/src/qr.rs b/src/qr.rs index fa3a749d..65443ddc 100644 --- a/src/qr.rs +++ b/src/qr.rs @@ -11,7 +11,7 @@ use crate::layout::*; use crate::triangular::*; use crate::types::*; -pub use crate::lapack_traits::UPLO; +pub use crate::lapack::UPLO; /// QR decomposition for matrix reference /// diff --git a/src/solve.rs b/src/solve.rs index 6082d529..52a58dde 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -55,7 +55,7 @@ use crate::layout::*; use crate::opnorm::OperationNorm; use crate::types::*; -pub use crate::lapack_traits::{Pivot, Transpose}; +pub use crate::lapack::{Pivot, Transpose}; /// An interface for solving systems of linear equations. /// diff --git a/src/solveh.rs b/src/solveh.rs index 86e369b0..a54caa54 100644 --- a/src/solveh.rs +++ b/src/solveh.rs @@ -57,7 +57,7 @@ use crate::error::*; use crate::layout::*; use crate::types::*; -pub use crate::lapack_traits::{Pivot, UPLO}; +pub use crate::lapack::{Pivot, UPLO}; /// An interface for solving systems of Hermitian (or real symmetric) linear equations. /// diff --git a/src/triangular.rs b/src/triangular.rs index 581d3160..4b4dc7bc 100644 --- a/src/triangular.rs +++ b/src/triangular.rs @@ -5,11 +5,11 @@ use num_traits::Zero; use super::convert::*; use super::error::*; -use super::lapack_traits::*; +use super::lapack::*; use super::layout::*; use super::types::*; -pub use super::lapack_traits::Diag; +pub use super::lapack::Diag; /// solve a triangular system with upper triangular matrix pub trait SolveTriangular diff --git a/src/types.rs b/src/types.rs index 8025ac42..d42f6e24 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,6 +1,6 @@ //! Basic types and their methods for linear algebra -pub use super::lapack_traits::Lapack; +pub use super::lapack::Lapack; pub use cauchy::Scalar; pub use num_complex::Complex32 as c32; From 8e7dea5e994e0d52841627a8fe5fe350f7d63c2d Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 20:20:56 +0900 Subject: [PATCH 08/10] Fix random_{unitary,regular} --- src/generate.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/generate.rs b/src/generate.rs index ecc1d0b4..bc2ebebf 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -39,7 +39,8 @@ where /// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. pub fn random_unitary(n: usize) -> Array2 where - A: Scalar + RandNormal, + A: Scalar + Lapack, + Standard: Distribution, { let a: Array2 = random((n, n)); let (q, _r) = a.qr_into().unwrap(); @@ -51,12 +52,13 @@ where /// Be sure that this it **NOT** a uniform distribution. Use it only for test purpose. pub fn random_regular(n: usize) -> Array2 where - A: Scalar + RandNormal, + A: Scalar + Lapack, + Standard: Distribution, { let a: Array2 = random((n, n)); let (q, mut r) = a.qr_into().unwrap(); for i in 0..n { - r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs()); + r[(i, i)] = A::one() + A::from_real(r[(i, i)].abs()); } q.dot(&r) } From c12e7f163e45c4868b394baa9eeafcef05983fa5 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 20:22:51 +0900 Subject: [PATCH 09/10] Fix test/det --- tests/det.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/det.rs b/tests/det.rs index e0ce5e1f..065685bc 100644 --- a/tests/det.rs +++ b/tests/det.rs @@ -100,10 +100,9 @@ fn det_zero_nonsquare() { #[test] fn det() { - fn det_impl(a: Array2, rtol: Tol) + fn det_impl(a: Array2, rtol: A::Real) where - A: Scalar, - Tol: RealScalar, + A: Scalar + Lapack, { let det = det_naive(&a); let sign = det.div_real(det.abs()); From c870f1856cfee6f60651c719ddc0fc341aafd0ea Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Sat, 27 Apr 2019 21:11:32 +0900 Subject: [PATCH 10/10] Add short doc of Lapack trait --- src/lapack/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index a29c4de8..c07dbfa7 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -23,6 +23,7 @@ use super::types::*; pub type Pivot = Vec; +/// Trait for primitive types which implements LAPACK subroutines pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {} impl Lapack for f32 {}