Skip to content

Commit 5d74992

Browse files
committed
Split into krylov submodule
1 parent d48e215 commit 5d74992

File tree

5 files changed

+186
-82
lines changed

5 files changed

+186
-82
lines changed

src/krylov/householder.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use super::*;
2+
use crate::{inner::*, norm::*};
3+
4+
/// Iterative orthogonalizer using Householder reflection
5+
#[derive(Debug, Clone)]
6+
pub struct Householder<A> {
7+
dim: usize,
8+
v: Vec<Array1<A>>,
9+
}
10+
11+
impl<A: Scalar> Householder<A> {
12+
pub fn new(dim: usize) -> Self {
13+
Householder { dim, v: Vec::new() }
14+
}
15+
16+
/// Take a Reflection `P = I - 2ww^T`
17+
fn reflect<S: DataMut<Elem = A>>(&self, k: usize, a: &mut ArrayBase<S, Ix1>) {
18+
assert!(k < self.v.len());
19+
assert_eq!(a.len(), self.dim);
20+
let w = self.v[k].slice(s![k..]);
21+
let c = A::from(2.0).unwrap() * w.inner(&a.slice(s![k..]));
22+
for l in k..self.dim {
23+
a[l] -= c * w[l];
24+
}
25+
}
26+
}
27+
28+
impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
29+
type Elem = A;
30+
31+
fn new(dim: usize) -> Self {
32+
Self { dim, v: Vec::new() }
33+
}
34+
35+
fn dim(&self) -> usize {
36+
self.dim
37+
}
38+
39+
fn len(&self) -> usize {
40+
self.v.len()
41+
}
42+
43+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> A::Real
44+
where
45+
S: DataMut<Elem = A>,
46+
{
47+
for k in 0..self.len() {
48+
self.reflect(k, a);
49+
}
50+
// residual norm
51+
a.slice(s![self.len()..]).norm_l2()
52+
}
53+
54+
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
55+
where
56+
S: DataMut<Elem = A>,
57+
{
58+
let residual = self.orthogonalize(&mut a);
59+
let coef = a.slice(s![..self.len()]).into_owned();
60+
if residual < rtol {
61+
return Err(coef);
62+
}
63+
self.v.push(a.into_owned());
64+
Ok(coef)
65+
}
66+
67+
fn get_q(&self) -> Q<A> {
68+
unimplemented!()
69+
}
70+
}

