Skip to content

Commit 55d0e8c

Browse files
authored
Merge pull request #149 from rust-ndarray/mgs
modified Gram-Schmit
2 parents 081fec3 + 8cd17de commit 55d0e8c

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub mod generate;
4646
pub mod inner;
4747
pub mod lapack;
4848
pub mod layout;
49+
pub mod mgs;
4950
pub mod norm;
5051
pub mod operator;
5152
pub mod opnorm;

src/mgs.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//! Modified Gram-Schmit orthogonalizer
2+
3+
use crate::{generate::*, inner::*, norm::Norm, types::*};
4+
use ndarray::*;
5+
6+
/// Iterative orthogonalizer using modified Gram-Schmit procedure
7+
#[derive(Debug, Clone)]
8+
pub struct MGS<A> {
9+
/// Dimension of base space
10+
dimension: usize,
11+
/// Basis of spanned space
12+
q: Vec<Array1<A>>,
13+
}
14+
15+
/// Q-matrix
16+
///
17+
/// - Maybe **NOT** square
18+
/// - Unitary for existing columns
19+
///
20+
pub type Q<A> = Array2<A>;
21+
22+
/// R-matrix
23+
///
24+
/// - Maybe **NOT** square
25+
/// - Upper triangle
26+
///
27+
pub type R<A> = Array2<A>;
28+
29+
impl<A: Scalar> MGS<A> {
30+
/// Create an empty orthogonalizer
31+
pub fn new(dimension: usize) -> Self {
32+
Self {
33+
dimension,
34+
q: Vec::new(),
35+
}
36+
}
37+
38+
/// Dimension of input array
39+
pub fn dim(&self) -> usize {
40+
self.dimension
41+
}
42+
43+
/// Number of cached basis
44+
///
45+
/// ```rust
46+
/// # use ndarray::*;
47+
/// # use ndarray_linalg::{mgs::*, *};
48+
/// const N: usize = 3;
49+
/// let mut mgs = MGS::<f32>::new(N);
50+
/// assert_eq!(mgs.dim(), N);
51+
/// assert_eq!(mgs.len(), 0);
52+
///
53+
/// mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
54+
/// assert_eq!(mgs.len(), 1);
55+
/// ```
56+
pub fn len(&self) -> usize {
57+
self.q.len()
58+
}
59+
60+
/// Orthogonalize given vector using current basis
61+
///
62+
/// Panic
63+
/// -------
64+
/// - if the size of the input array mismatches to the dimension
65+
///
66+
pub fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
67+
where
68+
A: Lapack,
69+
S: DataMut<Elem = A>,
70+
{
71+
assert_eq!(a.len(), self.dim());
72+
let mut coef = Array1::zeros(self.len() + 1);
73+
for i in 0..self.len() {
74+
let q = &self.q[i];
75+
let c = q.inner(&a);
76+
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
77+
coef[i] = c;
78+
}
79+
let nrm = a.norm_l2();
80+
coef[self.len()] = A::from_real(nrm);
81+
coef
82+
}
83+
84+
/// Add new vector if the residual is larger than relative tolerance
85+
///
86+
/// ```rust
87+
/// # use ndarray::*;
88+
/// # use ndarray_linalg::{mgs::*, *};
89+
/// let mut mgs = MGS::new(3);
90+
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
91+
/// close_l2(&coef, &array![1.0], 1e-9);
92+
///
93+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
94+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
95+
///
96+
/// // Fail if the vector is linearly dependent
97+
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
98+
///
99+
/// // You can get coefficients of dependent vector
100+
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
101+
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
102+
/// }
103+
/// ```
104+
///
105+
/// Panic
106+
/// -------
107+
/// - if the size of the input array mismatches to the dimension
108+
///
109+
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
110+
where
111+
A: Lapack,
112+
S: Data<Elem = A>,
113+
{
114+
let mut a = a.into_owned();
115+
let coef = self.orthogonalize(&mut a);
116+
let nrm = coef[coef.len() - 1].re();
117+
if nrm < rtol {
118+
// Linearly dependent
119+
return Err(coef);
120+
}
121+
azip!(mut a in { *a = *a / A::from_real(nrm) });
122+
self.q.push(a);
123+
Ok(coef)
124+
}
125+
126+
/// Get orthogonal basis as Q matrix
127+
pub fn get_q(&self) -> Q<A> {
128+
hstack(&self.q).unwrap()
129+
}
130+
}
131+
132+
/// Strategy for linearly dependent vectors appearing in iterative QR decomposition
133+
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
134+
pub enum Strategy {
135+
/// Terminate iteration if dependent vector comes
136+
Terminate,
137+
138+
/// Skip dependent vector
139+
Skip,
140+
141+
/// Orthogonalize dependent vector without adding to Q,
142+
/// i.e. R must be non-square like following:
143+
///
144+
/// ```text
145+
/// x x x x x
146+
/// 0 x x x x
147+
/// 0 0 0 x x
148+
/// 0 0 0 0 x
149+
/// ```
150+
Full,
151+
}
152+
153+
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
154+
pub fn mgs<A, S>(
155+
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
156+
dim: usize,
157+
rtol: A::Real,
158+
strategy: Strategy,
159+
) -> (Q<A>, R<A>)
160+
where
161+
A: Scalar + Lapack,
162+
S: Data<Elem = A>,
163+
{
164+
let mut ortho = MGS::new(dim);
165+
let mut coefs = Vec::new();
166+
for a in iter {
167+
match ortho.append(a, rtol) {
168+
Ok(coef) => coefs.push(coef),
169+
Err(coef) => match strategy {
170+
Strategy::Terminate => break,
171+
Strategy::Skip => continue,
172+
Strategy::Full => coefs.push(coef),
173+
},
174+
}
175+
}
176+
let n = ortho.len();
177+
let m = coefs.len();
178+
let mut r = Array2::zeros((n, m).f());
179+
for j in 0..m {
180+
for i in 0..n {
181+
if i < coefs[j].len() {
182+
r[(i, j)] = coefs[j][i];
183+
}
184+
}
185+
}
186+
(ortho.get_q(), r)
187+
}

