Skip to content

Commit 7917fd3

Browse files
committed
Add random_hpd
1 parent ae27f6f commit 7917fd3

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

src/generate.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,23 @@ use ndarray::*;
33
use rand::*;
44
use std::ops::*;
55

6+
use super::layout::*;
67
use super::types::*;
78
use super::error::*;
89

10+
pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
11+
where A: Conjugate,
12+
Si: Data<Elem = A>,
13+
So: DataOwned<Elem = A> + DataMut
14+
{
15+
let mut a = replicate(&a.t());
16+
for val in a.iter_mut() {
17+
*val = Conjugate::conj(*val);
18+
}
19+
a
20+
}
21+
22+
/// Random square matrix
923
pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
1024
where A: Rand,
1125
S: DataOwned<Elem = A>
@@ -15,20 +29,31 @@ pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
1529
ArrayBase::from_shape_vec((n, n), v).unwrap()
1630
}
1731

32+
/// Random Hermite matrix
1833
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
1934
where A: Rand + Conjugate + Add<Output = A>,
2035
S: DataOwned<Elem = A> + DataMut
2136
{
2237
let mut a = random_square(n);
2338
for i in 0..n {
2439
a[(i, i)] = a[(i, i)] + Conjugate::conj(a[(i, i)]);
25-
for j in i..n {
26-
a[(i, j)] = a[(j, i)];
40+
for j in (i + 1)..n {
41+
a[(i, j)] = Conjugate::conj(a[(j, i)])
2742
}
2843
}
2944
a
3045
}
3146

47+
/// Random Hermite Positive-definite matrix
48+
pub fn random_hpd<A, S>(n: usize) -> ArrayBase<S, Ix2>
49+
where A: Rand + Conjugate + LinalgScalar,
50+
S: DataOwned<Elem = A> + DataMut
51+
{
52+
let a: Array2<A> = random_square(n);
53+
let ah: Array2<A> = conjugate(&a);
54+
replicate(&ah.dot(&a))
55+
}
56+
3257
/// construct matrix from diag
3358
pub fn from_diag<A>(d: &[A]) -> Array2<A>
3459
where A: LinalgScalar

src/prelude.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ pub use matrix::Matrix;
33
pub use square::SquareMatrix;
44
pub use triangular::*;
55
pub use util::*;
6+
pub use types::*;
7+
pub use generate::*;
68
pub use assert::*;
79

810
pub use qr::*;

tests/cholesky.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11

2-
extern crate rand_extra;
32
extern crate ndarray;
4-
extern crate ndarray_rand;
53
#[macro_use]
64
extern crate ndarray_linalg;
75

8-
use rand_extra::*;
96
use ndarray::*;
10-
use ndarray_rand::RandomExt;
117
use ndarray_linalg::prelude::*;
128

13-
pub fn random_hermite(n: usize) -> Array<f64, Ix2> {
14-
let r_dist = RealNormal::new(0., 1.);
15-
let a = Array::<f64, _>::random((n, n), r_dist);
16-
a.dot(&a.t())
17-
}
18-
199
#[test]
2010
fn cholesky() {
21-
let a = random_hermite(3);
11+
let a: Array2<f64> = random_hpd(3);
2212
println!("a = \n{:?}", a);
2313
let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap();
2414
println!("c = \n{:?}", c);
@@ -28,7 +18,7 @@ fn cholesky() {
2818

2919
#[test]
3020
fn cholesky_t() {
31-
let a = random_hermite(3);
21+
let a: Array2<f64> = random_hpd(3);
3222
println!("a = \n{:?}", a);
3323
let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap();
3424
println!("c = \n{:?}", c);

0 commit comments

Comments
 (0)