Skip to content

Commit 69c94ff

Browse files
committed
Add assert_close_{l1,l2,max}
1 parent a231937 commit 69c94ff

File tree

11 files changed

+60
-78
lines changed

11 files changed

+60
-78
lines changed

src/assert.rs

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,24 @@ use ndarray::*;
66

77
use super::vector::*;
88

9-
pub trait Close: Absolute {
10-
fn rclose(self, truth: Self, relative_tol: Self::Output) -> Result<Self::Output, Self::Output>;
11-
fn aclose(self, truth: Self, absolute_tol: Self::Output) -> Result<Self::Output, Self::Output>;
12-
}
13-
14-
macro_rules! impl_AssertClose {
15-
($scalar:ty) => {
16-
impl Close for $scalar {
17-
fn rclose(self, truth: Self, rtol: Self::Output) -> Result<Self::Output, Self::Output> {
18-
let dev = (self - truth).abs() / truth.abs();
19-
if dev < rtol {
20-
Ok(dev)
21-
} else {
22-
Err(dev)
23-
}
24-
}
25-
26-
fn aclose(self, truth: Self, atol: Self::Output) -> Result<Self::Output, Self::Output> {
27-
let dev = (self - truth).abs();
28-
if dev < atol {
29-
Ok(dev)
30-
} else {
31-
Err(dev)
32-
}
33-
}
34-
}
35-
}} // impl_AssertClose
36-
impl_AssertClose!(f64);
37-
impl_AssertClose!(f32);
38-
39-
#[macro_export]
40-
macro_rules! assert_rclose {
41-
($test:expr, $truth:expr, $tol:expr) => {
42-
$test.rclose($truth, $tol).unwrap();
43-
};
44-
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
45-
$test.rclose($truth, $tol).expect($comment);
46-
};
9+
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>
10+
where A: LinalgScalar + Absolute<Output = Tol>,
11+
Tol: Float
12+
{
13+
let dev = (test - truth).abs() / truth.abs();
14+
if dev < rtol { Ok(dev) } else { Err(dev) }
4715
}
4816

49-
#[macro_export]
50-
macro_rules! assert_aclose {
51-
($test:expr, $truth:expr, $tol:expr) => {
52-
$test.aclose($truth, $tol).unwrap();
53-
};
54-
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
55-
$test.aclose($truth, $tol).expect($comment);
56-
};
17+
pub fn aclose<A, Tol>(test: A, truth: A, atol: Tol) -> Result<Tol, Tol>
18+
where A: LinalgScalar + Absolute<Output = Tol>,
19+
Tol: Float
20+
{
21+
let dev = (test - truth).abs();
22+
if dev < atol { Ok(dev) } else { Err(dev) }
5723
}
5824

5925
/// check two arrays are close in maximum norm
60-
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
61-
truth: &ArrayBase<S2, D>,
62-
atol: Tol)
63-
-> Result<Tol, Tol>
26+
pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: Tol) -> Result<Tol, Tol>
6427
where A: LinalgScalar + Absolute<Output = Tol>,
6528
Tol: Float + Sum,
6629
S1: Data<Elem = A>,
@@ -72,7 +35,7 @@ pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
7235
}
7336

7437
/// check two arrays are close in L1 norm
75-
pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
38+
pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
7639
where A: LinalgScalar + Absolute<Output = Tol>,
7740
Tol: Float + Sum,
7841
S1: Data<Elem = A>,
@@ -84,7 +47,7 @@ pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBas
8447
}
8548

8649
/// check two arrays are close in L2 norm
87-
pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
50+
pub fn close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
8851
where A: LinalgScalar + Absolute<Output = Tol>,
8952
Tol: Float + Sum,
9053
S1: Data<Elem = A>,
@@ -94,3 +57,22 @@ pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBas
9457
let tol = (test - truth).norm_l2() / truth.norm_l2();
9558
if tol < rtol { Ok(tol) } else { Err(tol) }
9659
}
60+
61+
macro_rules! generate_assert {
62+
($assert:ident, $close:path) => {
63+
#[macro_export]
64+
macro_rules! $assert {
65+
($test:expr, $truth:expr, $tol:expr) => {
66+
$close($test, $truth, $tol).unwrap();
67+
};
68+
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
69+
$close($test, $truth, $tol).expect($comment);
70+
};
71+
}
72+
}} // generate_assert!
73+
74+
generate_assert!(assert_rclose, rclose);
75+
generate_assert!(assert_aclose, aclose);
76+
generate_assert!(assert_close_max, close_max);
77+
generate_assert!(assert_close_l1, close_l1);
78+
generate_assert!(assert_close_l2, close_l2);

tests/cholesky.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ mod $modname {
1212
let c = a.$clone().cholesky().unwrap();
1313
println!("c = \n{:?}", c);
1414
println!("cc = \n{:?}", c.t().dot(&c));
15-
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
15+
assert_close_l2!(&c.t().dot(&c), &a, 1e-7);
1616
}
1717
#[test]
1818
fn cholesky_t() {
@@ -21,7 +21,7 @@ mod $modname {
2121
let c = a.$clone().cholesky().unwrap();
2222
println!("c = \n{:?}", c);
2323
println!("cc = \n{:?}", c.t().dot(&c));
24-
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
24+
assert_close_l2!(&c.t().dot(&c), &a, 1e-7);
2525
}
2626
}
2727
}} // impl_test

