Skip to content

Commit e0219a1

Browse files
authored
Merge pull request #31 from termoshtt/assert
Assertion for tests
2 parents d6c8612 + 69c94ff commit e0219a1

22 files changed

+146
-104
lines changed

Cargo.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ openblas = ["blas/openblas", "lapack/openblas"]
1515
netlib = ["blas/netlib", "lapack/netlib"]
1616

1717
[dependencies]
18-
num-traits = "0.1"
18+
num-traits = "0.1"
19+
num-complex = "0.1"
1920
ndarray = { version = "0.9", default-features = false, features = ["blas"] }
2021
lapack = { version = "0.11", default-features = false }
2122
blas = { version = "0.15", default-features = false }
2223

2324
[dev-dependencies]
24-
ndarray-rand = "0.5"
25-
ndarray-numtest = "0.2"
25+
ndarray-rand = "0.5"
26+
rand-extra = "0.1"

src/assert.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//! Assertions for array
2+
3+
use std::iter::Sum;
4+
use num_traits::Float;
5+
use ndarray::*;
6+
7+
use super::vector::*;
8+
9+
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>
10+
where A: LinalgScalar + Absolute<Output = Tol>,
11+
Tol: Float
12+
{
13+
let dev = (test - truth).abs() / truth.abs();
14+
if dev < rtol { Ok(dev) } else { Err(dev) }
15+
}
16+
17+
pub fn aclose<A, Tol>(test: A, truth: A, atol: Tol) -> Result<Tol, Tol>
18+
where A: LinalgScalar + Absolute<Output = Tol>,
19+
Tol: Float
20+
{
21+
let dev = (test - truth).abs();
22+
if dev < atol { Ok(dev) } else { Err(dev) }
23+
}
24+
25+
/// check two arrays are close in maximum norm
26+
pub fn close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, atol: Tol) -> Result<Tol, Tol>
27+
where A: LinalgScalar + Absolute<Output = Tol>,
28+
Tol: Float + Sum,
29+
S1: Data<Elem = A>,
30+
S2: Data<Elem = A>,
31+
D: Dimension
32+
{
33+
let tol = (test - truth).norm_max();
34+
if tol < atol { Ok(tol) } else { Err(tol) }
35+
}
36+
37+
/// check two arrays are close in L1 norm
38+
pub fn close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
39+
where A: LinalgScalar + Absolute<Output = Tol>,
40+
Tol: Float + Sum,
41+
S1: Data<Elem = A>,
42+
S2: Data<Elem = A>,
43+
D: Dimension
44+
{
45+
let tol = (test - truth).norm_l1() / truth.norm_l1();
46+
if tol < rtol { Ok(tol) } else { Err(tol) }
47+
}
48+
49+
/// check two arrays are close in L2 norm
50+
pub fn close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
51+
where A: LinalgScalar + Absolute<Output = Tol>,
52+
Tol: Float + Sum,
53+
S1: Data<Elem = A>,
54+
S2: Data<Elem = A>,
55+
D: Dimension
56+
{
57+
let tol = (test - truth).norm_l2() / truth.norm_l2();
58+
if tol < rtol { Ok(tol) } else { Err(tol) }
59+
}
60+
61+
macro_rules! generate_assert {
62+
($assert:ident, $close:path) => {
63+
#[macro_export]
64+
macro_rules! $assert {
65+
($test:expr, $truth:expr, $tol:expr) => {
66+
$close($test, $truth, $tol).unwrap();
67+
};
68+
($test:expr, $truth:expr, $tol:expr; $comment:expr) => {
69+
$close($test, $truth, $tol).expect($comment);
70+
};
71+
}
72+
}} // generate_assert!
73+
74+
generate_assert!(assert_rclose, rclose);
75+
generate_assert!(assert_aclose, aclose);
76+
generate_assert!(assert_close_max, close_max);
77+
generate_assert!(assert_close_l1, close_l1);
78+
generate_assert!(assert_close_l2, close_l2);

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
extern crate blas;
3535
extern crate lapack;
3636
extern crate num_traits;
37+
extern crate num_complex;
3738
#[macro_use(s)]
3839
extern crate ndarray;
3940

