Skip to content

Commit 4c44542

Browse files
authored
Merge pull request #21 from termoshtt/numtest
Refactoring tests using ndarray-numtest
2 parents fa74eed + e883fd6 commit 4c44542

File tree

19 files changed

+150
-314
lines changed

19 files changed

+150
-314
lines changed

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,5 @@ version = "0.7"
1818
features = ["blas"]
1919

2020
[dev-dependencies]
21-
rand = "0.3.14"
2221
ndarray-rand = "0.3"
23-
float-cmp = "0.2.3"
22+
ndarray-numtest = "0.1.4"

src/eigh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use num_traits::Zero;
66
use error::LapackError;
77

88
pub trait ImplEigh: Sized {
9-
fn eigh(layout: Layout, n: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
9+
fn eigh(layout: Layout, n: usize, a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
1010
}
1111

1212
macro_rules! impl_eigh {

src/norm.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
use lapack::c::*;
44

55
pub trait ImplNorm: Sized {
6-
fn norm_1(m: usize, n: usize, mut a: Vec<Self>) -> Self;
7-
fn norm_i(m: usize, n: usize, mut a: Vec<Self>) -> Self;
8-
fn norm_f(m: usize, n: usize, mut a: Vec<Self>) -> Self;
6+
fn norm_1(m: usize, n: usize, a: Vec<Self>) -> Self;
7+
fn norm_i(m: usize, n: usize, a: Vec<Self>) -> Self;
8+
fn norm_f(m: usize, n: usize, a: Vec<Self>) -> Self;
99
}
1010

1111
macro_rules! impl_norm {

src/prelude.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ pub use vector::Vector;
33
pub use matrix::Matrix;
44
pub use square::SquareMatrix;
55
pub use hermite::HermiteMatrix;
6-
pub use triangular::TriangularMatrix;
6+
pub use triangular::{TriangularMatrix, drop_lower, drop_upper};

src/qr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use num_traits::Zero;
77
use error::LapackError;
88

99
pub trait ImplQR: Sized {
10-
fn qr(layout: Layout, n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
10+
fn qr(layout: Layout, n: usize, m: usize, a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>), LapackError>;
1111
}
1212

1313
macro_rules! impl_qr {

src/svd.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@ use num_traits::Zero;
77
use error::LapackError;
88

99
pub trait ImplSVD: Sized {
10-
fn svd(layout: Layout,
11-
n: usize,
12-
m: usize,
13-
mut a: Vec<Self>)
14-
-> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError>;
10+
fn svd(layout: Layout, n: usize, m: usize, a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError>;
1511
}
1612

1713
macro_rules! impl_svd {

src/triangular.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,21 @@ impl<A> TriangularMatrix for Array<A, Ix2>
3838
Ok(Array::from_vec(x))
3939
}
4040
}
41+
42+
pub fn drop_upper(mut a: Array<f64, Ix2>) -> Array<f64, Ix2> {
43+
for ((i, j), val) in a.indexed_iter_mut() {
44+
if i < j {
45+
*val = 0.0;
46+
}
47+
}
48+
a
49+
}
50+
51+
pub fn drop_lower(mut a: Array<f64, Ix2>) -> Array<f64, Ix2> {
52+
for ((i, j), val) in a.indexed_iter_mut() {
53+
if i > j {
54+
*val = 0.0;
55+
}
56+
}
57+
a
58+
}

tests/cholesky.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,25 @@
1-
2-
extern crate rand;
3-
extern crate ndarray;
4-
extern crate ndarray_rand;
5-
extern crate ndarray_linalg;
6-
7-
use rand::distributions::*;
8-
use ndarray::prelude::*;
9-
use ndarray_linalg::prelude::*;
10-
use ndarray_rand::RandomExt;
11-
12-
fn all_close(a: Array<f64, Ix2>, b: Array<f64, Ix2>) {
13-
if !a.all_close(&b, 1.0e-7) {
14-
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15-
a,
16-
b);
17-
}
18-
}
1+
include!("header.rs");
192

203
#[test]
214
fn cholesky() {
22-
let r_dist = Range::new(0., 1.);
5+
let r_dist = RealNormal::new(0., 1.);
236
let mut a = Array::<f64, _>::random((3, 3), r_dist);
247
a = a.dot(&a.t());
258
println!("a = \n{:?}", a);
269
let c = a.clone().cholesky().unwrap();
2710
println!("c = \n{:?}", c);
2811
println!("cc = \n{:?}", c.t().dot(&c));
29-
all_close(c.t().dot(&c), a);
12+
c.t().dot(&c).assert_allclose_l2(&a, 1e-7);
3013
}
3114

3215
#[test]
3316
fn cholesky_t() {
34-
let r_dist = Range::new(0., 1.);
17+
let r_dist = RealNormal::new(0., 1.);
3518
let mut a = Array::<f64, _>::random((3, 3), r_dist);
3619
a = a.dot(&a.t()).reversed_axes();
3720
println!("a = \n{:?}", a);
3821
let c = a.clone().cholesky().unwrap();
3922
println!("c = \n{:?}", c);
4023
println!("cc = \n{:?}", c.t().dot(&c));
41-
all_close(c.t().dot(&c), a);
24+
c.t().dot(&c).assert_allclose_l2(&a, 1e-7);
4225
}

tests/det.rs

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
1-
2-
extern crate ndarray;
3-
extern crate ndarray_linalg;
4-
extern crate ndarray_rand;
5-
extern crate rand;
6-
extern crate float_cmp;
7-
8-
use ndarray::prelude::*;
9-
use ndarray_linalg::prelude::*;
10-
use rand::distributions::*;
11-
use ndarray_rand::RandomExt;
12-
use float_cmp::ApproxEqRatio;
13-
14-
fn approx_eq(val: f64, truth: f64, ratio: f64) {
15-
if !val.approx_eq_ratio(&truth, ratio) {
16-
panic!("Not almost equal! val={:?}, truth={:?}, ratio={:?}",
17-
val,
18-
truth,
19-
ratio);
20-
}
21-
}
1+
include!("header.rs");
222

233
fn random_hermite(n: usize) -> Array<f64, Ix2> {
24-
let r_dist = Range::new(0., 1.);
4+
let r_dist = RealNormal::new(0., 1.);
255
let a = Array::<f64, _>::random((n, n), r_dist);
266
a.dot(&a.t())
277
}
@@ -32,5 +12,5 @@ fn deth() {
3212
let (e, _) = a.clone().eigh().unwrap();
3313
let deth = a.clone().deth().unwrap();
3414
let det_eig = e.iter().fold(1.0, |x, y| x * y);
35-
approx_eq(deth, det_eig, 1.0e-7);
15+
deth.assert_close(det_eig, 1.0e-7);
3616
}

tests/eigh.rs

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,4 @@
1-
2-
extern crate ndarray;
3-
extern crate ndarray_linalg;
4-
5-
use ndarray::prelude::*;
6-
use ndarray_linalg::prelude::*;
7-
8-
fn assert_almost_eq(a: f64, b: f64) {
9-
let rel_dev = (a - b).abs() / (a.abs() + b.abs());
10-
if rel_dev > 1.0e-7 {
11-
panic!("a={:?}, b={:?} are not almost equal", a, b);
12-
}
13-
}
14-
1+
include!("header.rs");
152

163
#[test]
174
fn eigen_vector_manual() {
@@ -31,6 +18,6 @@ fn diagonalize() {
3118
let (e, vecs) = a.clone().eigh().unwrap();
3219
let s = vecs.t().dot(&a).dot(&vecs);
3320
for i in 0..3 {
34-
assert_almost_eq(e[i], s[(i, i)]);
21+
e[i].assert_close(s[(i, i)], 1e-7);
3522
}
3623
}

0 commit comments

Comments
 (0)