diff --git a/src/generate.rs b/src/generate.rs index d1d22ef0..c1f1598d 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -19,40 +19,26 @@ pub fn conjugate(a: &ArrayBase) -> ArrayBase a } -/// Random vector -pub fn random_vector(n: usize) -> ArrayBase +pub fn random(sh: Sh) -> ArrayBase where A: RandNormal, - S: DataOwned + S: DataOwned, + D: Dimension, + Sh: ShapeBuilder { let mut rng = thread_rng(); - let v: Vec = (0..n).map(|_| A::randn(&mut rng)).collect(); - ArrayBase::from_vec(v) -} - -/// Random matrix -pub fn random_matrix(n: usize, m: usize) -> ArrayBase - where A: RandNormal, - S: DataOwned -{ - let mut rng = thread_rng(); - let v: Vec = (0..n * m).map(|_| A::randn(&mut rng)).collect(); - ArrayBase::from_shape_vec((n, m), v).unwrap() -} - -/// Random square matrix -pub fn random_square(n: usize) -> ArrayBase - where A: RandNormal, - S: DataOwned -{ - random_matrix(n, n) + ArrayBase::from_shape_fn(sh, |_| A::randn(&mut rng)) } /// Random Hermite matrix -pub fn random_hermite(n: usize) -> ArrayBase +pub fn random_hermite(n: usize, c_order: bool) -> ArrayBase where A: RandNormal + Conjugate + Add, S: DataOwned + DataMut { - let mut a = random_square(n); + let mut a = if c_order { + random((n, n)) + } else { + random((n, n).f()) + }; for i in 0..n { a[(i, i)] = a[(i, i)] + Conjugate::conj(a[(i, i)]); for j in (i + 1)..n { @@ -67,7 +53,7 @@ pub fn random_hpd(n: usize) -> ArrayBase where A: RandNormal + Conjugate + LinalgScalar, S: DataOwned + DataMut { - let a: Array2 = random_square(n); + let a: Array2 = random((n, n)); let ah: Array2 = conjugate(&a); replicate(&ah.dot(&a)) } diff --git a/tests/header.rs b/tests/header.rs deleted file mode 100644 index ba60580d..00000000 --- a/tests/header.rs +++ /dev/null @@ -1,47 +0,0 @@ - -extern crate rand_extra; -extern crate ndarray; -extern crate ndarray_rand; -#[macro_use] -#[allow(unused_imports)] -extern crate ndarray_linalg; -extern crate num_traits; - -#[allow(unused_imports)] -use ndarray::*; -#[allow(unused_imports)] -use ndarray_linalg::*; -#[allow(unused_imports)] -use rand_extra::*; -#[allow(unused_imports)] -use ndarray_rand::RandomExt; -#[allow(unused_imports)] -use num_traits::Float; - -pub fn random_owned(n: usize, m: usize, reversed: bool) -> Array { - let r_dist = RealNormal::new(0., 1.); - if reversed { - Array::random((m, n), r_dist).reversed_axes() - } else { - Array::random((n, m), r_dist) - } -} -pub fn random_shared(n: usize, m: usize, reversed: bool) -> RcArray { - let r_dist = RealNormal::new(0., 1.); - if reversed { - RcArray::random((m, n), r_dist).reversed_axes() - } else { - RcArray::random((n, m), r_dist) - } -} - -pub fn random_square(n: usize) -> Array { - let r_dist = RealNormal::new(0., 1.); - Array::::random((n, n), r_dist) -} - -pub fn random_hermite(n: usize) -> Array { - let r_dist = RealNormal::new(0., 1.); - let a = Array::::random((n, n), r_dist); - a.dot(&a.t()) -} diff --git a/tests/inv.rs b/tests/inv.rs index 8e705e24..c2050b69 100644 --- a/tests/inv.rs +++ b/tests/inv.rs @@ -1,37 +1,33 @@ -include!("header.rs"); -macro_rules! impl_test{ - ($modname:ident, $clone:ident) => { -mod $modname { - use super::random_square; - use ndarray::*; - use ndarray_linalg::*; - #[test] - fn inv_random() { - let a = random_square(3); - let ai = a.$clone().inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); - } +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; - #[test] - fn inv_random_t() { - let a = random_square(3).reversed_axes(); - let ai = a.$clone().inv().unwrap(); - let id = Array::eye(3); - assert_close_l2!(&ai.dot(&a), &id, 1e-7); - } +use ndarray::*; +use ndarray_linalg::*; - #[test] - #[should_panic] - fn inv_error() { - // do not have inverse - let a = Array::::zeros(9).into_shape((3, 3)).unwrap(); - let a_inv = a.$clone().inv().unwrap(); - println!("{:?}", a_inv); - } +#[test] +fn inv_random() { + let a: Array2 = random((3, 3)); + let ai: Array2<_> = (&a).inv().unwrap(); + let id = Array::eye(3); + assert_close_l2!(&ai.dot(&a), &id, 1e-7); } -}} // impl_test -impl_test!(owned, clone); -impl_test!(shared, to_shared); +#[test] +fn inv_random_t() { + let a: Array2 = random((3, 3).f()); + let ai: Array2<_> = (&a).inv().unwrap(); + let id = Array::eye(3); + assert_close_l2!(&ai.dot(&a), &id, 1e-7); +} + +#[test] +#[should_panic] +fn inv_error() { + // do not have inverse + let a = Array::::zeros(9).into_shape((3, 3)).unwrap(); + let a_inv = a.inv().unwrap(); + println!("{:?}", a_inv); +} diff --git a/tests/normalize.rs b/tests/normalize.rs index d5e276f8..2b5c22d4 100644 --- a/tests/normalize.rs +++ b/tests/normalize.rs @@ -1,15 +1,22 @@ -include!("header.rs"); + +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use ndarray::*; +use ndarray_linalg::*; #[test] fn n_columns() { - let a = random_owned(3, 2, true); + let a: Array2 = random((3, 2)); let (n, v) = normalize(a.clone(), NormalizeAxis::Column); assert_close_l2!(&n.dot(&from_diag(&v)), &a, 1e-7); } #[test] fn n_rows() { - let a = random_owned(3, 2, true); + let a: Array2 = random((3, 2)); let (n, v) = normalize(a.clone(), NormalizeAxis::Row); assert_close_l2!(&from_diag(&v).dot(&n), &a, 1e-7); } diff --git a/tests/opnorm.rs b/tests/opnorm.rs index 6c3ec34d..6aa2a8e0 100644 --- a/tests/opnorm.rs +++ b/tests/opnorm.rs @@ -1,41 +1,57 @@ -include!("header.rs"); -macro_rules! impl_test { - ($funcname:ident, $a:expr, $op1:expr, $opi:expr, $opf:expr) => { -#[test] -fn $funcname() { - let a = $a; +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use ndarray::*; +use ndarray_linalg::*; +use num_traits::Float; + +fn test(a: Array2, one: f64, inf: f64, fro: f64) { println!("ONE = {:?}", a.opnorm_one()); println!("INF = {:?}", a.opnorm_inf()); println!("FRO = {:?}", a.opnorm_fro()); - assert_rclose!(a.opnorm_fro().unwrap(), $opf, 1e-7; "Frobenius norm"); - assert_rclose!(a.opnorm_one().unwrap(), $op1, 1e-7; "One norm"); - assert_rclose!(a.opnorm_inf().unwrap(), $opi, 1e-7; "Infinity norm"); -} -}} // impl_test - -macro_rules! impl_test_opnorm { - ($modname:ident, $array:ty, $range:path) => { -mod $modname { - use ndarray::*; - use ndarray_linalg::*; - use num_traits::Float; - fn gen(i: usize, j: usize, rev: bool) -> $array { - let n = (i * j + 1) as f64; - if rev { - $range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes() - } else { - $range(1., n, 1.).into_shape((i, j)).unwrap() - } + assert_rclose!(a.opnorm_one().unwrap(), one, 1e-7; "One norm"); + assert_rclose!(a.opnorm_inf().unwrap(), inf, 1e-7; "Infinity norm"); + assert_rclose!(a.opnorm_fro().unwrap(), fro, 1e-7; "Frobenius norm"); +} + +fn gen(i: usize, j: usize, rev: bool) -> Array2 { + let n = (i * j + 1) as f64; + if rev { + Array::range(1., n, 1.).into_shape((j, i)).unwrap().reversed_axes() + } else { + Array::range(1., n, 1.).into_shape((i, j)).unwrap() } - impl_test!(opnorm_square, gen(3, 3, false), 18.0, 24.0, 285.0.sqrt()); - impl_test!(opnorm_square_t, gen(3, 3, true), 24.0, 18.0, 285.0.sqrt()); - impl_test!(opnorm_3x4, gen(3, 4, false), 24.0, 42.0, 650.0.sqrt()); - impl_test!(opnorm_4x3_t, gen(4, 3, true), 42.0, 24.0, 650.0.sqrt()); - impl_test!(opnorm_3x4_t, gen(3, 4, true), 33.0, 30.0, 650.0.sqrt()); - impl_test!(opnorm_4x3, gen(4, 3, false), 30.0, 33.0, 650.0.sqrt()); } -}} // impl_test_opnorm -impl_test_opnorm!(owned, Array, Array::range); -impl_test_opnorm!(shared, RcArray, RcArray::range); +#[test] +fn opnorm_square() { + test(gen(3, 3, false), 18.0, 24.0, 285.0.sqrt()); +} + +#[test] +fn opnorm_square_t() { + test(gen(3, 3, true), 24.0, 18.0, 285.0.sqrt()); +} + +#[test] +fn opnorm_3x4() { + test(gen(3, 4, false), 24.0, 42.0, 650.0.sqrt()); +} + +#[test] +fn opnorm_3x4_t() { + test(gen(3, 4, true), 33.0, 30.0, 650.0.sqrt()); +} + +#[test] +fn opnorm_4x3() { + test(gen(4, 3, false), 30.0, 33.0, 650.0.sqrt()); +} + +#[test] +fn opnorm_4x3_t() { + test(gen(4, 3, true), 42.0, 24.0, 650.0.sqrt()); +} diff --git a/tests/qr.rs b/tests/qr.rs index b3d46398..5dc226a0 100644 --- a/tests/qr.rs +++ b/tests/qr.rs @@ -1,35 +1,55 @@ -include!("header.rs"); -macro_rules! impl_test { - ($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => { -#[test] -fn $funcname() { - use std::cmp::min; - use ndarray::*; - use ndarray_linalg::*; - let a = $random($n, $m, $t); +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; + +use std::cmp::min; +use ndarray::*; +use ndarray_linalg::*; + +fn test(a: Array2, n: usize, m: usize) { let ans = a.clone(); println!("a = \n{:?}", &a); - let (q, r) : (Array2<_>, Array2<_>) = a.qr().unwrap(); + let (q, r): (Array2<_>, Array2<_>) = a.qr().unwrap(); println!("q = \n{:?}", &q); println!("r = \n{:?}", &r); - assert_close_l2!(&q.t().dot(&q), &Array::eye(min($n, $m)), 1e-7); + assert_close_l2!(&q.t().dot(&q), &Array::eye(min(n, m)), 1e-7); assert_close_l2!(&q.dot(&r), &ans, 1e-7); assert_close_l2!(&drop_lower(r.clone()), &r, 1e-7); } -}} // impl_test - -macro_rules! impl_test_qr { - ($modname:ident, $random:path) => { -mod $modname { - impl_test!(qr_square, $random, 3, 3, false); - impl_test!(qr_square_t, $random, 3, 3, true); - impl_test!(qr_3x4, $random, 3, 4, false); - impl_test!(qr_3x4_t, $random, 3, 4, true); - impl_test!(qr_4x3, $random, 4, 3, false); - impl_test!(qr_4x3_t, $random, 4, 3, true); + +#[test] +fn qr_square() { + let a = random((3, 3)); + test(a, 3, 3); +} + +#[test] +fn qr_square_t() { + let a = random((3, 3).f()); + test(a, 3, 3); } -}} // impl_test_qr -impl_test_qr!(owned, super::random_owned); -impl_test_qr!(shared, super::random_shared); +#[test] +fn qr_3x4() { + let a = random((3, 4)); + test(a, 3, 4); +} + +#[test] +fn qr_3x4_t() { + let a = random((3, 4).f()); + test(a, 3, 4); +} + +#[test] +fn qr_4x3() { + let a = random((4, 3)); + test(a, 4, 3); +} + +#[test] +fn qr_4x3_t() { + let a = random((4, 3).f()); + test(a, 4, 3); +} diff --git a/tests/svd.rs b/tests/svd.rs index 1aeeafcb..f0e963d1 100644 --- a/tests/svd.rs +++ b/tests/svd.rs @@ -1,13 +1,15 @@ -include!("header.rs"); -macro_rules! impl_test { - ($funcname:ident, $random:path, $n:expr, $m:expr, $t:expr) => { -#[test] -fn $funcname() { - use std::cmp::min; - use ndarray::*; - use ndarray_linalg::*; - let a = $random($n, $m, $t); +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use std::cmp::min; +use ndarray::*; +use ndarray_linalg::*; +use num_traits::Float; + +fn test(a: Array2, n: usize, m: usize) { let answer = a.clone(); println!("a = \n{:?}", &a); let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap(); @@ -16,25 +18,45 @@ fn $funcname() { println!("u = \n{:?}", &u); println!("s = \n{:?}", &s); println!("v = \n{:?}", &vt); - let mut sm = Array::zeros(($n, $m)); - for i in 0..min($n, $m) { + let mut sm = Array::zeros((n, m)); + for i in 0..min(n, m) { sm[(i, i)] = s[i]; } assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); } -}} // impl_test - -macro_rules! impl_test_svd { - ($modname:ident, $random:path) => { -mod $modname { - impl_test!(svd_square, $random, 3, 3, false); - impl_test!(svd_square_t, $random, 3, 3, true); - impl_test!(svd_4x3, $random, 4, 3, false); - impl_test!(svd_4x3_t, $random, 4, 3, true); - impl_test!(svd_3x4, $random, 3, 4, false); - impl_test!(svd_3x4_t, $random, 3, 4, true); + +#[test] +fn svd_square() { + let a = random((3, 3)); + test(a, 3, 3); +} + +#[test] +fn svd_square_t() { + let a = random((3, 3).f()); + test(a, 3, 3); } -}} // impl_test_svd -impl_test_svd!(owned, super::random_owned); -impl_test_svd!(shared, super::random_shared); +#[test] +fn svd_3x4() { + let a = random((3, 4)); + test(a, 3, 4); +} + +#[test] +fn svd_3x4_t() { + let a = random((3, 4).f()); + test(a, 3, 4); +} + +#[test] +fn svd_4x3() { + let a = random((4, 3)); + test(a, 4, 3); +} + +#[test] +fn svd_4x3_t() { + let a = random((4, 3).f()); + test(a, 4, 3); +} diff --git a/tests/trace.rs b/tests/trace.rs index d023194a..726885a3 100644 --- a/tests/trace.rs +++ b/tests/trace.rs @@ -1,8 +1,12 @@ -include!("header.rs"); + +extern crate ndarray; +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::*; #[test] fn trace() { - let r_dist = RealNormal::new(0., 1.); - let a = Array::::random((3, 3), r_dist); + let a: Array2 = random((3, 3)); assert_rclose!(a.trace().unwrap(), a[(0, 0)] + a[(1, 1)] + a[(2, 2)], 1e-7); } diff --git a/tests/triangular.rs b/tests/triangular.rs index 818dc54e..a00536fe 100644 --- a/tests/triangular.rs +++ b/tests/triangular.rs @@ -40,87 +40,87 @@ fn test2d(uplo: UPLO, a: ArrayBase, b: ArrayBase = random_vector(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper); + let b: Array1 = random(n); + let a: Array2 = random((n, n)).into_triangular(UPLO::Upper); test1d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_1d_lower() { let n = 3; - let b: Array1 = random_vector(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower); + let b: Array1 = random(n); + let a: Array2 = random((n, n)).into_triangular(UPLO::Lower); test1d(UPLO::Lower, a, b, 1e-7); } #[test] -fn triangular_1d_lower_t() { +fn triangular_1d_upper_t() { let n = 3; - let b: Array1 = random_vector(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Lower).reversed_axes(); + let b: Array1 = random(n); + let a: Array2 = random((n, n).f()).into_triangular(UPLO::Upper); test1d(UPLO::Upper, a, b, 1e-7); } #[test] -fn triangular_1d_upper_t() { +fn triangular_1d_lower_t() { let n = 3; - let b: Array1 = random_vector(n); - let a: Array2 = random_square(n).into_triangular(UPLO::Upper).reversed_axes(); + let b: Array1 = random(n); + let a: Array2 = random((n, n).f()).into_triangular(UPLO::Lower); test1d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_upper() { - let b: Array2 = random_matrix(3, 4); - let a: Array2 = random_square(3).into_triangular(UPLO::Upper); + let b: Array2 = random((3, 4)); + let a: Array2 = random((3, 3)).into_triangular(UPLO::Upper); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_lower() { - let b: Array2 = random_matrix(3, 4); - let a: Array2 = random_square(3).into_triangular(UPLO::Lower); + let b: Array2 = random((3, 4)); + let a: Array2 = random((3, 3)).into_triangular(UPLO::Lower); test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_lower_t() { - let b: Array2 = random_matrix(3, 4); - let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); - test2d(UPLO::Upper, a, b, 1e-7); + let b: Array2 = random((3, 4)); + let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_upper_t() { - let b: Array2 = random_matrix(3, 4); - let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); - test2d(UPLO::Lower, a, b, 1e-7); + let b: Array2 = random((3, 4)); + let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_upper_bt() { - let b: Array2 = random_matrix(4, 3).reversed_axes(); - let a: Array2 = random_square(3).into_triangular(UPLO::Upper); + let b: Array2 = random((3, 4).f()); + let a: Array2 = random((3, 3)).into_triangular(UPLO::Upper); test2d(UPLO::Upper, a, b, 1e-7); } #[test] fn triangular_2d_lower_bt() { - let b: Array2 = random_matrix(4, 3).reversed_axes(); - let a: Array2 = random_square(3).into_triangular(UPLO::Lower); + let b: Array2 = random((3, 4).f()); + let a: Array2 = random((3, 3)).into_triangular(UPLO::Lower); test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_lower_t_bt() { - let b: Array2 = random_matrix(4, 3).reversed_axes(); - let a: Array2 = random_square(3).into_triangular(UPLO::Lower).reversed_axes(); - test2d(UPLO::Upper, a, b, 1e-7); + let b: Array2 = random((3, 4).f()); + let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Lower); + test2d(UPLO::Lower, a, b, 1e-7); } #[test] fn triangular_2d_upper_t_bt() { - let b: Array2 = random_matrix(4, 3).reversed_axes(); - let a: Array2 = random_square(3).into_triangular(UPLO::Upper).reversed_axes(); - test2d(UPLO::Lower, a, b, 1e-7); + let b: Array2 = random((3, 4).f()); + let a: Array2 = random((3, 3).f()).into_triangular(UPLO::Upper); + test2d(UPLO::Upper, a, b, 1e-7); } diff --git a/tests/vector.rs b/tests/vector.rs index 27ce6cf7..07a80366 100644 --- a/tests/vector.rs +++ b/tests/vector.rs @@ -1,4 +1,12 @@ -include!("header.rs"); + +extern crate ndarray; +#[macro_use] +extern crate ndarray_linalg; +extern crate num_traits; + +use ndarray::*; +use ndarray_linalg::*; +use num_traits::Float; #[test] fn vector_norm() {