Skip to content

Commit b4fbce6

Browse files
committed
Merge branch 'trianglar'
2 parents 86f38a7 + a144a54 commit b4fbce6

File tree

6 files changed

+188
-20
lines changed

6 files changed

+188
-20
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pub mod vector;
3535
pub mod matrix;
3636
pub mod square;
3737
pub mod hermite;
38+
pub mod triangular;
3839

3940
pub mod qr;
4041
pub mod svd;

src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pub use vector::Vector;
33
pub use matrix::Matrix;
44
pub use square::SquareMatrix;
55
pub use hermite::HermiteMatrix;
6+
pub use triangular::TriangularMatrix;

src/solve.rs

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,29 @@ use std::cmp::min;
66
use error::LapackError;
77

88
pub trait ImplSolve: Sized {
9-
fn inv(layout: Layout, size: usize, a: Vec<Self>) -> Result<Vec<Self>, LapackError>;
9+
/// execute LU decomposition
1010
fn lu(layout: Layout, m: usize, n: usize, a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError>;
11+
/// calc inverse matrix with LU factorized matrix
12+
fn inv(layout: Layout, size: usize, a: Vec<Self>, ipiv: &Vec<i32>) -> Result<Vec<Self>, LapackError>;
13+
/// solve linear problem with LU factorized matrix
14+
fn solve(layout: Layout,
15+
size: usize,
16+
a: &Vec<Self>,
17+
ipiv: &Vec<i32>,
18+
b: Vec<Self>)
19+
-> Result<Vec<Self>, LapackError>;
20+
/// solve triangular linear problem
21+
fn solve_triangle(layout: Layout,
22+
uplo: u8,
23+
size: usize,
24+
a: &[Self],
25+
b: Vec<Self>)
26+
-> Result<Vec<Self>, LapackError>;
1127
}
1228

1329
macro_rules! impl_solve {
14-
($scalar:ty, $getrf:path, $getri:path, $laswp:path) => {
30+
($scalar:ty, $getrf:path, $getri:path, $getrs:path, $trtrs:path) => {
1531
impl ImplSolve for $scalar {
16-
fn inv(layout: Layout, size: usize, mut a: Vec<Self>) -> Result<Vec<Self>, LapackError> {
17-
let n = size as i32;
18-
let lda = n;
19-
let mut ipiv = vec![0; size];
20-
let info = $getrf(layout, n, n, &mut a, lda, &mut ipiv);
21-
if info != 0 {
22-
return Err(From::from(info));
23-
}
24-
let info = $getri(layout, n, &mut a, lda, &mut ipiv);
25-
if info == 0 {
26-
Ok(a)
27-
} else {
28-
Err(From::from(info))
29-
}
30-
}
3132
fn lu(layout: Layout, m: usize, n: usize, mut a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError> {
3233
let m = m as i32;
3334
let n = n as i32;
@@ -44,8 +45,42 @@ impl ImplSolve for $scalar {
4445
Err(From::from(info))
4546
}
4647
}
48+
fn inv(layout: Layout, size: usize, mut a: Vec<Self>, ipiv: &Vec<i32>) -> Result<Vec<Self>, LapackError> {
49+
let n = size as i32;
50+
let lda = n;
51+
let info = $getri(layout, n, &mut a, lda, &ipiv);
52+
if info == 0 {
53+
Ok(a)
54+
} else {
55+
Err(From::from(info))
56+
}
57+
}
58+
fn solve(layout: Layout, size: usize, a: &Vec<Self>, ipiv: &Vec<i32>, mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
59+
let n = size as i32;
60+
let lda = n;
61+
let info = $getrs(layout, 'N' as u8, n, 1, a, lda, ipiv, &mut b, n);
62+
if info == 0 {
63+
Ok(b)
64+
} else {
65+
Err(From::from(info))
66+
}
67+
}
68+
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &[Self], mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
69+
let n = size as i32;
70+
let lda = n;
71+
let ldb = match layout {
72+
Layout::ColumnMajor => n,
73+
Layout::RowMajor => 1,
74+
};
75+
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, 1, a, lda, &mut b, ldb);
76+
if info == 0 {
77+
Ok(b)
78+
} else {
79+
Err(From::from(info))
80+
}
81+
}
4782
}
4883
}} // end macro_rules
4984

50-
impl_solve!(f64, dgetrf, dgetri, dlaswp);
51-
impl_solve!(f32, sgetrf, sgetri, slaswp);
85+
impl_solve!(f64, dgetrf, dgetri, dgetrs, dtrtrs);
86+
impl_solve!(f32, sgetrf, sgetri, sgetrs, strtrs);

src/square.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ impl<A> SquareMatrix for Array<A, Ix2>
4444
self.check_square()?;
4545
let (n, _) = self.size();
4646
let layout = self.layout()?;
47-
let a = ImplSolve::inv(layout, n, self.into_raw_vec())?;
47+
let (ipiv, a) = ImplSolve::lu(layout, n, n, self.into_raw_vec())?;
48+
let a = ImplSolve::inv(layout, n, a, &ipiv)?;
4849
let m = Array::from_vec(a).into_shape((n, n)).unwrap();
4950
match layout {
5051
Layout::RowMajor => Ok(m),

src/triangular.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
use ndarray::{Ix2, Array, LinalgScalar};
3+
use std::fmt::Debug;
4+
use num_traits::float::Float;
5+
6+
use matrix::Matrix;
7+
use square::SquareMatrix;
8+
use error::LinalgError;
9+
use qr::ImplQR;
10+
use svd::ImplSVD;
11+
use norm::ImplNorm;
12+
use solve::ImplSolve;
13+
14+
pub trait TriangularMatrix: Matrix + SquareMatrix {
15+
/// solve a triangular system with upper triangular matrix
16+
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17+
/// solve a triangular system with lower triangular matrix
18+
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
19+
}
20+
21+
impl<A> TriangularMatrix for Array<A, Ix2>
22+
where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float + Debug
23+
{
24+
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
25+
self.check_square()?;
26+
let (n, _) = self.size();
27+
let layout = self.layout()?;
28+
let a = self.as_slice_memory_order().unwrap();
29+
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec())?;
30+
Ok(Array::from_vec(x))
31+
}
32+
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
33+
self.check_square()?;
34+
let (n, _) = self.size();
35+
let layout = self.layout()?;
36+
let a = self.as_slice_memory_order().unwrap();
37+
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.into_raw_vec())?;
38+
Ok(Array::from_vec(x))
39+
}
40+
}

