Skip to content

Commit a46a055

Browse files
committed
Add FactorizedCholesky
This provides additional operations based on Cholesky decomposition, including matrix inverses, determinants, and solutions to linear systems.
1 parent 1b17b65 commit a46a055

File tree

3 files changed

+318
-51
lines changed

3 files changed

+318
-51
lines changed

src/cholesky.rs

Lines changed: 184 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,45 @@
66
//!
77
//! # Example
88
//!
9-
//! Calculate `L` in the Cholesky decomposition `A = L * L^H`, where `A` is a
10-
//! Hermitian (or real symmetric) positive definite matrix:
9+
//! Using the Cholesky decomposition of `A` for various operations, where `A`
10+
//! is a Hermitian (or real symmetric) positive definite matrix:
1111
//!
1212
//! ```
1313
//! #[macro_use]
1414
//! extern crate ndarray;
1515
//! extern crate ndarray_linalg;
1616
//!
1717
//! use ndarray::prelude::*;
18-
//! use ndarray_linalg::{CholeskyInto, UPLO};
18+
//! use ndarray_linalg::{Cholesky, UPLO};
1919
//! # fn main() {
2020
//!
2121
//! let a: Array2<f64> = array![
2222
//! [ 4., 12., -16.],
2323
//! [ 12., 37., -43.],
2424
//! [-16., -43., 98.]
2525
//! ];
26-
//! let lower = a.cholesky_into(UPLO::Lower).unwrap();
27-
//! assert!(lower.all_close(&array![
26+
//! let chol_lower = a.cholesky(UPLO::Lower).unwrap();
27+
//!
28+
//! // Examine `L`
29+
//! assert!(chol_lower.factor.all_close(&array![
2830
//! [ 2., 0., 0.],
2931
//! [ 6., 1., 0.],
3032
//! [-8., 5., 3.]
3133
//! ], 1e-9));
34+
//!
35+
//! // Find the determinant of `A`
36+
//! let det = chol_lower.det();
37+
//! assert!((det - 36.).abs() < 1e-9);
38+
//!
39+
//! // Solve `A * x = b`
40+
//! let b = array![4., 13., -11.];
41+
//! let x = chol_lower.solve(&b).unwrap();
42+
//! assert!(x.all_close(&array![-2., 1., 0.], 1e-9));
3243
//! # }
3344
//! ```
3445
3546
use ndarray::*;
47+
use num_traits::Float;
3648

3749
use super::convert::*;
3850
use super::error::*;
@@ -42,79 +54,212 @@ use super::types::*;
4254

4355
pub use lapack_traits::UPLO;
4456

