Skip to content

Commit a6640e2

Browse files
committed
Import assert module from ndarray-numtest
1 parent d6c8612 commit a6640e2

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ openblas = ["blas/openblas", "lapack/openblas"]
1515
netlib = ["blas/netlib", "lapack/netlib"]
1616

1717
[dependencies]
18+
float-cmp = "*"
1819
num-traits = "0.1"
20+
num-complex = "0.1"
1921
ndarray = { version = "0.9", default-features = false, features = ["blas"] }
2022
lapack = { version = "0.11", default-features = false }
2123
blas = { version = "0.15", default-features = false }

src/assert.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
//! Assertions for value and array
2+
3+
use ndarray::{Array, Dimension, IntoDimension};
4+
use float_cmp::ApproxEqRatio;
5+
use num_complex::Complex;
6+
7+
/// test two values are close in relative tolerance sense
8+
pub trait AssertClose: Sized + Copy {
9+
type Tol;
10+
fn assert_close(self, truth: Self, rtol: Self::Tol);
11+
}
12+
13+
macro_rules! impl_AssertClose {
14+
($scalar:ty) => {
15+
impl AssertClose for $scalar {
16+
type Tol = $scalar;
17+
fn assert_close(self, truth: Self, rtol: Self::Tol) {
18+
if !self.approx_eq_ratio(&truth, rtol) {
19+
panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
20+
}
21+
}
22+
}
23+
impl AssertClose for Complex<$scalar> {
24+
type Tol = $scalar;
25+
fn assert_close(self, truth: Self, rtol: Self::Tol) {
26+
if !(self.re.approx_eq_ratio(&truth.re, rtol) && self.im.approx_eq_ratio(&truth.im, rtol)) {
27+
panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
28+
}
29+
}
30+
}
31+
}} // impl_AssertClose
32+
impl_AssertClose!(f64);
33+
impl_AssertClose!(f32);
34+
35+
/// test two arrays are close
36+
pub trait AssertAllClose {
37+
type Tol;
38+
/// test two arrays are close in L2-norm with relative tolerance
39+
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol);
40+
/// test two arrays are close in inf-norm with absolute tolerance
41+
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol);
42+
}
43+
44+
macro_rules! impl_AssertAllClose {
45+
($scalar:ty, $float:ty, $abs:ident) => {
46+
impl AssertAllClose for [$scalar]{
47+
type Tol = $float;
48+
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
49+
for (x, y) in self.iter().zip(truth.iter()) {
50+
let tol = (x - y).$abs();
51+
if tol > atol {
52+
panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
53+
atol, self, truth);
54+
}
55+
}
56+
}
57+
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
58+
let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
59+
let dev: Self::Tol = self.iter().zip(truth.iter()).map(|(x, y)| (x-y).$abs().powi(2)).sum();
60+
if dev / nrm > rtol.powi(2) {
61+
panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
62+
rtol, self, truth);
63+
}
64+
}
65+
}
66+
67+
impl AssertAllClose for Vec<$scalar> {
68+
type Tol = $float;
69+
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
70+
self.as_slice().assert_allclose_inf(&truth, atol);
71+
}
72+
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
73+
self.as_slice().assert_allclose_l2(&truth, rtol);
74+
}
75+
}
76+
77+
impl<D: Dimension> AssertAllClose for Array<$scalar, D> {
78+
type Tol = $float;
79+
fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
80+
if self.shape() != truth.shape() {
81+
panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
82+
}
83+
for (idx, val) in self.indexed_iter() {
84+
let t = truth[idx.into_dimension()];
85+
let tol = (*val - t).$abs();
86+
if tol > atol {
87+
panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
88+
atol, self, truth);
89+
}
90+
}
91+
}
92+
fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
93+
if self.shape() != truth.shape() {
94+
panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
95+
}
96+
let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
97+
let dev: Self::Tol = self.indexed_iter().map(|(idx, val)| (truth[idx.into_dimension()] - val).$abs().powi(2)).sum();
98+
if dev / nrm > rtol.powi(2) {
99+
panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
100+
rtol, self, truth);
101+
}
102+
}
103+
}
104+
}} // impl_AssertAllClose
105+
106+
impl_AssertAllClose!(f64, f64, abs);
107+
impl_AssertAllClose!(f32, f32, abs);
108+
impl_AssertAllClose!(Complex<f64>, f64, norm);
109+
impl_AssertAllClose!(Complex<f32>, f32, norm);

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
extern crate blas;
3535
extern crate lapack;
3636
extern crate num_traits;
37+
extern crate num_complex;
38+
extern crate float_cmp;
3739
#[macro_use(s)]
3840
extern crate ndarray;
3941

@@ -47,4 +49,5 @@ pub mod hermite;
4749
pub mod triangular;
4850

4951
pub mod util;
52+
pub mod assert;
5053
pub mod prelude;

0 commit comments

Comments
 (0)