Skip to content

Commit a5ee745

Browse files
authored
Merge pull request #47 from termoshtt/generate
Generate specific matrices
2 parents 1bc8a41 + a56fbc4 commit a5ee745

File tree

10 files changed

+191
-94
lines changed

10 files changed

+191
-94
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ openblas = ["blas/openblas", "lapack/openblas"]
1515
netlib = ["blas/netlib", "lapack/netlib"]
1616

1717
[dependencies]
18+
rand = "0.3"
1819
derive-new = "0.4"
1920
enum-error-derive = "0.1"
2021
num-traits = "0.1"

src/assert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::iter::Sum;
44
use num_traits::Float;
55
use ndarray::*;
66

7+
use super::types::*;
78
use super::vector::*;
89

910
pub fn rclose<A, Tol>(test: A, truth: A, rtol: Tol) -> Result<Tol, Tol>

src/generate.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
2+
use ndarray::*;
3+
use std::ops::*;
4+
use rand::*;
5+
6+
use super::layout::*;
7+
use super::types::*;
8+
use super::error::*;
9+
10+
pub fn conjugate<A, Si, So>(a: &ArrayBase<Si, Ix2>) -> ArrayBase<So, Ix2>
11+
where A: Conjugate,
12+
Si: Data<Elem = A>,
13+
So: DataOwned<Elem = A> + DataMut
14+
{
15+
let mut a = replicate(&a.t());
16+
for val in a.iter_mut() {
17+
*val = Conjugate::conj(*val);
18+
}
19+
a
20+
}
21+
22+
/// Random matrix
23+
pub fn random<A, S>(n: usize, m: usize) -> ArrayBase<S, Ix2>
24+
where A: RandNormal,
25+
S: DataOwned<Elem = A>
26+
{
27+
let mut rng = thread_rng();
28+
let v: Vec<A> = (0..n * m).map(|_| A::randn(&mut rng)).collect();
29+
ArrayBase::from_shape_vec((n, m), v).unwrap()
30+
}
31+
32+
/// Random square matrix
33+
pub fn random_square<A, S>(n: usize) -> ArrayBase<S, Ix2>
34+
where A: RandNormal,
35+
S: DataOwned<Elem = A>
36+
{
37+
random(n, n)
38+
}
39+
40+
/// Random Hermite matrix
41+
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
42+
where A: RandNormal + Conjugate + Add<Output = A>,
43+
S: DataOwned<Elem = A> + DataMut
44+
{
45+
let mut a = random_square(n);
46+
for i in 0..n {
47+
a[(i, i)] = a[(i, i)] + Conjugate::conj(a[(i, i)]);
48+
for j in (i + 1)..n {
49+
a[(i, j)] = Conjugate::conj(a[(j, i)])
50+
}
51+
}
52+
a
53+
}
54+
55+
/// Random Hermite Positive-definite matrix
56+
pub fn random_hpd<A, S>(n: usize) -> ArrayBase<S, Ix2>
57+
where A: RandNormal + Conjugate + LinalgScalar,
58+
S: DataOwned<Elem = A> + DataMut
59+
{
60+
let a: Array2<A> = random_square(n);
61+
let ah: Array2<A> = conjugate(&a);
62+
replicate(&ah.dot(&a))
63+
}
64+
65+
/// construct matrix from diag
66+
pub fn from_diag<A>(d: &[A]) -> Array2<A>
67+
where A: LinalgScalar
68+
{
69+
let n = d.len();
70+
let mut e = Array::zeros((n, n));
71+
for i in 0..n {
72+
e[(i, i)] = d[i];
73+
}
74+
e
75+
}
76+
77+
/// stack vectors into matrix horizontally
78+
pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>>
79+
where A: LinalgScalar,
80+
S: Data<Elem = A>
81+
{
82+
let views: Vec<_> = xs.iter()
83+
.map(|x| {
84+
let n = x.len();
85+
x.view().into_shape((n, 1)).unwrap()
86+
})
87+
.collect();
88+
stack(Axis(1), &views).map_err(|e| e.into())
89+
}
90+
91+
/// stack vectors into matrix vertically
92+
pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>>
93+
where A: LinalgScalar,
94+
S: Data<Elem = A>
95+
{
96+
let views: Vec<_> = xs.iter()
97+
.map(|x| {
98+
let n = x.len();
99+
x.view().into_shape((1, n)).unwrap()
100+
})
101+
.collect();
102+
stack(Axis(0), &views).map_err(|e| e.into())
103+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ extern crate blas;
3535
extern crate lapack;
3636
extern crate num_traits;
3737
extern crate num_complex;
38+
extern crate rand;
3839
#[macro_use(s)]
3940
extern crate ndarray;
4041
#[macro_use]
@@ -61,5 +62,6 @@ pub mod square;
6162
pub mod triangular;
6263

6364
pub mod util;
65+
pub mod generate;
6466
pub mod assert;
6567
pub mod prelude;

src/prelude.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ pub use matrix::Matrix;
33
pub use square::SquareMatrix;
44
pub use triangular::*;
55
pub use util::*;
6+
pub use types::*;
7+
pub use generate::*;
68
pub use assert::*;
79

810
pub use qr::*;

src/triangular.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::impl2::UPLO;
77
use super::matrix::{Matrix, MFloat};
88
use super::square::SquareMatrix;
99
use super::error::LinalgError;
10-
use super::util::hstack;
10+
use super::generate::hstack;
1111
use super::impls::solve::ImplSolve;
1212

1313
pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {

src/types.rs

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11

22
pub use num_complex::Complex32 as c32;
33
pub use num_complex::Complex64 as c64;
4+
use num_complex::Complex;
45
use num_traits::Float;
56
use std::ops::*;
7+
use rand::Rng;
8+
use rand::distributions::*;
69

710
pub trait AssociatedReal: Sized {
811
type Real: Float + Mul<Self, Output = Self>;
@@ -11,21 +14,91 @@ pub trait AssociatedComplex: Sized {
1114
type Complex;
1215
}
1316

14-
macro_rules! impl_assoc {
17+
/// Field with norm
18+
pub trait Absolute {
19+
type Output: Float;
20+
fn squared(&self) -> Self::Output;
21+
fn abs(&self) -> Self::Output {
22+
self.squared().sqrt()
23+
}
24+
}
25+
26+
pub trait Conjugate: Copy {
27+
fn conj(self) -> Self;
28+
}
29+
30+
pub trait RandNormal {
31+
fn randn<R: Rng>(&mut R) -> Self;
32+
}
33+
34+
macro_rules! impl_traits {
1535
($real:ty, $complex:ty) => {
36+
1637
impl AssociatedReal for $real {
1738
type Real = $real;
1839
}
40+
1941
impl AssociatedReal for $complex {
2042
type Real = $real;
2143
}
44+
2245
impl AssociatedComplex for $real {
2346
type Complex = $complex;
2447
}
48+
2549
impl AssociatedComplex for $complex {
2650
type Complex = $complex;
2751
}
28-
}} // impl_assoc!
2952

30-
impl_assoc!(f64, c64);
31-
impl_assoc!(f32, c32);
53+
impl Absolute for $real {
54+
type Output = Self;
55+
fn squared(&self) -> Self::Output {
56+
*self * *self
57+
}
58+
fn abs(&self) -> Self::Output {
59+
Float::abs(*self)
60+
}
61+
}
62+
63+
impl Absolute for $complex {
64+
type Output = $real;
65+
fn squared(&self) -> Self::Output {
66+
self.norm_sqr()
67+
}
68+
fn abs(&self) -> Self::Output {
69+
self.norm()
70+
}
71+
}
72+
73+
impl Conjugate for $real {
74+
fn conj(self) -> Self {
75+
self
76+
}
77+
}
78+
79+
impl Conjugate for $complex {
80+
fn conj(self) -> Self {
81+
Complex::conj(&self)
82+
}
83+
}
84+
85+
impl RandNormal for $real {
86+
fn randn<R: Rng>(rng: &mut R) -> Self {
87+
let dist = Normal::new(0., 1.);
88+
dist.ind_sample(rng) as $real
89+
}
90+
}
91+
92+
impl RandNormal for $complex {
93+
fn randn<R: Rng>(rng: &mut R) -> Self {
94+
let dist = Normal::new(0., 1.);
95+
let re = dist.ind_sample(rng) as $real;
96+
let im = dist.ind_sample(rng) as $real;
97+
Self::new(re, im)
98+
}
99+
}
100+
101+
}} // impl_traits!
102+
103+
impl_traits!(f64, c64);
104+
impl_traits!(f32, c32);

src/util.rs

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,10 @@
33
use std::iter::Sum;
44
use ndarray::*;
55
use num_traits::Float;
6-
use super::vector::*;
76
use std::ops::Div;
87

9-
/// construct matrix from diag
10-
pub fn from_diag<A>(d: &[A]) -> Array2<A>
11-
where A: LinalgScalar
12-
{
13-
let n = d.len();
14-
let mut e = Array::zeros((n, n));
15-
for i in 0..n {
16-
e[(i, i)] = d[i];
17-
}
18-
e
19-
}
20-
21-
/// stack vectors into matrix horizontally
22-
pub fn hstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
23-
where A: LinalgScalar,
24-
S: Data<Elem = A>
25-
{
26-
let views: Vec<_> = xs.iter()
27-
.map(|x| {
28-
let n = x.len();
29-
x.view().into_shape((n, 1)).unwrap()
30-
})
31-
.collect();
32-
stack(Axis(1), &views)
33-
}
34-
35-
/// stack vectors into matrix vertically
36-
pub fn vstack<A, S>(xs: &[ArrayBase<S, Ix1>]) -> Result<Array<A, Ix2>, ShapeError>
37-
where A: LinalgScalar,
38-
S: Data<Elem = A>
39-
{
40-
let views: Vec<_> = xs.iter()
41-
.map(|x| {
42-
let n = x.len();
43-
x.view().into_shape((1, n)).unwrap()
44-
})
45-
.collect();
46-
stack(Axis(0), &views)
47-
}
8+
use super::types::*;
9+
use super::vector::*;
4810

4911
pub enum NormalizeAxis {
5012
Row = 0,

src/vector.rs

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,40 +43,3 @@ impl<A, S, D, T> Norm for ArrayBase<S, D>
4343
})
4444
}
4545
}
46-
47-
/// Field with norm
48-
pub trait Absolute {
49-
type Output: Float;
50-
fn squared(&self) -> Self::Output;
51-
fn abs(&self) -> Self::Output {
52-
self.squared().sqrt()
53-
}
54-
}
55-
56-
macro_rules! impl_abs {
57-
($f:ty, $c:ty) => {
58-
59-
impl Absolute for $f {
60-
type Output = Self;
61-
fn squared(&self) -> Self::Output {
62-
*self * *self
63-
}
64-
fn abs(&self) -> Self::Output {
65-
Float::abs(*self)
66-
}
67-
}
68-
69-
impl Absolute for $c {
70-
type Output = $f;
71-
fn squared(&self) -> Self::Output {
72-
self.norm_sqr()
73-
}
74-
fn abs(&self) -> Self::Output {
75-
self.norm()
76-
}
77-
}
78-
79-
}} // impl_abs!
80-
81-
impl_abs!(f64, c64);
82-
impl_abs!(f32, c32);

