Skip to content

Commit 87e5e28

Browse files
committed
Rewrite test for LU
1 parent 64ca50f commit 87e5e28

File tree

1 file changed

+22
-78
lines changed

1 file changed

+22
-78
lines changed

tests/lu.rs

Lines changed: 22 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,32 @@
11
include!("header.rs");
22

3-
fn test_lu(a: Array<f64, Ix2>) {
4-
println!("a = \n{:?}", &a);
5-
let (p, l, u) = a.clone().lu().unwrap();
3+
macro_rules! impl_test {
4+
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
5+
#[test]
6+
fn $funcname() {
7+
use ndarray_linalg::prelude::*;
8+
let a = $random($n, $m, $t);
9+
let ans = a.clone();
10+
let (p, l, u) = a.lu().unwrap();
611
println!("P = \n{:?}", &p);
712
println!("L = \n{:?}", &l);
813
println!("U = \n{:?}", &u);
914
println!("LU = \n{:?}", l.dot(&u));
10-
all_close_l2(&l.dot(&u).permutated(&p), &a, 1e-7).unwrap();
15+
all_close_l2(&l.dot(&u).permutated(&p), &ans, 1e-7).unwrap();
1116
}
17+
}} // impl_test
1218

13-
macro_rules! test_lu_upper {
14-
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
15-
#[test]
16-
fn $testname() {
17-
let r_dist = RealNormal::new(0., 1.);
18-
let mut a = Array::<f64, _>::random(($n, $m), r_dist);
19-
for ((i, j), val) in a.indexed_iter_mut() {
20-
if i > j {
21-
*val = 0.0;
22-
}
23-
}
24-
test_lu(a);
25-
}
26-
#[test]
27-
fn $testname_t() {
28-
let r_dist = RealNormal::new(0., 1.);
29-
let mut a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
30-
for ((i, j), val) in a.indexed_iter_mut() {
31-
if i > j {
32-
*val = 0.0;
33-
}
34-
}
35-
test_lu(a);
19+
macro_rules! impl_test_lu {
20+
($modname:ident, $random:path) => {
21+
mod $modname {
22+
impl_test!(lu_square, $random, 3, 3, false);
23+
impl_test!(lu_square_t, $random, 3, 3, true);
24+
impl_test!(lu_3x4, $random, 3, 4, false);
25+
impl_test!(lu_3x4_t, $random, 3, 4, true);
26+
impl_test!(lu_4x3, $random, 4, 3, false);
27+
impl_test!(lu_4x3_t, $random, 4, 3, true);
3628
}
37-
}} // end test_lu_upper
38-
test_lu_upper!(lu_square_upper, lu_square_upper_t, 3, 3);
39-
test_lu_upper!(lu_3x4_upper, lu_3x4_upper_t, 3, 4);
40-
test_lu_upper!(lu_4x3_upper, lu_4x3_upper_t, 4, 3);
29+
}} // impl_test_lu
4130

42-
macro_rules! test_lu_lower {
43-
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
44-
#[test]
45-
fn $testname() {
46-
let r_dist = RealNormal::new(0., 1.);
47-
let mut a = Array::<f64, _>::random(($n, $m), r_dist);
48-
for ((i, j), val) in a.indexed_iter_mut() {
49-
if i < j {
50-
*val = 0.0;
51-
}
52-
}
53-
test_lu(a);
54-
}
55-
#[test]
56-
fn $testname_t() {
57-
let r_dist = RealNormal::new(0., 1.);
58-
let mut a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
59-
for ((i, j), val) in a.indexed_iter_mut() {
60-
if i < j {
61-
*val = 0.0;
62-
}
63-
}
64-
test_lu(a);
65-
}
66-
}} // end test_lu_lower
67-
test_lu_lower!(lu_square_lower, lu_square_lower_t, 3, 3);
68-
test_lu_lower!(lu_3x4_lower, lu_3x4_lower_t, 3, 4);
69-
test_lu_lower!(lu_4x3_lower, lu_4x3_lower_t, 4, 3);
70-
71-
macro_rules! test_lu {
72-
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
73-
#[test]
74-
fn $testname() {
75-
let r_dist = RealNormal::new(0., 1.);
76-
let a = Array::<f64, _>::random(($n, $m), r_dist);
77-
test_lu(a);
78-
}
79-
#[test]
80-
fn $testname_t() {
81-
let r_dist = RealNormal::new(0., 1.);
82-
let a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
83-
test_lu(a);
84-
}
85-
}} // end test_lu
86-
test_lu!(lu_square, lu_square_t, 3, 3);
87-
test_lu!(lu_3x4, lu_3x4_t, 3, 4);
88-
test_lu!(lu_4x3, lu_4x3_t, 4, 3);
31+
impl_test_lu!(owned, super::random_owned);
32+
impl_test_lu!(shared, super::random_shared);

0 commit comments

Comments
 (0)