Skip to content

Commit b7c0ecf

Browse files
authored
Merge pull request #49 from termoshtt/field
Field trait
2 parents a5ee745 + fbfef26 commit b7c0ecf

File tree

7 files changed

+93
-57
lines changed

7 files changed

+93
-57
lines changed

src/assert.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
11
//! Assertions for array
22
3-
use std::iter::Sum;
4-
use num_traits::Float;
53
use ndarray::*;
64

75
use super::types::*;
8-
use super::vector::*;
6+
use super::norm::*;
97

108
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>
11-
where A: LinalgScalar + Absolute<Output = Tol>,
12-
Tol: Float
9+
where A: Field + Absolute<Output = Tol>,
10+
Tol: RealField
1311
{
1412
let dev = (test - truth).abs() / truth.abs();
1513
if dev < rtol { Ok(dev) } else { Err(dev) }
1614
}
1715

1816
pub fn aclose<A, Tol>(test: A, truth: A, atol: Tol) -> Result<Tol, Tol>
19-
where A: LinalgScalar + Absolute<Output = Tol>,
20-
Tol: Float
17+
where A: Field + Absolute<Output = Tol>,
18+
Tol: RealField
2119
{
2220
let dev = (test - truth).abs();
2321
if dev < atol { Ok(dev) } else { Err(dev) }
2422
}
2523

2624
/// check two arrays are close in maximum norm
2725
pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: Tol) -> Result<Tol, Tol>
28-
where A: LinalgScalar + Absolute<Output = Tol>,
29-
Tol: Float + Sum,
26+
where A: Field + Absolute<Output = Tol>,
27+
Tol: RealField,
3028
S1: Data<Elem = A>,
3129
S2: Data<Elem = A>,
3230
D: Dimension
@@ -37,8 +35,8 @@ pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S
3735

3836
/// check two arrays are close in L1 norm
3937
pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
40-
where A: LinalgScalar + Absolute<Output = Tol>,
41-
Tol: Float + Sum,
38+
where A: Field + Absolute<Output = Tol>,
39+
Tol: RealField,
4240
S1: Data<Elem = A>,
4341
S2: Data<Elem = A>,
4442
D: Dimension
@@ -49,8 +47,8 @@ pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2
4947

5048
/// check two arrays are close in L2 norm
5149
pub fn close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
52-
where A: LinalgScalar + Absolute<Output = Tol>,
53-
Tol: Float + Sum,
50+
where A: Field + Absolute<Output = Tol>,
51+
Tol: RealField,
5452
S1: Data<Elem = A>,
5553
S2: Data<Elem = A>,
5654
D: Dimension

src/impl2/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ pub use self::eigh::*;
1515

1616
use super::error::*;
1717

18-
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {}
19-
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {}
18+
trait_alias!(LapackScalar: OperatorNorm_,
19+
QR_,
20+
SVD_,
21+
Solve_,
22+
Cholesky_,
23+
Eigh_);
2024

