Skip to content

Commit d28b856

Browse files
committed
Test for inv()
1 parent 417c7ad commit d28b856

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

tests/header.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ use ndarray_rand::RandomExt;
1616
#[allow(unused_imports)]
1717
use num_traits::Float;
1818

19+
pub fn random_square(n: usize) -> Array<f64, Ix2> {
20+
let r_dist = RealNormal::new(0., 1.);
21+
Array::<f64, _>::random((n, n), r_dist)
22+
}
23+
1924
pub fn random_hermite(n: usize) -> Array<f64, Ix2> {
2025
let r_dist = RealNormal::new(0., 1.);
2126
let a = Array::<f64, _>::random((n, n), r_dist);

tests/inv.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,36 @@
11
include!("header.rs");
22

3-
#[test]
4-
fn inv_random() {
5-
let r_dist = RealNormal::new(0., 1.);
6-
let a = Array::<f64, _>::random((3, 3), r_dist);
7-
let ai = a.clone().inv().unwrap();
8-
let id = Array::eye(3);
9-
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
10-
}
3+
macro_rules! impl_test{
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use super::random_square;
7+
use ndarray::prelude::*;
8+
use ndarray_linalg::prelude::*;
9+
#[test]
10+
fn inv_random() {
11+
let a = random_square(3);
12+
let ai = a.$clone().inv().unwrap();
13+
let id = Array::eye(3);
14+
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
15+
}
1116

12-
#[test]
13-
fn inv_random_t() {
14-
let r_dist = RealNormal::new(0., 1.);
15-
let a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
16-
let ai = a.clone().inv().unwrap();
17-
let id = Array::eye(3);
18-
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
19-
}
17+
#[test]
18+
fn inv_random_t() {
19+
let a = random_square(3).reversed_axes();
20+
let ai = a.$clone().inv().unwrap();
21+
let id = Array::eye(3);
22+
all_close_l2(&ai.dot(&a), &id, 1e-7).unwrap();
23+
}
2024

21-
#[test]
22-
#[should_panic]
23-
fn inv_error() {
24-
// do not have inverse
25-
let a = Array::range(1., 10., 1.).into_shape((3, 3)).unwrap();
26-
let _ = a.clone().inv().unwrap();
25+
#[test]
26+
#[should_panic]
27+
fn inv_error() {
28+
// do not have inverse
29+
let a = Array::range(1., 10., 1.).into_shape((3, 3)).unwrap();
30+
let _ = a.$clone().inv().unwrap();
31+
}
2732
}
33+
}} // impl_test
34+
35+
impl_test!(owned, clone);
36+
impl_test!(shared, to_shared);

0 commit comments

Comments
 (0)