Skip to content

Commit b979050

Browse files
committed
Test for QR and refactoring SVD
1 parent d91a9d2 commit b979050

File tree

2 files changed

+37
-186
lines changed

2 files changed

+37
-186
lines changed

tests/qr.rs

Lines changed: 25 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,179 +1,35 @@
11
include!("header.rs");
22

3+
macro_rules! impl_test {
4+
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
35
#[test]
4-
fn qr_square_upper() {
5-
let r_dist = RealNormal::new(0., 1.);
6-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
7-
for ((i, j), val) in a.indexed_iter_mut() {
8-
if i > j {
9-
*val = 0.0;
10-
}
11-
}
6+
fn $funcname() {
7+
use std::cmp::min;
8+
use ndarray::prelude::*;
9+
use ndarray_linalg::prelude::*;
10+
let a = $random($n, $m, $t);
11+
let ans = a.clone();
1212
println!("a = \n{:?}", &a);
13-
let (q, r) = a.clone().qr().unwrap();
13+
let (q, r) = a.qr().unwrap();
1414
println!("q = \n{:?}", &q);
1515
println!("r = \n{:?}", &r);
16-
all_close_l2(&q.clone(), &Array::eye(3), 1e-7).unwrap();
17-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
18-
all_close_l2(&r, &a, 1e-7).unwrap();
16+
all_close_l2(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7).unwrap();
17+
all_close_l2(&q.dot(&r), &ans, 1e-7).unwrap();
18+
all_close_l2(&drop_lower(r.clone()), &r, 1e-7).unwrap();
1919
}
20+
}} // impl_test
2021

21-
#[test]
22-
fn qr_square_upper_t() {
23-
let r_dist = RealNormal::new(0., 1.);
24-
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
25-
for ((i, j), val) in a.indexed_iter_mut() {
26-
if i > j {
27-
*val = 0.0;
28-
}
29-
}
30-
println!("a = \n{:?}", &a);
31-
let (q, r) = a.clone().qr().unwrap();
32-
println!("q = \n{:?}", &q);
33-
println!("r = \n{:?}", &r);
34-
all_close_l2(&q.clone(), &Array::eye(3), 1e-7).unwrap();
35-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
36-
all_close_l2(&r, &a, 1e-7).unwrap();
37-
}
38-
39-
#[test]
40-
fn qr_square() {
41-
let r_dist = RealNormal::new(0., 1.);
42-
let a = Array::<f64, _>::random((3, 3), r_dist);
43-
println!("a = \n{:?}", &a);
44-
let (q, r) = a.clone().qr().unwrap();
45-
println!("q = \n{:?}", &q);
46-
println!("r = \n{:?}", &r);
47-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
48-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
49-
}
50-
51-
#[test]
52-
fn qr_square_t() {
53-
let r_dist = RealNormal::new(0., 1.);
54-
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
55-
println!("a = \n{:?}", &a);
56-
let (q, r) = a.clone().qr().unwrap();
57-
println!("q = \n{:?}", &q);
58-
println!("r = \n{:?}", &r);
59-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
60-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
22+
macro_rules! impl_test_qr {
23+
($modname:ident, $random:path) => {
24+
mod $modname {
25+
impl_test!(qr_square, $random, 3, 3, false);
26+
impl_test!(qr_square_t, $random, 3, 3, true);
27+
impl_test!(qr_3x4, $random, 3, 4, false);
28+
impl_test!(qr_3x4_t, $random, 3, 4, true);
29+
impl_test!(qr_4x3, $random, 4, 3, false);
30+
impl_test!(qr_4x3_t, $random, 4, 3, true);
6131
}
32+
}} // impl_test_qr
6233

63-
#[test]
64-
fn qr_3x4_upper() {
65-
let r_dist = RealNormal::new(0., 1.);
66-
let mut a = Array::<f64, _>::random((3, 4), r_dist);
67-
for ((i, j), val) in a.indexed_iter_mut() {
68-
if i > j {
69-
*val = 0.0;
70-
}
71-
}
72-
println!("a = \n{:?}", &a);
73-
let (q, r) = a.clone().qr().unwrap();
74-
println!("q = \n{:?}", &q);
75-
println!("r = \n{:?}", &r);
76-
all_close_l2(&q.clone(), &Array::eye(3), 1e-7).unwrap();
77-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
78-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
79-
}
80-
81-
#[test]
82-
fn qr_3x4_upper_t() {
83-
let r_dist = RealNormal::new(0., 1.);
84-
let mut a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
85-
for ((i, j), val) in a.indexed_iter_mut() {
86-
if i > j {
87-
*val = 0.0;
88-
}
89-
}
90-
println!("a = \n{:?}", &a);
91-
let (q, r) = a.clone().qr().unwrap();
92-
println!("q = \n{:?}", &q);
93-
println!("r = \n{:?}", &r);
94-
all_close_l2(&q.clone(), &Array::eye(3), 1e-7).unwrap();
95-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
96-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
97-
}
98-
99-
#[test]
100-
fn qr_3x4() {
101-
let r_dist = RealNormal::new(0., 1.);
102-
let a = Array::<f64, _>::random((3, 4), r_dist);
103-
println!("a = \n{:?}", &a);
104-
let (q, r) = a.clone().qr().unwrap();
105-
println!("q = \n{:?}", &q);
106-
println!("r = \n{:?}", &r);
107-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
108-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
109-
}
110-
111-
#[test]
112-
fn qr_3x4_t() {
113-
let r_dist = RealNormal::new(0., 1.);
114-
let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
115-
println!("a = \n{:?}", &a);
116-
let (q, r) = a.clone().qr().unwrap();
117-
println!("q = \n{:?}", &q);
118-
println!("r = \n{:?}", &r);
119-
all_close_l2(&q.dot(&q.t()), &Array::eye(3), 1e-7).unwrap();
120-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
121-
}
122-
123-
#[test]
124-
fn qr_4x3_upper() {
125-
let r_dist = RealNormal::new(0., 1.);
126-
let mut a = Array::<f64, _>::random((4, 3), r_dist);
127-
for ((i, j), val) in a.indexed_iter_mut() {
128-
if i > j {
129-
*val = 0.0;
130-
}
131-
}
132-
println!("a = \n{:?}", &a);
133-
let (q, r) = a.clone().qr().unwrap();
134-
println!("q = \n{:?}", &q);
135-
println!("r = \n{:?}", &r);
136-
all_close_l2(&q.t().dot(&q), &Array::eye(3), 1e-7).unwrap();
137-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
138-
}
139-
140-
#[test]
141-
fn qr_4x3_upper_t() {
142-
let r_dist = RealNormal::new(0., 1.);
143-
let mut a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
144-
for ((i, j), val) in a.indexed_iter_mut() {
145-
if i > j {
146-
*val = 0.0;
147-
}
148-
}
149-
println!("a = \n{:?}", &a);
150-
let (q, r) = a.clone().qr().unwrap();
151-
println!("q = \n{:?}", &q);
152-
println!("r = \n{:?}", &r);
153-
all_close_l2(&q.t().dot(&q), &Array::eye(3), 1e-7).unwrap();
154-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
155-
}
156-
157-
#[test]
158-
fn qr_4x3() {
159-
let r_dist = RealNormal::new(0., 1.);
160-
let a = Array::<f64, _>::random((4, 3), r_dist);
161-
println!("a = \n{:?}", &a);
162-
let (q, r) = a.clone().qr().unwrap();
163-
println!("q = \n{:?}", &q);
164-
println!("r = \n{:?}", &r);
165-
all_close_l2(&q.t().dot(&q), &Array::eye(3), 1e-7).unwrap();
166-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
167-
}
168-
169-
#[test]
170-
fn qr_4x3_t() {
171-
let r_dist = RealNormal::new(0., 1.);
172-
let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
173-
println!("a = \n{:?}", &a);
174-
let (q, r) = a.clone().qr().unwrap();
175-
println!("q = \n{:?}", &q);
176-
println!("r = \n{:?}", &r);
177-
all_close_l2(&q.t().dot(&q), &Array::eye(3), 1e-7).unwrap();
178-
all_close_l2(&q.dot(&r), &a, 1e-7).unwrap();
179-
}
34+
impl_test_qr!(owned, super::random_owned);
35+
impl_test_qr!(shared, super::random_shared);

tests/svd.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,17 @@ fn $funcname() {
2222
}
2323
}} // impl_test
2424