2125
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
2226
if info == 0 {

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ extern crate enum_error_derive;
4343
#[macro_use]
4444
extern crate derive_new;
4545

46+
#[macro_use]
4647
pub mod types;
4748
pub mod error;
4849
pub mod layout;
@@ -56,12 +57,12 @@ pub mod solve;
5657
pub mod cholesky;
5758
pub mod eigh;
5859

59-
pub mod vector;
6060
pub mod matrix;
6161
pub mod square;
6262
pub mod triangular;
6363

64-
pub mod util;
6564
pub mod generate;
6665
pub mod assert;
66+
pub mod norm;
67+
6768
pub mod prelude;

src/vector.rs renamed to src/norm.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
//! Define trait for vectors
22
3-
use std::iter::Sum;
3+
use std::ops::*;
44
use ndarray::*;
5-
use num_traits::Float;
65

76
use super::types::*;
87

@@ -24,8 +23,8 @@ pub trait Norm {
2423
}
2524

2625
impl<A, S, D, T> Norm for ArrayBase<S, D>
27-
where A: LinalgScalar + Absolute<Output = T>,
28-
T: Float + Sum,
26+
where A: Field + Absolute<Output = T>,
27+
T: RealField,
2928
S: Data<Elem = A>,
3029
D: Dimension
3130
{
@@ -43,3 +42,23 @@ impl<A, S, D, T> Norm for ArrayBase<S, D>
4342
})
4443
}
4544
}
45+
46+
pub enum NormalizeAxis {
47+
Row = 0,
48+
Column = 1,
49+
}
50+
51+
/// normalize in L2 norm
52+
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
53+
where A: Field + Absolute<Output = T> + Div<T, Output = A>,
54+
S: DataMut<Elem = A>,
55+
T: RealField
56+
{
57+
let mut ms = Vec::new();
58+
for mut v in m.axis_iter_mut(Axis(axis as usize)) {
59+
let n = v.norm();
60+
ms.push(n);
61+
v.map_inplace(|x| *x = *x / n)
62+
}
63+
(m, ms)
64+
}

src/prelude.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
pub use vector::Norm;
21
pub use matrix::Matrix;
32
pub use square::SquareMatrix;
43
pub use triangular::*;
5-
pub use util::*;
4+
pub use norm::*;
65
pub use types::*;
76
pub use generate::*;
87
pub use assert::*;

src/types.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
11

2-
pub use num_complex::Complex32 as c32;
3-
pub use num_complex::Complex64 as c64;
4-
use num_complex::Complex;
5-
use num_traits::Float;
62
use std::ops::*;
3+
use std::fmt::Debug;
4+
use std::iter::Sum;
5+
use num_complex::Complex;
6+
use num_traits::*;
77
use rand::Rng;
88
use rand::distributions::*;
9+
use ndarray::LinalgScalar;
10+
11+
use super::impl2::LapackScalar;
12+
13+
pub use num_complex::Complex32 as c32;
14+
pub use num_complex::Complex64 as c64;
15+
16+
macro_rules! trait_alias {
17+
($name:ident: $($t:ident),*) => {
18+
19+
pub trait $name : $($t +)* {}
20+
21+
impl<T> $name for T where T: $($t +)* {}
22+
23+
}} // trait_alias!
24+
25+
trait_alias!(Field: LapackScalar,
26+
LinalgScalar,
27+
AssociatedReal,
28+
AssociatedComplex,
29+
Absolute,
30+
SquareRoot,
31+
Conjugate,
32+
RandNormal,
33+
Sum,
34+
Debug);
35+
36+
trait_alias!(RealField: Field, Float);
937

1038
pub trait AssociatedReal: Sized {
1139
type Real: Float + Mul<Self, Output = Self>;
@@ -16,13 +44,17 @@ pub trait AssociatedComplex: Sized {
1644

1745
/// Field with norm
1846
pub trait Absolute {
19-
type Output: Float;
47+
type Output: RealField;
2048
fn squared(&self) -> Self::Output;
2149
fn abs(&self) -> Self::Output {
2250
self.squared().sqrt()
2351
}
2452
}
2553

54+
pub trait SquareRoot {
55+
fn sqrt(&self) -> Self;
56+
}
57+
2658
pub trait Conjugate: Copy {
2759
fn conj(self) -> Self;
2860
}
@@ -70,6 +102,18 @@ impl Absolute for $complex {
70102
}
71103
}
72104

105+
impl SquareRoot for $real {
106+
fn sqrt(&self) -> Self {
107+
Float::sqrt(*self)
108+
}
109+
}
110+
111+
impl SquareRoot for $complex {
112+
fn sqrt(&self) -> Self {
113+
Complex::sqrt(self)
114+
}
115+
}
116+
73117
impl Conjugate for $real {
74118
fn conj(self) -> Self {
75119
self

src/util.rs

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)