Skip to content

Commit 8734052

Browse files
committed
Rewrite tests/svd.rs
1 parent deaaa42 commit 8734052

File tree

1 file changed

+16
-24
lines changed

1 file changed

+16
-24
lines changed

tests/svd.rs

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,95 @@
11

2-
extern crate rand;
32
extern crate ndarray;
43
extern crate ndarray_rand;
54
extern crate ndarray_linalg;
5+
extern crate ndarray_numtest;
66

77
use std::cmp::min;
88
use ndarray::prelude::*;
99
use ndarray_linalg::prelude::*;
10-
use rand::distributions::*;
10+
use ndarray_numtest::prelude::*;
1111
use ndarray_rand::RandomExt;
1212

13-
fn all_close(a: Array<f64, Ix2>, b: Array<f64, Ix2>) {
14-
if !a.all_close(&b, 1.0e-7) {
15-
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
16-
a,
17-
b);
18-
}
19-
}
20-
2113
#[test]
2214
fn svd_square() {
23-
let r_dist = Range::new(0., 1.);
15+
let r_dist = RealNormal::new(0., 1.);
2416
let a = Array::<f64, _>::random((3, 3), r_dist);
2517
let (u, s, vt) = a.clone().svd().unwrap();
2618
let mut sm = Array::eye(3);
2719
for i in 0..3 {
2820
sm[(i, i)] = s[i];
2921
}
30-
all_close(u.dot(&sm).dot(&vt), a);
22+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
3123
}
3224
#[test]
3325
fn svd_square_t() {
34-
let r_dist = Range::new(0., 1.);
26+
let r_dist = RealNormal::new(0., 1.);
3527
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
3628
let (u, s, vt) = a.clone().svd().unwrap();
3729
let mut sm = Array::eye(3);
3830
for i in 0..3 {
3931
sm[(i, i)] = s[i];
4032
}
41-
all_close(u.dot(&sm).dot(&vt), a);
33+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
4234
}
4335

4436
#[test]
4537
fn svd_4x3() {
46-
let r_dist = Range::new(0., 1.);
38+
let r_dist = RealNormal::new(0., 1.);
4739
let a = Array::<f64, _>::random((4, 3), r_dist);
4840
let (u, s, vt) = a.clone().svd().unwrap();
4941
let mut sm = Array::zeros((4, 3));
5042
for i in 0..3 {
5143
sm[(i, i)] = s[i];
5244
}
53-
all_close(u.dot(&sm).dot(&vt), a);
45+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
5446
}
5547
#[test]
5648
fn svd_4x3_t() {
57-
let r_dist = Range::new(0., 1.);
49+
let r_dist = RealNormal::new(0., 1.);
5850
let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
5951
let (u, s, vt) = a.clone().svd().unwrap();
6052
let mut sm = Array::zeros((4, 3));
6153
for i in 0..3 {
6254
sm[(i, i)] = s[i];
6355
}
64-
all_close(u.dot(&sm).dot(&vt), a);
56+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
6557
}
6658

6759
#[test]
6860
fn svd_3x4() {
69-
let r_dist = Range::new(0., 1.);
61+
let r_dist = RealNormal::new(0., 1.);
7062
let a = Array::<f64, _>::random((3, 4), r_dist);
7163
let (u, s, vt) = a.clone().svd().unwrap();
7264
let mut sm = Array::zeros((3, 4));
7365
for i in 0..3 {
7466
sm[(i, i)] = s[i];
7567
}
76-
all_close(u.dot(&sm).dot(&vt), a);
68+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
7769
}
7870
#[test]
7971
fn svd_3x4_t() {
80-
let r_dist = Range::new(0., 1.);
72+
let r_dist = RealNormal::new(0., 1.);
8173
let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
8274
let (u, s, vt) = a.clone().svd().unwrap();
8375
let mut sm = Array::zeros((3, 4));
8476
for i in 0..3 {
8577
sm[(i, i)] = s[i];
8678
}
87-
all_close(u.dot(&sm).dot(&vt), a);
79+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
8880
}
8981

9082
#[test]
9183
#[ignore]
9284
fn svd_large() {
9385
let n = 2480;
9486
let m = 4280;
95-
let r_dist = Range::new(0., 1.);
87+
let r_dist = RealNormal::new(0., 1.);
9688
let a = Array::<f64, _>::random((n, m), r_dist);
9789
let (u, s, vt) = a.clone().svd().unwrap();
9890
let mut sm = Array::zeros((n, m));
9991
for i in 0..min(n, m) {
10092
sm[(i, i)] = s[i];
10193
}
102-
all_close(u.dot(&sm).dot(&vt), a);
94+
u.dot(&sm).dot(&vt).assert_allclose_l2(&a, 1e-7);
10395
}

0 commit comments

Comments
 (0)