Skip to content

Commit 3a21b4d

Browse files
committed
Test eigh for RcArray
1 parent 7e5164c commit 3a21b4d

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pub use matrix::Matrix;
44
pub use square::SquareMatrix;
55
pub use hermite::HermiteMatrix;
66
pub use triangular::{TriangularMatrix, drop_lower, drop_upper};
7+
pub use util::{all_close_l1, all_close_l2, all_close_max};

tests/eigh.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
include!("header.rs");
22

3-
#[test]
4-
fn eigen_vector_manual() {
5-
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
6-
let (e, vecs) = a.clone().eigh().unwrap();
7-
assert!(e.all_close(&arr1(&[2.0, 2.0, 5.0]), 1.0e-7));
8-
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
9-
let av = a.dot(&v);
10-
let ev = v.mapv(|x| e[i] * x);
11-
assert!(av.all_close(&ev, 1.0e-7));
3+
macro_rules! impl_test_eigh {
4+
($modname:ident, $clone:ident) => {
5+
mod $modname {
6+
use ndarray::prelude::*;
7+
use ndarray_linalg::prelude::*;
8+
use ndarray_numtest::prelude::*;
9+
#[test]
10+
fn eigen_vector_manual() {
11+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
12+
let (e, vecs) = a.$clone().eigh().unwrap();
13+
all_close_l2(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7).unwrap();
14+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
15+
let av = a.dot(&v);
16+
let ev = v.mapv(|x| e[i] * x);
17+
all_close_l2(&av, &ev, 1.0e-7).unwrap();
18+
}
1219
}
13-
}
14-
15-
#[test]
16-
fn diagonalize() {
17-
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
18-
let (e, vecs) = a.clone().eigh().unwrap();
19-
let s = vecs.t().dot(&a).dot(&vecs);
20-
for i in 0..3 {
21-
e[i].assert_close(s[(i, i)], 1e-7);
20+
#[test]
21+
fn diagonalize() {
22+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
23+
let (e, vecs) = a.$clone().eigh().unwrap();
24+
let s = vecs.t().dot(&a).dot(&vecs);
25+
for i in 0..3 {
26+
e[i].assert_close(s[(i, i)], 1e-7);
27+
}
2228
}
2329
}
30+
}} // impl_test_eigh
31+
32+
impl_test_eigh!(owned, clone);
33+
impl_test_eigh!(shared, to_shared);

tests/header.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ use ndarray::prelude::*;
1010
#[allow(unused_imports)]
1111
use ndarray_linalg::prelude::*;
1212
#[allow(unused_imports)]
13-
use ndarray_linalg::util::*;
14-
#[allow(unused_imports)]
1513
use ndarray_numtest::prelude::*;
1614
#[allow(unused_imports)]
1715
use ndarray_rand::RandomExt;

0 commit comments

Comments
 (0)