tests/mgs.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
use ndarray::*;
2+
use ndarray_linalg::{mgs::*, *};
3+
4+
fn qr_full<A: Scalar + Lapack>() {
5+
const N: usize = 5;
6+
let rtol: A::Real = A::real(1e-9);
7+
8+
let a: Array2<A> = random((N, N));
9+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
10+
assert_close_l2!(&q.dot(&r), &a, rtol);
11+
12+
let qc: Array2<A> = conjugate(&q);
13+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
14+
}
15+
16+
#[test]
17+
fn qr_full_real() {
18+
qr_full::<f64>();
19+
}
20+
21+
#[test]
22+
fn qr_full_complex() {
23+
qr_full::<c64>();
24+
}
25+
26+
fn qr<A: Scalar + Lapack>() {
27+
const N: usize = 4;
28+
let rtol: A::Real = A::real(1e-9);
29+
30+
let a: Array2<A> = random((N, N / 2));
31+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
32+
assert_close_l2!(&q.dot(&r), &a, rtol);
33+
34+
let qc: Array2<A> = conjugate(&q);
35+
assert_close_l2!(&qc.dot(&q), &Array::eye(N / 2), rtol);
36+
}
37+
38+
#[test]
39+
fn qr_real() {
40+
qr::<f64>();
41+
}
42+
43+
#[test]
44+
fn qr_complex() {
45+
qr::<c64>();
46+
}
47+
48+
fn qr_over<A: Scalar + Lapack>() {
49+
const N: usize = 4;
50+
let rtol: A::Real = A::real(1e-9);
51+
52+
let a: Array2<A> = random((N, N * 2));
53+
54+
// Terminate
55+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
56+
let a_sub = a.slice(s![.., 0..N]);
57+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
58+
let qc: Array2<A> = conjugate(&q);
59+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
60+
61+
// Skip
62+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
63+
let a_sub = a.slice(s![.., 0..N]);
64+
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
65+
let qc: Array2<A> = conjugate(&q);
66+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
67+
68+
// Full
69+
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
70+
assert_close_l2!(&q.dot(&r), &a, rtol);
71+
let qc: Array2<A> = conjugate(&q);
72+
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
73+
}
74+
75+
#[test]
76+
fn qr_over_real() {
77+
qr_over::<f64>();
78+
}
79+
80+
#[test]
81+
fn qr_over_complex() {
82+
qr_over::<c64>();
83+
}

0 commit comments

Comments
 (0)