Skip to content

Commit 18a0066

Browse files
committed
Test for ssqrt
1 parent 1f7cc56 commit 18a0066

File tree

1 file changed

+22
-33
lines changed

1 file changed

+22
-33
lines changed

tests/ssqrt.rs

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,26 @@
11
include!("header.rs");
22

3-
#[test]
4-
fn ssqrt_symmetric_random() {
5-
let r_dist = RealNormal::new(0., 1.);
6-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
7-
a = a.dot(&a.t());
8-
let ar = a.clone().ssqrt().unwrap();
9-
all_close_l2(&ar.clone().reversed_axes(), &ar, 1e-7).unwrap();
3+
macro_rules! impl_test{
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use super::random_hermite;
7+
use ndarray_linalg::prelude::*;
8+
#[test]
9+
fn ssqrt() {
10+
let a = random_hermite(3);
11+
let ar = a.$clone().ssqrt().unwrap();
12+
all_close_l2(&ar.clone().t(), &ar, 1e-7).expect("not symmetric");
13+
all_close_l2(&ar.dot(&ar), &a, 1e-7).expect("not sqrt");
14+
}
15+
#[test]
16+
fn ssqrt_t() {
17+
let a = random_hermite(3).reversed_axes();
18+
let ar = a.$clone().ssqrt().unwrap();
19+
all_close_l2(&ar.clone().t(), &ar, 1e-7).expect("not symmetric");
20+
all_close_l2(&ar.dot(&ar), &a, 1e-7).expect("not sqrt");
21+
}
1022
}
23+
}} // impl_test
1124

12-
#[test]
13-
fn ssqrt_symmetric_random_t() {
14-
let r_dist = RealNormal::new(0., 1.);
15-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
16-
a = a.dot(&a.t()).reversed_axes();
17-
let ar = a.clone().ssqrt().unwrap();
18-
all_close_l2(&ar.clone().reversed_axes(), &ar, 1e-7).unwrap();
19-
}
20-
21-
#[test]
22-
fn ssqrt_sqrt_random() {
23-
let r_dist = RealNormal::new(0., 1.);
24-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
25-
a = a.dot(&a.t());
26-
let ar = a.clone().ssqrt().unwrap();
27-
all_close_l2(&ar.clone().reversed_axes(), &ar, 1e-7).unwrap();
28-
}
29-
30-
#[test]
31-
fn ssqrt_sqrt_random_t() {
32-
let r_dist = RealNormal::new(0., 1.);
33-
let mut a = Array::<f64, _>::random((3, 3), r_dist);
34-
a = a.dot(&a.t()).reversed_axes();
35-
let ar = a.clone().ssqrt().unwrap();
36-
all_close_l2(&ar.clone().reversed_axes(), &ar, 1e-7).unwrap();
37-
}
25+
impl_test!(owned, clone);
26+
impl_test!(shared, to_shared);

0 commit comments

Comments
 (0)