Skip to content

Commit 3c508c4

Browse files
committed
Compress duplicated codes for testing LU
1 parent cd66a3d commit 3c508c4

File tree

1 file changed

+46
-160
lines changed

1 file changed

+46
-160
lines changed

tests/lu.rs

Lines changed: 46 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ fn $testname() {
2828
println!("permutated = \n{:?}", &pa);
2929
all_close(pa, arr2($answer))
3030
}
31-
}
32-
}
31+
}} // end test_permutate
3332

3433
macro_rules! test_permutate_t {
3534
($testname:ident, $permutate:expr, $input:expr, $answer:expr) => {
@@ -42,8 +41,7 @@ fn $testname() {
4241
println!("permutated = \n{:?}", &pa);
4342
all_close(pa, arr2($answer))
4443
}
45-
}
46-
}
44+
}} // end test_permutate_t
4745

4846
test_permutate!(permutate,
4947
vec![2, 2, 3],
@@ -70,201 +68,89 @@ test_permutate_t!(permutate_4x3_t,
7068
&[[1., 4., 7., 10.], [2., 5., 8., 11.], [3., 6., 9., 12.]],
7169
&[[10., 11., 12.], [4., 5., 6.], [7., 8., 9.], [1., 2., 3.]]);
7270

73-
#[test]
74-
fn lu_square_upper() {
75-
let r_dist = Range::new(0., 1.);
76-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
77-
for ((i, j), val) in a.indexed_iter_mut() {
78-
if i > j {
79-
*val = 0.0;
80-
}
81-
}
71+
fn test_lu(a: Array<f64, (Ix, Ix)>) {
8272
println!("a = \n{:?}", &a);
8373
let (p, l, u) = a.clone().lu().unwrap();
8474
println!("P = \n{:?}", &p);
8575
println!("L = \n{:?}", &l);
8676
println!("U = \n{:?}", &u);
77+
println!("LU = \n{:?}", l.dot(&u));
8778
all_close(l.dot(&u).permutated(&p), a);
8879
}
8980

81+
macro_rules! test_lu_upper {
82+
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
9083
#[test]
91-
fn lu_square_upper_t() {
84+
fn $testname() {
9285
let r_dist = Range::new(0., 1.);
93-
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
86+
let mut a = Array::<f64, _>::random(($n, $m), r_dist);
9487
for ((i, j), val) in a.indexed_iter_mut() {
9588
if i > j {
9689
*val = 0.0;
9790
}
9891
}
99-
println!("a = \n{:?}", &a);
100-
let (p, l, u) = a.clone().lu().unwrap();
101-
println!("P = \n{:?}", &p);
102-
println!("L = \n{:?}", &l);
103-
println!("U = \n{:?}", &u);
104-
all_close(l.dot(&u).permutated(&p), a);
92+
test_lu(a);
10593
}
106-
10794
#[test]
108-
fn lu_square_lower() {
95+
fn $testname_t() {
10996
let r_dist = Range::new(0., 1.);
110-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
97+
let mut a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
11198
for ((i, j), val) in a.indexed_iter_mut() {
112-
if i < j {
99+
if i > j {
113100
*val = 0.0;
114101
}
115102
}
116-
println!("a = \n{:?}", &a);
117-
let (p, l, u) = a.clone().lu().unwrap();
118-
println!("P = \n{:?}", &p);
119-
println!("L = \n{:?}", &l);
120-
println!("U = \n{:?}", &u);
121-
println!("LU = \n{:?}", l.dot(&u));
122-
all_close(l.dot(&u).permutated(&p), a);
103+
test_lu(a);
123104
}
105+
}} // end test_lu_upper
106+
test_lu_upper!(lu_square_upper, lu_square_upper_t, 3, 3);
107+
test_lu_upper!(lu_3x4_upper, lu_3x4_upper_t, 3, 4);
108+
test_lu_upper!(lu_4x3_upper, lu_4x3_upper_t, 4, 3);
124109

110+
macro_rules! test_lu_lower {
111+
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
125112
#[test]
126-
fn lu_square_lower_t() {
113+
fn $testname() {
127114
let r_dist = Range::new(0., 1.);
128-
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
115+
let mut a = Array::<f64, _>::random(($n, $m), r_dist);
129116
for ((i, j), val) in a.indexed_iter_mut() {
130117
if i < j {
131118
*val = 0.0;
132119
}
133120
}
134-
println!("a = \n{:?}", &a);
135-
let (p, l, u) = a.clone().lu().unwrap();
136-
println!("P = \n{:?}", &p);
137-
println!("L = \n{:?}", &l);
138-
println!("U = \n{:?}", &u);
139-
println!("LU = \n{:?}", l.dot(&u));
140-
all_close(l.dot(&u).permutated(&p), a);
121+
test_lu(a);
141122
}
142-
143123
#[test]
144-
fn lu_square() {
124+
fn $testname_t() {
145125
let r_dist = Range::new(0., 1.);
146-
let a = Array::<f64, _>::random((3, 3), r_dist);
147-
println!("a = \n{:?}", &a);
148-
let (p, l, u) = a.clone().lu().unwrap();
149-
println!("P = \n{:?}", &p);
150-
println!("L = \n{:?}", &l);
151-
println!("U = \n{:?}", &u);
152-
println!("LU = \n{:?}", l.dot(&u));
153-
all_close(l.dot(&u).permutated(&p), a);
126+
let mut a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
127+
for ((i, j), val) in a.indexed_iter_mut() {
128+
if i < j {
129+
*val = 0.0;
130+
}
131+
}
132+
test_lu(a);
154133
}
134+
}} // end test_lu_lower
135+
test_lu_lower!(lu_square_lower, lu_square_lower_t, 3, 3);
136+
test_lu_lower!(lu_3x4_lower, lu_3x4_lower_t, 3, 4);
137+
test_lu_lower!(lu_4x3_lower, lu_4x3_lower_t, 4, 3);
155138

139+
macro_rules! test_lu {
140+
($testname:ident, $testname_t:ident, $n:expr, $m:expr) => {
156141
#[test]
157-
fn lu_square_t() {
142+
fn $testname() {
158143
let r_dist = Range::new(0., 1.);
159-
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
160-
println!("a = \n{:?}", &a);
161-
let (p, l, u) = a.clone().lu().unwrap();
162-
println!("P = \n{:?}", &p);
163-
println!("L = \n{:?}", &l);
164-
println!("U = \n{:?}", &u);
165-
all_close(l.dot(&u).permutated(&p), a);
144+
let a = Array::<f64, _>::random(($n, $m), r_dist);
145+
test_lu(a);
166146
}
167-
168-
// #[test]
169-
// fn lu_3x4() {
170-
// let r_dist = Range::new(0., 1.);
171-
// let a = Array::<f64, _>::random((3, 4), r_dist);
172-
// println!("a = \n{:?}", &a);
173-
// let (p, l, u) = a.clone().lu().unwrap();
174-
// println!("P = \n{:?}", &p);
175-
// println!("L = \n{:?}", &l);
176-
// println!("U = \n{:?}", &u);
177-
// println!("LU = \n{:?}", l.dot(&u));
178-
// all_close(l.dot(&u).permutated(&p), a);
179-
// }
180-
//
181-
// #[test]
182-
// fn lu_3x4_t() {
183-
// let r_dist = Range::new(0., 1.);
184-
// let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
185-
// println!("a = \n{:?}", &a);
186-
// let (p, l, u) = a.clone().lu().unwrap();
187-
// println!("P = \n{:?}", &p);
188-
// println!("L = \n{:?}", &l);
189-
// println!("U = \n{:?}", &u);
190-
// all_close(l.dot(&u).permutated(&p), a);
191-
// }
192-
193-
// #[test]
194-
// fn lu_4x3_upper() {
195-
// let r_dist = Range::new(0., 1.);
196-
// let mut a = Array::<f64, _>::random((4, 3), r_dist);
197-
// for ((i, j), val) in a.indexed_iter_mut() {
198-
// if i > j {
199-
// *val = 0.0;
200-
// }
201-
// }
202-
// println!("a = \n{:?}", &a);
203-
// let (p, l, u) = a.clone().lu().unwrap();
204-
// println!("P = \n{:?}", &p);
205-
// println!("L = \n{:?}", &l);
206-
// println!("U = \n{:?}", &u);
207-
// println!("LU = \n{:?}", l.dot(&u));
208-
// all_close(l.dot(&u).permutated(&p), a);
209-
// }
210-
//
211-
// #[test]
212-
// fn lu_4x3_lower() {
213-
// let r_dist = Range::new(0., 1.);
214-
// let mut a = Array::<f64, _>::random((4, 3), r_dist);
215-
// for ((i, j), val) in a.indexed_iter_mut() {
216-
// if i < j {
217-
// *val = 0.0;
218-
// }
219-
// }
220-
// println!("a = \n{:?}", &a);
221-
// let (p, l, u) = a.clone().lu().unwrap();
222-
// println!("P = \n{:?}", &p);
223-
// println!("L = \n{:?}", &l);
224-
// println!("U = \n{:?}", &u);
225-
// println!("LU = \n{:?}", l.dot(&u));
226-
// all_close(l.dot(&u).permutated(&p), a);
227-
// }
228-
229147
#[test]
230-
fn lu_4x3_upper_t() {
148+
fn $testname_t() {
231149
let r_dist = Range::new(0., 1.);
232-
let mut a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
233-
for ((i, j), val) in a.indexed_iter_mut() {
234-
if i > j {
235-
*val = 0.0;
236-
}
237-
}
238-
println!("a = \n{:?}", &a);
239-
let (p, l, u) = a.clone().lu().unwrap();
240-
println!("P = \n{:?}", &p);
241-
println!("L = \n{:?}", &l);
242-
println!("U = \n{:?}", &u);
243-
println!("LU = \n{:?}", l.dot(&u));
244-
all_close(l.dot(&u).permutated(&p), a);
150+
let a = Array::<f64, _>::random(($m, $n), r_dist).reversed_axes();
151+
test_lu(a);
245152
}
246-
247-
// #[test]
248-
// fn lu_4x3() {
249-
// let r_dist = Range::new(0., 1.);
250-
// let a = Array::<f64, _>::random((4, 3), r_dist);
251-
// println!("a = \n{:?}", &a);
252-
// let (p, l, u) = a.clone().lu().unwrap();
253-
// println!("P = \n{:?}", &p);
254-
// println!("L = \n{:?}", &l);
255-
// println!("U = \n{:?}", &u);
256-
// println!("LU = \n{:?}", l.dot(&u));
257-
// all_close(l.dot(&u).permutated(&p), a);
258-
// }
259-
//
260-
// #[test]
261-
// fn lu_4x3_t() {
262-
// let r_dist = Range::new(0., 1.);
263-
// let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
264-
// println!("a = \n{:?}", &a);
265-
// let (p, l, u) = a.clone().lu().unwrap();
266-
// println!("P = \n{:?}", &p);
267-
// println!("L = \n{:?}", &l);
268-
// println!("U = \n{:?}", &u);
269-
// all_close(l.dot(&u).permutated(&p), a);
270-
// }
153+
}} // end test_lu
154+
test_lu!(lu_square, lu_square_t, 3, 3);
155+
test_lu!(lu_3x4, lu_3x4_t, 3, 4);
156+
test_lu!(lu_4x3, lu_4x3_t, 4, 3);

0 commit comments

Comments
 (0)