@@ -47,4 +48,5 @@ pub mod hermite;
4748
pub mod triangular;
4849

4950
pub mod util;
51+
pub mod assert;
5052
pub mod prelude;

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pub use square::SquareMatrix;
44
pub use hermite::HermiteMatrix;
55
pub use triangular::*;
66
pub use util::*;
7+
pub use assert::*;

src/util.rs

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub enum NormalizeAxis {
5353

5454
/// normalize in L2 norm
5555
pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (ArrayBase<S, Ix2>, Vec<T>)
56-
where A: LinalgScalar + NormedField<Output = T> + Div<T, Output = A>,
56+
where A: LinalgScalar + Absolute<Output = T> + Div<T, Output = A>,
5757
S: DataMut<Elem = A>,
5858
T: Float + Sum
5959
{
@@ -65,42 +65,3 @@ pub fn normalize<A, S, T>(mut m: ArrayBase<S, Ix2>, axis: NormalizeAxis) -> (Arr
6565
}
6666
(m, ms)
6767
}
68-
69-
/// check two arrays are close in maximum norm
70-
pub fn all_close_max<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>,
71-
truth: &ArrayBase<S2, D>,
72-
atol: Tol)
73-
-> Result<Tol, Tol>
74-
where A: LinalgScalar + NormedField<Output = Tol>,
75-
Tol: Float + Sum,
76-
S1: Data<Elem = A>,
77-
S2: Data<Elem = A>,
78-
D: Dimension
79-
{
80-
let tol = (test - truth).norm_max();
81-
if tol < atol { Ok(tol) } else { Err(tol) }
82-
}
83-
84-
/// check two arrays are close in L1 norm
85-
pub fn all_close_l1<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
86-
where A: LinalgScalar + NormedField<Output = Tol>,
87-
Tol: Float + Sum,
88-
S1: Data<Elem = A>,
89-
S2: Data<Elem = A>,
90-
D: Dimension
91-
{
92-
let tol = (test - truth).norm_l1() / truth.norm_l1();
93-
if tol < rtol { Ok(tol) } else { Err(tol) }
94-
}
95-
96-
/// check two arrays are close in L2 norm
97-
pub fn all_close_l2<A, Tol, S1, S2, D>(test: &ArrayBase<S1, D>, truth: &ArrayBase<S2, D>, rtol: Tol) -> Result<Tol, Tol>
98-
where A: LinalgScalar + NormedField<Output = Tol>,
99-
Tol: Float + Sum,
100-
S1: Data<Elem = A>,
101-
S2: Data<Elem = A>,
102-
D: Dimension
103-
{
104-
let tol = (test - truth).norm_l2() / truth.norm_l2();
105-
if tol < rtol { Ok(tol) } else { Err(tol) }
106-
}

src/vector.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,42 @@ pub trait Norm {
2121
}
2222

2323
impl<A, S, D, T> Norm for ArrayBase<S, D>
24-
where A: LinalgScalar + NormedField<Output = T>,
24+
where A: LinalgScalar + Absolute<Output = T>,
2525
T: Float + Sum,
2626
S: Data<Elem = A>,
2727
D: Dimension
2828
{
2929
type Output = T;
3030
fn norm_l1(&self) -> Self::Output {
31-
self.iter().map(|x| x.norm()).sum()
31+
self.iter().map(|x| x.abs()).sum()
3232
}
3333
fn norm_l2(&self) -> Self::Output {
34-
self.iter().map(|x| x.squared()).sum::<T>().sqrt()
34+
self.iter().map(|x| x.sq_abs()).sum::<T>().sqrt()
3535
}
3636
fn norm_max(&self) -> Self::Output {
3737
self.iter().fold(T::zero(), |f, &val| {
38-
let v = val.norm();
38+
let v = val.abs();
3939
if f > v { f } else { v }
4040
})
4141
}
4242
}
4343

