Skip to content

Commit d6684d7

Browse files
committed
Use random_regular matrix for det test
1 parent 2cd31e3 commit d6684d7

File tree

3 files changed

+63
-38
lines changed

3 files changed

+63
-38
lines changed

src/generate.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ where
5858
q.dot(&r)
5959
}
6060

61+
/// Generate random regular matrix
62+
pub fn random_regular_t<A>(n: usize) -> Array2<A>
63+
where
64+
A: Scalar + RandNormal,
65+
{
66+
let a: Array2<A> = random((n, n).f());
67+
let (q, mut r) = a.qr_into().unwrap();
68+
for i in 0..n {
69+
r[(i, i)] = A::from_f64(1.0) + AssociatedReal::inject(r[(i, i)].abs());
70+
}
71+
q.dot(&r).t().to_owned()
72+
}
73+
6174
/// Random Hermite matrix
6275
pub fn random_hermite<A, S>(n: usize) -> ArrayBase<S, Ix2>
6376
where

tests/det.rs

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -100,46 +100,48 @@ fn det_zero_nonsquare() {
100100

101101
#[test]
102102
fn det() {
103-
macro_rules! det {
104-
($elem:ty, $shape:expr, $rtol:expr) => {
105-
let a: Array2<$elem> = random($shape);
106-
println!("a = \n{:?}", a);
107-
let det = det_naive(&a);
108-
let sign = det.div_real(det.abs());
109-
let ln_det = det.abs().ln();
110-
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, $rtol);
111-
{
112-
let result = a.factorize().unwrap().sln_det().unwrap();
113-
assert_rclose!(result.0, sign, $rtol);
114-
assert_rclose!(result.1, ln_det, $rtol);
115-
}
116-
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, $rtol);
117-
{
118-
let result = a.factorize().unwrap().sln_det_into().unwrap();
119-
assert_rclose!(result.0, sign, $rtol);
120-
assert_rclose!(result.1, ln_det, $rtol);
121-
}
122-
assert_rclose!(a.det().unwrap(), det, $rtol);
123-
{
124-
let result = a.sln_det().unwrap();
125-
assert_rclose!(result.0, sign, $rtol);
126-
assert_rclose!(result.1, ln_det, $rtol);
127-
}
128-
assert_rclose!(a.clone().det_into().unwrap(), det, $rtol);
129-
{
130-
let result = a.sln_det_into().unwrap();
131-
assert_rclose!(result.0, sign, $rtol);
132-
assert_rclose!(result.1, ln_det, $rtol);
133-
}
134-
};
103+
fn det_impl<A, Tol>(a: Array2<A>, rtol: Tol)
104+
where
105+
A: Scalar<Real = Tol>,
106+
Tol: RealScalar<Real = Tol>,
107+
{
108+
let det = det_naive(&a);
109+
let sign = det.div_real(det.abs());
110+
let ln_det = det.abs().ln();
111+
assert_rclose!(a.factorize().unwrap().det().unwrap(), det, rtol);
112+
{
113+
let result = a.factorize().unwrap().sln_det().unwrap();
114+
assert_rclose!(result.0, sign, rtol);
115+
assert_rclose!(result.1, ln_det, rtol);
116+
}
117+
assert_rclose!(a.factorize().unwrap().det_into().unwrap(), det, rtol);
118+
{
119+
let result = a.factorize().unwrap().sln_det_into().unwrap();
120+
assert_rclose!(result.0, sign, rtol);
121+
assert_rclose!(result.1, ln_det, rtol);
122+
}
123+
assert_rclose!(a.det().unwrap(), det, rtol);
124+
{
125+
let result = a.sln_det().unwrap();
126+
assert_rclose!(result.0, sign, rtol);
127+
assert_rclose!(result.1, ln_det, rtol);
128+
}
129+
assert_rclose!(a.clone().det_into().unwrap(), det, rtol);
130+
{
131+
let result = a.sln_det_into().unwrap();
132+
assert_rclose!(result.0, sign, rtol);
133+
assert_rclose!(result.1, ln_det, rtol);
134+
}
135135
}
136136
for rows in 1..5 {
137-
for &shape in &[(rows, rows).into_shape(), (rows, rows).f()] {
138-
det!(f64, shape, 1e-9);
139-
det!(f32, shape, 1e-4);
140-
det!(c64, shape, 1e-9);
141-
det!(c32, shape, 1e-4);
142-
}
137+
det_impl(random_regular::<f64>(rows), 1e-9);
138+
det_impl(random_regular::<f32>(rows), 1e-4);
139+
det_impl(random_regular::<c64>(rows), 1e-9);
140+
det_impl(random_regular::<c32>(rows), 1e-4);
141+
det_impl(random_regular_t::<f64>(rows), 1e-9);
142+
det_impl(random_regular_t::<f32>(rows), 1e-4);
143+
det_impl(random_regular_t::<c64>(rows), 1e-9);
144+
det_impl(random_regular_t::<c32>(rows), 1e-4);
143145
}
144146
}
145147

tests/generate.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
#[test]
5+
fn random_regular_transpose() {
6+
let a: Array2<f32> = random_regular(3);
7+
assert!(a.is_standard_layout());
8+
let a: Array2<f32> = random_regular_t(3);
9+
assert!(!a.is_standard_layout());
10+
}

0 commit comments

Comments
 (0)