tests/eigh.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ mod $modname {
99
fn eigen_vector_manual() {
1010
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
1111
let (e, vecs) = a.$clone().eigh().unwrap();
12-
all_close_l2(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7).unwrap();
12+
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
1313
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
1414
let av = a.dot(&v);
1515
let ev = v.mapv(|x| e[i] * x);
16-
all_close_l2(&av, &ev, 1.0e-7).unwrap();
16+
assert_close_l2!(&av, &ev, 1.0e-7);
1717
}
1818
}
1919
#[test]

tests/inv.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ mod $modname {
1111
let a = random_square(3);
1212
let ai = a.$clone().inv().unwrap();
1313
let id = Array::eye(3);
14-
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
14+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
1515
}
1616

1717
#[test]
1818
fn inv_random_t() {
1919
let a = random_square(3).reversed_axes();
2020
let ai = a.$clone().inv().unwrap();
2121
let id = Array::eye(3);
22-
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
22+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
2323
}
2424

2525
#[test]

tests/lu.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ fn $funcname() {
1212
println!("L = \n{:?}", &l);
1313
println!("U = \n{:?}", &u);
1414
println!("LU = \n{:?}", l.dot(&u));
15-
all_close_l2(&l.dot(&u).permutated(&p), &ans, 1e-7).unwrap();
15+
assert_close_l2!(&l.dot(&u).permutated(&p), &ans, 1e-7);
1616
}
1717
}} // impl_test
1818

tests/normalize.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ include!("header.rs");
44
fn n_columns() {
55
let a = random_owned(3, 2, true);
66
let (n, v) = normalize(a.clone(), NormalizeAxis::Column);
7-
all_close_l2(&n.dot(&from_diag(&v)), &a, 1e-7).unwrap();
7+
assert_close_l2!(&n.dot(&from_diag(&v)), &a, 1e-7);
88
}
99

1010
#[test]
1111
fn n_rows() {
1212
let a = random_owned(3, 2, true);
1313
let (n, v) = normalize(a.clone(), NormalizeAxis::Row);
14-
all_close_l2(&from_diag(&v).dot(&n), &a, 1e-7).unwrap();
14+
assert_close_l2!(&from_diag(&v).dot(&n), &a, 1e-7);
1515
}

tests/permutate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ fn $testname() {
1010
let p = $permutate; // replace 1-2
1111
let pa = a.permutated(&p);
1212
println!("permutated = \n{:?}", &pa);
13-
all_close_l2(&pa, &$answer, 1e-7).unwrap();
13+
assert_close_l2!(&pa, &$answer, 1e-7);
1414
}
1515
}} // impl_test
1616

tests/qr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ fn $funcname() {
1313
let (q, r) = a.qr().unwrap();
1414
println!("q = \n{:?}", &q);
1515
println!("r = \n{:?}", &r);
16-
all_close_l2(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7).unwrap();
17-
all_close_l2(&q.dot(&r), &ans, 1e-7).unwrap();
18-
all_close_l2(&drop_lower(r.clone()), &r, 1e-7).unwrap();
16+
assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7);
17+
assert_close_l2!(&q.dot(&r), &ans, 1e-7);
18+
assert_close_l2!(&drop_lower(r.clone()), &r, 1e-7);
1919
}
2020
}} // impl_test
2121

tests/ssqrt.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ mod $modname {
99
fn ssqrt() {
1010
let a = random_hermite(3);
1111
let ar = a.$clone().ssqrt().unwrap();
12-
all_close_l2(&ar.clone().t(), &ar, 1e-7).expect("not symmetric");
13-
all_close_l2(&ar.dot(&ar), &a, 1e-7).expect("not sqrt");
12+
assert_close_l2!(&ar.clone().t(), &ar, 1e-7; "not symmetric");
13+
assert_close_l2!(&ar.dot(&ar), &a, 1e-7; "not sqrt");
1414
}
1515
#[test]
1616
fn ssqrt_t() {
1717
let a = random_hermite(3).reversed_axes();
1818
let ar = a.$clone().ssqrt().unwrap();
19-
all_close_l2(&ar.clone().t(), &ar, 1e-7).expect("not symmetric");
20-
all_close_l2(&ar.dot(&ar), &a, 1e-7).expect("not sqrt");
19+
assert_close_l2!(&ar.clone().t(), &ar, 1e-7; "not symmetric");
20+
assert_close_l2!(&ar.dot(&ar), &a, 1e-7; "not sqrt");
2121
}
2222
}
2323
}} // impl_test

tests/svd.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ fn $funcname() {
1818
for i in 0..min($n, $m) {
1919
sm[(i, i)] = s[i];
2020
}
21-
all_close_l2(&u.dot(&sm).dot(&vt), &answer, 1e-7).unwrap();
21+
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
2222
}
2323
}} // impl_test
2424

0 commit comments

Comments
 (0)