25-
mod owned {
26-
use super::random_owned;
27-
impl_test!(svd_square, random_owned, 3, 3, false);
28-
impl_test!(svd_square_t, random_owned, 3, 3, true);
29-
impl_test!(svd_4x3, random_owned, 4, 3, false);
30-
impl_test!(svd_4x3_t, random_owned, 4, 3, true);
31-
impl_test!(svd_3x4, random_owned, 3, 4, false);
32-
impl_test!(svd_3x4_t, random_owned, 3, 4, true);
25+
macro_rules! impl_test_svd {
26+
($modname:ident, $random:path) => {
27+
mod $modname {
28+
impl_test!(svd_square, $random, 3, 3, false);
29+
impl_test!(svd_square_t, $random, 3, 3, true);
30+
impl_test!(svd_4x3, $random, 4, 3, false);
31+
impl_test!(svd_4x3_t, $random, 4, 3, true);
32+
impl_test!(svd_3x4, $random, 3, 4, false);
33+
impl_test!(svd_3x4_t, $random, 3, 4, true);
3334
}
35+
}} // impl_test_svd
3436

35-
mod shared {
36-
use super::random_shared;
37-
impl_test!(svd_square, random_shared, 3, 3, false);
38-
impl_test!(svd_square_t, random_shared, 3, 3, true);
39-
impl_test!(svd_4x3, random_shared, 4, 3, false);
40-
impl_test!(svd_4x3_t, random_shared, 4, 3, true);
41-
impl_test!(svd_3x4, random_shared, 3, 4, false);
42-
impl_test!(svd_3x4_t, random_shared, 3, 4, true);
43-
}
37+
impl_test_svd!(owned, super::random_owned);
38+
impl_test_svd!(shared, super::random_shared);

0 commit comments

Comments
 (0)