Skip to content

Commit ae2ce6a

Browse files
committed
Remove dependency to ndarray-rand; Fix restarting issue
1 parent f8ad553 commit ae2ce6a

File tree

4 files changed

+19
-24
lines changed

4 files changed

+19
-24
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ num-traits = "0.2"
2828
cauchy = "0.2.1"
2929
num-complex = "0.2.1"
3030
rand = "0.5"
31-
ndarray-rand = "0.11"
3231

3332
[dependencies.ndarray]
3433
version = "0.13"

src/lobpcg/eig.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
use super::lobpcg::{lobpcg, LobpcgResult, Order};
2-
use crate::{Lapack, Scalar};
2+
use crate::{Lapack, Scalar, generate};
33
///! Implements truncated eigenvalue decomposition
44
///
55
use ndarray::prelude::*;
66
use ndarray::stack;
77
use ndarray::ScalarOperand;
8-
use ndarray_rand::rand_distr::Uniform;
9-
use ndarray_rand::RandomExt;
108
use num_traits::{Float, NumCast};
119

1210
/// Truncated eigenproblem solver
@@ -62,8 +60,8 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
6260

6361
// calculate the eigenvalues decompose
6462
pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
65-
let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0))
66-
.mapv(|x| NumCast::from(x).unwrap());
63+
let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
64+
let x = x.mapv(|x| NumCast::from(x).unwrap());
6765

6866
if let Some(ref preconditioner) = self.preconditioner {
6967
lobpcg(

src/lobpcg/lobpcg.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ pub fn lobpcg<
304304
};
305305

306306
// mask and orthonormalize P and AP
307-
let p_ap = previous_p_ap
307+
let mut p_ap = previous_p_ap
308308
.as_ref()
309309
.and_then(|(p, ap)| {
310310
let active_p = ndarray_mask(p.view(), &activemask);
@@ -356,6 +356,8 @@ pub fn lobpcg<
356356
)
357357
})
358358
.or_else(|_| {
359+
p_ap = None;
360+
359361
sorted_eig(
360362
stack![Axis(0), stack![Axis(1), xax, xar], stack![Axis(1), xar.t(), rar]],
361363
Some(stack![Axis(0), stack![Axis(1), xx, xr], stack![Axis(1), xr.t(), rr]]),
@@ -431,14 +433,13 @@ mod tests {
431433
use super::Order;
432434
use crate::close_l2;
433435
use crate::qr::*;
436+
use crate::generate;
434437
use ndarray::prelude::*;
435-
use ndarray_rand::rand_distr::Uniform;
436-
use ndarray_rand::RandomExt;
437438

438439
/// Test the `sorted_eigen` function
439440
#[test]
440441
fn test_sorted_eigen() {
441-
let matrix = Array2::random((10, 10), Uniform::new(0., 10.));
442+
let matrix: Array2<f64> = generate::random((10, 10)) * 10.0;
442443
let matrix = matrix.t().dot(&matrix);
443444

444445
// return all eigenvectors with largest first
@@ -454,15 +455,15 @@ mod tests {
454455
/// Test the masking function
455456
#[test]
456457
fn test_masking() {
457-
let matrix = Array2::random((10, 5), Uniform::new(0., 10.));
458+
let matrix: Array2<f64> = generate::random((10, 5)) * 10.0;
458459
let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]);
459460
close_l2(&masked_matrix.slice(s![.., 2]), &matrix.slice(s![.., 3]), 1e-12);
460461
}
461462

462463
/// Test orthonormalization of a random matrix
463464
#[test]
464465
fn test_orthonormalize() {
465-
let matrix: Array2<f64> = Array2::random((10, 10), Uniform::new(-10., 10.));
466+
let matrix: Array2<f64> = generate::random((10, 10)) * 10.0;
466467

467468
let (n, l) = orthonormalize(matrix.clone()).unwrap();
468469

@@ -483,10 +484,9 @@ mod tests {
483484
assert_symmetric(a);
484485

485486
let n = a.len_of(Axis(0));
486-
let x: Array2<f64> = Array2::random((n, num), Uniform::new(0.0, 1.0));
487+
let x: Array2<f64> = generate::random((n, num));
487488

488489
let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-5, n * 2, order);
489-
dbg!(&result);
490490
match result {
491491
LobpcgResult::Ok(vals, _, r_norms) | LobpcgResult::Err(vals, _, r_norms, _) => {
492492
// check convergence
@@ -523,7 +523,7 @@ mod tests {
523523
#[test]
524524
fn test_eigsolver_constructed() {
525525
let n = 50;
526-
let tmp = Array2::random((n, n), Uniform::new(0.0, 1.0));
526+
let tmp = generate::random((n, n));
527527
//let (v, _) = tmp.qr_square().unwrap();
528528
let (v, _) = orthonormalize(tmp).unwrap();
529529

@@ -540,7 +540,7 @@ mod tests {
540540
fn test_eigsolver_constrained() {
541541
let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
542542
let a = Array2::from_diag(&diag);
543-
let x: Array2<f64> = Array2::random((10, 1), Uniform::new(0.0, 1.0));
543+
let x: Array2<f64> = generate::random((10, 1));
544544
let y: Array2<f64> = arr2(&[
545545
[1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
546546
[0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.],

src/lobpcg/svd.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
///! This module computes the k largest/smallest singular values/vectors for a dense matrix.
44
use super::lobpcg::{lobpcg, LobpcgResult, Order};
55
use crate::error::Result;
6-
use crate::{Lapack, Scalar};
6+
use crate::{Lapack, Scalar, generate};
77
use ndarray::prelude::*;
88
use ndarray::ScalarOperand;
9-
use ndarray_rand::rand_distr::Uniform;
10-
use ndarray_rand::RandomExt;
119
use num_traits::{Float, NumCast};
1210
use std::ops::DivAssign;
1311

@@ -129,7 +127,8 @@ impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Truncate
129127
let (n, m) = (self.problem.nrows(), self.problem.ncols());
130128

131129
// generate initial matrix
132-
let x = Array2::random((usize::min(n, m), num), Uniform::new(0.0, 1.0)).mapv(|x| NumCast::from(x).unwrap());
130+
let x: Array2<f32> = generate::random((usize::min(n, m), num));
131+
let x = x.mapv(|x| NumCast::from(x).unwrap());
133132

134133
// square precision because the SVD squares the eigenvalue as well
135134
let precision = self.precision * self.precision;
@@ -190,10 +189,9 @@ impl MagnitudeCorrection for f64 {
190189
mod tests {
191190
use super::Order;
192191
use super::TruncatedSvd;
193-
use crate::close_l2;
192+
use crate::{close_l2, generate};
193+
194194
use ndarray::{arr1, arr2, Array2};
195-
use ndarray_rand::rand_distr::Uniform;
196-
use ndarray_rand::RandomExt;
197195

198196
#[test]
199197
fn test_truncated_svd() {
@@ -212,7 +210,7 @@ mod tests {
212210

213211
#[test]
214212
fn test_truncated_svd_random() {
215-
let a: Array2<f64> = Array2::random((50, 10), Uniform::new(0.0, 1.0));
213+
let a: Array2<f64> = generate::random((50, 10));
216214

217215
let res = TruncatedSvd::new(a.clone(), Order::Largest)
218216
.precision(1e-5)

0 commit comments

Comments
 (0)