Skip to content

Commit 3f42447

Browse files
committed
Add tests of msg
1 parent 6fce47a commit 3f42447

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

src/arnoldi.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ pub enum Strategy {
130130
Full,
131131
}
132132

133+
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
133134
pub fn mgs<A, S>(
134135
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
135136
dim: usize,
@@ -152,11 +153,14 @@ where
152153
},
153154
}
154155
}
156+
let n = ortho.len();
155157
let m = coefs.len();
156-
let mut r = Array2::zeros((m, m));
157-
for i in 0..m {
158-
for j in 0..=i {
159-
r[(j, i)] = coefs[i][j];
158+
let mut r = Array2::zeros((n, m).f());
159+
for j in 0..m {
160+
for i in 0..n {
161+
if i < coefs[j].len() {
162+
r[(i, j)] = coefs[j][i];
163+
}
160164
}
161165
}
162166
(ortho.get_q(), r)

tests/arnoldi.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use ndarray::*;
2+
use ndarray_linalg::{arnoldi::*, *};
3+
4+
fn qr_full<A: Scalar + Lapack>() {
5+
const N: usize = 5;
6+
let rtol: A::Real = A::real(1e-9);
7+
8+
let a: Array2<A> = random((N, N));
9+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
10+
assert_close_l2!(&q.dot(&r), &a, rtol);
11+
12+
let qc: Array2<A> = conjugate(&q);
13+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
14+
}
15+
16+
#[test]
17+
fn qr_full_real() {
18+
qr_full::<f64>();
19+
}
20+
21+
#[test]
22+
fn qr_full_complex() {
23+
qr_full::<c64>();
24+
}
25+
26+
fn qr<A: Scalar + Lapack>() {
27+
const N: usize = 4;
28+
let rtol: A::Real = A::real(1e-9);
29+
30+
let a: Array2<A> = random((N, N / 2));
31+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
32+
assert_close_l2!(&q.dot(&r), &a, rtol);
33+
34+
let qc: Array2<A> = conjugate(&q);
35+
assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol);
36+
}
37+
38+
#[test]
39+
fn qr_real() {
40+
qr::<f64>();
41+
}
42+
43+
#[test]
44+
fn qr_complex() {
45+
qr::<c64>();
46+
}
47+
48+
fn qr_over<A: Scalar + Lapack>() {
49+
const N: usize = 4;
50+
let rtol: A::Real = A::real(1e-9);
51+
52+
let a: Array2<A> = random((N, N * 2));
53+
54+
// Terminate
55+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
56+
let a_sub = a.slice(s![.., 0..N]);
57+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
58+
let qc: Array2<A> = conjugate(&q);
59+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
60+
61+
// Skip
62+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
63+
let a_sub = a.slice(s![.., 0..N]);
64+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
65+
let qc: Array2<A> = conjugate(&q);
66+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
67+
68+
// Full
69+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
70+
assert_close_l2!(&q.dot(&r), &a, rtol);
71+
let qc: Array2<A> = conjugate(&q);
72+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
73+
}
74+
75+
#[test]
76+
fn qr_over_real() {
77+
qr_over::<f64>();
78+
}
79+
80+
#[test]
81+
fn qr_over_complex() {
82+
qr_over::<c64>();
83+
}

0 commit comments

Comments
 (0)