Skip to content

Commit 0f5f87f

Browse files
committed
Add ImplSolve::solve to reuse LU factorized matrix
1 parent 86f38a7 commit 0f5f87f

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

src/solve.rs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,22 @@ use std::cmp::min;
66
use error::LapackError;
77

88
pub trait ImplSolve: Sized {
9-
fn inv(layout: Layout, size: usize, a: Vec<Self>) -> Result<Vec<Self>, LapackError>;
9+
/// execute LU decomposition
1010
fn lu(layout: Layout, m: usize, n: usize, a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError>;
11+
/// calc inverse matrix with LU factorized matrix
12+
fn inv(layout: Layout, size: usize, a: Vec<Self>, ipiv: &Vec<i32>) -> Result<Vec<Self>, LapackError>;
13+
/// solve linear problem with LU factorized matrix
14+
fn solve(layout: Layout,
15+
size: usize,
16+
a: &Vec<Self>,
17+
ipiv: &Vec<i32>,
18+
b: Vec<Self>)
19+
-> Result<Vec<Self>, LapackError>;
1120
}
1221

1322
macro_rules! impl_solve {
14-
($scalar:ty, $getrf:path, $getri:path, $laswp:path) => {
23+
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
1524
impl ImplSolve for $scalar {
16-
fn inv(layout: Layout, size: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError> {
17-
let n = size as i32;
18-
let lda = n;
19-
let mut ipiv = vec![0; size];
20-
let info = $getrf(layout, n, n, &mut a, lda, &mut ipiv);
21-
if info != 0 {
22-
return Err(From::from(info));
23-
}
24-
let info = $getri(layout, n, &mut a, lda, &mut ipiv);
25-
if info == 0 {
26-
Ok(a)
27-
} else {
28-
Err(From::from(info))
29-
}
30-
}
3125
fn lu(layout: Layout, m: usize, n: usize, mut a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError> {
3226
let m = m as i32;
3327
let n = n as i32;
@@ -44,8 +38,28 @@ impl ImplSolve for $scalar {
4438
Err(From::from(info))
4539
}
4640
}
41+
fn inv(layout: Layout, size: usize, mut a: Vec<Self>, ipiv: &Vec<i32>) -> Result<Vec<Self>, LapackError> {
42+
let n = size as i32;
43+
let lda = n;
44+
let info = $getri(layout, n, &mut a, lda, &ipiv);
45+
if info == 0 {
46+
Ok(a)
47+
} else {
48+
Err(From::from(info))
49+
}
50+
}
51+
fn solve(layout: Layout, size: usize, a: &Vec<Self>, ipiv: &Vec<i32>, mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
52+
let n = size as i32;
53+
let lda = n;
54+
let info = $getrs(layout, 'N' as u8, n, 1, a, lda, &ipiv, &mut b, n);
55+
if info == 0 {
56+
Ok(b)
57+
} else {
58+
Err(From::from(info))
59+
}
60+
}
4761
}
4862
}} // end macro_rules
4963

50-
impl_solve!(f64, dgetrf, dgetri, dlaswp);
51-
impl_solve!(f32, sgetrf, sgetri, slaswp);
64+
impl_solve!(f64, dgetrf, dgetri, dgetrs);
65+
impl_solve!(f32, sgetrf, sgetri, sgetrs);

src/square.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ impl<A> SquareMatrix for Array<A, Ix2>
4444
self.check_square()?;
4545
let (n, _) = self.size();
4646
let layout = self.layout()?;
47-
let a = ImplSolve::inv(layout, n, self.into_raw_vec())?;
47+
let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?;
48+
let a = ImplSolve::inv(layout, n, a, &ipiv)?;
4849
let m = Array::from_vec(a).into_shape((n, n)).unwrap();
4950
match layout {
5051
Layout::RowMajor => Ok(m),

0 commit comments

Comments
 (0)