Skip to content

Commit c016bc5

Browse files
committed
Move asserts in util into assert submodule
1 parent 0e32c8f commit c016bc5

File tree

3 files changed

+86
-138
lines changed

3 files changed

+86
-138
lines changed

src/assert.rs

Lines changed: 73 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,96 @@
1-
//! Assertions for value and array
1+
//! Assertions for array
22
3-
use ndarray::{Array, Dimension, IntoDimension};
4-
use float_cmp::ApproxEqRatio;
5-
use num_complex::Complex;
3+
use std::iter::Sum;
4+
use num_traits::Float;
5+
use ndarray::*;
66

7-
/// test two values are close in relative tolerance sense
8-
pub trait AssertClose: Sized + Copy {
9-
type Tol;
10-
fn assert_close(self, truth: Self, rtol: Self::Tol);
7+
use super::vector::*;
8+
9+
pub trait Close: Absolute {
10+
fn rclose(self, truth: Self, relative_tol: Self::Output) -> Result<Self::Output, Self::Output>;
11+
fn aclose(self, truth: Self, absolute_tol: Self::Output) -> Result<Self::Output, Self::Output>;
1112
}
1213

1314
macro_rules! impl_AssertClose {
1415
($scalar:ty) => {
15-
impl AssertClose for $scalar {
16-
type Tol = $scalar;
17-
fn assert_close(self, truth: Self, rtol: Self::Tol) {
18-
if !self.approx_eq_ratio(&truth, rtol) {
19-
panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
16+
impl Close for $scalar {
17+
fn rclose(self, truth: Self, rtol: Self::Output) -> Result<Self::Output, Self::Output> {
18+
let dev = (self - truth).abs() / truth.abs();
19+
if dev < rtol {
20+
Ok(dev)
21+
} else {
22+
Err(dev)
2023
}
2124
}
22-
}
23-
impl AssertClose for Complex<$scalar> {
24-
type Tol = $scalar;
25-
fn assert_close(self, truth: Self, rtol: Self::Tol) {
26-
if !(self.re.approx_eq_ratio(&truth.re, rtol) && self.im.approx_eq_ratio(&truth.im, rtol)) {
27-
panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
25+
26+
fn aclose(self, truth: Self, atol: Self::Output) -> Result<Self::Output, Self::Output> {
27+
let dev = (self - truth).abs();
28+
if dev < atol {
29+
Ok(dev)
30+
} else {
31+
Err(dev)
2832
}
2933
}
3034
}
3135
}} // impl_AssertClose
3236
impl_AssertClose!(f64);
3337
impl_AssertClose!(f32);
3438

35-
/// test two arrays are close
36-
pub trait AssertAllClose {
37-
type Tol;
38-
/// test two arrays are close in L2-norm with relative tolerance
39-
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol);
40-
/// test two arrays are close in inf-norm with absolute tolerance
41-
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol);
39+
#[macro_export]
40+
macro_rules! assert_rclose {
41+
($test:expr, $truth:expr, $tol:expr) => {
42+
$test.rclose($truth, $tol).unwrap();
43+
};
44+
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
45+
$test.rclose($truth, $tol).expect($comment);
46+
};
4247
}
4348

44-
macro_rules! impl_AssertAllClose {
45-
($scalar:ty, $float:ty, $abs:ident) => {
46-
impl AssertAllClose for [$scalar]{
47-
type Tol = $float;
48-
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
49-
for (x, y) in self.iter().zip(truth.iter()) {
50-
let tol = (x - y).$abs();
51-
if tol > atol {
52-
panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
53-
atol, self, truth);
54-
}
55-
}
56-
}
57-
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
58-
let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
59-
let dev: Self::Tol = self.iter().zip(truth.iter()).map(|(x, y)| (x-y).$abs().powi(2)).sum();
60-
if dev / nrm > rtol.powi(2) {
61-
panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
62-
rtol, self, truth);
63-
}
64-
}
49+
#[macro_export]
50+
macro_rules! assert_aclose {
51+
($test:expr, $truth:expr, $tol:expr) => {
52+
$test.aclose($truth, $tol).unwrap();
53+
};
54+
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
55+
$test.aclose($truth, $tol).expect($comment);
56+
};
6557
}
6658

67-
impl AssertAllClose for Vec<$scalar> {
68-
type Tol = $float;
69-
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
70-
self.as_slice().assert_allclose_inf(&truth, atol);
71-
}
72-
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
73-
self.as_slice().assert_allclose_l2(&truth, rtol);
74-
}
59+
/// check two arrays are close in maximum norm
60+
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
61+
truth: &ArrayBase<S2, D>,
62+
atol: Tol)
63+
-> Result<Tol, Tol>
64+
where A: LinalgScalar + Absolute<Output = Tol>,
65+
Tol: Float + Sum,
66+
S1: Data<Elem = A>,
67+
S2: Data<Elem = A>,
68+
D: Dimension
69+
{
70+
let tol = (test - truth).norm_max();
71+
if tol < atol { Ok(tol) } else { Err(tol) }
7572
}
7673

