|
| 1 | +//! Assertions for value and array |
| 2 | +
|
| 3 | +use ndarray::{Array, Dimension, IntoDimension}; |
| 4 | +use float_cmp::ApproxEqRatio; |
| 5 | +use num_complex::Complex; |
| 6 | + |
| 7 | +/// test two values are close in relative tolerance sense |
| 8 | +pub trait AssertClose: Sized + Copy { |
| 9 | + type Tol; |
| 10 | + fn assert_close(self, truth: Self, rtol: Self::Tol); |
| 11 | +} |
| 12 | + |
| 13 | +macro_rules! impl_AssertClose { |
| 14 | + ($scalar:ty) => { |
| 15 | +impl AssertClose for $scalar { |
| 16 | + type Tol = $scalar; |
| 17 | + fn assert_close(self, truth: Self, rtol: Self::Tol) { |
| 18 | + if !self.approx_eq_ratio(&truth, rtol) { |
| 19 | + panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol); |
| 20 | + } |
| 21 | + } |
| 22 | +} |
| 23 | +impl AssertClose for Complex<$scalar> { |
| 24 | + type Tol = $scalar; |
| 25 | + fn assert_close(self, truth: Self, rtol: Self::Tol) { |
| 26 | + if !(self.re.approx_eq_ratio(&truth.re, rtol) && self.im.approx_eq_ratio(&truth.im, rtol)) { |
| 27 | + panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol); |
| 28 | + } |
| 29 | + } |
| 30 | +} |
| 31 | +}} // impl_AssertClose |
| 32 | +impl_AssertClose!(f64); |
| 33 | +impl_AssertClose!(f32); |
| 34 | + |
| 35 | +/// test two arrays are close |
| 36 | +pub trait AssertAllClose { |
| 37 | + type Tol; |
| 38 | + /// test two arrays are close in L2-norm with relative tolerance |
| 39 | + fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol); |
| 40 | + /// test two arrays are close in inf-norm with absolute tolerance |
| 41 | + fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol); |
| 42 | +} |
| 43 | + |
| 44 | +macro_rules! impl_AssertAllClose { |
| 45 | + ($scalar:ty, $float:ty, $abs:ident) => { |
| 46 | +impl AssertAllClose for [$scalar]{ |
| 47 | + type Tol = $float; |
| 48 | + fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) { |
| 49 | + for (x, y) in self.iter().zip(truth.iter()) { |
| 50 | + let tol = (x - y).$abs(); |
| 51 | + if tol > atol { |
| 52 | + panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}", |
| 53 | + atol, self, truth); |
| 54 | + } |
| 55 | + } |
| 56 | + } |
| 57 | + fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) { |
| 58 | + let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum(); |
| 59 | + let dev: Self::Tol = self.iter().zip(truth.iter()).map(|(x, y)| (x-y).$abs().powi(2)).sum(); |
| 60 | + if dev / nrm > rtol.powi(2) { |
| 61 | + panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}", |
| 62 | + rtol, self, truth); |
| 63 | + } |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +impl AssertAllClose for Vec<$scalar> { |
| 68 | + type Tol = $float; |
| 69 | + fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) { |
| 70 | + self.as_slice().assert_allclose_inf(&truth, atol); |
| 71 | + } |
| 72 | + fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) { |
| 73 | + self.as_slice().assert_allclose_l2(&truth, rtol); |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +impl<D: Dimension> AssertAllClose for Array<$scalar, D> { |
| 78 | + type Tol = $float; |
| 79 | + fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) { |
| 80 | + if self.shape() != truth.shape() { |
| 81 | + panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape()); |
| 82 | + } |
| 83 | + for (idx, val) in self.indexed_iter() { |
| 84 | + let t = truth[idx.into_dimension()]; |
| 85 | + let tol = (*val - t).$abs(); |
| 86 | + if tol > atol { |
| 87 | + panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}", |
| 88 | + atol, self, truth); |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) { |
| 93 | + if self.shape() != truth.shape() { |
| 94 | + panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape()); |
| 95 | + } |
| 96 | + let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum(); |
| 97 | + let dev: Self::Tol = self.indexed_iter().map(|(idx, val)| (truth[idx.into_dimension()] - val).$abs().powi(2)).sum(); |
| 98 | + if dev / nrm > rtol.powi(2) { |
| 99 | + panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}", |
| 100 | + rtol, self, truth); |
| 101 | + } |
| 102 | + } |
| 103 | +} |
| 104 | +}} // impl_AssertAllClose |
| 105 | + |
| 106 | +impl_AssertAllClose!(f64, f64, abs); |
| 107 | +impl_AssertAllClose!(f32, f32, abs); |
| 108 | +impl_AssertAllClose!(Complex<f64>, f64, norm); |
| 109 | +impl_AssertAllClose!(Complex<f32>, f32, norm); |
0 commit comments