Skip to content

Commit 28f79ef

Browse files
authored
Merge pull request #90 from jturner314/add-rcond
Add .rcond() and .rcond_into()
2 parents 4635131 + 78554cc commit 28f79ef

File tree

5 files changed

+161
-22
lines changed

5 files changed

+161
-22
lines changed

src/lapack_traits/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,21 @@ pub enum Transpose {
5656
Transpose = b'T',
5757
Hermite = b'C',
5858
}
59+
60+
#[derive(Debug, Clone, Copy)]
61+
#[repr(u8)]
62+
pub enum NormType {
63+
One = b'O',
64+
Infinity = b'I',
65+
Frobenius = b'F',
66+
}
67+
68+
impl NormType {
69+
pub(crate) fn transpose(self) -> Self {
70+
match self {
71+
NormType::One => NormType::Infinity,
72+
NormType::Infinity => NormType::One,
73+
NormType::Frobenius => NormType::Frobenius,
74+
}
75+
}
76+
}

src/lapack_traits/opnorm.rs

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,7 @@ use lapack::c::Layout::ColumnMajor as cm;
66
use layout::MatrixLayout;
77
use types::*;
88

9-
#[repr(u8)]
10-
pub enum NormType {
11-
One = b'o',
12-
Infinity = b'i',
13-
Frobenius = b'f',
14-
}
15-
16-
impl NormType {
17-
fn transpose(self) -> Self {
18-
match self {
19-
NormType::One => NormType::Infinity,
20-
NormType::Infinity => NormType::One,
21-
NormType::Frobenius => NormType::Frobenius,
22-
}
23-
}
24-
}
9+
use super::NormType;
2510

