Skip to content

Commit 3a8f0c5

Browse files
authored
Merge pull request #296 from jturner314/fix-solve_h-complex
Fix Solve::solve_h_* for complex inputs with standard layout
2 parents 15bda91 + ffe65cb commit 3a8f0c5

File tree

2 files changed

+209
-32
lines changed

2 files changed

+209
-32
lines changed

lax/src/solve.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,49 @@ macro_rules! impl_solve {
7575
ipiv: &Pivot,
7676
b: &mut [Self],
7777
) -> Result<()> {
78-
let t = match l {
78+
// If the array has C layout, then it needs to be handled
79+
// specially, since LAPACK expects a Fortran-layout array.
80+
// Reinterpreting a C layout array as Fortran layout is
81+
// equivalent to transposing it. So, we can handle the "no
82+
// transpose" and "transpose" cases by swapping to "transpose"
83+
// or "no transpose", respectively. For the "Hermite" case, we
84+
// can take advantage of the following:
85+
//
86+
// ```text
87+
// A^H x = b
88+
// ⟺ conj(A^T) x = b
89+
// ⟺ conj(conj(A^T) x) = conj(b)
90+
// ⟺ conj(conj(A^T)) conj(x) = conj(b)
91+
// ⟺ A^T conj(x) = conj(b)
92+
// ```
93+
//
94+
// So, we can handle this case by switching to "no transpose"
95+
// (which is equivalent to transposing the array since it will
96+
// be reinterpreted as Fortran layout) and applying the
97+
// elementwise conjugate to `x` and `b`.
98+
let (t, conj) = match l {
7999
MatrixLayout::C { .. } => match t {
80-
Transpose::No => Transpose::Transpose,
81-
Transpose::Transpose | Transpose::Hermite => Transpose::No,
100+
Transpose::No => (Transpose::Transpose, false),
101+
Transpose::Transpose => (Transpose::No, false),
102+
Transpose::Hermite => (Transpose::No, true),
82103
},
83-
_ => t,
104+
MatrixLayout::F { .. } => (t, false),
84105
};
85106
let (n, _) = l.size();
86107
let nrhs = 1;
87108
let ldb = l.lda();
88109
let mut info = 0;
110+
if conj {
111+
for b_elem in &mut *b {
112+
*b_elem = b_elem.conj();
113+
}
114+
}
89115
unsafe { $getrs(t as u8, n, nrhs, a, l.lda(), ipiv, b, ldb, &mut info) };
116+
if conj {
117+
for b_elem in &mut *b {
118+
*b_elem = b_elem.conj();
119+
}
120+
}
90121
info.as_lapack_result()?;
91122
Ok(())
92123
}

ndarray-linalg/tests/solve.rs

Lines changed: 174 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,188 @@
1-
use ndarray::*;
2-
use ndarray_linalg::*;
1+
use ndarray::prelude::*;
2+
use ndarray_linalg::{
3+
assert_aclose, assert_close_l2, c32, c64, random, random_hpd, solve::*, OperationNorm, Scalar,
4+
};
5+
6+
macro_rules! test_solve {
7+
(
8+
[$($elem_type:ty => $rtol:expr),*],
9+
$a_ident:ident = $a:expr,
10+
$x_ident:ident = $x:expr,
11+
b = $b:expr,
12+
$solve:ident,
13+
) => {
14+
$({
15+
let $a_ident: Array2<$elem_type> = $a;
16+
let $x_ident: Array1<$elem_type> = $x;
17+
let b: Array1<$elem_type> = $b;
18+
let a = $a_ident;
19+
let x = $x_ident;
20+
let rtol = $rtol;
21+
assert_close_l2!(&a.$solve(&b).unwrap(), &x, rtol);
22+
assert_close_l2!(&a.factorize().unwrap().$solve(&b).unwrap(), &x, rtol);
23+
assert_close_l2!(&a.factorize_into().unwrap().$solve(&b).unwrap(), &x, rtol);
24+
})*
25+
};
26+
}
27+
28+
macro_rules! test_solve_into {
29+
(
30+
[$($elem_type:ty => $rtol:expr),*],
31+
$a_ident:ident = $a:expr,
32+
$x_ident:ident = $x:expr,
33+
b = $b:expr,
34+
$solve_into:ident,
35+
) => {
36+
$({
37+
let $a_ident: Array2<$elem_type> = $a;
38+
let $x_ident: Array1<$elem_type> = $x;
39+
let b: Array1<$elem_type> = $b;
40+
let a = $a_ident;
41+
let x = $x_ident;
42+
let rtol = $rtol;
43+
assert_close_l2!(&a.$solve_into(b.clone()).unwrap(), &x, rtol);
44+
assert_close_l2!(&a.factorize().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
45+
assert_close_l2!(&a.factorize_into().unwrap().$solve_into(b.clone()).unwrap(), &x, rtol);
46+
})*
47+
};
48+
}
49+
50+
macro_rules! test_solve_inplace {
51+
(
52+
[$($elem_type:ty => $rtol:expr),*],
53+
$a_ident:ident = $a:expr,
54+
$x_ident:ident = $x:expr,
55+
b = $b:expr,
56+
$solve_inplace:ident,
57+
) => {
58+
$({
59+
let $a_ident: Array2<$elem_type> = $a;
60+
let $x_ident: Array1<$elem_type> = $x;
61+
let b: Array1<$elem_type> = $b;
62+
let a = $a_ident;
63+
let x = $x_ident;
64+
let rtol = $rtol;
65+
{
66+
let mut b = b.clone();
67+
assert_close_l2!(&a.$solve_inplace(&mut b).unwrap(), &x, rtol);
68+
assert_close_l2!(&b, &x, rtol);
69+
}
70+
{
71+
let mut b = b.clone();
72+
assert_close_l2!(&a.factorize().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
73+
assert_close_l2!(&b, &x, rtol);
74+
}
75+
{
76+
let mut b = b.clone();
77+
assert_close_l2!(&a.factorize_into().unwrap().$solve_inplace(&mut b).unwrap(), &x, rtol);
78+
assert_close_l2!(&b, &x, rtol);
79+
}
80+
})*
81+
};
82+
}
83+
84+
macro_rules! test_solve_all {
85+
(
86+
[$($elem_type:ty => $rtol:expr),*],
87+
$a_ident:ident = $a:expr,
88+
$x_ident:ident = $x:expr,
89+
b = $b:expr,
90+
[$solve:ident, $solve_into:ident, $solve_inplace:ident],
91+
) => {
92+
test_solve!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve,);
93+
test_solve_into!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_into,);
94+
test_solve_inplace!([$($elem_type => $rtol),*], $a_ident = $a, $x_ident = $x, b = $b, $solve_inplace,);
95+
};
96+
}
97+
98+
#[test]
99+
fn solve_random_float() {
100+
for n in 0..=8 {
101+
for &set_f in &[false, true] {
102+
test_solve_all!(
103+
[f32 => 1e-3, f64 => 1e-9],
104+
a = random([n; 2].set_f(set_f)),
105+
x = random(n),
106+
b = a.dot(&x),
107+
[solve, solve_into, solve_inplace],
108+
);
109+
}
110+
}
111+
}
112+
113+
#[test]
114+
fn solve_random_complex() {
115+
for n in 0..=8 {
116+
for &set_f in &[false, true] {
117+
test_solve_all!(
118+
[c32 => 1e-3, c64 => 1e-9],
119+
a = random([n; 2].set_f(set_f)),
120+
x = random(n),
121+
b = a.dot(&x),
122+
[solve, solve_into, solve_inplace],
123+
);
124+
}
125+
}
126+
}
3127

4128
#[test]
5-
fn solve_random() {
6-
let a: Array2<f64> = random((3, 3));
7-
let x: Array1<f64> = random(3);
8-
let b = a.dot(&x);
9-
let y = a.solve_into(b).unwrap();
10-
assert_close_l2!(&x, &y, 1e-7);
129+
fn solve_t_random_float() {
130+
for n in 0..=8 {
131+
for &set_f in &[false, true] {
132+
test_solve_all!(
133+
[f32 => 1e-3, f64 => 1e-9],
134+
a = random([n; 2].set_f(set_f)),
135+
x = random(n),
136+
b = a.t().dot(&x),
137+
[solve_t, solve_t_into, solve_t_inplace],
138+
);
139+
}
140+
}
11141
}
12142

13143
#[test]
14-
fn solve_random_t() {
15-
let a: Array2<f64> = random((3, 3).f());
16-
let x: Array1<f64> = random(3);
17-
let b = a.dot(&x);
18-
let y = a.solve_into(b).unwrap();
19-
assert_close_l2!(&x, &y, 1e-7);
144+
fn solve_t_random_complex() {
145+
for n in 0..=8 {
146+
for &set_f in &[false, true] {
147+
test_solve_all!(
148+
[c32 => 1e-3, c64 => 1e-9],
149+
a = random([n; 2].set_f(set_f)),
150+
x = random(n),
151+
b = a.t().dot(&x),
152+
[solve_t, solve_t_into, solve_t_inplace],
153+
);
154+
}
155+
}
20156
}
21157

22158
#[test]
23-
fn solve_factorized() {
24-
let a: Array2<f64> = random((3, 3));
25-
let ans: Array1<f64> = random(3);
26-
let b = a.dot(&ans);
27-
let f = a.factorize_into().unwrap();
28-
let x = f.solve_into(b).unwrap();
29-
assert_close_l2!(&x, &ans, 1e-7);
159+
fn solve_h_random_float() {
160+
for n in 0..=8 {
161+
for &set_f in &[false, true] {
162+
test_solve_all!(
163+
[f32 => 1e-3, f64 => 1e-9],
164+
a = random([n; 2].set_f(set_f)),
165+
x = random(n),
166+
b = a.t().mapv(|x| x.conj()).dot(&x),
167+
[solve_h, solve_h_into, solve_h_inplace],
168+
);
169+
}
170+
}
30171
}
31172

32173
#[test]
33-
fn solve_factorized_t() {
34-
let a: Array2<f64> = random((3, 3).f());
35-
let ans: Array1<f64> = random(3);
36-
let b = a.dot(&ans);
37-
let f = a.factorize_into().unwrap();
38-
let x = f.solve_into(b).unwrap();
39-
assert_close_l2!(&x, &ans, 1e-7);
174+
fn solve_h_random_complex() {
175+
for n in 0..=8 {
176+
for &set_f in &[false, true] {
177+
test_solve_all!(
178+
[c32 => 1e-3, c64 => 1e-9],
179+
a = random([n; 2].set_f(set_f)),
180+
x = random(n),
181+
b = a.t().mapv(|x| x.conj()).dot(&x),
182+
[solve_h, solve_h_into, solve_h_inplace],
183+
);
184+
}
185+
}
40186
}
41187

42188
#[test]

0 commit comments

Comments
 (0)