Skip to content

Commit 06a6c65

Browse files
committed
square_transpose
1 parent df396b5 commit 06a6c65

File tree

2 files changed

+60
-14
lines changed

2 files changed

+60
-14
lines changed

lax/src/cholesky.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Cholesky decomposition
22
33
use super::*;
4-
use crate::{error::*, layout::MatrixLayout};
4+
use crate::{error::*, layout::*};
55
use cauchy::*;
66

77
pub trait Cholesky_: Sized {
@@ -24,45 +24,48 @@ macro_rules! impl_cholesky {
2424
impl Cholesky_ for $scalar {
2525
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
2626
let (n, _) = l.size();
27+
if matches!(l, MatrixLayout::C { .. }) {
28+
square_transpose(l, a);
29+
}
2730
let mut info = 0;
28-
let uplo = match l {
29-
MatrixLayout::F { .. } => uplo,
30-
MatrixLayout::C { .. } => uplo.t(),
31-
};
3231
unsafe {
3332
$trf(uplo as u8, n, a, n, &mut info);
3433
}
3534
info.as_lapack_result()?;
35+
if matches!(l, MatrixLayout::C { .. }) {
36+
square_transpose(l, a);
37+
}
3638
Ok(())
3739
}
3840

3941
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
4042
let (n, _) = l.size();
43+
if matches!(l, MatrixLayout::C { .. }) {
44+
square_transpose(l, a);
45+
}
4146
let mut info = 0;
42-
let uplo = match l {
43-
MatrixLayout::F { .. } => uplo,
44-
MatrixLayout::C { .. } => uplo.t(),
45-
};
4647
unsafe {
4748
$tri(uplo as u8, n, a, l.lda(), &mut info);
4849
}
4950
info.as_lapack_result()?;
51+
if matches!(l, MatrixLayout::C { .. }) {
52+
square_transpose(l, a);
53+
}
5054
Ok(())
5155
}
5256

5357
fn solve_cholesky(
5458
l: MatrixLayout,
55-
uplo: UPLO,
59+
mut uplo: UPLO,
5660
a: &[Self],
5761
b: &mut [Self],
5862
) -> Result<()> {
5963
let (n, _) = l.size();
6064
let nrhs = 1;
61-
let uplo = match l {
62-
MatrixLayout::F { .. } => uplo,
63-
MatrixLayout::C { .. } => uplo.t(),
64-
};
6565
let mut info = 0;
66+
if matches!(l, MatrixLayout::C { .. }) {
67+
uplo = uplo.t();
68+
}
6669
unsafe {
6770
$trs(uplo as u8, n, nrhs, a, l.lda(), b, n, &mut info);
6871
}

lax/src/layout.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
//! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`.
3838
//!
3939
40+
use cauchy::Scalar;
41+
4042
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4143
pub enum MatrixLayout {
4244
C { row: i32, lda: i32 },
@@ -96,3 +98,44 @@ impl MatrixLayout {
9698
}
9799
}
98100
}
101+
102+
/// In-place transpose of a square matrix by keeping F/C layout
103+
///
104+
/// Transpose for C-continuous array
105+
///
106+
/// ```rust
107+
/// # use lax::layout::*;
108+
/// let layout = MatrixLayout::C { row: 2, lda: 2 };
109+
/// let mut a = vec![1., 2., 3., 4.];
110+
/// square_transpose(layout, &mut a);
111+
/// assert_eq!(a, &[1., 3., 2., 4.]);
112+
/// ```
113+
///
114+
/// Transpose for F-continuous array
115+
///
116+
/// ```rust
117+
/// # use lax::layout::*;
118+
/// let layout = MatrixLayout::F { col: 2, lda: 2 };
119+
/// let mut a = vec![1., 3., 2., 4.];
120+
/// square_transpose(layout, &mut a);
121+
/// assert_eq!(a, &[1., 2., 3., 4.]);
122+
/// ```
123+
///
124+
/// Panics
125+
/// ------
126+
/// - If size of `a` and `layout` size mismatch
127+
///
128+
pub fn square_transpose<T: Scalar>(layout: MatrixLayout, a: &mut [T]) {
129+
let (m, n) = layout.size();
130+
let n = n as usize;
131+
let m = m as usize;
132+
assert_eq!(a.len(), n * m);
133+
for i in 0..m {
134+
for j in (i + 1)..n {
135+
let a_ij = a[i * n + j];
136+
let a_ji = a[j * m + i];
137+
a[i * n + j] = a_ji.conj();
138+
a[j * m + i] = a_ij.conj();
139+
}
140+
}
141+
}

0 commit comments

Comments
 (0)