|
1 | 1 | include!("header.rs");
|
2 | 2 |
|
3 |
| -use std::cmp::min; |
4 |
| - |
5 |
| -#[test] |
6 |
| -fn svd_square() { |
7 |
| - let r_dist = RealNormal::new(0., 1.); |
8 |
| - let a = Array::<f64, _>::random((3, 3), r_dist); |
9 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
10 |
| - let mut sm = Array::eye(3); |
11 |
| - for i in 0..3 { |
12 |
| - sm[(i, i)] = s[i]; |
13 |
| - } |
14 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
15 |
| -} |
16 |
| -#[test] |
17 |
| -fn svd_square_t() { |
18 |
| - let r_dist = RealNormal::new(0., 1.); |
19 |
| - let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes(); |
20 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
21 |
| - let mut sm = Array::eye(3); |
22 |
| - for i in 0..3 { |
23 |
| - sm[(i, i)] = s[i]; |
24 |
| - } |
25 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
26 |
| -} |
27 |
| - |
28 |
| -#[test] |
29 |
| -fn svd_4x3() { |
30 |
| - let r_dist = RealNormal::new(0., 1.); |
31 |
| - let a = Array::<f64, _>::random((4, 3), r_dist); |
32 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
33 |
| - let mut sm = Array::zeros((4, 3)); |
34 |
| - for i in 0..3 { |
35 |
| - sm[(i, i)] = s[i]; |
36 |
| - } |
37 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
38 |
| -} |
| 3 | +macro_rules! impl_test { |
| 4 | + ($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => { |
39 | 5 | #[test]
|
40 |
| -fn svd_4x3_t() { |
41 |
| - let r_dist = RealNormal::new(0., 1.); |
42 |
| - let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes(); |
43 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
44 |
| - let mut sm = Array::zeros((4, 3)); |
45 |
| - for i in 0..3 { |
| 6 | +fn $funcname() { |
| 7 | + use std::cmp::min; |
| 8 | + use ndarray::prelude::*; |
| 9 | + use ndarray_linalg::prelude::*; |
| 10 | + let a = $random($n, $m, $t); |
| 11 | + let answer = a.clone(); |
| 12 | + println!("a = \n{}", &a); |
| 13 | + let (u, s, vt) = a.svd().unwrap(); |
| 14 | + println!("u = \n{}", &u); |
| 15 | + println!("s = \n{}", &s); |
| 16 | + println!("v = \n{}", &vt); |
| 17 | + let mut sm = Array::zeros(($n, $m)); |
| 18 | + for i in 0..min($n, $m) { |
46 | 19 | sm[(i, i)] = s[i];
|
47 | 20 | }
|
48 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
| 21 | + all_close_l2(&u.dot(&sm).dot(&vt), &answer, 1e-7).unwrap(); |
49 | 22 | }
|
| 23 | +}} // impl_test |
50 | 24 |
|
51 |
| -#[test] |
52 |
| -fn svd_3x4() { |
53 |
| - let r_dist = RealNormal::new(0., 1.); |
54 |
| - let a = Array::<f64, _>::random((3, 4), r_dist); |
55 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
56 |
| - let mut sm = Array::zeros((3, 4)); |
57 |
| - for i in 0..3 { |
58 |
| - sm[(i, i)] = s[i]; |
59 |
| - } |
60 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
61 |
| -} |
62 |
| -#[test] |
63 |
| -fn svd_3x4_t() { |
64 |
| - let r_dist = RealNormal::new(0., 1.); |
65 |
| - let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes(); |
66 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
67 |
| - let mut sm = Array::zeros((3, 4)); |
68 |
| - for i in 0..3 { |
69 |
| - sm[(i, i)] = s[i]; |
70 |
| - } |
71 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
| 25 | +mod owned { |
| 26 | + use super::random_owned; |
| 27 | + impl_test!(svd_square, random_owned, 3, 3, false); |
| 28 | + impl_test!(svd_square_t, random_owned, 3, 3, true); |
| 29 | + impl_test!(svd_4x3, random_owned, 4, 3, false); |
| 30 | + impl_test!(svd_4x3_t, random_owned, 4, 3, true); |
| 31 | + impl_test!(svd_3x4, random_owned, 3, 4, false); |
| 32 | + impl_test!(svd_3x4_t, random_owned, 3, 4, true); |
72 | 33 | }
|
73 | 34 |
|
74 |
| -#[test] |
75 |
| -#[ignore] |
76 |
| -fn svd_large() { |
77 |
| - let n = 2480; |
78 |
| - let m = 4280; |
79 |
| - let r_dist = RealNormal::new(0., 1.); |
80 |
| - let a = Array::<f64, _>::random((n, m), r_dist); |
81 |
| - let (u, s, vt) = a.clone().svd().unwrap(); |
82 |
| - let mut sm = Array::zeros((n, m)); |
83 |
| - for i in 0..min(n, m) { |
84 |
| - sm[(i, i)] = s[i]; |
85 |
| - } |
86 |
| - all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap(); |
| 35 | +mod shared { |
| 36 | + use super::random_shared; |
| 37 | + impl_test!(svd_square, random_shared, 3, 3, false); |
| 38 | + impl_test!(svd_square_t, random_shared, 3, 3, true); |
| 39 | + impl_test!(svd_4x3, random_shared, 4, 3, false); |
| 40 | + impl_test!(svd_4x3_t, random_shared, 4, 3, true); |
| 41 | + impl_test!(svd_3x4, random_shared, 3, 4, false); |
| 42 | + impl_test!(svd_3x4_t, random_shared, 3, 4, true); |
87 | 43 | }
|
0 commit comments