Skip to content

Commit 7d54e16

Browse files
authored
Merge pull request #18 from termoshtt/lu
LU decomposition
2 parents 6b2549a + e6423ce commit 7d54e16

File tree

6 files changed

+269
-18
lines changed

6 files changed

+269
-18
lines changed

src/hermite.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Define trait for Hermite matrices
22
3+
use std::fmt::Debug;
34
use ndarray::prelude::*;
45
use ndarray::LinalgScalar;
56
use num_traits::float::Float;
@@ -22,7 +23,7 @@ pub trait HermiteMatrix: SquareMatrix + Matrix {
2223
}
2324

2425
impl<A> HermiteMatrix for Array<A, (Ix, Ix)>
25-
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float
26+
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + ImplEigh + LinalgScalar + Float + Debug
2627
{
2728
fn eigh(self) -> Result<(Self::Vector, Self), LinalgError> {
2829
try!(self.check_square());

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
//! - [WIP] Cholesky factorization
2626
2727
extern crate lapack;
28-
extern crate ndarray;
2928
extern crate num_traits;
29+
#[macro_use(s)]
30+
extern crate ndarray;
3031

3132
pub mod prelude;
3233
pub mod error;

src/matrix.rs

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
//! Define trait for general matrix
22
33
use std::cmp::min;
4+
use std::fmt::Debug;
45
use ndarray::prelude::*;
56
use ndarray::LinalgScalar;
7+
use lapack::c::Layout;
68

79
use error::LapackError;
810
use qr::ImplQR;
911
use svd::ImplSVD;
1012
use norm::ImplNorm;
13+
use solve::ImplSolve;
1114

1215
/// Methods for general matrices
1316
pub trait Matrix: Sized {
1417
type Scalar;
1518
type Vector;
19+
type Permutator;
1620
/// number of (rows, columns)
1721
fn size(&self) -> (usize, usize);
22+
/// Layout (C/Fortran) of matrix
23+
fn layout(&self) -> Layout;
1824
/// Operator norm for L-1 norm
1925
fn norm_1(&self) -> Self::Scalar;
2026
/// Operator norm for L-inf norm
@@ -25,16 +31,35 @@ pub trait Matrix: Sized {
2531
fn svd(self) -> Result<(Self, Self::Vector, Self), LapackError>;
2632
/// QR decomposition
2733
fn qr(self) -> Result<(Self, Self), LapackError>;
34+
/// LU decomposition
35+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError>;
36+
/// permutate matrix (inplace)
37+
fn permutate(&mut self, p: &Self::Permutator);
38+
/// permutate matrix (outplace)
39+
fn permutated(mut self, p: &Self::Permutator) -> Self {
40+
self.permutate(p);
41+
self
42+
}
2843
}
2944

3045
impl<A> Matrix for Array<A, (Ix, Ix)>
31-
where A: ImplQR + ImplSVD + ImplNorm + LinalgScalar
46+
where A: ImplQR + ImplSVD + ImplNorm + ImplSolve + LinalgScalar + Debug
3247
{
3348
type Scalar = A;
3449
type Vector = Array<A, Ix>;
50+
type Permutator = Vec<i32>;
51+
3552
fn size(&self) -> (usize, usize) {
3653
(self.rows(), self.cols())
3754
}
55+
fn layout(&self) -> Layout {
56+
let strides = self.strides();
57+
if strides[0] < strides[1] {
58+
Layout::ColumnMajor
59+
} else {
60+
Layout::RowMajor
61+
}
62+
}
3863
fn norm_1(&self) -> Self::Scalar {
3964
let (m, n) = self.size();
4065
let strides = self.strides();
@@ -112,4 +137,54 @@ impl<A> Matrix for Array<A, (Ix, Ix)>
112137
}
113138
Ok((qm, rm))
114139
}
140+
fn lu(self) -> Result<(Self::Permutator, Self, Self), LapackError> {
141+
let (n, m) = self.size();
142+
println!("n={}, m={}", n, m);
143+
let k = min(n, m);
144+
let (p, mut a) = match self.layout() {
145+
Layout::ColumnMajor => {
146+
println!("ColumnMajor");
147+
let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?;
148+
(p, Array::from_vec(l).into_shape((m, n)).unwrap().reversed_axes())
149+
}
150+
Layout::RowMajor => {
151+
println!("RowMajor");
152+
let (p, l) = ImplSolve::lu(self.layout(), n, m, self.clone().into_raw_vec())?;
153+
(p, Array::from_vec(l).into_shape((n, m)).unwrap())
154+
}
155+
};
156+
println!("a (after LU) = \n{:?}", &a);
157+
let mut lm = Array::zeros((n, k));
158+
for ((i, j), val) in lm.indexed_iter_mut() {
159+
if i > j {
160+
*val = a[(i, j)];
161+
} else if i == j {
162+
*val = A::one();
163+
}
164+
}
165+
for ((i, j), val) in a.indexed_iter_mut() {
166+
if i > j {
167+
*val = A::zero();
168+
}
169+
}
170+
let am = if n > k {
171+
a.slice(s![0..k as isize, ..]).to_owned()
172+
} else {
173+
a
174+
};
175+
println!("am = \n{:?}", am);
176+
Ok((p, lm, am))
177+
}
178+
fn permutate(&mut self, ipiv: &Self::Permutator) {
179+
let (_, m) = self.size();
180+
for (i, j_) in ipiv.iter().enumerate().rev() {
181+
let j = (j_ - 1) as usize;
182+
if i == j {
183+
continue;
184+
}
185+
for k in 0..m {
186+
self.swap((i, k), (j, k));
187+
}
188+
}
189+
}
115190
}

src/solve.rs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,55 @@
11
//! Implement linear solver and inverse matrix
22
3-
use lapack::fortran::*;
4-
use num_traits::Zero;
3+
use lapack::c::*;
4+
use std::cmp::min;
55

66
use error::LapackError;
77

88
pub trait ImplSolve: Sized {
9-
fn inv(size: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError>;
9+
fn inv(layout: Layout, size: usize, a: Vec<Self>) -> Result<Vec<Self>, LapackError>;
10+
fn lu(layout: Layout,
11+
m: usize,
12+
n: usize,
13+
a: Vec<Self>)
14+
-> Result<(Vec<i32>, Vec<Self>), LapackError>;
1015
}
1116

1217
macro_rules! impl_solve {
13-
($scalar:ty, $getrf:path, $getri:path) => {
18+
($scalar:ty, $getrf:path, $getri:path, $laswp:path) => {
1419
impl ImplSolve for $scalar {
15-
fn inv(size: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError> {
20+
fn inv(layout: Layout, size: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError> {
1621
let n = size as i32;
1722
let lda = n;
1823
let mut ipiv = vec![0; size];
19-
let mut info = 0;
20-
$getrf(n, n, &mut a, lda, &mut ipiv, &mut info);
24+
let info = $getrf(layout, n, n, &mut a, lda, &mut ipiv);
2125
if info != 0 {
2226
return Err(From::from(info));
2327
}
24-
let lwork = n;
25-
let mut work = vec![Self::zero(); size];
26-
$getri(n, &mut a, lda, &mut ipiv, &mut work, lwork, &mut info);
28+
let info = $getri(layout, n, &mut a, lda, &mut ipiv);
2729
if info == 0 {
2830
Ok(a)
2931
} else {
3032
Err(From::from(info))
3133
}
3234
}
35+
fn lu(layout: Layout, m: usize, n: usize, mut a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError> {
36+
let m = m as i32;
37+
let n = n as i32;
38+
let k = min(m, n);
39+
let lda = match layout {
40+
Layout::ColumnMajor => m,
41+
Layout::RowMajor => n,
42+
};
43+
let mut ipiv = vec![0; k as usize];
44+
let info = $getrf(layout, m, n, &mut a, lda, &mut ipiv);
45+
if info == 0 {
46+
Ok((ipiv, a))
47+
} else {
48+
Err(From::from(info))
49+
}
50+
}
3351
}
3452
}} // end macro_rules
3553

36-
impl_solve!(f64, dgetrf, dgetri);
37-
impl_solve!(f32, sgetrf, sgetri);
54+
impl_solve!(f64, dgetrf, dgetri, dlaswp);
55+
impl_solve!(f32, sgetrf, sgetri, slaswp);

src/square.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Define trait for Hermite matrices
22
3+
use std::fmt::Debug;
34
use ndarray::prelude::*;
45
use ndarray::LinalgScalar;
56
use num_traits::float::Float;
@@ -17,7 +18,6 @@ use solve::ImplSolve;
1718
/// but does not assure that the matrix is square.
1819
/// If not square, `NotSquareError` will be thrown.
1920
pub trait SquareMatrix: Matrix {
20-
// fn lu(self) -> (Self, Self);
2121
// fn eig(self) -> (Self::Vector, Self);
2222
/// inverse matrix
2323
fn inv(self) -> Result<Self, LinalgError>;
@@ -38,13 +38,13 @@ pub trait SquareMatrix: Matrix {
3838
}
3939

4040
impl<A> SquareMatrix for Array<A, (Ix, Ix)>
41-
where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float
41+
where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float + Debug
4242
{
4343
fn inv(self) -> Result<Self, LinalgError> {
4444
try!(self.check_square());
4545
let (n, _) = self.size();
4646
let is_fortran_align = self.strides()[0] > self.strides()[1];
47-
let a = try!(ImplSolve::inv(n, self.into_raw_vec()));
47+
let a = try!(ImplSolve::inv(self.layout(), n, self.into_raw_vec()));
4848
let m = Array::from_vec(a).into_shape((n, n)).unwrap();
4949
if is_fortran_align {
5050
Ok(m)

0 commit comments

Comments
 (0)