77-
impl<D: Dimension> AssertAllClose for Array<$scalar, D> {
78-
type Tol = $float;
79-
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
80-
if self.shape() != truth.shape() {
81-
panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
82-
}
83-
for (idx, val) in self.indexed_iter() {
84-
let t = truth[idx.into_dimension()];
85-
let tol = (*val - t).$abs();
86-
if tol > atol {
87-
panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
88-
atol, self, truth);
89-
}
90-
}
91-
}
92-
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
93-
if self.shape() != truth.shape() {
94-
panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
95-
}
96-
let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
97-
let dev: Self::Tol = self.indexed_iter().map(|(idx, val)| (truth[idx.into_dimension()] - val).$abs().powi(2)).sum();
98-
if dev / nrm > rtol.powi(2) {
99-
panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
100-
rtol, self, truth);
101-
}
102-
}
74+
/// check two arrays are close in L1 norm
75+
pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
76+
where A: LinalgScalar + Absolute<Output = Tol>,
77+
Tol: Float + Sum,
78+
S1: Data<Elem = A>,
79+
S2: Data<Elem = A>,
80+
D: Dimension
81+
{
82+
let tol = (test - truth).norm_l1() / truth.norm_l1();
83+
if tol < rtol { Ok(tol) } else { Err(tol) }
10384
}
104-
}} // impl_AssertAllClose
10585

106-
impl_AssertAllClose!(f64, f64, abs);
107-
impl_AssertAllClose!(f32, f32, abs);
108-
impl_AssertAllClose!(Complex<f64>, f64, norm);
109-
impl_AssertAllClose!(Complex<f32>, f32, norm);
86+
/// check two arrays are close in L2 norm
87+
pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
88+
where A: LinalgScalar + Absolute<Output = Tol>,
89+
Tol: Float + Sum,
90+
S1: Data<Elem = A>,
91+
S2: Data<Elem = A>,
92+
D: Dimension
93+
{
94+
let tol = (test - truth).norm_l2() / truth.norm_l2();
95+
if tol < rtol { Ok(tol) } else { Err(tol) }
96+
}

src/util.rs

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub enum NormalizeAxis {
5353

5454
/// normalize in L2 norm
5555
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
56-
where A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
56+
where A: LinalgScalar + Absolute<Output = T> + Div<T, Output = A>,
5757
S: DataMut<Elem = A>,
5858
T: Float + Sum
5959
{
@@ -65,42 +65,3 @@ pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (Arr
6565
}
6666
(m, ms)
6767
}
68-
69-
/// check two arrays are close in maximum norm
70-
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
71-
truth: &ArrayBase<S2, D>,
72-
atol: Tol)
73-
-> Result<Tol, Tol>
74-
where A: LinalgScalar + NormedField<Output = Tol>,
75-
Tol: Float + Sum,
76-
S1: Data<Elem = A>,
77-
S2: Data<Elem = A>,
78-
D: Dimension
79-
{
80-
let tol = (test - truth).norm_max();
81-
if tol < atol { Ok(tol) } else { Err(tol) }
82-
}
83-
84-
/// check two arrays are close in L1 norm
85-
pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
86-
where A: LinalgScalar + NormedField<Output = Tol>,
87-
Tol: Float + Sum,
88-
S1: Data<Elem = A>,
89-
S2: Data<Elem = A>,
90-
D: Dimension
91-
{
92-
let tol = (test - truth).norm_l1() / truth.norm_l1();
93-
if tol < rtol { Ok(tol) } else { Err(tol) }
94-
}
95-
96-
/// check two arrays are close in L2 norm
97-
pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
98-
where A: LinalgScalar + NormedField<Output = Tol>,
99-
Tol: Float + Sum,
100-
S1: Data<Elem = A>,
101-
S2: Data<Elem = A>,
102-
D: Dimension
103-
{
104-
let tol = (test - truth).norm_l2() / truth.norm_l2();
105-
if tol < rtol { Ok(tol) } else { Err(tol) }
106-
}

src/vector.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,42 @@ pub trait Norm {
2121
}
2222

2323
impl<A, S, D, T> Norm for ArrayBase<S, D>
24-
where A: LinalgScalar + NormedField<Output = T>,
24+
where A: LinalgScalar + Absolute<Output = T>,
2525
T: Float + Sum,
2626
S: Data<Elem = A>,
2727
D: Dimension
2828
{
2929
type Output = T;
3030
fn norm_l1(&self) -> Self::Output {
31-
self.iter().map(|x| x.norm()).sum()
31+
self.iter().map(|x| x.abs()).sum()
3232
}
3333
fn norm_l2(&self) -> Self::Output {
34-
self.iter().map(|x| x.squared()).sum::<T>().sqrt()
34+
self.iter().map(|x| x.sq_abs()).sum::<T>().sqrt()
3535
}
3636
fn norm_max(&self) -> Self::Output {
3737
self.iter().fold(T::zero(), |f, &val| {
38-
let v = val.norm();
38+
let v = val.abs();
3939
if f > v { f } else { v }
4040
})
4141
}
4242
}
4343

4444
/// Field with norm
45-
pub trait NormedField {
45+
pub trait Absolute {
4646
type Output: Float;
47-
fn squared(&self) -> Self::Output;
48-
fn norm(&self) -> Self::Output {
49-
self.squared().sqrt()
47+
fn sq_abs(&self) -> Self::Output;
48+
fn abs(&self) -> Self::Output {
49+
self.sq_abs().sqrt()
5050
}
5151
}
5252

53-
impl<A: Float> NormedField for A {
53+
impl<A: Float> Absolute for A {
5454
type Output = A;
55-
fn squared(&self) -> A {
55+
fn sq_abs(&self) -> A {
5656
*self * *self
5757
}
58-
fn norm(&self) -> A {
59-
self.abs()
58+
fn abs(&self) -> A {
59+
Float::abs(*self)
6060
}
6161
}
6262

0 commit comments

Comments
 (0)