Skip to content

Commit 755f654

Browse files
committed
Rewrite tests for MGS::append
1 parent a7543f5 commit 755f654

File tree

1 file changed

+51
-75
lines changed

1 file changed

+51
-75
lines changed

src/arnoldi.rs

Lines changed: 51 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,25 @@ pub struct MGS<A> {
1010
q: Vec<Array1<A>>,
1111
}
1212

13+
/// Residual vector of orthogonalization
1314
pub type Residual<S> = ArrayBase<S, Ix1>;
15+
/// Residual vector of orthogonalization
1416
pub type Coefficient<A> = Array1<A>;
17+
/// Q-matrix (unitary matrix)
18+
pub type Q<A> = Array2<A>;
19+
/// R-matrix (upper triangle)
20+
pub type R<A> = Array2<A>;
1521

16-
impl<A: Scalar + Lapack> MGS<A> {
22+
impl<A: Scalar> MGS<A> {
23+
/// Create empty linear space
24+
///
25+
/// ```rust
26+
/// # use ndarray_linalg::*;
27+
/// const N: usize = 5;
28+
/// let mgs = arnoldi::MGS::<f32>::new(N);
29+
/// assert_eq!(mgs.dim(), N);
30+
/// assert_eq!(mgs.len(), 0);
31+
/// ```
1732
pub fn new(dimension: usize) -> Self {
1833
Self {
1934
dimension,
@@ -36,6 +51,7 @@ impl<A: Scalar + Lapack> MGS<A> {
3651
/// - if the size of the input array mismaches to the dimension
3752
pub fn orthogonalize<S>(&self, mut a: ArrayBase<S, Ix1>) -> (Residual<S>, Coefficient<A>)
3853
where
54+
A: Lapack,
3955
S: DataMut<Elem = A>,
4056
{
4157
assert_eq!(a.len(), self.dim());
@@ -56,13 +72,27 @@ impl<A: Scalar + Lapack> MGS<A> {
5672
/// Panic
5773
/// -------
5874
/// - if the size of the input array mismaches to the dimension
59-
pub fn append_if_independent<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Option<Coefficient<A>>
75+
///
76+
/// ```rust
77+
/// # use ndarray::*;
78+
/// # use ndarray_linalg::*;
79+
/// let mut mgs = arnoldi::MGS::new(3);
80+
/// let coef = mgs.append(array![1.0, 0.0, 0.0], 1e-9).unwrap();
81+
/// close_l2(&coef, &array![1.0], 1e-9).unwrap();
82+
///
83+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
84+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9).unwrap();
85+
///
86+
/// assert!(mgs.append(array![1.0, 1.0, 0.0], 1e-9).is_none()); // Cannot append dependent vector
87+
/// ```
88+
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Option<Coefficient<A>>
6089
where
90+
A: Lapack,
6191
S: Data<Elem = A>,
6292
{
6393
let a = a.into_owned();
6494
let (mut a, coef) = self.orthogonalize(a);
65-
let nrm = coef[coef.len()].re();
95+
let nrm = coef[coef.len() - 1].re();
6696
if nrm < rtol {
6797
// Linearly dependent
6898
return None;
@@ -73,84 +103,30 @@ impl<A: Scalar + Lapack> MGS<A> {
73103
}
74104

75105
/// Get orthogonal basis as Q matrix
76-
pub fn get_q(&self) -> Array2<A> {
106+
pub fn get_q(&self) -> Q<A> {
77107
hstack(&self.q).unwrap()
78108
}
79109
}
80110

81-
#[cfg(test)]
82-
mod tests {
83-
use super::*;
84-
use crate::assert::*;
85-
86-
const N: usize = 5;
87-
88-
#[test]
89-
fn new() {
90-
let mgs: MGS<f32> = MGS::new(N);
91-
assert_eq!(mgs.dim(), N);
92-
assert_eq!(mgs.len(), 0);
93-
}
94-
95-
fn test_append<A: Scalar + Lapack>(rtol: A::Real) {
96-
let mut mgs: MGS<A> = MGS::new(N);
97-
let a: Array2<A> = crate::generate::random((N, 3));
98-
dbg!(&a);
99-
for col in a.axis_iter(Axis(1)) {
100-
let res = mgs.append(col);
101-
dbg!(res);
111+
pub fn mgs<A, S>(iter: impl Iterator<Item = ArrayBase<S, Ix1>>, dim: usize, rtol: A::Real) -> (Q<A>, R<A>)
112+
where
113+
A: Scalar + Lapack,
114+
S: Data<Elem = A>,
115+
{
116+
let mut ortho = MGS::new(dim);
117+
let mut coefs = Vec::new();
118+
for a in iter {
119+
match ortho.append(a, rtol) {
120+
Some(coef) => coefs.push(coef),
121+
None => break,
102122
}
103-
let q = mgs.get_q();
104-
dbg!(&q);
105-
let r = mgs.get_r();
106-
dbg!(&r);
107-
108-
dbg!(q.dot(&r));
109-
close_l2(&q.dot(&r), &a, rtol).unwrap();
110-
111-
let qt: Array2<_> = conjugate(&q);
112-
dbg!(qt.dot(&q));
113-
close_l2(&qt.dot(&q), &Array2::eye(3), rtol).unwrap();
114-
}
115-
116-
#[test]
117-
fn append() {
118-
test_append::<f32>(1e-5);
119-
test_append::<c32>(1e-5);
120-
test_append::<f64>(1e-9);
121-
test_append::<c64>(1e-9);
122123
}
123-
124-
fn test_append_if<A: Scalar + Lapack>(rtol: A::Real) {
125-
let mut mgs: MGS<A> = MGS::new(N);
126-
let a: Array2<A> = crate::generate::random((N, 8));
127-
dbg!(&a);
128-
for col in a.axis_iter(Axis(1)) {
129-
match mgs.append_if(col, rtol) {
130-
Some(res) => {
131-
dbg!(res);
132-
}
133-
None => break,
134-
}
124+
let m = coefs.len();
125+
let mut r = Array2::zeros((m, m));
126+
for i in 0..m {
127+
for j in 0..=i {
128+
r[(j, i)] = coefs[i][j];
135129
}
136-
let q = mgs.get_q();
137-
dbg!(&q);
138-
let r = mgs.get_r();
139-
dbg!(&r);
140-
141-
dbg!(q.dot(&r));
142-
close_l2(&q.dot(&r), &a.slice(s![.., 0..N]), rtol).unwrap();
143-
144-
let qt: Array2<_> = conjugate(&q);
145-
dbg!(qt.dot(&q));
146-
close_l2(&qt.dot(&q), &Array2::eye(N), rtol).unwrap();
147-
}
148-
149-
#[test]
150-
fn append_if() {
151-
test_append_if::<f32>(1e-5);
152-
test_append_if::<c32>(1e-5);
153-
test_append_if::<f64>(1e-9);
154-
test_append_if::<c64>(1e-9);
155130
}
131+
(ortho.get_q(), r)
156132
}

0 commit comments

Comments
 (0)