Skip to content

Commit f002e8e

Browse files
authored
Merge pull request #51 from termoshtt/revice_test
Revice test
2 parents 844ea5a + 31a3f72 commit f002e8e

File tree

10 files changed

+238
-226
lines changed

10 files changed

+238
-226
lines changed

src/generate.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,26 @@ pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
1919
a
2020
}
2121

22-
/// Random vector
23-
pub fn random_vector<A, S>(n: usize) -> ArrayBase<S, Ix1>
22+
pub fn random<A, S, Sh, D>(sh: Sh) -> ArrayBase<S, D>
2423
where A: RandNormal,
25-
S: DataOwned<Elem = A>
24+
S: DataOwned<Elem = A>,
25+
D: Dimension,
26+
Sh: ShapeBuilder<Dim = D>
2627
{
2728
let mut rng = thread_rng();
28-
let v: Vec<A> = (0..n).map(|_| A::randn(&mut rng)).collect();
29-
ArrayBase::from_vec(v)
30-
}
31-
32-
/// Random matrix
33-
pub fn random_matrix<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
34-
where A: RandNormal,
35-
S: DataOwned<Elem = A>
36-
{
37-
let mut rng = thread_rng();
38-
let v: Vec<A> = (0..n * m).map(|_| A::randn(&mut rng)).collect();
39-
ArrayBase::from_shape_vec((n, m), v).unwrap()
40-
}
41-
42-
/// Random square matrix
43-
pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
44-
where A: RandNormal,
45-
S: DataOwned<Elem = A>
46-
{
47-
random_matrix(n, n)
29+
ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng))
4830
}
4931

