Skip to content

Commit 2702823

Browse files
authored
Merge pull request #82 from termoshtt/test_solve
Test for solve/solveh modules
2 parents 8988eff + c169853 commit 2702823

File tree

4 files changed

+193
-129
lines changed

4 files changed

+193
-129
lines changed

src/lapack_traits/solveh.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ impl Solveh_ for $scalar {
3939
unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
4040
let (n, _) = l.size();
4141
let nrhs = 1;
42-
let ldb = 1;
42+
let ldb = match l {
43+
MatrixLayout::C(_) => 1,
44+
MatrixLayout::F(_) => n,
45+
};
4346
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
4447
into_result(info, ())
4548
}

tests/det.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
extern crate ndarray;
2+
#[macro_use]
3+
extern crate ndarray_linalg;
4+
extern crate num_traits;
5+
6+
use ndarray::*;
7+
use ndarray_linalg::*;
8+
use num_traits::{One, Zero};
9+
10+
/// Returns the matrix with the specified `row` and `col` removed.
11+
fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
12+
where
13+
A: Scalar,
14+
S: Data<Elem = A>,
15+
{
16+
let mut select_rows = (0..a.rows()).collect::<Vec<_>>();
17+
select_rows.remove(row);
18+
let mut select_cols = (0..a.cols()).collect::<Vec<_>>();
19+
select_cols.remove(col);
20+
a.select(Axis(0), &select_rows).select(
21+
Axis(1),
22+
&select_cols,
23+
)
24+
}
25+
26+
/// Computes the determinant of matrix `a`.
27+
///
28+
/// Note: This implementation is written to be clearly correct so that it's
29+
/// useful for verification, but it's very inefficient.
30+
fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A
31+
where
32+
A: Scalar,
33+
S: Data<Elem = A>,
34+
{
35+
assert_eq!(a.rows(), a.cols());
36+
match a.cols() {
37+
0 => A::one(),
38+
1 => a[(0, 0)],
39+
cols => {
40+
(0..cols)
41+
.map(|col| {
42+
let sign = if col % 2 == 0 { A::one() } else { -A::one() };
43+
sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col)))
44+
})
45+
.fold(A::zero(), |sum, subdet| sum + subdet)
46+
}
47+
}
48+
}
49+
50+
#[test]
51+
fn det_empty() {
52+
macro_rules! det_empty {
53+
($elem:ty) => {
54+
let a: Array2<$elem> = Array2::zeros((0, 0));
55+
assert_eq!(a.factorize().unwrap().det().unwrap(), One::one());
56+
assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one());
57+
assert_eq!(a.det().unwrap(), One::one());
58+
assert_eq!(a.det_into().unwrap(), One::one());
59+
}
60+
}
61+
det_empty!(f64);
62+
det_empty!(f32);
63+
det_empty!(c64);
64+
det_empty!(c32);
65+
}
66+
67+
#[test]
68+
fn det_zero() {
69+
macro_rules! det_zero {
70+
($elem:ty) => {
71+
let a: Array2<$elem> = Array2::zeros((1, 1));
72+
assert_eq!(a.det().unwrap(), Zero::zero());
73+
assert_eq!(a.det_into().unwrap(), Zero::zero());
74+
}
75+
}
76+
det_zero!(f64);
77+
det_zero!(f32);
78+
det_zero!(c64);
79+
det_zero!(c32);
80+
}
81+
82+
#[test]
83+
fn det_zero_nonsquare() {
84+
macro_rules! det_zero_nonsquare {
85+
($elem:ty, $shape:expr) => {
86+
let a: Array2<$elem> = Array2::zeros($shape);
87+
assert!(a.det().is_err());
88+
assert!(a.det_into().is_err());
89+
}
90+
}
91+
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
92+
det_zero_nonsquare!(f64, shape);
93+
det_zero_nonsquare!(f32, shape);
94+
det_zero_nonsquare!(c64, shape);
95+
det_zero_nonsquare!(c32, shape);
96+
}
97+
}
98+
99+
#[test]
100+
fn det() {
101+
macro_rules! det {
102+
($elem:ty, $shape:expr, $rtol:expr) => {
103+
let a: Array2<$elem> = random($shape);
104+
println!("a = \n{:?}", a);
105+
let det = det_naive(a.view());
106+
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
107+
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
108+
assert_rclose!(a.det().unwrap(), det, $rtol);
109+
assert_rclose!(a.det_into().unwrap(), det, $rtol);
110+
}
111+
}
112+
for rows in 1..5 {
113+
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
114+
det!(f64, shape, 1e-9);
115+
det!(f32, shape, 1e-4);
116+
det!(c64, shape, 1e-9);
117+
det!(c32, shape, 1e-4);
118+
}
119+
}
120+
}
121+
122+
#[test]
123+
fn det_nonsquare() {
124+
macro_rules! det_nonsquare {
125+
($elem:ty, $shape:expr) => {
126+
let a: Array2<$elem> = random($shape);
127+
assert!(a.factorize().unwrap().det().is_err());
128+
assert!(a.factorize().unwrap().det_into().is_err());
129+
assert!(a.det().is_err());
130+
assert!(a.det_into().is_err());
131+
}
132+
}
133+
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
134+
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
135+
det_nonsquare!(f64, shape);
136+
det_nonsquare!(f32, shape);
137+
det_nonsquare!(c64, shape);
138+
det_nonsquare!(c32, shape);
139+
}
140+
}
141+
}

