Skip to content

Commit 0c28823

Browse files
committed
Add randn(&mut rng)
1 parent 7917fd3 commit 0c28823

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/generate.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
use ndarray::*;
3-
use rand::*;
43
use std::ops::*;
4+
use rand::*;
55

66
use super::layout::*;
77
use super::types::*;
@@ -21,17 +21,17 @@ pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
2121

2222
/// Random square matrix
2323
pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
24-
where A: Rand,
24+
where A: RandNormal,
2525
S: DataOwned<Elem = A>
2626
{
2727
let mut rng = thread_rng();
28-
let v: Vec<A> = (0..n * n).map(|_| rng.gen()).collect();
28+
let v: Vec<A> = (0..n * n).map(|_| A::randn(&mut rng)).collect();
2929
ArrayBase::from_shape_vec((n, n), v).unwrap()
3030
}
3131

3232
/// Random Hermite matrix
3333
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
34-
where A: Rand + Conjugate + Add<Output = A>,
34+
where A: RandNormal + Conjugate + Add<Output = A>,
3535
S: DataOwned<Elem = A> + DataMut
3636
{
3737
let mut a = random_square(n);
@@ -46,7 +46,7 @@ pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
4646

4747
/// Random Hermite Positive-definite matrix
4848
pub fn random_hpd<A, S>(n: usize) -> ArrayBase<S, Ix2>
49-
where A: Rand + Conjugate + LinalgScalar,
49+
where A: RandNormal + Conjugate + LinalgScalar,
5050
S: DataOwned<Elem = A> + DataMut
5151
{
5252
let a: Array2<A> = random_square(n);

src/types.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ pub use num_complex::Complex64 as c64;
44
use num_complex::Complex;
55
use num_traits::Float;
66
use std::ops::*;
7+
use rand::Rng;
8+
use rand::distributions::*;
79

810
pub trait AssociatedReal: Sized {
911
type Real: Float + Mul<Self, Output = Self>;
@@ -25,6 +27,10 @@ pub trait Conjugate: Copy {
2527
fn conj(self) -> Self;
2628
}
2729

30+
pub trait RandNormal {
31+
fn randn<R: Rng>(&mut R) -> Self;
32+
}
33+
2834
macro_rules! impl_traits {
2935
($real:ty, $complex:ty) => {
3036

@@ -76,6 +82,22 @@ impl Conjugate for $complex {
7682
}
7783
}
7884

85+
impl RandNormal for $real {
86+
fn randn<R: Rng>(rng: &mut R) -> Self {
87+
let dist = Normal::new(0., 1.);
88+
dist.ind_sample(rng) as $real
89+
}
90+
}
91+
92+
impl RandNormal for $complex {
93+
fn randn<R: Rng>(rng: &mut R) -> Self {
94+
let dist = Normal::new(0., 1.);
95+
let re = dist.ind_sample(rng) as $real;
96+
let im = dist.ind_sample(rng) as $real;
97+
Self::new(re, im)
98+
}
99+
}
100+
79101
}} // impl_traits!
80102

81103
impl_traits!(f64, c64);

0 commit comments

Comments
 (0)