Skip to content

Commit 3e912da

Browse files
committed
impl Orthogonalizer for MGS
1 parent e79d4b9 commit 3e912da

File tree

2 files changed

+31
-75
lines changed

2 files changed

+31
-75
lines changed

src/krylov/mgs.rs

Lines changed: 30 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@ use super::*;
44
use crate::{generate::*, inner::*, norm::Norm};
55

66
/// Iterative orthogonalizer using modified Gram-Schmit procedure
7+
///
8+
/// ```rust
9+
/// # use ndarray::*;
10+
/// # use ndarray_linalg::{mgs::*, krylov::*, *};
11+
/// let mut mgs = MGS::new(3);
12+
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
13+
/// close_l2(&coef, &array![1.0], 1e-9);
14+
///
15+
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
16+
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
17+
///
18+
/// // Fail if the vector is linearly dependent
19+
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
20+
///
21+
/// // You can get coefficients of dependent vector
22+
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
23+
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
24+
/// }
25+
/// ```
726
#[derive(Debug, Clone)]
827
pub struct MGS<A> {
928
/// Dimension of base space
@@ -20,36 +39,20 @@ impl<A: Scalar> MGS<A> {
2039
q: Vec::new(),
2140
}
2241
}
42+
}
43+
44+
impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
45+
type Elem = A;
2346

24-
/// Dimension of input array
25-
pub fn dim(&self) -> usize {
47+
fn dim(&self) -> usize {
2648
self.dimension
2749
}
2850

29-
/// Number of cached basis
30-
///
31-
/// ```rust
32-
/// # use ndarray::*;
33-
/// # use ndarray_linalg::{mgs::*, *};
34-
/// const N: usize = 3;
35-
/// let mut mgs = MGS::<f32>::new(N);
36-
/// assert_eq!(mgs.dim(), N);
37-
/// assert_eq!(mgs.len(), 0);
38-
///
39-
/// mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
40-
/// assert_eq!(mgs.len(), 1);
41-
/// ```
42-
pub fn len(&self) -> usize {
51+
fn len(&self) -> usize {
4352
self.q.len()
4453
}
4554

46-
/// Orthogonalize given vector using current basis
47-
///
48-
/// Panic
49-
/// -------
50-
/// - if the size of the input array mismatches to the dimension
51-
///
52-
pub fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
55+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
5356
where
5457
A: Lapack,
5558
S: DataMut<Elem = A>,
@@ -67,32 +70,7 @@ impl<A: Scalar> MGS<A> {
6770
coef
6871
}
6972

70-
/// Add new vector if the residual is larger than relative tolerance
71-
///
72-
/// ```rust
73-
/// # use ndarray::*;
74-
/// # use ndarray_linalg::{mgs::*, *};
75-
/// let mut mgs = MGS::new(3);
76-
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
77-
/// close_l2(&coef, &array![1.0], 1e-9);
78-
///
79-
/// let coef = mgs.append(array![1.0, 1.0, 0.0], 1e-9).unwrap();
80-
/// close_l2(&coef, &array![1.0, 1.0], 1e-9);
81-
///
82-
/// // Fail if the vector is linearly dependent
83-
/// assert!(mgs.append(array![1.0, 2.0, 0.0], 1e-9).is_err());
84-
///
85-
/// // You can get coefficients of dependent vector
86-
/// if let Err(coef) = mgs.append(array![1.0, 2.0, 0.0], 1e-9) {
87-
/// close_l2(&coef, &array![2.0, 1.0, 0.0], 1e-9);
88-
/// }
89-
/// ```
90-
///
91-
/// Panic
92-
/// -------
93-
/// - if the size of the input array mismatches to the dimension
94-
///
95-
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
73+
fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
9674
where
9775
A: Lapack,
9876
S: Data<Elem = A>,
@@ -109,8 +87,7 @@ impl<A: Scalar> MGS<A> {
10987
Ok(coef)
11088
}
11189

112-
/// Get orthogonal basis as Q matrix
113-
pub fn get_q(&self) -> Q<A> {
90+
fn get_q(&self) -> Q<A> {
11491
hstack(&self.q).unwrap()
11592
}
11693
}
@@ -126,27 +103,6 @@ where
126103
A: Scalar + Lapack,
127104
S: Data<Elem = A>,
128105
{
129-
let mut ortho = MGS::new(dim);
130-
let mut coefs = Vec::new();
131-
for a in iter {
132-
match ortho.append(a, rtol) {
133-
Ok(coef) => coefs.push(coef),
134-
Err(coef) => match strategy {
135-
Strategy::Terminate => break,
136-
Strategy::Skip => continue,
137-
Strategy::Full => coefs.push(coef),
138-
},
139-
}
140-
}
141-
let n = ortho.len();
142-
let m = coefs.len();
143-
let mut r = Array2::zeros((n, m).f());
144-
for j in 0..m {
145-
for i in 0..n {
146-
if i < coefs[j].len() {
147-
r[(i, j)] = coefs[j][i];
148-
}
149-
}
150-
}
151-
(ortho.get_q(), r)
106+
let mgs = MGS::new(dim);
107+
qr(iter, mgs, rtol, strategy)
152108
}

tests/mgs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use ndarray::*;
2-
use ndarray_linalg::{mgs::*, *};
2+
use ndarray_linalg::{krylov::*, mgs::*, *};
33

44
fn qr_full<A: Scalar + Lapack>() {
55
const N: usize = 5;

0 commit comments

Comments
 (0)