tests/cholesky.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11

2-
extern crate rand_extra;
32
extern crate ndarray;
4-
extern crate ndarray_rand;
53
#[macro_use]
64
extern crate ndarray_linalg;
75

8-
use rand_extra::*;
96
use ndarray::*;
10-
use ndarray_rand::RandomExt;
117
use ndarray_linalg::prelude::*;
128

13-
pub fn random_hermite(n: usize) -> Array<f64, Ix2> {
14-
let r_dist = RealNormal::new(0., 1.);
15-
let a = Array::<f64, _>::random((n, n), r_dist);
16-
a.dot(&a.t())
17-
}
18-
199
#[test]
2010
fn cholesky() {
21-
let a = random_hermite(3);
11+
let a: Array2<f64> = random_hpd(3);
2212
println!("a = \n{:?}", a);
2313
let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap();
2414
println!("c = \n{:?}", c);
@@ -28,7 +18,7 @@ fn cholesky() {
2818

2919
#[test]
3020
fn cholesky_t() {
31-
let a = random_hermite(3);
21+
let a: Array2<f64> = random_hpd(3);
3222
println!("a = \n{:?}", a);
3323
let c: Array2<_> = (&a).cholesky(UPLO::Upper).unwrap();
3424
println!("c = \n{:?}", c);

0 commit comments

Comments
 (0)