Skip to content

Commit d91a9d2

Browse files
committed
Test for SVD
1 parent 3e43ea7 commit d91a9d2

File tree

2 files changed

+50
-77
lines changed

2 files changed

+50
-77
lines changed

tests/header.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ use ndarray_rand::RandomExt;
1616
#[allow(unused_imports)]
1717
use num_traits::Float;
1818

19+
pub fn random_owned(n: usize, m: usize, reversed: bool) -> Array<f64, Ix2> {
20+
let r_dist = RealNormal::new(0., 1.);
21+
if reversed {
22+
Array::random((m, n), r_dist).reversed_axes()
23+
} else {
24+
Array::random((n, m), r_dist)
25+
}
26+
}
27+
pub fn random_shared(n: usize, m: usize, reversed: bool) -> RcArray<f64, Ix2> {
28+
let r_dist = RealNormal::new(0., 1.);
29+
if reversed {
30+
RcArray::random((m, n), r_dist).reversed_axes()
31+
} else {
32+
RcArray::random((n, m), r_dist)
33+
}
34+
}
35+
1936
pub fn random_square(n: usize) -> Array<f64, Ix2> {
2037
let r_dist = RealNormal::new(0., 1.);
2138
Array::<f64, _>::random((n, n), r_dist)

tests/svd.rs

Lines changed: 33 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,43 @@
11
include!("header.rs");
22

3-
use std::cmp::min;
4-
5-
#[test]
6-
fn svd_square() {
7-
let r_dist = RealNormal::new(0., 1.);
8-
let a = Array::<f64, _>::random((3, 3), r_dist);
9-
let (u, s, vt) = a.clone().svd().unwrap();
10-
let mut sm = Array::eye(3);
11-
for i in 0..3 {
12-
sm[(i, i)] = s[i];
13-
}
14-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
15-
}
16-
#[test]
17-
fn svd_square_t() {
18-
let r_dist = RealNormal::new(0., 1.);
19-
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
20-
let (u, s, vt) = a.clone().svd().unwrap();
21-
let mut sm = Array::eye(3);
22-
for i in 0..3 {
23-
sm[(i, i)] = s[i];
24-
}
25-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
26-
}
27-
28-
#[test]
29-
fn svd_4x3() {
30-
let r_dist = RealNormal::new(0., 1.);
31-
let a = Array::<f64, _>::random((4, 3), r_dist);
32-
let (u, s, vt) = a.clone().svd().unwrap();
33-
let mut sm = Array::zeros((4, 3));
34-
for i in 0..3 {
35-
sm[(i, i)] = s[i];
36-
}
37-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
38-
}
3+
macro_rules! impl_test {
4+
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
395
#[test]
40-
fn svd_4x3_t() {
41-
let r_dist = RealNormal::new(0., 1.);
42-
let a = Array::<f64, _>::random((3, 4), r_dist).reversed_axes();
43-
let (u, s, vt) = a.clone().svd().unwrap();
44-
let mut sm = Array::zeros((4, 3));
45-
for i in 0..3 {
6+
fn $funcname() {
7+
use std::cmp::min;
8+
use ndarray::prelude::*;
9+
use ndarray_linalg::prelude::*;
10+
let a = $random($n, $m, $t);
11+
let answer = a.clone();
12+
println!("a = \n{}", &a);
13+
let (u, s, vt) = a.svd().unwrap();
14+
println!("u = \n{}", &u);
15+
println!("s = \n{}", &s);
16+
println!("v = \n{}", &vt);
17+
let mut sm = Array::zeros(($n, $m));
18+
for i in 0..min($n, $m) {
4619
sm[(i, i)] = s[i];
4720
}
48-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
21+
all_close_l2(&u.dot(&sm).dot(&vt), &answer, 1e-7).unwrap();
4922
}
23+
}} // impl_test
5024

51-
#[test]
52-
fn svd_3x4() {
53-
let r_dist = RealNormal::new(0., 1.);
54-
let a = Array::<f64, _>::random((3, 4), r_dist);
55-
let (u, s, vt) = a.clone().svd().unwrap();
56-
let mut sm = Array::zeros((3, 4));
57-
for i in 0..3 {
58-
sm[(i, i)] = s[i];
59-
}
60-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
61-
}
62-
#[test]
63-
fn svd_3x4_t() {
64-
let r_dist = RealNormal::new(0., 1.);
65-
let a = Array::<f64, _>::random((4, 3), r_dist).reversed_axes();
66-
let (u, s, vt) = a.clone().svd().unwrap();
67-
let mut sm = Array::zeros((3, 4));
68-
for i in 0..3 {
69-
sm[(i, i)] = s[i];
70-
}
71-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
25+
mod owned {
26+
use super::random_owned;
27+
impl_test!(svd_square, random_owned, 3, 3, false);
28+
impl_test!(svd_square_t, random_owned, 3, 3, true);
29+
impl_test!(svd_4x3, random_owned, 4, 3, false);
30+
impl_test!(svd_4x3_t, random_owned, 4, 3, true);
31+
impl_test!(svd_3x4, random_owned, 3, 4, false);
32+
impl_test!(svd_3x4_t, random_owned, 3, 4, true);
7233
}
7334

74-
#[test]
75-
#[ignore]
76-
fn svd_large() {
77-
let n = 2480;
78-
let m = 4280;
79-
let r_dist = RealNormal::new(0., 1.);
80-
let a = Array::<f64, _>::random((n, m), r_dist);
81-
let (u, s, vt) = a.clone().svd().unwrap();
82-
let mut sm = Array::zeros((n, m));
83-
for i in 0..min(n, m) {
84-
sm[(i, i)] = s[i];
85-
}
86-
all_close_l2(&u.dot(&sm).dot(&vt), &a, 1e-7).unwrap();
35+
mod shared {
36+
use super::random_shared;
37+
impl_test!(svd_square, random_shared, 3, 3, false);
38+
impl_test!(svd_square_t, random_shared, 3, 3, true);
39+
impl_test!(svd_4x3, random_shared, 4, 3, false);
40+
impl_test!(svd_4x3_t, random_shared, 4, 3, true);
41+
impl_test!(svd_3x4, random_shared, 3, 4, false);
42+
impl_test!(svd_3x4_t, random_shared, 3, 4, true);
8743
}

0 commit comments

Comments
 (0)