4444
/// Field with norm
45-
pub trait NormedField {
45+
pub trait Absolute {
4646
type Output: Float;
47-
fn squared(&self) -> Self::Output;
48-
fn norm(&self) -> Self::Output {
49-
self.squared().sqrt()
47+
fn sq_abs(&self) -> Self::Output;
48+
fn abs(&self) -> Self::Output {
49+
self.sq_abs().sqrt()
5050
}
5151
}
5252

53-
impl<A: Float> NormedField for A {
53+
impl<A: Float> Absolute for A {
5454
type Output = A;
55-
fn squared(&self) -> A {
55+
fn sq_abs(&self) -> A {
5656
*self * *self
5757
}
58-
fn norm(&self) -> A {
59-
self.abs()
58+
fn abs(&self) -> A {
59+
Float::abs(*self)
6060
}
6161
}
6262

tests/cholesky.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ mod $modname {
1212
let c = a.$clone().cholesky().unwrap();
1313
println!("c = \n{:?}", c);
1414
println!("cc = \n{:?}", c.t().dot(&c));
15-
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
15+
assert_close_l2!(&c.t().dot(&c), &a, 1e-7);
1616
}
1717
#[test]
1818
fn cholesky_t() {
@@ -21,7 +21,7 @@ mod $modname {
2121
let c = a.$clone().cholesky().unwrap();
2222
println!("c = \n{:?}", c);
2323
println!("cc = \n{:?}", c.t().dot(&c));
24-
all_close_l2(&c.t().dot(&c), &a, 1e-7).unwrap();
24+
assert_close_l2!(&c.t().dot(&c), &a, 1e-7);
2525
}
2626
}
2727
}} // impl_test

tests/det.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@ macro_rules! impl_test{
55
mod $modname {
66
use super::random_hermite;
77
use ndarray_linalg::prelude::*;
8-
use ndarray_numtest::prelude::*;
98
#[test]
109
fn deth() {
1110
let a = random_hermite(3);
1211
let (e, _) = a.$clone().eigh().unwrap();
1312
let deth = a.$clone().deth().unwrap();
1413
let det_eig = e.iter().fold(1.0, |x, y| x * y);
15-
deth.assert_close(det_eig, 1.0e-7);
14+
assert_rclose!(deth, det_eig, 1.0e-7);
1615
}
1716
}
1817
}} // impl_test

tests/eigh.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@ macro_rules! impl_test {
55
mod $modname {
66
use ndarray::prelude::*;
77
use ndarray_linalg::prelude::*;
8-
use ndarray_numtest::prelude::*;
98
#[test]
109
fn eigen_vector_manual() {
1110
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
1211
let (e, vecs) = a.$clone().eigh().unwrap();
13-
all_close_l2(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7).unwrap();
12+
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
1413
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
1514
let av = a.dot(&v);
1615
let ev = v.mapv(|x| e[i] * x);
17-
all_close_l2(&av, &ev, 1.0e-7).unwrap();
16+
assert_close_l2!(&av, &ev, 1.0e-7);
1817
}
1918
}
2019
#[test]
@@ -23,7 +22,7 @@ mod $modname {
2322
let (e, vecs) = a.$clone().eigh().unwrap();
2423
let s = vecs.t().dot(&a).dot(&vecs);
2524
for i in 0..3 {
26-
e[i].assert_close(s[(i, i)], 1e-7);
25+
assert_rclose!(e[i], s[(i, i)], 1e-7);
2726
}
2827
}
2928
}

tests/header.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11

2+
extern crate rand_extra;
23
extern crate ndarray;
34
extern crate ndarray_rand;
5+
#[macro_use]
6+
#[allow(unused_imports)]
47
extern crate ndarray_linalg;
5-
extern crate ndarray_numtest;
68
extern crate num_traits;
79

810
#[allow(unused_imports)]
911
use ndarray::*;
1012
#[allow(unused_imports)]
1113
use ndarray_linalg::prelude::*;
1214
#[allow(unused_imports)]
13-
use ndarray_numtest::prelude::*;
15+
use rand_extra::*;
1416
#[allow(unused_imports)]
1517
use ndarray_rand::RandomExt;
1618
#[allow(unused_imports)]

0 commit comments

Comments
 (0)