diff --git a/src/assert.rs b/src/assert.rs index 880dba1b..da9d6ffa 100644 --- a/src/assert.rs +++ b/src/assert.rs @@ -1,23 +1,21 @@ //! Assertions for array -use std::iter::Sum; -use num_traits::Float; use ndarray::*; use super::types::*; -use super::vector::*; +use super::norm::*; pub fn rclose(test: A, truth: A, rtol: Tol) -> Result - where A: LinalgScalar + Absolute, - Tol: Float + where A: Field + Absolute, + Tol: RealField { let dev = (test - truth).abs() / truth.abs(); if dev < rtol { Ok(dev) } else { Err(dev) } } pub fn aclose(test: A, truth: A, atol: Tol) -> Result - where A: LinalgScalar + Absolute, - Tol: Float + where A: Field + Absolute, + Tol: RealField { let dev = (test - truth).abs(); if dev < atol { Ok(dev) } else { Err(dev) } @@ -25,8 +23,8 @@ pub fn aclose(test: A, truth: A, atol: Tol) -> Result /// check two arrays are close in maximum norm pub fn close_max(test: &ArrayBase, truth: &ArrayBase, atol: Tol) -> Result - where A: LinalgScalar + Absolute, - Tol: Float + Sum, + where A: Field + Absolute, + Tol: RealField, S1: Data, S2: Data, D: Dimension @@ -37,8 +35,8 @@ pub fn close_max(test: &ArrayBase, truth: &ArrayBase(test: &ArrayBase, truth: &ArrayBase, rtol: Tol) -> Result - where A: LinalgScalar + Absolute, - Tol: Float + Sum, + where A: Field + Absolute, + Tol: RealField, S1: Data, S2: Data, D: Dimension @@ -49,8 +47,8 @@ pub fn close_l1(test: &ArrayBase, truth: &ArrayBase(test: &ArrayBase, truth: &ArrayBase, rtol: Tol) -> Result - where A: LinalgScalar + Absolute, - Tol: Float + Sum, + where A: Field + Absolute, + Tol: RealField, S1: Data, S2: Data, D: Dimension diff --git a/src/impl2/mod.rs b/src/impl2/mod.rs index 91b0bcf0..fd5dba5e 100644 --- a/src/impl2/mod.rs +++ b/src/impl2/mod.rs @@ -15,8 +15,12 @@ pub use self::eigh::*; use super::error::*; -pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {} -impl LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {} +trait_alias!(LapackScalar: OperatorNorm_, + QR_, + SVD_, + Solve_, + Cholesky_, + Eigh_); pub fn into_result(info: i32, val: T) -> Result { if info == 0 { diff --git a/src/lib.rs b/src/lib.rs index b45c369b..f654d8a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ extern crate enum_error_derive; #[macro_use] extern crate derive_new; +#[macro_use] pub mod types; pub mod error; pub mod layout; @@ -56,12 +57,12 @@ pub mod solve; pub mod cholesky; pub mod eigh; -pub mod vector; pub mod matrix; pub mod square; pub mod triangular; -pub mod util; pub mod generate; pub mod assert; +pub mod norm; + pub mod prelude; diff --git a/src/vector.rs b/src/norm.rs similarity index 62% rename from src/vector.rs rename to src/norm.rs index e770e7ed..63877b07 100644 --- a/src/vector.rs +++ b/src/norm.rs @@ -1,8 +1,7 @@ //! Define trait for vectors -use std::iter::Sum; +use std::ops::*; use ndarray::*; -use num_traits::Float; use super::types::*; @@ -24,8 +23,8 @@ pub trait Norm { } impl Norm for ArrayBase - where A: LinalgScalar + Absolute, - T: Float + Sum, + where A: Field + Absolute, + T: RealField, S: Data, D: Dimension { @@ -43,3 +42,23 @@ impl Norm for ArrayBase }) } } + +pub enum NormalizeAxis { + Row = 0, + Column = 1, +} + +/// normalize in L2 norm +pub fn normalize(mut m: ArrayBase, axis: NormalizeAxis) -> (ArrayBase, Vec) + where A: Field + Absolute + Div, + S: DataMut, + T: RealField +{ + let mut ms = Vec::new(); + for mut v in m.axis_iter_mut(Axis(axis as usize)) { + let n = v.norm(); + ms.push(n); + v.map_inplace(|x| *x = *x / n) + } + (m, ms) +} diff --git a/src/prelude.rs b/src/prelude.rs index d3f0da73..f34cb921 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,8 +1,7 @@ -pub use vector::Norm; pub use matrix::Matrix; pub use square::SquareMatrix; pub use triangular::*; -pub use util::*; +pub use norm::*; pub use types::*; pub use generate::*; pub use assert::*; diff --git a/src/types.rs b/src/types.rs index d8118b69..d2ce0759 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,11 +1,39 @@ -pub use num_complex::Complex32 as c32; -pub use num_complex::Complex64 as c64; -use num_complex::Complex; -use num_traits::Float; use std::ops::*; +use std::fmt::Debug; +use std::iter::Sum; +use num_complex::Complex; +use num_traits::*; use rand::Rng; use rand::distributions::*; +use ndarray::LinalgScalar; + +use super::impl2::LapackScalar; + +pub use num_complex::Complex32 as c32; +pub use num_complex::Complex64 as c64; + +macro_rules! trait_alias { + ($name:ident: $($t:ident),*) => { + +pub trait $name : $($t +)* {} + +impl $name for T where T: $($t +)* {} + +}} // trait_alias! + +trait_alias!(Field: LapackScalar, + LinalgScalar, + AssociatedReal, + AssociatedComplex, + Absolute, + SquareRoot, + Conjugate, + RandNormal, + Sum, + Debug); + +trait_alias!(RealField: Field, Float); pub trait AssociatedReal: Sized { type Real: Float + Mul; @@ -16,13 +44,17 @@ pub trait AssociatedComplex: Sized { /// Field with norm pub trait Absolute { - type Output: Float; + type Output: RealField; fn squared(&self) -> Self::Output; fn abs(&self) -> Self::Output { self.squared().sqrt() } } +pub trait SquareRoot { + fn sqrt(&self) -> Self; +} + pub trait Conjugate: Copy { fn conj(self) -> Self; } @@ -70,6 +102,18 @@ impl Absolute for $complex { } } +impl SquareRoot for $real { + fn sqrt(&self) -> Self { + Float::sqrt(*self) + } +} + +impl SquareRoot for $complex { + fn sqrt(&self) -> Self { + Complex::sqrt(self) + } +} + impl Conjugate for $real { fn conj(self) -> Self { self diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 03cb9482..00000000 --- a/src/util.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! misc utilities - -use std::iter::Sum; -use ndarray::*; -use num_traits::Float; -use std::ops::Div; - -use super::types::*; -use super::vector::*; - -pub enum NormalizeAxis { - Row = 0, - Column = 1, -} - -/// normalize in L2 norm -pub fn normalize(mut m: ArrayBase, axis: NormalizeAxis) -> (ArrayBase, Vec) - where A: LinalgScalar + Absolute + Div, - S: DataMut, - T: Float + Sum -{ - let mut ms = Vec::new(); - for mut v in m.axis_iter_mut(Axis(axis as usize)) { - let n = v.norm(); - ms.push(n); - v.map_inplace(|x| *x = *x / n) - } - (m, ms) -}