tests/solve.rs

Lines changed: 12 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -5,137 +5,21 @@ extern crate num_traits;
55

66
use ndarray::*;
77
use ndarray_linalg::*;
8-
use num_traits::{One, Zero};
9-
10-
/// Returns the matrix with the specified `row` and `col` removed.
11-
fn matrix_minor<A, S>(a: ArrayBase<S, Ix2>, (row, col): (usize, usize)) -> Array2<A>
12-
where
13-
A: Scalar,
14-
S: Data<Elem = A>,
15-
{
16-
let mut select_rows = (0..a.rows()).collect::<Vec<_>>();
17-
select_rows.remove(row);
18-
let mut select_cols = (0..a.cols()).collect::<Vec<_>>();
19-
select_cols.remove(col);
20-
a.select(Axis(0), &select_rows).select(
21-
Axis(1),
22-
&select_cols,
23-
)
24-
}
25-
26-
/// Computes the determinant of matrix `a`.
27-
///
28-
/// Note: This implementation is written to be clearly correct so that it's
29-
/// useful for verification, but it's very inefficient.
30-
fn det_naive<A, S>(a: ArrayBase<S, Ix2>) -> A
31-
where
32-
A: Scalar,
33-
S: Data<Elem = A>,
34-
{
35-
assert_eq!(a.rows(), a.cols());
36-
match a.cols() {
37-
0 => A::one(),
38-
1 => a[(0, 0)],
39-
cols => {
40-
(0..cols)
41-
.map(|col| {
42-
let sign = if col % 2 == 0 { A::one() } else { -A::one() };
43-
sign * a[(0, col)] * det_naive(matrix_minor(a.view(), (0, col)))
44-
})
45-
.fold(A::zero(), |sum, subdet| sum + subdet)
46-
}
47-
}
48-
}
49-
50-
#[test]
51-
fn det_empty() {
52-
macro_rules! det_empty {
53-
($elem:ty) => {
54-
let a: Array2<$elem> = Array2::zeros((0, 0));
55-
assert_eq!(a.factorize().unwrap().det().unwrap(), One::one());
56-
assert_eq!(a.factorize().unwrap().det_into().unwrap(), One::one());
57-
assert_eq!(a.det().unwrap(), One::one());
58-
assert_eq!(a.det_into().unwrap(), One::one());
59-
}
60-
}
61-
det_empty!(f64);
62-
det_empty!(f32);
63-
det_empty!(c64);
64-
det_empty!(c32);
65-
}
66-
67-
#[test]
68-
fn det_zero() {
69-
macro_rules! det_zero {
70-
($elem:ty) => {
71-
let a: Array2<$elem> = Array2::zeros((1, 1));
72-
assert_eq!(a.det().unwrap(), Zero::zero());
73-
assert_eq!(a.det_into().unwrap(), Zero::zero());
74-
}
75-
}
76-
det_zero!(f64);
77-
det_zero!(f32);
78-
det_zero!(c64);
79-
det_zero!(c32);
80-
}
81-
82-
#[test]
83-
fn det_zero_nonsquare() {
84-
macro_rules! det_zero_nonsquare {
85-
($elem:ty, $shape:expr) => {
86-
let a: Array2<$elem> = Array2::zeros($shape);
87-
assert!(a.det().is_err());
88-
assert!(a.det_into().is_err());
89-
}
90-
}
91-
for &shape in &[(1, 2).into_shape(), (1, 2).f()] {
92-
det_zero_nonsquare!(f64, shape);
93-
det_zero_nonsquare!(f32, shape);
94-
det_zero_nonsquare!(c64, shape);
95-
det_zero_nonsquare!(c32, shape);
96-
}
97-
}
988

