Skip to content

Commit 549e6d6

Browse files
committed
Add test for eigh
1 parent bcaf959 commit 549e6d6

File tree

1 file changed

+78
-6
lines changed

1 file changed

+78
-6
lines changed

tests/eigh.rs

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@ use ndarray::*;
22
use ndarray_linalg::*;
33

44
#[test]
5-
fn eigen_vector_manual() {
5+
fn fixed() {
66
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
77
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Upper).unwrap();
88
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
9+
10+
// Check eigenvectors are orthogonalized
11+
let s = vecs.t().dot(&vecs);
12+
assert_close_l2!(&s, &Array::eye(3), 1.0e-7);
13+
914
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
1015
let av = a.dot(&v);
1116
let ev = v.mapv(|x| e[i] * x);
@@ -14,12 +19,53 @@ fn eigen_vector_manual() {
1419
}
1520

1621
#[test]
17-
fn diagonalize() {
18-
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
22+
fn fixed_t() {
23+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]).reversed_axes();
1924
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Upper).unwrap();
20-
let s = vecs.t().dot(&a).dot(&vecs);
21-
for i in 0..3 {
22-
assert_rclose!(e[i], s[(i, i)], 1e-7);
25+
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
26+
27+
// Check eigenvectors are orthogonalized
28+
let s = vecs.t().dot(&vecs);
29+
assert_close_l2!(&s, &Array::eye(3), 1.0e-7);
30+
31+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
32+
let av = a.dot(&v);
33+
let ev = v.mapv(|x| e[i] * x);
34+
assert_close_l2!(&av, &ev, 1.0e-7);
35+
}
36+
}
37+
38+
#[test]
39+
fn fixed_lower() {
40+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]);
41+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Lower).unwrap();
42+
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
43+
44+
// Check eigenvectors are orthogonalized
45+
let s = vecs.t().dot(&vecs);
46+
assert_close_l2!(&s, &Array::eye(3), 1.0e-7);
47+
48+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
49+
let av = a.dot(&v);
50+
let ev = v.mapv(|x| e[i] * x);
51+
assert_close_l2!(&av, &ev, 1.0e-7);
52+
}
53+
}
54+
55+
#[test]
56+
fn fixed_t_lower() {
57+
let a = arr2(&[[3.0, 1.0, 1.0], [1.0, 3.0, 1.0], [1.0, 1.0, 3.0]]).reversed_axes();
58+
let (e, vecs): (Array1<_>, Array2<_>) = (&a).eigh(UPLO::Lower).unwrap();
59+
assert_close_l2!(&e, &arr1(&[2.0, 2.0, 5.0]), 1.0e-7);
60+
61+
// Check eigenvectors are orthogonalized
62+
let s = vecs.t().dot(&vecs);
63+
assert_close_l2!(&s, &Array::eye(3), 1.0e-7);
64+
65+
for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
66+
let av = a.dot(&v);
67+
let ev = v.mapv(|x| e[i] * x);
68+
assert_close_l2!(&av, &ev, 1.0e-7);
2369
}
2470
}
2571

@@ -48,3 +94,29 @@ fn ssqrt_t() {
4894
println!("ss = {:?}", &ss);
4995
assert_close_l2!(&ss, &ans, 1e-7);
5096
}
97+
98+
#[test]
99+
fn ssqrt_lower() {
100+
let a: Array2<f64> = random_hpd(3);
101+
let ans = a.clone();
102+
let s = a.ssqrt(UPLO::Lower).unwrap();
103+
println!("a = {:?}", &ans);
104+
println!("s = {:?}", &s);
105+
assert_close_l2!(&s.t(), &s, 1e-7);
106+
let ss = s.dot(&s);
107+
println!("ss = {:?}", &ss);
108+
assert_close_l2!(&ss, &ans, 1e-7);
109+
}
110+
111+
#[test]
112+
fn ssqrt_t_lower() {
113+
let a: Array2<f64> = random_hpd(3).reversed_axes();
114+
let ans = a.clone();
115+
let s = a.ssqrt(UPLO::Lower).unwrap();
116+
println!("a = {:?}", &ans);
117+
println!("s = {:?}", &s);
118+
assert_close_l2!(&s.t(), &s, 1e-7);
119+
let ss = s.dot(&s);
120+
println!("ss = {:?}", &ss);
121+
assert_close_l2!(&ss, &ans, 1e-7);
122+
}

0 commit comments

Comments
 (0)