|
1 | 1 |
|
2 |
| -extern crate rand; |
3 | 2 | extern crate ndarray;
|
4 | 3 | extern crate ndarray_rand;
|
5 | 4 | extern crate ndarray_linalg;
|
| 5 | +extern crate ndarray_numtest; |
6 | 6 |
|
7 | 7 | use std::cmp::min;
|
8 | 8 | use ndarray::prelude::*;
|
9 | 9 | use ndarray_linalg::prelude::*;
|
10 |
| -use rand::distributions::*; |
| 10 | +use ndarray_numtest::prelude::*; |
11 | 11 | use ndarray_rand::RandomExt;
|
12 | 12 |
|
13 |
| -fn all_close(a: Array<f64, Ix2>, b: Array<f64, Ix2>) { |
14 |
| - if !a.all_close(&b, 1.0e-7) { |
15 |
| - panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n", |
16 |
| - a, |
17 |
| - b); |
18 |
| - } |
19 |
| -} |
20 |
| - |
21 | 13 | #[test]
|
22 | 14 | fn svd_square() {
|
23 |
| - let r_dist = Range::new(0., 1.); |
| 15 | + let r_dist = RealNormal::new(0., 1.); |
24 | 16 | let a = Array::<f64, _>::random((3, 3), r_dist);
|
25 | 17 | let (u, s, vt) = a.clone().svd().unwrap();
|
26 | 18 | let mut sm = Array::eye(3);
|
27 | 19 | for i in 0..3 {
|
28 | 20 | sm[(i, i)] = s[i];
|
29 | 21 | }
|
30 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 22 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
31 | 23 | }
|
32 | 24 | #[test]
|
33 | 25 | fn svd_square_t() {
|
34 |
| - let r_dist = Range::new(0., 1.); |
| 26 | + let r_dist = RealNormal::new(0., 1.); |
35 | 27 | let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
|
36 | 28 | let (u, s, vt) = a.clone().svd().unwrap();
|
37 | 29 | let mut sm = Array::eye(3);
|
38 | 30 | for i in 0..3 {
|
39 | 31 | sm[(i, i)] = s[i];
|
40 | 32 | }
|
41 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 33 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
42 | 34 | }
|
43 | 35 |
|
44 | 36 | #[test]
|
45 | 37 | fn svd_4x3() {
|
46 |
| - let r_dist = Range::new(0., 1.); |
| 38 | + let r_dist = RealNormal::new(0., 1.); |
47 | 39 | let a = Array::<f64, _>::random((4, 3), r_dist);
|
48 | 40 | let (u, s, vt) = a.clone().svd().unwrap();
|
49 | 41 | let mut sm = Array::zeros((4, 3));
|
50 | 42 | for i in 0..3 {
|
51 | 43 | sm[(i, i)] = s[i];
|
52 | 44 | }
|
53 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 45 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
54 | 46 | }
|
55 | 47 | #[test]
|
56 | 48 | fn svd_4x3_t() {
|
57 |
| - let r_dist = Range::new(0., 1.); |
| 49 | + let r_dist = RealNormal::new(0., 1.); |
58 | 50 | let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
|
59 | 51 | let (u, s, vt) = a.clone().svd().unwrap();
|
60 | 52 | let mut sm = Array::zeros((4, 3));
|
61 | 53 | for i in 0..3 {
|
62 | 54 | sm[(i, i)] = s[i];
|
63 | 55 | }
|
64 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 56 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
65 | 57 | }
|
66 | 58 |
|
67 | 59 | #[test]
|
68 | 60 | fn svd_3x4() {
|
69 |
| - let r_dist = Range::new(0., 1.); |
| 61 | + let r_dist = RealNormal::new(0., 1.); |
70 | 62 | let a = Array::<f64, _>::random((3, 4), r_dist);
|
71 | 63 | let (u, s, vt) = a.clone().svd().unwrap();
|
72 | 64 | let mut sm = Array::zeros((3, 4));
|
73 | 65 | for i in 0..3 {
|
74 | 66 | sm[(i, i)] = s[i];
|
75 | 67 | }
|
76 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 68 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
77 | 69 | }
|
78 | 70 | #[test]
|
79 | 71 | fn svd_3x4_t() {
|
80 |
| - let r_dist = Range::new(0., 1.); |
| 72 | + let r_dist = RealNormal::new(0., 1.); |
81 | 73 | let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
|
82 | 74 | let (u, s, vt) = a.clone().svd().unwrap();
|
83 | 75 | let mut sm = Array::zeros((3, 4));
|
84 | 76 | for i in 0..3 {
|
85 | 77 | sm[(i, i)] = s[i];
|
86 | 78 | }
|
87 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 79 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
88 | 80 | }
|
89 | 81 |
|
90 | 82 | #[test]
|
91 | 83 | #[ignore]
|
92 | 84 | fn svd_large() {
|
93 | 85 | let n = 2480;
|
94 | 86 | let m = 4280;
|
95 |
| - let r_dist = Range::new(0., 1.); |
| 87 | + let r_dist = RealNormal::new(0., 1.); |
96 | 88 | let a = Array::<f64, _>::random((n, m), r_dist);
|
97 | 89 | let (u, s, vt) = a.clone().svd().unwrap();
|
98 | 90 | let mut sm = Array::zeros((n, m));
|
99 | 91 | for i in 0..min(n, m) {
|
100 | 92 | sm[(i, i)] = s[i];
|
101 | 93 | }
|
102 |
| - all_close(u.dot(&sm).dot(&vt), a); |
| 94 | + u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7); |
103 | 95 | }
|
0 commit comments