5032
/// Random Hermite matrix
51-
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
33+
pub fn random_hermite<A, S>(n: usize, c_order: bool) -> ArrayBase<S, Ix2>
5234
where A: RandNormal + Conjugate + Add<Output = A>,
5335
S: DataOwned<Elem = A> + DataMut
5436
{
55-
let mut a = random_square(n);
37+
let mut a = if c_order {
38+
random((n, n))
39+
} else {
40+
random((n, n).f())
41+
};
5642
for i in 0..n {
5743
a[(i, i)] = a[(i, i)] + Conjugate::conj(a[(i, i)]);
5844
for j in (i + 1)..n {
@@ -67,7 +53,7 @@ pub fn random_hpd<A, S>(n: usize) -> ArrayBase<S, Ix2>
6753
where A: RandNormal + Conjugate + LinalgScalar,
6854
S: DataOwned<Elem = A> + DataMut
6955
{
70-
let a: Array2<A> = random_square(n);
56+
let a: Array2<A> = random((n, n));
7157
let ah: Array2<A> = conjugate(&a);
7258
replicate(&ah.dot(&a))
7359
}

tests/header.rs

Lines changed: 0 additions & 47 deletions
This file was deleted.

tests/inv.rs

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,33 @@
1-
include!("header.rs");
21

3-
macro_rules! impl_test{
4-
($modname:ident, $clone:ident) => {
5-
mod $modname {
6-
use super::random_square;
7-
use ndarray::*;
8-
use ndarray_linalg::*;
9-
#[test]
10-
fn inv_random() {
11-
let a = random_square(3);
12-
let ai = a.$clone().inv().unwrap();
13-
let id = Array::eye(3);
14-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
15-
}
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
166

17-
#[test]
18-
fn inv_random_t() {
19-
let a = random_square(3).reversed_axes();
20-
let ai = a.$clone().inv().unwrap();
21-
let id = Array::eye(3);
22-
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
23-
}
7+
use ndarray::*;
8+
use ndarray_linalg::*;
249

25-
#[test]
26-
#[should_panic]
27-
fn inv_error() {
28-
// do not have inverse
29-
let a = Array::<f64, _>::zeros(9).into_shape((3, 3)).unwrap();
30-
let a_inv = a.$clone().inv().unwrap();
31-
println!("{:?}", a_inv);
32-
}
10+
#[test]
11+
fn inv_random() {
12+
let a: Array2<f64> = random((3, 3));
13+
let ai: Array2<_> = (&a).inv().unwrap();
14+
let id = Array::eye(3);
15+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
3316
}
34-
}} // impl_test
3517

36-
impl_test!(owned, clone);
37-
impl_test!(shared, to_shared);
18+
#[test]
19+
fn inv_random_t() {
20+
let a: Array2<f64> = random((3, 3).f());
21+
let ai: Array2<_> = (&a).inv().unwrap();
22+
let id = Array::eye(3);
23+
assert_close_l2!(&ai.dot(&a), &id, 1e-7);
24+
}
25+
26+
#[test]
27+
#[should_panic]
28+
fn inv_error() {
29+
// do not have inverse
30+
let a = Array::<f64, _>::zeros(9).into_shape((3, 3)).unwrap();
31+
let a_inv = a.inv().unwrap();
32+
println!("{:?}", a_inv);
33+
}

tests/normalize.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1-
include!("header.rs");
1+
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
29

310
#[test]
411
fn n_columns() {
5-
let a = random_owned(3, 2, true);
12+
let a: Array2<f64> = random((3, 2));
613
let (n, v) = normalize(a.clone(), NormalizeAxis::Column);
714
assert_close_l2!(&n.dot(&from_diag(&v)), &a, 1e-7);
815
}
916

1017
#[test]
1118
fn n_rows() {
12-
let a = random_owned(3, 2, true);
19+
let a: Array2<f64> = random((3, 2));
1320
let (n, v) = normalize(a.clone(), NormalizeAxis::Row);
1421
assert_close_l2!(&from_diag(&v).dot(&n), &a, 1e-7);
1522
}

tests/opnorm.rs

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,57 @@
1-
include!("header.rs");
21

3-
macro_rules! impl_test {
4-
($funcname:ident, $a:expr, $op1:expr, $opi:expr, $opf:expr) => {
5-
#[test]
6-
fn $funcname() {
7-
let a = $a;
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
use num_traits::Float;
10+
11+
fn test(a: Array2<f64>, one: f64, inf: f64, fro: f64) {
812
println!("ONE = {:?}", a.opnorm_one());
913
println!("INF = {:?}", a.opnorm_inf());
1014
println!("FRO = {:?}", a.opnorm_fro());
11-
assert_rclose!(a.opnorm_fro().unwrap(), $opf, 1e-7; "Frobenius norm");
12-
assert_rclose!(a.opnorm_one().unwrap(), $op1, 1e-7; "One norm");
13-
assert_rclose!(a.opnorm_inf().unwrap(), $opi, 1e-7; "Infinity norm");
14-
}
15-
}} // impl_test
16-
17-
macro_rules! impl_test_opnorm {
18-
($modname:ident, $array:ty, $range:path) => {
19-
mod $modname {
20-
use ndarray::*;
21-
use ndarray_linalg::*;
22-
use num_traits::Float;
23-
fn gen(i: usize, j: usize, rev: bool) -> $array {
24-
let n = (i * j + 1) as f64;
25-
if rev {
26-
$range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes()
27-
} else {
28-
$range(1., n, 1.).into_shape((i, j)).unwrap()
29-
}
15+
assert_rclose!(a.opnorm_one().unwrap(), one, 1e-7; "One norm");
16+
assert_rclose!(a.opnorm_inf().unwrap(), inf, 1e-7; "Infinity norm");
17+
assert_rclose!(a.opnorm_fro().unwrap(), fro, 1e-7; "Frobenius norm");
18+
}
19+
20+
fn gen(i: usize, j: usize, rev: bool) -> Array2<f64> {
21+
let n = (i * j + 1) as f64;
22+
if rev {
23+
Array::range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes()
24+
} else {
25+
Array::range(1., n, 1.).into_shape((i, j)).unwrap()
3026
}
31-
impl_test!(opnorm_square, gen(3, 3, false), 18.0, 24.0, 285.0.sqrt());
32-
impl_test!(opnorm_square_t, gen(3, 3, true), 24.0, 18.0, 285.0.sqrt());
33-
impl_test!(opnorm_3x4, gen(3, 4, false), 24.0, 42.0, 650.0.sqrt());
34-
impl_test!(opnorm_4x3_t, gen(4, 3, true), 42.0, 24.0, 650.0.sqrt());
35-
impl_test!(opnorm_3x4_t, gen(3, 4, true), 33.0, 30.0, 650.0.sqrt());
36-
impl_test!(opnorm_4x3, gen(4, 3, false), 30.0, 33.0, 650.0.sqrt());
3727
}
38-
}} // impl_test_opnorm
3928

40-
impl_test_opnorm!(owned, Array<f64, Ix2>, Array::range);
41-
impl_test_opnorm!(shared, RcArray<f64, Ix2>, RcArray::range);
29+
#[test]
30+
fn opnorm_square() {
31+
test(gen(3, 3, false), 18.0, 24.0, 285.0.sqrt());
32+
}
33+
34+
#[test]
35+
fn opnorm_square_t() {
36+
test(gen(3, 3, true), 24.0, 18.0, 285.0.sqrt());
37+
}
38+
39+
#[test]
40+
fn opnorm_3x4() {
41+
test(gen(3, 4, false), 24.0, 42.0, 650.0.sqrt());
42+
}
43+
44+
#[test]
45+
fn opnorm_3x4_t() {
46+
test(gen(3, 4, true), 33.0, 30.0, 650.0.sqrt());
47+
}
48+
49+
#[test]
50+
fn opnorm_4x3() {
51+
test(gen(4, 3, false), 30.0, 33.0, 650.0.sqrt());
52+
}
53+
54+
#[test]
55+
fn opnorm_4x3_t() {
56+
test(gen(4, 3, true), 42.0, 24.0, 650.0.sqrt());
57+
}

tests/qr.rs

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,55 @@
1-
include!("header.rs");
21

3-
macro_rules! impl_test {
4-
($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => {
5-
#[test]
6-
fn $funcname() {
7-
use std::cmp::min;
8-
use ndarray::*;
9-
use ndarray_linalg::*;
10-
let a = $random($n, $m, $t);
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
6+
use std::cmp::min;
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
10+
fn test(a: Array2<f64>, n: usize, m: usize) {
1111
let ans = a.clone();
1212
println!("a = \n{:?}", &a);
13-
let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap();
13+
let (q, r): (Array2<_>, Array2<_>) = a.qr().unwrap();
1414
println!("q = \n{:?}", &q);
1515
println!("r = \n{:?}", &r);
16-
assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7);
16+
assert_close_l2!(&q.t().dot(&q), &Array::eye(min(n, m)), 1e-7);
1717
assert_close_l2!(&q.dot(&r), &ans, 1e-7);
1818
assert_close_l2!(&drop_lower(r.clone()), &r, 1e-7);
1919
}
20-
}} // impl_test
21-
22-
macro_rules! impl_test_qr {
23-
($modname:ident, $random:path) => {
24-
mod $modname {
25-
impl_test!(qr_square, $random, 3, 3, false);
26-
impl_test!(qr_square_t, $random, 3, 3, true);
27-
impl_test!(qr_3x4, $random, 3, 4, false);
28-
impl_test!(qr_3x4_t, $random, 3, 4, true);
29-
impl_test!(qr_4x3, $random, 4, 3, false);
30-
impl_test!(qr_4x3_t, $random, 4, 3, true);
20+
21+
#[test]
22+
fn qr_square() {
23+
let a = random((3, 3));
24+
test(a, 3, 3);
25+
}
26+
27+
#[test]
28+
fn qr_square_t() {
29+
let a = random((3, 3).f());
30+
test(a, 3, 3);
3131
}
32-
}} // impl_test_qr
3332

34-
impl_test_qr!(owned, super::random_owned);
35-
impl_test_qr!(shared, super::random_shared);
33+
#[test]
34+
fn qr_3x4() {
35+
let a = random((3, 4));
36+
test(a, 3, 4);
37+
}
38+
39+
#[test]
40+
fn qr_3x4_t() {
41+
let a = random((3, 4).f());
42+
test(a, 3, 4);
43+
}
44+
45+
#[test]
46+
fn qr_4x3() {
47+
let a = random((4, 3));
48+
test(a, 4, 3);
49+
}
50+
51+
#[test]
52+
fn qr_4x3_t() {
53+
let a = random((4, 3).f());
54+
test(a, 4, 3);
55+
}

0 commit comments

Comments
 (0)