2611
pub trait OperatorNorm_: AssociatedReal {
2712
unsafe fn opnorm(NormType, MatrixLayout, &[Self]) -> Self::Real;

src/lapack_traits/solve.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ use lapack::c;
44

55
use error::*;
66
use layout::MatrixLayout;
7+
use num_traits::Zero;
78
use types::*;
89

910
use super::{Pivot, Transpose, into_result};
11+
use super::NormType;
1012

1113
/// Wraps `*getrf`, `*getri`, and `*getrs`
12-
pub trait Solve_: Sized {
14+
pub trait Solve_: AssociatedReal + Sized {
1315
/// Computes the LU factorization of a general `m x n` matrix `a` using
1416
/// partial pivoting with row interchanges.
1517
///
@@ -20,11 +22,15 @@ pub trait Solve_: Sized {
2022
/// if it is used to solve a system of equations.
2123
unsafe fn lu(MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
2224
unsafe fn inv(MatrixLayout, a: &mut [Self], &Pivot) -> Result<()>;
25+
/// Estimates the the reciprocal of the condition number of the matrix in 1-norm.
26+
///
27+
/// `anorm` should be the 1-norm of the matrix `a`.
28+
unsafe fn rcond(MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real>;
2329
unsafe fn solve(MatrixLayout, Transpose, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
2430
}
2531

2632
macro_rules! impl_solve {
27-
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
33+
($scalar:ty, $getrf:path, $getri:path, $gecon:path, $getrs:path) => {
2834

2935
impl Solve_ for $scalar {
3036
unsafe fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
@@ -41,6 +47,13 @@ impl Solve_ for $scalar {
4147
into_result(info, ())
4248
}
4349

50+
unsafe fn rcond(l: MatrixLayout, a: &[Self], anorm: Self::Real) -> Result<Self::Real> {
51+
let (n, _) = l.size();
52+
let mut rcond = Self::Real::zero();
53+
let info = $gecon(l.lapacke_layout(), NormType::One as u8, n, a, l.lda(), anorm, &mut rcond);
54+
into_result(info, rcond)
55+
}
56+
4457
unsafe fn solve(l: MatrixLayout, t: Transpose, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
4558
let (n, _) = l.size();
4659
let nrhs = 1;
@@ -52,7 +65,7 @@ impl Solve_ for $scalar {
5265

5366
}} // impl_solve!
5467

55-
impl_solve!(f64, c::dgetrf, c::dgetri, c::dgetrs);
56-
impl_solve!(f32, c::sgetrf, c::sgetri, c::sgetrs);
57-
impl_solve!(c64, c::zgetrf, c::zgetri, c::zgetrs);
58-
impl_solve!(c32, c::cgetrf, c::cgetri, c::cgetrs);
68+
impl_solve!(f64, c::dgetrf, c::dgetri, c::dgecon, c::dgetrs);
69+
impl_solve!(f32, c::sgetrf, c::sgetri, c::sgecon, c::sgetrs);
70+
impl_solve!(c64, c::zgetrf, c::zgetri, c::zgecon, c::zgetrs);
71+
impl_solve!(c32, c::cgetrf, c::cgetri, c::cgecon, c::cgetrs);

src/solve.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ use ndarray::*;
5050
use super::convert::*;
5151
use super::error::*;
5252
use super::layout::*;
53+
use super::opnorm::OperationNorm;
5354
use super::types::*;
5455

5556
pub use lapack_traits::{Pivot, Transpose};
@@ -419,3 +420,77 @@ where
419420
}
420421
}
421422
}
423+
424+
/// An interface for *estimating* the reciprocal condition number of matrix refs.
425+
pub trait ReciprocalConditionNum<A: Scalar> {
426+
/// *Estimates* the reciprocal of the condition number of the matrix in
427+
/// 1-norm.
428+
///
429+
/// This method uses the LAPACK `*gecon` routines, which *estimate*
430+
/// `self.inv().opnorm_one()` and then compute `rcond = 1. /
431+
/// (self.opnorm_one() * self.inv().opnorm_one())`.
432+
///
433+
/// * If `rcond` is near `0.`, the matrix is badly conditioned.
434+
/// * If `rcond` is near `1.`, the matrix is well conditioned.
435+
fn rcond(&self) -> Result<A::Real>;
436+
}
437+
438+
/// An interface for *estimating* the reciprocal condition number of matrices.
439+
pub trait ReciprocalConditionNumInto<A: Scalar> {
440+
/// *Estimates* the reciprocal of the condition number of the matrix in
441+
/// 1-norm.
442+
///
443+
/// This method uses the LAPACK `*gecon` routines, which *estimate*
444+
/// `self.inv().opnorm_one()` and then compute `rcond = 1. /
445+
/// (self.opnorm_one() * self.inv().opnorm_one())`.
446+
///
447+
/// * If `rcond` is near `0.`, the matrix is badly conditioned.
448+
/// * If `rcond` is near `1.`, the matrix is well conditioned.
449+
fn rcond_into(self) -> Result<A::Real>;
450+
}
451+
452+
impl<A, S> ReciprocalConditionNum<A> for LUFactorized<S>
453+
where
454+
A: Scalar,
455+
S: Data<Elem = A>,
456+
{
457+
fn rcond(&self) -> Result<A::Real> {
458+
unsafe {
459+
A::rcond(
460+
self.a.layout()?,
461+
self.a.as_allocated()?,
462+
self.a.opnorm_one()?,
463+
)
464+
}
465+
}
466+
}
467+
468+
impl<A, S> ReciprocalConditionNumInto<A> for LUFactorized<S>
469+
where
470+
A: Scalar,
471+
S: Data<Elem = A>,
472+
{
473+
fn rcond_into(self) -> Result<A::Real> {
474+
self.rcond()
475+
}
476+
}
477+
478+
impl<A, S> ReciprocalConditionNum<A> for ArrayBase<S, Ix2>
479+
where
480+
A: Scalar,
481+
S: Data<Elem = A>,
482+
{
483+
fn rcond(&self) -> Result<A::Real> {
484+
self.factorize()?.rcond_into()
485+
}
486+
}
487+
488+
impl<A, S> ReciprocalConditionNumInto<A> for ArrayBase<S, Ix2>
489+
where
490+
A: Scalar,
491+
S: DataMut<Elem = A>,
492+
{
493+
fn rcond_into(self) -> Result<A::Real> {
494+
self.factorize_into()?.rcond_into()
495+
}
496+
}

tests/solve.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,51 @@ fn solve_random_t() {
2323
let y = a.solve_into(b).unwrap();
2424
assert_close_l2!(&x, &y, 1e-7);
2525
}
26+
27+
#[test]
28+
fn rcond() {
29+
macro_rules! rcond {
30+
($elem:ty, $rows:expr, $atol:expr) => {
31+
let a: Array2<$elem> = random(($rows, $rows));
32+
let rcond = 1. / (a.opnorm_one().unwrap() * a.inv().unwrap().opnorm_one().unwrap());
33+
assert_aclose!(a.rcond().unwrap(), rcond, $atol);
34+
assert_aclose!(a.rcond_into().unwrap(), rcond, $atol);
35+
}
36+
}
37+
for rows in 1..6 {
38+
rcond!(f64, rows, 0.2);
39+
rcond!(f32, rows, 0.5);
40+
rcond!(c64, rows, 0.2);
41+
rcond!(c32, rows, 0.5);
42+
}
43+
}
44+
45+
#[test]
46+
fn rcond_hilbert() {
47+
macro_rules! rcond_hilbert {
48+
($elem:ty, $rows:expr, $atol:expr) => {
49+
let a = Array2::<$elem>::from_shape_fn(($rows, $rows), |(i, j)| 1. / (i as $elem + j as $elem - 1.));
50+
assert_aclose!(a.rcond().unwrap(), 0., $atol);
51+
assert_aclose!(a.rcond_into().unwrap(), 0., $atol);
52+
}
53+
}
54+
rcond_hilbert!(f64, 10, 1e-9);
55+
rcond_hilbert!(f32, 10, 1e-3);
56+
}
57+
58+
#[test]
59+
fn rcond_identity() {
60+
macro_rules! rcond_identity {
61+
($elem:ty, $rows:expr, $atol:expr) => {
62+
let a = Array2::<$elem>::eye($rows);
63+
assert_aclose!(a.rcond().unwrap(), 1., $atol);
64+
assert_aclose!(a.rcond_into().unwrap(), 1., $atol);
65+
}
66+
}
67+
for rows in 1..6 {
68+
rcond_identity!(f64, rows, 1e-9);
69+
rcond_identity!(f32, rows, 1e-3);
70+
rcond_identity!(c64, rows, 1e-9);
71+
rcond_identity!(c32, rows, 1e-3);
72+
}
73+
}

0 commit comments

Comments
 (0)