Skip to content

Commit 4e84acb

Browse files
committed
Add test for householder (and bug found)
1 parent 5d74992 commit 4e84acb

File tree

3 files changed

+173
-66
lines changed

3 files changed

+173
-66
lines changed

src/krylov/mgs.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,26 @@ impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
7676
hstack(&self.q).unwrap()
7777
}
7878
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use super::*;
83+
use crate::assert::*;
84+
85+
#[test]
86+
fn mgs_append() {
87+
let mut mgs = MGS::new(3);
88+
let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
89+
close_l2(&coef, &array![1.0], 1e-9).unwrap();
90+
91+
let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
92+
close_l2(&coef, &array![1.0, 1.0], 1e-9).unwrap();
93+
94+
assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
95+
96+
if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
97+
close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9).unwrap();
98+
}
99+
}
100+
101+
}

src/krylov/mod.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,19 @@ pub enum Strategy {
8888
Full,
8989
}
9090

91-
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
91+
/// Online QR decomposition using arbitary orthogonalizer
9292
pub fn qr<A, S>(
9393
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
94-
dim: usize,
94+
mut ortho: impl Orthogonalizer<Elem = A>,
9595
rtol: A::Real,
9696
strategy: Strategy,
9797
) -> (Q<A>, R<A>)
9898
where
9999
A: Scalar + Lapack,
100100
S: Data<Elem = A>,
101101
{
102-
let mut ortho = MGS::new(dim);
102+
assert_eq!(ortho.len(), 0);
103+
103104
let mut coefs = Vec::new();
104105
for a in iter {
105106
match ortho.append(a.into_owned(), rtol) {
@@ -123,3 +124,33 @@ where
123124
}
124125
(ortho.get_q(), r)
125126
}
127+
128+
/// Online QR decomposition using modified Gram-Schmit
129+
pub fn mgs<A, S>(
130+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
131+
dim: usize,
132+
rtol: A::Real,
133+
strategy: Strategy,
134+
) -> (Q<A>, R<A>)
135+
where
136+
A: Scalar + Lapack,
137+
S: Data<Elem = A>,
138+
{
139+
let mgs = MGS::new(dim);
140+
qr(iter, mgs, rtol, strategy)
141+
}
142+
143+
/// Online QR decomposition using modified Gram-Schmit
144+
pub fn householder<A, S>(
145+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
146+
dim: usize,
147+
rtol: A::Real,
148+
strategy: Strategy,
149+
) -> (Q<A>, R<A>)
150+
where
151+
A: Scalar + Lapack,
152+
S: Data<Elem = A>,
153+
{
154+
let h = Householder::new(dim);
155+
qr(iter, h, rtol, strategy)
156+
}

tests/krylov.rs

Lines changed: 116 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,136 @@
11
use ndarray::*;
22
use ndarray_linalg::{krylov::*, *};
33

4-
fn qr_full<A: Scalar + Lapack>() {
5-
const N: usize = 5;
6-
let rtol: A::Real = A::real(1e-9);
7-
8-
let a: Array2<A> = random((N, N));
9-
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
10-
assert_close_l2!(&q.dot(&r), &a, rtol);
11-
12-
let qc: Array2<A> = conjugate(&q);
13-
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
14-
}
15-
164
#[test]
17-
fn qr_full_real() {
18-
qr_full::<f64>();
5+
fn mgs_full() {
6+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
7+
const N: usize = 5;
8+
let a: Array2<A> = random((N, N));
9+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
10+
assert_close_l2!(&q.dot(&r), &a, rtol);
11+
let qc: Array2<A> = conjugate(&q);
12+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
13+
}
14+
15+
test::<f32>(1e-5);
16+
test::<f64>(1e-9);
17+
test::<c32>(1e-5);
18+
test::<c64>(1e-9);
1919
}
2020

2121
#[test]
22-
fn qr_full_complex() {
23-
qr_full::<c64>();
24-
}
25-
26-
fn qr_<A: Scalar + Lapack>() {
27-
const N: usize = 4;
28-
let rtol: A::Real = A::real(1e-9);
29-
30-
let a: Array2<A> = random((N, N / 2));
31-
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
32-
assert_close_l2!(&q.dot(&r), &a, rtol);
33-
34-
let qc: Array2<A> = conjugate(&q);
35-
assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol);
22+
fn mgs_half() {
23+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
24+
const N: usize = 4;
25+
let a: Array2<A> = random((N, N / 2));
26+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
27+
assert_close_l2!(&q.dot(&r), &a, rtol);
28+
let qc: Array2<A> = conjugate(&q);
29+
assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol);
30+
}
31+
32+
test::<f32>(1e-5);
33+
test::<f64>(1e-9);
34+
test::<c32>(1e-5);
35+
test::<c64>(1e-9);
3636
}
3737

