Skip to content

Commit a6ac4b3

Browse files
committed
Rename EigResult to LobpcgResult
1 parent 564bb77 commit a6ac4b3

File tree

4 files changed

+63
-61
lines changed

4 files changed

+63
-61
lines changed

src/lobpcg/eig.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::lobpcg::{lobpcg, EigResult, Order};
1+
use super::lobpcg::{lobpcg, LobpcgResult, Order};
22
use crate::{Lapack, Scalar};
33
///! Implements truncated eigenvalue decomposition
44
///
@@ -61,7 +61,7 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
6161
}
6262

6363
// calculate the eigenvalues decompose
64-
pub fn decompose(&self, num: usize) -> EigResult<A> {
64+
pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
6565
let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0))
6666
.mapv(|x| NumCast::from(x).unwrap());
6767

@@ -124,7 +124,7 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
124124
let res = self.eig.decompose(step_size);
125125

126126
match res {
127-
EigResult::Ok(vals, vecs, norms) | EigResult::Err(vals, vecs, norms, _) => {
127+
LobpcgResult::Ok(vals, vecs, norms) | LobpcgResult::Err(vals, vecs, norms, _) => {
128128
// abort if any eigenproblem did not converge
129129
for r_norm in norms {
130130
if r_norm > NumCast::from(0.1).unwrap() {
@@ -151,7 +151,7 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
151151

152152
Some((vals, vecs))
153153
}
154-
EigResult::NoResult(_) => None,
154+
LobpcgResult::NoResult(_) => None,
155155
}
156156
}
157157
}

src/lobpcg/lobpcg.rs

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{Lapack, Scalar};
88
use ndarray::prelude::*;
99
use ndarray::OwnedRepr;
1010
use ndarray::ScalarOperand;
11-
use num_traits::{NumCast, Float};
11+
use num_traits::{Float, NumCast};
1212

1313
/// Find largest or smallest eigenvalues
1414
#[derive(Debug, Clone)]
@@ -20,12 +20,12 @@ pub enum Order {
2020
/// The result of the eigensolver
2121
///
2222
/// In the best case the eigensolver has converged with a result better than the given threshold,
23-
/// then a `EigResult::Ok` gives the eigenvalues, eigenvectors and norms. If an error ocurred
24-
/// during the process, it is returned in `EigResult::Err`, but the best result is still returned,
25-
/// as it could be usable. If there is no result at all, then `EigResult::NoResult` is returned.
23+
/// then a `LobpcgResult::Ok` gives the eigenvalues, eigenvectors and norms. If an error ocurred
24+
/// during the process, it is returned in `LobpcgResult::Err`, but the best result is still returned,
25+
/// as it could be usable. If there is no result at all, then `LobpcgResult::NoResult` is returned.
2626
/// This happens if the algorithm fails in an early stage, for example if the matrix `A` is not SPD
2727
#[derive(Debug)]
28-
pub enum EigResult<A> {
28+
pub enum LobpcgResult<A> {
2929
Ok(Array1<A>, Array2<A>, Vec<A>),
3030
Err(Array1<A>, Array2<A>, Vec<A>, LinalgError),
3131
NoResult(LinalgError),
@@ -61,16 +61,18 @@ fn sorted_eig<A: Scalar + Lapack>(
6161
fn ndarray_mask<A: Scalar>(matrix: ArrayView2<A>, mask: &[bool]) -> Array2<A> {
6262
assert_eq!(mask.len(), matrix.ncols());
6363

64-
let indices = (0..mask.len()).zip(mask.into_iter())
65-
.filter(|(_,b)| **b).map(|(a,_)| a)
64+
let indices = (0..mask.len())
65+
.zip(mask.into_iter())
66+
.filter(|(_, b)| **b)
67+
.map(|(a, _)| a)
6668
.collect::<Vec<usize>>();
6769

6870
matrix.select(Axis(1), &indices)
6971
}
7072

7173
/// Applies constraints ensuring that a matrix is orthogonal to it
7274
///
73-
/// This functions takes a matrix `v` and constraint matrix `y` and orthogonalize the `v` to `y`.
75+
/// This functions takes a matrix `v` and constraint-matrix `y` and orthogonalize `v` to `y`.
7476
fn apply_constraints<A: Scalar + Lapack>(
7577
mut v: ArrayViewMut<A, Ix2>,
7678
cholesky_yy: &CholeskyFactorized<OwnedRepr<A>>,
@@ -132,19 +134,23 @@ fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2
132134
/// * `maxiter` - The maximal number of iterations
133135
/// * `order` - Whether to solve for the largest or lowest eigenvalues
134136
///
135-
/// The function returns an `EigResult` with the eigenvalue/eigenvector and achieved residual norm
137+
/// The function returns an `LobpcgResult` with the eigenvalue/eigenvector and achieved residual norm
136138
/// for it. All iterations are tracked and the optimal solution returned. In case of an error a
137-
/// special variant `EigResult::NotConverged` additionally carries the error. This can happen when
139+
/// special variant `LobpcgResult::NotConverged` additionally carries the error. This can happen when
138140
/// the precision of the matrix is too low (switch then from `f32` to `f64` for example).
139-
pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default, F: Fn(ArrayView2<A>) -> Array2<A>, G: Fn(ArrayViewMut2<A>)>(
141+
pub fn lobpcg<
142+
A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
143+
F: Fn(ArrayView2<A>) -> Array2<A>,
144+
G: Fn(ArrayViewMut2<A>),
145+
>(
140146
a: F,
141147
mut x: Array2<A>,
142148
m: G,
143149
y: Option<Array2<A>>,
144150
tol: A::Real,
145151
maxiter: usize,
146152
order: Order,
147-
) -> EigResult<A> {
153+
) -> LobpcgResult<A> {
148154
// the initital approximation should be maximal square
149155
// n is the dimensionality of the problem
150156
let (n, size_x) = (x.nrows(), x.ncols());
@@ -172,7 +178,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
172178
// orthonormalize the initial guess
173179
let (x, _) = match orthonormalize(x) {
174180
Ok(x) => x,
175-
Err(err) => return EigResult::NoResult(err),
181+
Err(err) => return LobpcgResult::NoResult(err),
176182
};
177183

178184
// calculate AX and XAX for Rayleigh quotient
@@ -182,7 +188,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
182188
// perform eigenvalue decomposition of XAX
183189
let (mut lambda, eig_block) = match sorted_eig(xax.view(), None, size_x, &order) {
184190
Ok(x) => x,
185-
Err(err) => return EigResult::NoResult(err),
191+
Err(err) => return LobpcgResult::NoResult(err),
186192
};
187193

188194
// initiate approximation of the eigenvector
@@ -219,12 +225,20 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
219225

220226
// compare best result and update if we improved
221227
let sum_rnorm: A::Real = residual_norms.iter().cloned().sum();
222-
if best_result.as_ref().map(|x: &(_,_,Vec<A::Real>)| x.2.iter().cloned().sum::<A::Real>() > sum_rnorm).unwrap_or(true) {
228+
if best_result
229+
.as_ref()
230+
.map(|x: &(_, _, Vec<A::Real>)| x.2.iter().cloned().sum::<A::Real>() > sum_rnorm)
231+
.unwrap_or(true)
232+
{
223233
best_result = Some((lambda.clone(), x.clone(), residual_norms.clone()));
224234
}
225235

226236
// disable eigenvalues which are below the tolerance threshold
227-
activemask = residual_norms.iter().zip(activemask.iter()).map(|(x, a)| *x > tol && *a).collect();
237+
activemask = residual_norms
238+
.iter()
239+
.zip(activemask.iter())
240+
.map(|(x, a)| *x > tol && *a)
241+
.collect();
228242

229243
// resize identity block if necessary
230244
let current_block_size = activemask.iter().filter(|x| **x).count();
@@ -279,23 +293,19 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
279293
rar = (&rar + &rar.t()) / two;
280294
let xax = x.t().dot(&ax);
281295

282-
(
283-
(&xax + &xax.t()) / two,
284-
x.t().dot(&x),
285-
r.t().dot(&r),
286-
x.t().dot(&r)
287-
)
296+
((&xax + &xax.t()) / two, x.t().dot(&x), r.t().dot(&r), x.t().dot(&r))
288297
} else {
289298
(
290299
lambda_diag,
291300
ident0.clone(),
292301
ident.clone(),
293-
Array2::zeros((size_x, current_block_size))
302+
Array2::zeros((size_x, current_block_size)),
294303
)
295304
};
296305

297306
// mask and orthonormalize P and AP
298-
let p_ap = previous_p_ap.as_ref()
307+
let p_ap = previous_p_ap
308+
.as_ref()
299309
.and_then(|(p, ap)| {
300310
let active_p = ndarray_mask(p.view(), &activemask);
301311
let active_ap = ndarray_mask(ap.view(), &activemask);
@@ -318,10 +328,7 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
318328
let xp = x.t().dot(active_p);
319329
let rp = r.t().dot(active_p);
320330
let (pap, pp) = if explicit_gram_flag {
321-
(
322-
(&pap + &pap.t()) / two,
323-
active_p.t().dot(active_p)
324-
)
331+
((&pap + &pap.t()) / two, active_p.t().dot(active_p))
325332
} else {
326333
(pap, ident.clone())
327334
};
@@ -342,16 +349,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
342349
)
343350
} else {
344351
(
345-
stack![
346-
Axis(0),
347-
stack![Axis(1), xax, xar],
348-
stack![Axis(1), xar.t(), rar]
349-
],
350-
stack![
351-
Axis(0),
352-
stack![Axis(1), xx, xr],
353-
stack![Axis(1), xr.t(), rr]
354-
],
352+
stack![Axis(0), stack![Axis(1), xax, xar], stack![Axis(1), xar.t(), rar]],
353+
stack![Axis(0), stack![Axis(1), xx, xr], stack![Axis(1), xr.t(), rr]],
355354
)
356355
};
357356

@@ -363,16 +362,16 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
363362
if previous_p_ap.is_some() {
364363
previous_p_ap = None;
365364
continue;
366-
} else { // or break if restart is not possible
365+
} else {
366+
// or break if restart is not possible
367367
break Err(err);
368368
}
369369
}
370370
};
371371
lambda = new_lambda;
372372

373373
// approximate eigenvector X and conjugate vectors P with solution of eigenproblem
374-
let (p, ap, tau) = if let Some((active_p, active_ap)) = p_ap
375-
{
374+
let (p, ap, tau) = if let Some((active_p, active_ap)) = p_ap {
376375
// tau are eigenvalues to basis of X
377376
let tau = eig_vecs.slice(s![..size_x, ..]);
378377
// alpha are eigenvalues to basis of R
@@ -414,8 +413,8 @@ pub fn lobpcg<A: Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default,
414413
dbg!(&residual_norms_history);
415414

416415
match final_norm {
417-
Ok(_) => EigResult::Ok(vals, vecs, rnorm),
418-
Err(err) => EigResult::Err(vals, vecs, rnorm, err)
416+
Ok(_) => LobpcgResult::Ok(vals, vecs, rnorm),
417+
Err(err) => LobpcgResult::Err(vals, vecs, rnorm, err),
419418
}
420419
}
421420

@@ -425,7 +424,7 @@ mod tests {
425424
use super::ndarray_mask;
426425
use super::orthonormalize;
427426
use super::sorted_eig;
428-
use super::EigResult;
427+
use super::LobpcgResult;
429428
use super::Order;
430429
use crate::close_l2;
431430
use crate::qr::*;
@@ -486,7 +485,7 @@ mod tests {
486485
let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n * 2, order);
487486
dbg!(&result);
488487
match result {
489-
EigResult::Ok(vals, _, r_norms) | EigResult::Err(vals, _, r_norms, _) => {
488+
LobpcgResult::Ok(vals, _, r_norms) | LobpcgResult::Err(vals, _, r_norms, _) => {
490489
// check convergence
491490
for (i, norm) in r_norms.into_iter().enumerate() {
492491
if norm > 1e-5 {
@@ -501,7 +500,7 @@ mod tests {
501500
close_l2(&Array1::from(ground_truth_eigvals.to_vec()), &vals, num as f64 * 5e-4)
502501
}
503502
}
504-
EigResult::NoResult(err) => panic!("Did not converge: {:?}", err),
503+
LobpcgResult::NoResult(err) => panic!("Did not converge: {:?}", err),
505504
}
506505
}
507506

@@ -539,11 +538,15 @@ mod tests {
539538
let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
540539
let a = Array2::from_diag(&diag);
541540
let x: Array2<f64> = Array2::random((10, 1), Uniform::new(0.0, 1.0));
542-
let y: Array2<f64> = arr2(&[[1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.]]).reversed_axes();
541+
let y: Array2<f64> = arr2(&[
542+
[1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
543+
[0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.],
544+
])
545+
.reversed_axes();
543546

544547
let result = lobpcg(|y| a.dot(&y), x, |_| {}, Some(y), 1e-10, 50, Order::Smallest);
545548
match result {
546-
EigResult::Ok(vals, vecs, r_norms) | EigResult::Err(vals, vecs, r_norms, _) => {
549+
LobpcgResult::Ok(vals, vecs, r_norms) | LobpcgResult::Err(vals, vecs, r_norms, _) => {
547550
// check convergence
548551
for (i, norm) in r_norms.into_iter().enumerate() {
549552
if norm > 0.01 {
@@ -561,7 +564,7 @@ mod tests {
561564
1e-5,
562565
);
563566
}
564-
EigResult::NoResult(err) => panic!("Did not converge: {:?}", err),
567+
LobpcgResult::NoResult(err) => panic!("Did not converge: {:?}", err),
565568
}
566569
}
567570
}

src/lobpcg/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ mod lobpcg;
33
mod svd;
44

55
pub use eig::TruncatedEig;
6-
pub use lobpcg::{lobpcg, EigResult, Order as TruncatedOrder};
6+
pub use lobpcg::{lobpcg, LobpcgResult, Order as TruncatedOrder};
77
pub use svd::TruncatedSvd;

src/lobpcg/svd.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
///! Truncated singular value decomposition
2-
///!
3-
///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
4-
use super::lobpcg::{lobpcg, EigResult, Order};
2+
///!
3+
///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
4+
use super::lobpcg::{lobpcg, LobpcgResult, Order};
55
use crate::error::Result;
66
use crate::{Lapack, Scalar};
77
use ndarray::prelude::*;
@@ -129,8 +129,7 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
129129
let (n, m) = (self.problem.nrows(), self.problem.ncols());
130130

131131
// generate initial matrix
132-
let x = Array2::random((usize::min(n, m), num), Uniform::new(0.0, 1.0))
133-
.mapv(|x| NumCast::from(x).unwrap());
132+
let x = Array2::random((usize::min(n, m), num), Uniform::new(0.0, 1.0)).mapv(|x| NumCast::from(x).unwrap());
134133

135134
// square precision because the SVD squares the eigenvalue as well
136135
let precision = self.precision * self.precision;
@@ -160,13 +159,13 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
160159

161160
// convert into TruncatedSvdResult
162161
match res {
163-
EigResult::Ok(vals, vecs, _) | EigResult::Err(vals, vecs, _, _) => Ok(TruncatedSvdResult {
162+
LobpcgResult::Ok(vals, vecs, _) | LobpcgResult::Err(vals, vecs, _, _) => Ok(TruncatedSvdResult {
164163
problem: self.problem.clone(),
165164
eigvals: vals,
166165
eigvecs: vecs,
167166
ngm: n > m,
168167
}),
169-
EigResult::NoResult(err) => Err(err),
168+
LobpcgResult::NoResult(err) => Err(err),
170169
}
171170
}
172171
}

0 commit comments

Comments
 (0)