Skip to content

Commit 91266e4

Browse files
committed
Rewrite all tests
1 parent 4ef22a7 commit 91266e4

File tree

8 files changed

+208
-116
lines changed

8 files changed

+208
-116
lines changed

src/generate.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
1919
a
2020
}
2121

22+
pub fn random<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D>
23+
where A: RandNormal,
24+
S: DataOwned<Elem = A>,
25+
D: Dimension,
26+
Sh: ShapeBuilder<Dim = D>
27+
{
28+
let mut rng = thread_rng();
29+
ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng))
30+
}
31+
2232
/// Random vector
2333
pub fn random_vector<A, S>(n: usize) -> ArrayBase<S, Ix1>
2434
where A: RandNormal,

tests/inv.rs

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,33 @@
11

2-
macro_rules! impl_test{
3-
($modname:ident, $clone:ident) => {
4-
mod $modname {
5-
use super::random_square;
6-
use ndarray::*;
7-
use ndarray_linalg::*;
8-
#[test]
9-
fn inv_random() {
10-
let a = random_square(3);
11-
let ai = a.$clone().inv().unwrap();
12-
let id = Array::eye(3);
13-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
14-
}
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
156

16-
#[test]
17-
fn inv_random_t() {
18-
let a = random_square(3).reversed_axes();
19-
let ai = a.$clone().inv().unwrap();
20-
let id = Array::eye(3);
21-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
22-
}
7+
use ndarray::*;
8+
use ndarray_linalg::*;
239

24-
#[test]
25-
#[should_panic]
26-
fn inv_error() {
27-
// do not have inverse
28-
let a = Array::<f64, _>::zeros(9).into_shape((3, 3)).unwrap();
29-
let a_inv = a.$clone().inv().unwrap();
30-
println!("{:?}", a_inv);
31-
}
10+
#[test]
11+
fn inv_random() {
12+
let a: Array2<f64> = random((3, 3));
13+
let ai: Array2<_> = (&a).inv().unwrap();
14+
let id = Array::eye(3);
15+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
3216
}
33-
}} // impl_test
3417

35-
impl_test!(owned, clone);
36-
impl_test!(shared, to_shared);
18+
#[test]
19+
fn inv_random_t() {
20+
let a: Array2<f64> = random((3, 3).f());
21+
let ai: Array2<_> = (&a).inv().unwrap();
22+
let id = Array::eye(3);
23+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
24+
}
25+
26+
#[test]
27+
#[should_panic]
28+
fn inv_error() {
29+
// do not have inverse
30+
let a = Array::<f64, _>::zeros(9).into_shape((3, 3)).unwrap();
31+
let a_inv = a.inv().unwrap();
32+
println!("{:?}", a_inv);
33+
}

tests/normalize.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11

2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
210
#[test]
311
fn n_columns() {
4-
let a = random_owned(3, 2, true);
12+
let a: Array2<f64> = random((3, 2));
513
let (n, v) = normalize(a.clone(), NormalizeAxis::Column);
614
assert_close_l2!(&n.dot(&from_diag(&v)), &a, 1e-7);
715
}
816

917
#[test]
1018
fn n_rows() {
11-
let a = random_owned(3, 2, true);
19+
let a: Array2<f64> = random((3, 2));
1220
let (n, v) = normalize(a.clone(), NormalizeAxis::Row);
1321
assert_close_l2!(&from_diag(&v).dot(&n), &a, 1e-7);
1422
}

tests/opnorm.rs

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,57 @@
11

2-
macro_rules! impl_test {
3-
($funcname:ident, $a:expr, $op1:expr, $opi:expr, $opf:expr) => {
4-
#[test]
5-
fn $funcname() {
6-
let a = $a;
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
use num_traits::Float;
10+
11+
fn test(a: Array2<f64>, one: f64, inf: f64, fro: f64) {
712
println!("ONE = {:?}", a.opnorm_one());
813
println!("INF = {:?}", a.opnorm_inf());
914
println!("FRO = {:?}", a.opnorm_fro());
10-
assert_rclose!(a.opnorm_fro().unwrap(), $opf, 1e-7; "Frobenius norm");
11-
assert_rclose!(a.opnorm_one().unwrap(), $op1, 1e-7; "One norm");
12-
assert_rclose!(a.opnorm_inf().unwrap(), $opi, 1e-7; "Infinity norm");
13-
}
14-
}} // impl_test
15-
16-
macro_rules! impl_test_opnorm {
17-
($modname:ident, $array:ty, $range:path) => {
18-
mod $modname {
19-
use ndarray::*;
20-
use ndarray_linalg::*;
21-
use num_traits::Float;
22-
fn gen(i: usize, j: usize, rev: bool) -> $array {
23-
let n = (i * j + 1) as f64;
24-
if rev {
25-
$range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes()
26-
} else {
27-
$range(1., n, 1.).into_shape((i, j)).unwrap()
28-
}
15+
assert_rclose!(a.opnorm_one().unwrap(), one, 1e-7; "One norm");
16+
assert_rclose!(a.opnorm_inf().unwrap(), inf, 1e-7; "Infinity norm");
17+
assert_rclose!(a.opnorm_fro().unwrap(), fro, 1e-7; "Frobenius norm");
18+
}
19+
20+
fn gen(i: usize, j: usize, rev: bool) -> Array2<f64> {
21+
let n = (i * j + 1) as f64;
22+
if rev {
23+
Array::range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes()
24+
} else {
25+
Array::range(1., n, 1.).into_shape((i, j)).unwrap()
2926
}
30-
impl_test!(opnorm_square, gen(3, 3, false), 18.0, 24.0, 285.0.sqrt());
31-
impl_test!(opnorm_square_t, gen(3, 3, true), 24.0, 18.0, 285.0.sqrt());
32-
impl_test!(opnorm_3x4, gen(3, 4, false), 24.0, 42.0, 650.0.sqrt());
33-
impl_test!(opnorm_4x3_t, gen(4, 3, true), 42.0, 24.0, 650.0.sqrt());
34-
impl_test!(opnorm_3x4_t, gen(3, 4, true), 33.0, 30.0, 650.0.sqrt());
35-
impl_test!(opnorm_4x3, gen(4, 3, false), 30.0, 33.0, 650.0.sqrt());
3627
}
37-
}} // impl_test_opnorm
3828

39-
impl_test_opnorm!(owned, Array<f64, Ix2>, Array::range);
40-
impl_test_opnorm!(shared, RcArray<f64, Ix2>, RcArray::range);
29+
#[test]
30+
fn opnorm_square() {
31+
test(gen(3, 3, false), 18.0, 24.0, 285.0.sqrt());
32+
}
33+
34+
#[test]
35+
fn opnorm_square_t() {
36+
test(gen(3, 3, true), 24.0, 18.0, 285.0.sqrt());
37+
}
38+
39+
#[test]
40+
fn opnorm_3x4() {
41+
test(gen(3, 4, false), 24.0, 42.0, 650.0.sqrt());
42+
}
43+
44+
#[test]
45+
fn opnorm_3x4_t() {
46+
test(gen(3, 4, true), 33.0, 30.0, 650.0.sqrt());
47+
}
48+
49+
#[test]
50+
fn opnorm_4x3() {
51+
test(gen(4, 3, false), 30.0, 33.0, 650.0.sqrt());
52+
}
53+
54+
#[test]
55+
fn opnorm_4x3_t() {
56+
test(gen(4, 3, true), 42.0, 24.0, 650.0.sqrt());
57+
}

tests/qr.rs

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,57 @@
11

2-
macro_rules! impl_test {
3-
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
4-
#[test]
5-
fn $funcname() {
6-
use std::cmp::min;
7-
use ndarray::*;
8-
use ndarray_linalg::*;
9-
let a = $random($n, $m, $t);
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use std::cmp::min;
8+
use ndarray::*;
9+
use ndarray_linalg::*;
10+
use num_traits::Float;
11+
12+
fn test(a: Array2<f64>, n: usize, m: usize) {
1013
let ans = a.clone();
1114
println!("a = \n{:?}", &a);
12-
let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap();
15+
let (q, r): (Array2<_>, Array2<_>) = a.qr().unwrap();
1316
println!("q = \n{:?}", &q);
1417
println!("r = \n{:?}", &r);
15-
assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7);
18+
assert_close_l2!(&q.t().dot(&q), &Array::eye(min(n, m)), 1e-7);
1619
assert_close_l2!(&q.dot(&r), &ans, 1e-7);
1720
assert_close_l2!(&drop_lower(r.clone()), &r, 1e-7);
1821
}
19-
}} // impl_test
20-
21-
macro_rules! impl_test_qr {
22-
($modname:ident, $random:path) => {
23-
mod $modname {
24-
impl_test!(qr_square, $random, 3, 3, false);
25-
impl_test!(qr_square_t, $random, 3, 3, true);
26-
impl_test!(qr_3x4, $random, 3, 4, false);
27-
impl_test!(qr_3x4_t, $random, 3, 4, true);
28-
impl_test!(qr_4x3, $random, 4, 3, false);
29-
impl_test!(qr_4x3_t, $random, 4, 3, true);
22+
23+
#[test]
24+
fn qr_square() {
25+
let a = random((3, 3));
26+
test(a, 3, 3);
27+
}
28+
29+
#[test]
30+
fn qr_square_t() {
31+
let a = random((3, 3).f());
32+
test(a, 3, 3);
3033
}
31-
}} // impl_test_qr
3234

33-
impl_test_qr!(owned, super::random_owned);
34-
impl_test_qr!(shared, super::random_shared);
35+
#[test]
36+
fn qr_3x4() {
37+
let a = random((3, 4));
38+
test(a, 3, 4);
39+
}
40+
41+
#[test]
42+
fn qr_3x4_t() {
43+
let a = random((3, 4).f());
44+
test(a, 3, 4);
45+
}
46+
47+
#[test]
48+
fn qr_4x3() {
49+
let a = random((4, 3));
50+
test(a, 4, 3);
51+
}
52+
53+
#[test]
54+
fn qr_4x3_t() {
55+
let a = random((4, 3).f());
56+
test(a, 4, 3);
57+
}

tests/svd.rs

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11

2-
macro_rules! impl_test {
3-
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
4-
#[test]
5-
fn $funcname() {
6-
use std::cmp::min;
7-
use ndarray::*;
8-
use ndarray_linalg::*;
9-
let a = $random($n, $m, $t);
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use std::cmp::min;
8+
use ndarray::*;
9+
use ndarray_linalg::*;
10+
use num_traits::Float;
11+
12+
fn test(a: Array2<f64>, n: usize, m: usize) {
1013
let answer = a.clone();
1114
println!("a = \n{:?}", &a);
1215
let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
@@ -15,25 +18,45 @@ fn $funcname() {
1518
println!("u = \n{:?}", &u);
1619
println!("s = \n{:?}", &s);
1720
println!("v = \n{:?}", &vt);
18-
let mut sm = Array::zeros(($n, $m));
19-
for i in 0..min($n, $m) {
21+
let mut sm = Array::zeros((n, m));
22+
for i in 0..min(n, m) {
2023
sm[(i, i)] = s[i];
2124
}
2225
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
2326
}
24-
}} // impl_test
25-
26-
macro_rules! impl_test_svd {
27-
($modname:ident, $random:path) => {
28-
mod $modname {
29-
impl_test!(svd_square, $random, 3, 3, false);
30-
impl_test!(svd_square_t, $random, 3, 3, true);
31-
impl_test!(svd_4x3, $random, 4, 3, false);
32-
impl_test!(svd_4x3_t, $random, 4, 3, true);
33-
impl_test!(svd_3x4, $random, 3, 4, false);
34-
impl_test!(svd_3x4_t, $random, 3, 4, true);
27+
28+
#[test]
29+
fn svd_square() {
30+
let a = random((3, 3));
31+
test(a, 3, 3);
32+
}
33+
34+
#[test]
35+
fn svd_square_t() {
36+
let a = random((3, 3).f());
37+
test(a, 3, 3);
3538
}
36-
}} // impl_test_svd
3739

38-
impl_test_svd!(owned, super::random_owned);
39-
impl_test_svd!(shared, super::random_shared);
40+
#[test]
41+
fn svd_3x4() {
42+
let a = random((3, 4));
43+
test(a, 3, 4);
44+
}
45+
46+
#[test]
47+
fn svd_3x4_t() {
48+
let a = random((3, 4).f());
49+
test(a, 3, 4);
50+
}
51+
52+
#[test]
53+
fn svd_4x3() {
54+
let a = random((4, 3));
55+
test(a, 4, 3);
56+
}
57+
58+
#[test]
59+
fn svd_4x3_t() {
60+
let a = random((4, 3).f());
61+
test(a, 4, 3);
62+
}

tests/trace.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11

2+
extern crate ndarray;
3+
extern crate ndarray_linalg;
4+
5+
use ndarray::*;
6+
use ndarray_linalg::*;
7+
28
#[test]
39
fn trace() {
4-
let r_dist = RealNormal::new(0., 1.);
5-
let a = Array::<f64, _>::random((3, 3), r_dist);
10+
let a: Array2<f64> = random((3, 3));
611
assert_rclose!(a.trace().unwrap(), a[(0, 0)] + a[(1, 1)] + a[(2, 2)], 1e-7);
712
}

tests/vector.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11

2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
use num_traits::Float;
10+
211
#[test]
312
fn vector_norm() {
413
let a = Array::range(1., 10., 1.);

0 commit comments

Comments
 (0)