Skip to content

Commit df396b5

Browse files
committed
WIP
1 parent e352347 commit df396b5

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

lax/src/cholesky.rs

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,29 @@ macro_rules! impl_cholesky {
2424
impl Cholesky_ for $scalar {
2525
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
2626
let (n, _) = l.size();
27+
let mut info = 0;
28+
let uplo = match l {
29+
MatrixLayout::F { .. } => uplo,
30+
MatrixLayout::C { .. } => uplo.t(),
31+
};
2732
unsafe {
28-
$trf(l.lapacke_layout(), uplo as u8, n, a, n).as_lapack_result()?;
33+
$trf(uplo as u8, n, a, n, &mut info);
2934
}
35+
info.as_lapack_result()?;
3036
Ok(())
3137
}
3238

3339
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
3440
let (n, _) = l.size();
41+
let mut info = 0;
42+
let uplo = match l {
43+
MatrixLayout::F { .. } => uplo,
44+
MatrixLayout::C { .. } => uplo.t(),
45+
};
3546
unsafe {
36-
$tri(l.lapacke_layout(), uplo as u8, n, a, l.lda()).as_lapack_result()?;
47+
$tri(uplo as u8, n, a, l.lda(), &mut info);
3748
}
49+
info.as_lapack_result()?;
3850
Ok(())
3951
}
4052

@@ -46,18 +58,22 @@ macro_rules! impl_cholesky {
4658
) -> Result<()> {
4759
let (n, _) = l.size();
4860
let nrhs = 1;
49-
let ldb = 1;
61+
let uplo = match l {
62+
MatrixLayout::F { .. } => uplo,
63+
MatrixLayout::C { .. } => uplo.t(),
64+
};
65+
let mut info = 0;
5066
unsafe {
51-
$trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), b, ldb)
52-
.as_lapack_result()?;
67+
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
5368
}
69+
info.as_lapack_result()?;
5470
Ok(())
5571
}
5672
}
5773
};
5874
} // end macro_rules
5975

60-
impl_cholesky!(f64, lapacke::dpotrf, lapacke::dpotri, lapacke::dpotrs);
61-
impl_cholesky!(f32, lapacke::spotrf, lapacke::spotri, lapacke::spotrs);
62-
impl_cholesky!(c64, lapacke::zpotrf, lapacke::zpotri, lapacke::zpotrs);
63-
impl_cholesky!(c32, lapacke::cpotrf, lapacke::cpotri, lapacke::cpotrs);
76+
impl_cholesky!(f64, lapack::dpotrf, lapack::dpotri, lapack::dpotrs);
77+
impl_cholesky!(f32, lapack::spotrf, lapack::spotri, lapack::spotrs);
78+
impl_cholesky!(c64, lapack::zpotrf, lapack::zpotri, lapack::zpotrs);
79+
impl_cholesky!(c32, lapack::cpotrf, lapack::cpotri, lapack::cpotrs);

lax/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ pub enum UPLO {
126126
Lower = b'L',
127127
}
128128

129+
impl UPLO {
130+
pub fn t(self) -> Self {
131+
match self {
132+
UPLO::Upper => UPLO::Lower,
133+
UPLO::Lower => UPLO::Upper,
134+
}
135+
}
136+
}
137+
129138
#[derive(Debug, Clone, Copy)]
130139
#[repr(u8)]
131140
pub enum Transpose {

0 commit comments

Comments
 (0)