999
#[test]
100-
fn det() {
101-
macro_rules! det {
102-
($elem:ty, $shape:expr, $rtol:expr) => {
103-
let a: Array2<$elem> = random($shape);
104-
println!("a = \n{:?}", a);
105-
let det = det_naive(a.view());
106-
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
107-
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
108-
assert_rclose!(a.det().unwrap(), det, $rtol);
109-
assert_rclose!(a.det_into().unwrap(), det, $rtol);
110-
}
111-
}
112-
for rows in 1..5 {
113-
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
114-
det!(f64, shape, 1e-9);
115-
det!(f32, shape, 1e-4);
116-
det!(c64, shape, 1e-9);
117-
det!(c32, shape, 1e-4);
118-
}
119-
}
10+
fn solve_random() {
11+
let a: Array2<f64> = random((3, 3));
12+
let x: Array1<f64> = random(3);
13+
let b = a.dot(&x);
14+
let y = a.solve_into(b).unwrap();
15+
assert_close_l2!(&x, &y, 1e-7);
12016
}
12117

12218
#[test]
123-
fn det_nonsquare() {
124-
macro_rules! det_nonsquare {
125-
($elem:ty, $shape:expr) => {
126-
let a: Array2<$elem> = random($shape);
127-
assert!(a.factorize().unwrap().det().is_err());
128-
assert!(a.factorize().unwrap().det_into().is_err());
129-
assert!(a.det().is_err());
130-
assert!(a.det_into().is_err());
131-
}
132-
}
133-
for &dims in &[(1, 0), (1, 2), (2, 1), (2, 3)] {
134-
for &shape in &[dims.clone().into_shape(), dims.clone().f()] {
135-
det_nonsquare!(f64, shape);
136-
det_nonsquare!(f32, shape);
137-
det_nonsquare!(c64, shape);
138-
det_nonsquare!(c32, shape);
139-
}
140-
}
19+
fn solve_random_t() {
20+
let a: Array2<f64> = random((3, 3).f());
21+
let x: Array1<f64> = random(3);
22+
let b = a.dot(&x);
23+
let y = a.solve_into(b).unwrap();
24+
assert_close_l2!(&x, &y, 1e-7);
14125
}

tests/solveh.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
extern crate ndarray;
3+
#[macro_use]
4+
extern crate ndarray_linalg;
5+
extern crate num_traits;
6+
7+
use ndarray::*;
8+
use ndarray_linalg::*;
9+
10+
#[test]
11+
fn solveh_random() {
12+
let a: Array2<f64> = random_hpd(3);
13+
let x: Array1<f64> = random(3);
14+
let b = a.dot(&x);
15+
let y = a.solveh_into(b).unwrap();
16+
assert_close_l2!(&x, &y, 1e-7);
17+
18+
let b = a.dot(&x);
19+
let f = a.factorizeh_into().unwrap();
20+
let y = f.solveh_into(b).unwrap();
21+
assert_close_l2!(&x, &y, 1e-7);
22+
}
23+
24+
#[test]
25+
fn solveh_random_t() {
26+
let a: Array2<f64> = random_hpd(3).reversed_axes();
27+
let x: Array1<f64> = random(3);
28+
let b = a.dot(&x);
29+
let y = a.solveh_into(b).unwrap();
30+
assert_close_l2!(&x, &y, 1e-7);
31+
32+
let b = a.dot(&x);
33+
let f = a.factorizeh_into().unwrap();
34+
let y = f.solveh_into(b).unwrap();
35+
assert_close_l2!(&x, &y, 1e-7);
36+
}

0 commit comments

Comments
 (0)