3838
#[test]
39-
fn qr_real() {
40-
qr_::<f64>();
39+
fn mgs_over() {
40+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
41+
const N: usize = 4;
42+
let a: Array2<A> = random((N, N * 2));
43+
44+
// Terminate
45+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
46+
let a_sub = a.slice(s![.., 0..N]);
47+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
48+
let qc: Array2<A> = conjugate(&q);
49+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
50+
51+
// Skip
52+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
53+
let a_sub = a.slice(s![.., 0..N]);
54+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
55+
let qc: Array2<A> = conjugate(&q);
56+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
57+
58+
// Full
59+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
60+
assert_close_l2!(&q.dot(&r), &a, rtol);
61+
let qc: Array2<A> = conjugate(&q);
62+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
63+
}
64+
65+
test::<f32>(1e-5);
66+
test::<f64>(1e-9);
67+
test::<c32>(1e-5);
68+
test::<c64>(1e-9);
4169
}
4270

4371
#[test]
44-
fn qr_complex() {
45-
qr_::<c64>();
46-
}
47-
48-
fn qr_over<A: Scalar + Lapack>() {
49-
const N: usize = 4;
50-
let rtol: A::Real = A::real(1e-9);
51-
52-
let a: Array2<A> = random((N, N * 2));
53-
54-
// Terminate
55-
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
56-
let a_sub = a.slice(s![.., 0..N]);
57-
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
58-
let qc: Array2<A> = conjugate(&q);
59-
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
60-
61-
// Skip
62-
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
63-
let a_sub = a.slice(s![.., 0..N]);
64-
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
65-
let qc: Array2<A> = conjugate(&q);
66-
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
67-
68-
// Full
69-
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
70-
assert_close_l2!(&q.dot(&r), &a, rtol);
71-
let qc: Array2<A> = conjugate(&q);
72-
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
72+
fn householder_full() {
73+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
74+
const N: usize = 5;
75+
let a: Array2<A> = random((N, N));
76+
let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
77+
assert_close_l2!(&q.dot(&r), &a, rtol);
78+
let qc: Array2<A> = conjugate(&q);
79+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
80+
}
81+
82+
test::<f32>(1e-5);
83+
test::<f64>(1e-9);
84+
test::<c32>(1e-5);
85+
test::<c64>(1e-9);
7386
}
7487

7588
#[test]
76-
fn qr_over_real() {
77-
qr_over::<f64>();
89+
fn householder_half() {
90+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
91+
const N: usize = 4;
92+
let a: Array2<A> = random((N, N / 2));
93+
let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
94+
assert_close_l2!(&q.dot(&r), &a, rtol);
95+
let qc: Array2<A> = conjugate(&q);
96+
assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol);
97+
}
98+
99+
test::<f32>(1e-5);
100+
test::<f64>(1e-9);
101+
test::<c32>(1e-5);
102+
test::<c64>(1e-9);
78103
}
79104

80105
#[test]
81-
fn qr_over_complex() {
82-
qr_over::<c64>();
106+
fn householder_over() {
107+
fn test<A: Scalar + Lapack>(rtol: A::Real) {
108+
const N: usize = 4;
109+
let a: Array2<A> = random((N, N * 2));
110+
111+
// Terminate
112+
let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
113+
let a_sub = a.slice(s![.., 0..N]);
114+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
115+
let qc: Array2<A> = conjugate(&q);
116+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
117+
118+
// Skip
119+
let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
120+
let a_sub = a.slice(s![.., 0..N]);
121+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
122+
let qc: Array2<A> = conjugate(&q);
123+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
124+
125+
// Full
126+
let (q, r) = householder(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
127+
assert_close_l2!(&q.dot(&r), &a, rtol);
128+
let qc: Array2<A> = conjugate(&q);
129+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
130+
}
131+
132+
test::<f32>(1e-5);
133+
test::<f64>(1e-9);
134+
test::<c32>(1e-5);
135+
test::<c64>(1e-9);
83136
}

0 commit comments

Comments
 (0)