Skip to content

Commit 67c7327

Browse files
committed
append_if
1 parent c795b20 commit 67c7327

File tree

1 file changed

+52
-13
lines changed

1 file changed

+52
-13
lines changed

src/arnoldi.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::{generate::*, inner::*, norm::Norm, types::*};
22
use ndarray::*;
3+
use num_traits::Zero;
34

45
#[derive(Debug, Clone)]
56
pub struct MGS<A> {
@@ -8,6 +9,11 @@ pub struct MGS<A> {
89
r: Vec<Array1<A>>,
910
}
1011

12+
pub enum Dependence<A: Scalar> {
13+
Orthogonal(A::Real),
14+
Coefficient(Array1<A>),
15+
}
16+
1117
impl<A: Scalar + Lapack> MGS<A> {
1218
pub fn new(dim: usize) -> Self {
1319
Self {
@@ -25,7 +31,16 @@ impl<A: Scalar + Lapack> MGS<A> {
2531
self.q.len()
2632
}
2733

34+
/// Add new vector, return residual norm
2835
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>) -> A::Real
36+
where
37+
S: Data<Elem = A>,
38+
{
39+
self.append_if(a, A::Real::zero()).unwrap()
40+
}
41+
42+
/// Add new vector if the residual is larger than relative tolerance
43+
pub fn append_if<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Option<A::Real>
2944
where
3045
S: Data<Elem = A>,
3146
{
@@ -39,11 +54,14 @@ impl<A: Scalar + Lapack> MGS<A> {
3954
coef[i] = c;
4055
}
4156
let nrm = a.norm_l2();
57+
if nrm < rtol {
58+
return None;
59+
}
4260
coef[self.len()] = A::from_real(nrm);
4361
self.r.push(coef);
4462
azip!(mut a in { *a = *a / A::from_real(nrm) });
4563
self.q.push(a);
46-
nrm
64+
Some(nrm)
4765
}
4866

4967
/// Get orthogonal basis as Q matrix
@@ -78,7 +96,7 @@ mod tests {
7896
assert_eq!(mgs.len(), 0);
7997
}
8098

81-
fn test<A: Scalar + Lapack>(rtol: A::Real) {
99+
fn test_append<A: Scalar + Lapack>(rtol: A::Real) {
82100
let mut mgs: MGS<A> = MGS::new(N);
83101
let a: Array2<A> = crate::generate::random((N, 3));
84102
dbg!(&a);
@@ -100,22 +118,43 @@ mod tests {
100118
}
101119

102120
#[test]
103-
fn test_f32() {
104-
test::<f32>(1e-5);
121+
fn append() {
122+
test_append::<f32>(1e-5);
123+
test_append::<c32>(1e-5);
124+
test_append::<f64>(1e-9);
125+
test_append::<c64>(1e-9);
105126
}
106127

107-
#[test]
108-
fn test_c32() {
109-
test::<c32>(1e-5);
110-
}
128+
fn test_append_if<A: Scalar + Lapack>(rtol: A::Real) {
129+
let mut mgs: MGS<A> = MGS::new(N);
130+
let a: Array2<A> = crate::generate::random((N, 8));
131+
dbg!(&a);
132+
for col in a.axis_iter(Axis(1)) {
133+
match mgs.append_if(col, rtol) {
134+
Some(res) => {
135+
dbg!(res);
136+
}
137+
None => break,
138+
}
139+
}
140+
let q = mgs.get_q();
141+
dbg!(&q);
142+
let r = mgs.get_r();
143+
dbg!(&r);
111144

112-
#[test]
113-
fn test_f64() {
114-
test::<f64>(1e-9);
145+
dbg!(q.dot(&r));
146+
close_l2(&q.dot(&r), &a.slice(s![.., 0..N]), rtol).unwrap();
147+
148+
let qt: Array2<_> = conjugate(&q);
149+
dbg!(qt.dot(&q));
150+
close_l2(&qt.dot(&q), &Array2::eye(N), rtol).unwrap();
115151
}
116152

117153
#[test]
118-
fn test_c64() {
119-
test::<c64>(1e-9);
154+
fn append_if() {
155+
test_append_if::<f32>(1e-5);
156+
test_append_if::<c32>(1e-5);
157+
test_append_if::<f64>(1e-9);
158+
test_append_if::<c64>(1e-9);
120159
}
121160
}

0 commit comments

Comments
 (0)