tests/triangular.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
2+
extern crate rand;
3+
extern crate ndarray;
4+
extern crate ndarray_rand;
5+
extern crate ndarray_linalg;
6+
7+
use ndarray::prelude::*;
8+
use ndarray_linalg::prelude::*;
9+
use rand::distributions::*;
10+
use ndarray_rand::RandomExt;
11+
12+
fn all_close<D: Dimension>(a: Array<f64, D>, b: Array<f64, D>) {
13+
if !a.all_close(&b, 1.0e-7) {
14+
panic!("\nTwo matrices are not equal:\na = \n{:?}\nb = \n{:?}\n",
15+
a,
16+
b);
17+
}
18+
}
19+
20+
#[test]
21+
fn solve_upper() {
22+
let r_dist = Range::new(0., 1.);
23+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
24+
for ((i, j), val) in a.indexed_iter_mut() {
25+
if i > j {
26+
*val = 0.0;
27+
}
28+
}
29+
println!("a = \n{:?}", &a);
30+
let b = Array::<f64, _>::random(3, r_dist);
31+
println!("b = \n{:?}", &b);
32+
let x = a.solve_upper(b.clone()).unwrap();
33+
println!("x = \n{:?}", &x);
34+
println!("Ax = \n{:?}", a.dot(&x));
35+
all_close(a.dot(&x), b);
36+
}
37+
38+
#[test]
39+
fn solve_upper_t() {
40+
let r_dist = Range::new(0., 1.);
41+
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
42+
for ((i, j), val) in a.indexed_iter_mut() {
43+
if i > j {
44+
*val = 0.0;
45+
}
46+
}
47+
println!("a = \n{:?}", &a);
48+
let b = Array::<f64, _>::random(3, r_dist);
49+
println!("b = \n{:?}", &b);
50+
let x = a.solve_upper(b.clone()).unwrap();
51+
println!("x = \n{:?}", &x);
52+
println!("Ax = \n{:?}", a.dot(&x));
53+
all_close(a.dot(&x), b);
54+
}
55+
56+
#[test]
57+
fn solve_lower() {
58+
let r_dist = Range::new(0., 1.);
59+
let mut a = Array::<f64, _>::random((3, 3), r_dist);
60+
for ((i, j), val) in a.indexed_iter_mut() {
61+
if i < j {
62+
*val = 0.0;
63+
}
64+
}
65+
println!("a = \n{:?}", &a);
66+
let b = Array::<f64, _>::random(3, r_dist);
67+
println!("b = \n{:?}", &b);
68+
let x = a.solve_lower(b.clone()).unwrap();
69+
println!("x = \n{:?}", &x);
70+
println!("Ax = \n{:?}", a.dot(&x));
71+
all_close(a.dot(&x), b);
72+
}
73+
74+
#[test]
75+
fn solve_lower_t() {
76+
let r_dist = Range::new(0., 1.);
77+
let mut a = Array::<f64, _>::random((3, 3), r_dist).reversed_axes();
78+
for ((i, j), val) in a.indexed_iter_mut() {
79+
if i < j {
80+
*val = 0.0;
81+
}
82+
}
83+
println!("a = \n{:?}", &a);
84+
let b = Array::<f64, _>::random(3, r_dist);
85+
println!("b = \n{:?}", &b);
86+
let x = a.solve_lower(b.clone()).unwrap();
87+
println!("x = \n{:?}", &x);
88+
println!("Ax = \n{:?}", a.dot(&x));
89+
all_close(a.dot(&x), b);
90+
}

0 commit comments

Comments
 (0)