src/krylov/mgs.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use super::*;
2+
use crate::{generate::*, inner::*, norm::Norm};
3+
4+
/// Iterative orthogonalizer using modified Gram-Schmit procedure
5+
#[derive(Debug, Clone)]
6+
pub struct MGS<A> {
7+
/// Dimension of base space
8+
dim: usize,
9+
/// Basis of spanned space
10+
q: Vec<Array1<A>>,
11+
}
12+
13+
impl<A: Scalar> MGS<A> {
14+
fn ortho<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
15+
where
16+
A: Lapack,
17+
S: DataMut<Elem = A>,
18+
{
19+
assert_eq!(a.len(), self.dim);
20+
let mut coef = Array1::zeros(self.q.len() + 1);
21+
for i in 0..self.q.len() {
22+
let q = &self.q[i];
23+
let c = q.inner(&a);
24+
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
25+
coef[i] = c;
26+
}
27+
let nrm = a.norm_l2();
28+
coef[self.q.len()] = A::from(nrm).unwrap();
29+
coef
30+
}
31+
}
32+
33+
impl<A: Scalar + Lapack> Orthogonalizer for MGS<A> {
34+
type Elem = A;
35+
36+
fn new(dim: usize) -> Self {
37+
Self { dim, q: Vec::new() }
38+
}
39+
40+
fn dim(&self) -> usize {
41+
self.dim
42+
}
43+
44+
fn len(&self) -> usize {
45+
self.q.len()
46+
}
47+
48+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> A::Real
49+
where
50+
S: DataMut<Elem = A>,
51+
{
52+
let coef = self.ortho(a);
53+
// Write coefficients into `a`
54+
azip!(mut a (a.slice_mut(s![0..self.len()])), coef in { *a = coef });
55+
// 0-fill for remaining
56+
azip!(mut a (a.slice_mut(s![self.len()..])) in { *a = A::zero() });
57+
coef[self.len()].re()
58+
}
59+
60+
fn append<S>(&mut self, mut a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
61+
where
62+
S: DataMut<Elem = A>,
63+
{
64+
let coef = self.ortho(&mut a);
65+
let nrm = coef[self.len()].re();
66+
if nrm < rtol {
67+
// Linearly dependent
68+
return Err(coef);
69+
}
70+
azip!(mut a in { *a = *a / A::from_real(nrm) });
71+
self.q.push(a.into_owned());
72+
Ok(coef)
73+
}
74+
75+
fn get_q(&self) -> Q<A> {
76+
hstack(&self.q).unwrap()
77+
}
78+
}

src/arnoldi.rs renamed to src/krylov/mod.rs

Lines changed: 28 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,42 @@
1-
use crate::{generate::*, inner::*, norm::Norm, types::*};
1+
use crate::types::*;
22
use ndarray::*;
33

4-
/// Iterative orthogonalizer using modified Gram-Schmit procedure
5-
#[derive(Debug, Clone)]
6-
pub struct MGS<A> {
7-
/// Dimension of base space
8-
dimension: usize,
9-
/// Basis of spanned space
10-
q: Vec<Array1<A>>,
11-
}
4+
mod householder;
5+
mod mgs;
6+
7+
pub use householder::Householder;
8+
pub use mgs::MGS;
129

1310
/// Q-matrix (unitary)
1411
pub type Q<A> = Array2<A>;
1512
/// R-matrix (upper triangle)
1613
pub type R<A> = Array2<A>;
1714

18-
impl<A: Scalar> MGS<A> {
19-
/// Create empty linear space
20-
///
21-
/// ```rust
22-
/// # use ndarray_linalg::*;
23-
/// const N: usize = 5;
24-
/// let mgs = arnoldi::MGS::<f32>::new(N);
25-
/// assert_eq!(mgs.dim(), N);
26-
/// assert_eq!(mgs.len(), 0);
27-
/// ```
28-
pub fn new(dimension: usize) -> Self {
29-
Self {
30-
dimension,
31-
q: Vec::new(),
32-
}
33-
}
15+
pub trait Orthogonalizer {
16+
type Elem: Scalar;
3417

35-
pub fn dim(&self) -> usize {
36-
self.dimension
37-
}
18+
/// Create empty linear space
19+
fn new(dim: usize) -> Self;
3820

39-
pub fn len(&self) -> usize {
40-
self.q.len()
41-
}
21+
fn dim(&self) -> usize;
22+
fn len(&self) -> usize;
4223

43-
/// Orthogonalize given vector using current basis
24+
/// Orthogonalize given vector
4425
///
4526
/// Panic
4627
/// -------
4728
/// - if the size of the input array mismaches to the dimension
48-
pub fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> Array1<A>
29+
///
30+
fn orthogonalize<S>(&self, a: &mut ArrayBase<S, Ix1>) -> <Self::Elem as Scalar>::Real
4931
where
50-
A: Lapack,
51-
S: DataMut<Elem = A>,
52-
{
53-
assert_eq!(a.len(), self.dim());
54-
let mut coef = Array1::zeros(self.len() + 1);
55-
for i in 0..self.len() {
56-
let q = &self.q[i];
57-
let c = q.inner(&a);
58-
azip!(mut a (&mut *a), q (q) in { *a = *a - c * q } );
59-
coef[i] = c;
60-
}
61-
let nrm = a.norm_l2();
62-
coef[self.len()] = A::from_real(nrm);
63-
coef
64-
}
32+
S: DataMut<Elem = Self::Elem>;
6533

6634
/// Add new vector if the residual is larger than relative tolerance
6735
///
6836
/// ```rust
6937
/// # use ndarray::*;
70-
/// # use ndarray_linalg::*;
71-
/// let mut mgs = arnoldi::MGS::new(3);
38+
/// # use ndarray_linalg::{*, krylov::*};
39+
/// let mut mgs = krylov::MGS::new(3);
7240
/// let coef = mgs.append(array![0.0, 1.0, 0.0], 1e-9).unwrap();
7341
/// close_l2(&coef, &array![1.0], 1e-9).unwrap();
7442
///
@@ -86,27 +54,16 @@ impl<A: Scalar> MGS<A> {
8654
/// -------
8755
/// - if the size of the input array mismaches to the dimension
8856
///
89-
pub fn append<S>(&mut self, a: ArrayBase<S, Ix1>, rtol: A::Real) -> Result<Array1<A>, Array1<A>>
57+
fn append<S>(
58+
&mut self,
59+
a: ArrayBase<S, Ix1>,
60+
rtol: <Self::Elem as Scalar>::Real,
61+
) -> Result<Array1<Self::Elem>, Array1<Self::Elem>>
9062
where
91-
A: Lapack,
92-
S: Data<Elem = A>,
93-
{
94-
let mut a = a.into_owned();
95-
let coef = self.orthogonalize(&mut a);
96-
let nrm = coef[coef.len() - 1].re();
97-
if nrm < rtol {
98-
// Linearly dependent
99-
return Err(coef);
100-
}
101-
azip!(mut a in { *a = *a / A::from_real(nrm) });
102-
self.q.push(a);
103-
Ok(coef)
104-
}
63+
S: DataMut<Elem = Self::Elem>;
10564

106-
/// Get orthogonal basis as Q matrix
107-
pub fn get_q(&self) -> Q<A> {
108-
hstack(&self.q).unwrap()
109-
}
65+
/// Get Q-matrix of generated basis
66+
fn get_q(&self) -> Q<Self::Elem>;
11067
}
11168

11269
/// Strategy for linearly dependent vectors appearing in iterative QR decomposition
@@ -132,7 +89,7 @@ pub enum Strategy {
13289
}
13390

13491
/// Online QR decomposition of vectors using modified Gram-Schmit algorithm
135-
pub fn mgs<A, S>(
92+
pub fn qr<A, S>(
13693
iter: impl Iterator<Item = ArrayBase<S, Ix1>>,
13794
dim: usize,
13895
rtol: A::Real,
@@ -145,7 +102,7 @@ where
145102
let mut ortho = MGS::new(dim);
146103
let mut coefs = Vec::new();
147104
for a in iter {
148-
match ortho.append(a, rtol) {
105+
match ortho.append(a.into_owned(), rtol) {
149106
Ok(coef) => coefs.push(coef),
150107
Err(coef) => match strategy {
151108
Strategy::Terminate => break,

src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
//! - [generator functions](generate/index.html)
1818
//! - [Scalar trait](types/trait.Scalar.html)
1919
20-
pub mod arnoldi;
2120
pub mod assert;
2221
pub mod cholesky;
2322
pub mod convert;
2423
pub mod diagonal;
2524
pub mod eigh;
2625
pub mod error;
2726
pub mod generate;
28-
pub mod householder;
2927
pub mod inner;
28+
pub mod krylov;
3029
pub mod lapack;
3130
pub mod layout;
3231
pub mod norm;

tests/arnoldi.rs renamed to tests/krylov.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use ndarray::*;
2-
use ndarray_linalg::{arnoldi::*, *};
2+
use ndarray_linalg::{krylov::*, *};
33

44
fn qr_full<A: Scalar + Lapack>() {
55
const N: usize = 5;
66
let rtol: A::Real = A::real(1e-9);
77

88
let a: Array2<A> = random((N, N));
9-
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
9+
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
1010
assert_close_l2!(&q.dot(&r), &a, rtol);
1111

1212
let qc: Array2<A> = conjugate(&q);
@@ -23,12 +23,12 @@ fn qr_full_complex() {
2323
qr_full::<c64>();
2424
}
2525

26-
fn qr<A: Scalar + Lapack>() {
26+
fn qr_<A: Scalar + Lapack>() {
2727
const N: usize = 4;
2828
let rtol: A::Real = A::real(1e-9);
2929

3030
let a: Array2<A> = random((N, N / 2));
31-
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
31+
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
3232
assert_close_l2!(&q.dot(&r), &a, rtol);
3333

3434
let qc: Array2<A> = conjugate(&q);
@@ -37,12 +37,12 @@ fn qr<A: Scalar + Lapack>() {
3737

3838
#[test]
3939
fn qr_real() {
40-
qr::<f64>();
40+
qr_::<f64>();
4141
}
4242

4343
#[test]
4444
fn qr_complex() {
45-
qr::<c64>();
45+
qr_::<c64>();
4646
}
4747

4848
fn qr_over<A: Scalar + Lapack>() {
@@ -52,21 +52,21 @@ fn qr_over<A: Scalar + Lapack>() {
5252
let a: Array2<A> = random((N, N * 2));
5353

5454
// Terminate
55-
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
55+
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Terminate);
5656
let a_sub = a.slice(s![.., 0..N]);
5757
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
5858
let qc: Array2<A> = conjugate(&q);
5959
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
6060

6161
// Skip
62-
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
62+
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Skip);
6363
let a_sub = a.slice(s![.., 0..N]);
6464
assert_close_l2!(&q.dot(&r), &a_sub, rtol);
6565
let qc: Array2<A> = conjugate(&q);
6666
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);
6767

6868
// Full
69-
let (q, r) = mgs(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
69+
let (q, r) = qr(a.axis_iter(Axis(1)), N, rtol, Strategy::Full);
7070
assert_close_l2!(&q.dot(&r), &a, rtol);
7171
let qc: Array2<A> = conjugate(&q);
7272
assert_close_l2!(&qc.dot(&q), &Array::eye(N), rtol);

0 commit comments

Comments
 (0)