Skip to content

Commit 00f2f83

Browse files
committed
Split Field/RealField from #45
1 parent a5ee745 commit 00f2f83

File tree

4 files changed

+63
-23
lines changed

4 files changed

+63
-23
lines changed

src/assert.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
11
//! Assertions for array
22
3-
use std::iter::Sum;
4-
use num_traits::Float;
53
use ndarray::*;
64

75
use super::types::*;
86
use super::vector::*;
97

108
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>
11-
where A: LinalgScalar + Absolute<Output = Tol>,
12-
Tol: Float
9+
where A: Field + Absolute<Output = Tol>,
10+
Tol: RealField
1311
{
1412
let dev = (test - truth).abs() / truth.abs();
1513
if dev < rtol { Ok(dev) } else { Err(dev) }
1614
}
1715

1816
pub fn aclose<A, Tol>(test: A, truth: A, atol: Tol) -> Result<Tol, Tol>
19-
where A: LinalgScalar + Absolute<Output = Tol>,
20-
Tol: Float
17+
where A: Field + Absolute<Output = Tol>,
18+
Tol: RealField
2119
{
2220
let dev = (test - truth).abs();
2321
if dev < atol { Ok(dev) } else { Err(dev) }
2422
}
2523

2624
/// check two arrays are close in maximum norm
2725
pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: Tol) -> Result<Tol, Tol>
28-
where A: LinalgScalar + Absolute<Output = Tol>,
29-
Tol: Float + Sum,
26+
where A: Field + Absolute<Output = Tol>,
27+
Tol: RealField,
3028
S1: Data<Elem = A>,
3129
S2: Data<Elem = A>,
3230
D: Dimension
@@ -37,8 +35,8 @@ pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S
3735

3836
/// check two arrays are close in L1 norm
3937
pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
40-
where A: LinalgScalar + Absolute<Output = Tol>,
41-
Tol: Float + Sum,
38+
where A: Field + Absolute<Output = Tol>,
39+
Tol: RealField,
4240
S1: Data<Elem = A>,
4341
S2: Data<Elem = A>,
4442
D: Dimension
@@ -49,8 +47,8 @@ pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2
4947

5048
/// check two arrays are close in L2 norm
5149
pub fn close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
52-
where A: LinalgScalar + Absolute<Output = Tol>,
53-
Tol: Float + Sum,
50+
where A: Field + Absolute<Output = Tol>,
51+
Tol: RealField,
5452
S1: Data<Elem = A>,
5553
S2: Data<Elem = A>,
5654
D: Dimension

src/types.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
11

2-
pub use num_complex::Complex32 as c32;
3-
pub use num_complex::Complex64 as c64;
4-
use num_complex::Complex;
5-
use num_traits::Float;
62
use std::ops::*;
3+
use std::fmt::Debug;
4+
use std::iter::Sum;
5+
use num_complex::Complex;
6+
use num_traits::*;
77
use rand::Rng;
88
use rand::distributions::*;
9+
use ndarray::LinalgScalar;
10+
11+
use super::impl2::LapackScalar;
12+
13+
pub use num_complex::Complex32 as c32;
14+
pub use num_complex::Complex64 as c64;
15+
16+
macro_rules! trait_alias {
17+
($name:ident: $($t:ident),*) => {
18+
19+
pub trait $name : $($t +)* {}
20+
21+
impl<T> $name for T where T: $($t +)* {}
22+
23+
}} // trait_alias!
24+
25+
trait_alias!(Field: LapackScalar,
26+
LinalgScalar,
27+
AssociatedReal,
28+
AssociatedComplex,
29+
Absolute,
30+
SquareRoot,
31+
Conjugate,
32+
RandNormal,
33+
Sum,
34+
Debug);
35+
36+
trait_alias!(RealField: Field, Float);
937

1038
pub trait AssociatedReal: Sized {
1139
type Real: Float + Mul<Self, Output = Self>;
@@ -16,13 +44,17 @@ pub trait AssociatedComplex: Sized {
1644

1745
/// Field with norm
1846
pub trait Absolute {
19-
type Output: Float;
47+
type Output: RealField;
2048
fn squared(&self) -> Self::Output;
2149
fn abs(&self) -> Self::Output {
2250
self.squared().sqrt()
2351
}
2452
}
2553

54+
pub trait SquareRoot {
55+
fn sqrt(&self) -> Self;
56+
}
57+
2658
pub trait Conjugate: Copy {
2759
fn conj(self) -> Self;
2860
}
@@ -70,6 +102,18 @@ impl Absolute for $complex {
70102
}
71103
}
72104

105+
impl SquareRoot for $real {
106+
fn sqrt(&self) -> Self {
107+
Float::sqrt(*self)
108+
}
109+
}
110+
111+
impl SquareRoot for $complex {
112+
fn sqrt(&self) -> Self {
113+
Complex::sqrt(self)
114+
}
115+
}
116+
73117
impl Conjugate for $real {
74118
fn conj(self) -> Self {
75119
self

src/util.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! misc utilities
22
3-
use std::iter::Sum;
43
use ndarray::*;
54
use num_traits::Float;
65
use std::ops::Div;
@@ -15,9 +14,9 @@ pub enum NormalizeAxis {
1514

1615
/// normalize in L2 norm
1716
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
18-
where A: LinalgScalar + Absolute<Output = T> + Div<T, Output = A>,
17+
where A: Field + Absolute<Output = T> + Div<T, Output = A>,
1918
S: DataMut<Elem = A>,
20-
T: Float + Sum
19+
T: Field + Float
2120
{
2221
let mut ms = Vec::new();
2322
for mut v in m.axis_iter_mut(Axis(axis as usize)) {

src/vector.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! Define trait for vectors
22
3-
use std::iter::Sum;
43
use ndarray::*;
54
use num_traits::Float;
65

@@ -24,8 +23,8 @@ pub trait Norm {
2423
}
2524

2625
impl<A, S, D, T> Norm for ArrayBase<S, D>
27-
where A: LinalgScalar + Absolute<Output = T>,
28-
T: Float + Sum,
26+
where A: Field + Absolute<Output = T>,
27+
T: Field + Float,
2928
S: Data<Elem = A>,
3029
D: Dimension
3130
{

0 commit comments

Comments
 (0)