57+
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
58+
pub struct FactorizedCholesky<S>
59+
where
60+
S: Data,
61+
{
62+
/// `L` from the decomposition `A = L * L^H` or `U` from the decomposition
63+
/// `A = U^H * U`.
64+
pub factor: ArrayBase<S, Ix2>,
65+
/// If this is `UPLO::Lower`, then `self.factor` is `L`. If this is
66+
/// `UPLO::Upper`, then `self.factor` is `U`.
67+
pub uplo: UPLO,
68+
}
69+
70+
impl<A, S> FactorizedCholesky<S>
71+
where
72+
A: Scalar,
73+
S: DataMut<Elem = A>,
74+
{
75+
/// Returns `L` from the Cholesky decomposition `A = L * L^H`.
76+
///
77+
/// If `self.uplo == UPLO::Lower`, then no computations need to be
78+
/// performed; otherwise, the conjugate transpose of `self.factor` is
79+
/// calculated.
80+
pub fn into_lower(self) -> ArrayBase<S, Ix2> {
81+
match self.uplo {
82+
UPLO::Lower => self.factor,
83+
UPLO::Upper => self.factor.reversed_axes().mapv_into(|elem| elem.conj()),
84+
}
85+
}
86+
87+
/// Returns `U` from the Cholesky decomposition `A = U^H * U`.
88+
///
89+
/// If `self.uplo == UPLO::Upper`, then no computations need to be
90+
/// performed; otherwise, the conjugate transpose of `self.factor` is
91+
/// calculated.
92+
pub fn into_upper(self) -> ArrayBase<S, Ix2> {
93+
match self.uplo {
94+
UPLO::Lower => self.factor.reversed_axes().mapv_into(|elem| elem.conj()),
95+
UPLO::Upper => self.factor,
96+
}
97+
}
98+
99+
/// Computes the inverse of the Cholesky-factored matrix.
100+
///
101+
/// **Warning: The inverse is stored only in the triangular portion of the
102+
/// result matrix corresponding to `self.uplo`!** If you want the other
103+
/// triangular portion to be correct, you must fill it in yourself.
104+
pub fn into_inverse(mut self) -> Result<ArrayBase<S, Ix2>> {
105+
unsafe {
106+
A::inv_cholesky(
107+
self.factor.square_layout()?,
108+
self.uplo,
109+
self.factor.as_allocated_mut()?,
110+
)?
111+
};
112+
Ok(self.factor)
113+
}
114+
}
115+
116+
impl<A, S> FactorizedCholesky<S>
117+
where
118+
A: Absolute,
119+
S: Data<Elem = A>,
120+
{
121+
/// Computes the natural log of the determinant of the Cholesky-factored
122+
/// matrix.
123+
pub fn ln_det(&self) -> <A as AssociatedReal>::Real {
124+
self.factor
125+
.diag()
126+
.iter()
127+
.map(|elem| elem.abs_sqr().ln())
128+
.sum()
129+
}
130+
131+
/// Computes the determinant of the Cholesky-factored matrix.
132+
pub fn det(&self) -> <A as AssociatedReal>::Real {
133+
self.ln_det().exp()
134+
}
135+
}
136+
137+
impl<A, S> FactorizedCholesky<S>
138+
where
139+
A: Scalar,
140+
S: Data<Elem = A>,
141+
{
142+
/// Solves a system of linear equations `A * x = b`, where `self` is the
143+
/// Cholesky factorization of `A`, `b` is the argument, and `x` is the
144+
/// successful result.
145+
pub fn solve<Sb>(&self, b: &ArrayBase<Sb, Ix1>) -> Result<Array1<A>>
146+
where
147+
Sb: Data<Elem = A>,
148+
{
149+
let mut b = replicate(b);
150+
self.solve_mut(&mut b)?;
151+
Ok(b)
152+
}
153+
154+
/// Solves a system of linear equations `A * x = b`, where `self` is the
155+
/// Cholesky factorization `A`, `b` is the argument, and `x` is the
156+
/// successful result.
157+
pub fn solve_into<Sb>(&self, mut b: ArrayBase<Sb, Ix1>) -> Result<ArrayBase<Sb, Ix1>>
158+
where
159+
Sb: DataMut<Elem = A>,
160+
{
161+
self.solve_mut(&mut b)?;
162+
Ok(b)
163+
}
164+
165+
/// Solves a system of linear equations `A * x = b`, where `self` is the
166+
/// Cholesky factorization of `A`, `b` is the argument, and `x` is the
167+
/// successful result. The value of `x` is also assigned to the argument.
168+
pub fn solve_mut<'a, Sb>(&self, b: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
169+
where
170+
Sb: DataMut<Elem = A>,
171+
{
172+
unsafe {
173+
A::solve_cholesky(
174+
self.factor.square_layout()?,
175+
self.uplo,
176+
self.factor.as_allocated()?,
177+
b.as_slice_mut().unwrap(),
178+
)?
179+
};
180+
Ok(b)
181+
}
182+
}
183+
45184
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix reference
46-
pub trait Cholesky {
47-
type Output;
185+
pub trait Cholesky<S: Data> {
48186
/// Computes the Cholesky decomposition of the Hermitian (or real
49187
/// symmetric) positive definite matrix.
50188
///
51189
/// If the argument is `UPLO::Upper`, then computes the decomposition `A =
52-
/// U^H * U` using the upper triangular portion of `A` and returns `U`.
53-
/// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
54-
/// `A = L * L^H` using the lower triangular portion of `A` and returns
55-
/// `L`.
56-
fn cholesky(&self, UPLO) -> Result<Self::Output>;
190+
/// U^H * U` using the upper triangular portion of `A` and returns the
191+
/// factorization containing `U`. Otherwise, if the argument is
192+
/// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower
193+
/// triangular portion of `A` and returns the factorization containing `L`.
194+
fn cholesky(&self, UPLO) -> Result<FactorizedCholesky<S>>;
57195
}
58196

59197
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite matrix
60-
pub trait CholeskyInto: Sized {
198+
pub trait CholeskyInto<S: Data> {
61199
/// Computes the Cholesky decomposition of the Hermitian (or real
62200
/// symmetric) positive definite matrix.
63201
///
64202
/// If the argument is `UPLO::Upper`, then computes the decomposition `A =
65-
/// U^H * U` using the upper triangular portion of `A` and returns `U`.
66-
/// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
67-
/// `A = L * L^H` using the lower triangular portion of `A` and returns
68-
/// `L`.
69-
fn cholesky_into(self, UPLO) -> Result<Self>;
203+
/// U^H * U` using the upper triangular portion of `A` and returns the
204+
/// factorization containing `U`. Otherwise, if the argument is
205+
/// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower
206+
/// triangular portion of `A` and returns the factorization containing `L`.
207+
fn cholesky_into(self, UPLO) -> Result<FactorizedCholesky<S>>;
70208
}
71209

72210
/// Cholesky decomposition of Hermitian (or real symmetric) positive definite mutable reference of matrix
73-
pub trait CholeskyMut {
211+
pub trait CholeskyMut<'a, S: Data> {
74212
/// Computes the Cholesky decomposition of the Hermitian (or real
75-
/// symmetric) positive definite matrix, storing the result in `self` and
76-
/// returning it.
213+
/// symmetric) positive definite matrix, storing the result (`L` or `U`
214+
/// according to the argument) in `self` and returning the factorization.
77215
///
78216
/// If the argument is `UPLO::Upper`, then computes the decomposition `A =
79-
/// U^H * U` using the upper triangular portion of `A` and returns `U`.
80-
/// Otherwise, if the argument is `UPLO::Lower`, computes the decomposition
81-
/// `A = L * L^H` using the lower triangular portion of `A` and returns
82-
/// `L`.
83-
fn cholesky_mut(&mut self, UPLO) -> Result<&mut Self>;
217+
/// U^H * U` using the upper triangular portion of `A` and returns the
218+
/// factorization containing `U`. Otherwise, if the argument is
219+
/// `UPLO::Lower`, computes the decomposition `A = L * L^H` using the lower
220+
/// triangular portion of `A` and returns the factorization containing `L`.
221+
fn cholesky_mut(&'a mut self, UPLO) -> Result<FactorizedCholesky<S>>;
84222
}
85223

86-
impl<A, S> CholeskyInto for ArrayBase<S, Ix2>
224+
impl<A, S> CholeskyInto<S> for ArrayBase<S, Ix2>
87225
where
88226
A: Scalar,
89227
S: DataMut<Elem = A>,
90228
{
91-
fn cholesky_into(mut self, uplo: UPLO) -> Result<Self> {
229+
fn cholesky_into(mut self, uplo: UPLO) -> Result<FactorizedCholesky<S>> {
92230
unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? };
93-
Ok(self.into_triangular(uplo))
231+
Ok(FactorizedCholesky {
232+
factor: self.into_triangular(uplo),
233+
uplo: uplo,
234+
})
94235
}
95236
}
96237

97-
impl<A, S> CholeskyMut for ArrayBase<S, Ix2>
238+
impl<'a, A, Si> CholeskyMut<'a, ViewRepr<&'a mut A>> for ArrayBase<Si, Ix2>
98239
where
99240
A: Scalar,
100-
S: DataMut<Elem = A>,
241+
Si: DataMut<Elem = A>,
101242
{
102-
fn cholesky_mut(&mut self, uplo: UPLO) -> Result<&mut Self> {
243+
fn cholesky_mut(&'a mut self, uplo: UPLO) -> Result<FactorizedCholesky<ViewRepr<&'a mut A>>> {
103244
unsafe { A::cholesky(self.square_layout()?, uplo, self.as_allocated_mut()?)? };
104-
Ok(self.into_triangular(uplo))
245+
Ok(FactorizedCholesky {
246+
factor: self.into_triangular(uplo).view_mut(),
247+
uplo: uplo,
248+
})
105249
}
106250
}
107251

108-
impl<A, S> Cholesky for ArrayBase<S, Ix2>
252+
impl<A, Si> Cholesky<OwnedRepr<A>> for ArrayBase<Si, Ix2>
109253
where
110254
A: Scalar,
111-
S: Data<Elem = A>,
255+
Si: Data<Elem = A>,
112256
{
113-
type Output = Array2<A>;
114-
115-
fn cholesky(&self, uplo: UPLO) -> Result<Self::Output> {
257+
fn cholesky(&self, uplo: UPLO) -> Result<FactorizedCholesky<OwnedRepr<A>>> {
116258
let mut a = replicate(self);
117259
unsafe { A::cholesky(a.square_layout()?, uplo, a.as_allocated_mut()?)? };
118-
Ok(a.into_triangular(uplo))
260+
Ok(FactorizedCholesky {
261+
factor: a.into_triangular(uplo),
262+
uplo: uplo,
263+
})
119264
}
120265
}

src/lapack_traits/cholesky.rs

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,40 @@ use types::*;
99
use super::{UPLO, into_result};
1010

1111
pub trait Cholesky_: Sized {
12+
/// Cholesky: wrapper of `*potrf`
1213
unsafe fn cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>;
14+
/// Wrapper of `*potri`
15+
unsafe fn inv_cholesky(MatrixLayout, UPLO, a: &mut [Self]) -> Result<()>;
16+
/// Wrapper of `*potrs`
17+
unsafe fn solve_cholesky(MatrixLayout, UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
1318
}
1419

1520
macro_rules! impl_cholesky {
16-
($scalar:ty, $potrf:path) => {
21+
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
1722
impl Cholesky_ for $scalar {
1823
unsafe fn cholesky(l: MatrixLayout, uplo: UPLO, mut a: &mut [Self]) -> Result<()> {
1924
let (n, _) = l.size();
20-
let info = $potrf(l.lapacke_layout(), uplo as u8, n, &mut a, n);
25+
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, n);
26+
into_result(info, ())
27+
}
28+
29+
unsafe fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
30+
let (n, _) = l.size();
31+
let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda());
32+
into_result(info, ())
33+
}
34+
35+
unsafe fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()> {
36+
let (n, _) = l.size();
37+
let nrhs = 1;
38+
let ldb = 1;
39+
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb);
2140
into_result(info, ())
2241
}
2342
}
2443
}} // end macro_rules
2544

26-
impl_cholesky!(f64, c::dpotrf);
27-
impl_cholesky!(f32, c::spotrf);
28-
impl_cholesky!(c64, c::zpotrf);
29-
impl_cholesky!(c32, c::cpotrf);
45+
impl_cholesky!(f64, c::dpotrf, c::dpotri, c::dpotrs);
46+
impl_cholesky!(f32, c::spotrf, c::spotri, c::spotrs);
47+
impl_cholesky!(c64, c::zpotrf, c::zpotri, c::zpotrs);
48+
impl_cholesky!(c32, c::cpotrf, c::cpotri, c::cpotrs);

0 commit comments

Comments
 (0)