Skip to content

Commit dafa8bb

Browse files
committed
Test implementation of iterative QR decomposition
1 parent f7ff6a0 commit dafa8bb

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

src/arnoldi.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use crate::{generate::*, norm::Norm, types::*};
2+
use ndarray::*;
3+
4+
#[derive(Debug, Clone)]
5+
pub struct MGS<A> {
6+
dim: usize,
7+
q: Vec<Array1<A>>,
8+
r: Vec<Array1<A>>,
9+
}
10+
11+
impl<A: Scalar + Lapack> MGS<A> {
12+
pub fn new(dim: usize) -> Self {
13+
Self {
14+
dim,
15+
q: Vec::new(),
16+
r: Vec::new(),
17+
}
18+
}
19+
20+
pub fn dim(&self) -> usize {
21+
self.dim
22+
}
23+
24+
pub fn len(&self) -> usize {
25+
self.q.len()
26+
}
27+
28+
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>) -> A::Real
29+
where
30+
S: Data<Elem = A>,
31+
{
32+
assert_eq!(a.len(), self.dim());
33+
let mut a = a.into_owned();
34+
let mut coef = Array1::zeros(self.len() + 1);
35+
for i in 0..self.len() {
36+
let q = &self.q[i];
37+
let c = a.dot(q);
38+
azip!(mut a, q (q) in { *a = *a - c * q } );
39+
coef[i] = c;
40+
}
41+
let nrm = a.norm_l2();
42+
coef[self.len()] = A::from_real(nrm);
43+
self.r.push(coef);
44+
azip!(mut a in { *a = *a / A::from_real(nrm) });
45+
self.q.push(a);
46+
nrm
47+
}
48+
49+
pub fn get_q(&self) -> Array2<A> {
50+
hstack(&self.q).unwrap()
51+
}
52+
53+
pub fn get_r(&self) -> Array2<A> {
54+
let len = self.len();
55+
let mut r = Array2::zeros((len, len));
56+
for i in 0..len {
57+
for j in 0..=i {
58+
r[(j, i)] = self.r[i][j];
59+
}
60+
}
61+
r
62+
}
63+
}
64+
65+
#[cfg(test)]
66+
mod tests {
67+
use super::*;
68+
use crate::assert::*;
69+
70+
const N: usize = 5;
71+
72+
#[test]
73+
fn new() {
74+
let mgs: MGS<f32> = MGS::new(N);
75+
assert_eq!(mgs.dim(), N);
76+
assert_eq!(mgs.len(), 0);
77+
}
78+
79+
#[test]
80+
fn append_random() {
81+
let mut mgs: MGS<f64> = MGS::new(N);
82+
let a: Array2<f64> = random((N, 3));
83+
dbg!(&a);
84+
for col in a.axis_iter(Axis(1)) {
85+
let res = mgs.append(col);
86+
dbg!(res);
87+
}
88+
let q = mgs.get_q();
89+
dbg!(&q);
90+
let r = mgs.get_r();
91+
dbg!(&r);
92+
93+
dbg!(q.dot(&r));
94+
close_l2(&q.dot(&r), &a, 1e-9).unwrap();
95+
96+
dbg!(q.t().dot(&q));
97+
close_l2(&q.t().dot(&q), &Array2::eye(3), 1e-9).unwrap();
98+
}
99+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
//! - [Random matrix generators](generate/index.html)
3737
//! - [Scalar trait](types/trait.Scalar.html)
3838
39+
pub mod arnoldi;
3940
pub mod assert;
4041
pub mod cholesky;
4142
pub mod convert;

0 commit comments

Comments
 (0)