Skip to content

Commit 53ccc1e

Browse files
committed
Sync
1 parent 161cb13 commit 53ccc1e

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

src/solve.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@ pub trait ImplSolve: Sized {
1717
ipiv: &Vec<i32>,
1818
b: Vec<Self>)
1919
-> Result<Vec<Self>, LapackError>;
20+
/// solve triangular linear problem
21+
fn solve_triangle(layout: Layout,
22+
uplo: u8,
23+
size: usize,
24+
a: &Vec<Self>,
25+
b: Vec<Self>)
26+
-> Result<Vec<Self>, LapackError>;
2027
}
2128

2229
macro_rules! impl_solve {
23-
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
30+
($scalar:ty, $getrf:path, $getri:path, $getrs:path, $trtrs:path) => {
2431
impl ImplSolve for $scalar {
2532
fn lu(layout: Layout, m: usize, n: usize, mut a: Vec<Self>) -> Result<(Vec<i32>, Vec<Self>), LapackError> {
2633
let m = m as i32;
@@ -51,7 +58,17 @@ impl ImplSolve for $scalar {
5158
fn solve(layout: Layout, size: usize, a: &Vec<Self>, ipiv: &Vec<i32>, mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
5259
let n = size as i32;
5360
let lda = n;
54-
let info = $getrs(layout, 'N' as u8, n, 1, a, lda, &ipiv, &mut b, n);
61+
let info = $getrs(layout, 'N' as u8, n, 1, a, lda, ipiv, &mut b, n);
62+
if info == 0 {
63+
Ok(b)
64+
} else {
65+
Err(From::from(info))
66+
}
67+
}
68+
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &Vec<Self>, mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
69+
let n = size as i32;
70+
let lda = n;
71+
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, 1, a, lda, &mut b, n);
5572
if info == 0 {
5673
Ok(b)
5774
} else {
@@ -61,5 +78,5 @@ impl ImplSolve for $scalar {
6178
}
6279
}} // end macro_rules
6380

64-
impl_solve!(f64, dgetrf, dgetri, dgetrs);
65-
impl_solve!(f32, sgetrf, sgetri, sgetrs);
81+
impl_solve!(f64, dgetrf, dgetri, dgetrs, dtrtrs);
82+
impl_solve!(f32, sgetrf, sgetri, sgetrs, strtrs);

src/triangular.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,36 @@
11

2+
use ndarray::{Ix2, Array, LinalgScalar};
3+
use std::fmt::Debug;
4+
use num_traits::float::Float;
5+
26
use matrix::Matrix;
37
use square::SquareMatrix;
8+
use error::LinalgError;
9+
use qr::ImplQR;
10+
use svd::ImplSVD;
11+
use norm::ImplNorm;
12+
use solve::ImplSolve;
413

514
pub trait TriangularMatrix: Matrix + SquareMatrix {
6-
fn solve_upper(&self, Self::Vector) -> Self::Vector;
7-
fn solve_lower(&self, Self::Vector) -> Self::Vector;
15+
/// solve a triangular system with upper triangular matrix
16+
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
17+
/// solve a triangular system with lower triangular matrix
18+
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
19+
}
20+
21+
impl<A> TriangularMatrix for Array<A, Ix2>
22+
where A: ImplQR + ImplNorm + ImplSVD + ImplSolve + LinalgScalar + Float + Debug
23+
{
24+
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
25+
self.check_square()?;
26+
let (n, _) = self.size();
27+
let layout = self.layout()?;
28+
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, self.as_slice().unwrap(), b)?;
29+
}
30+
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
31+
self.check_square()?;
32+
let (n, _) = self.size();
33+
let layout = self.layout()?;
34+
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, self.as_slice().unwrap(), b)?;
35+
}
836
}

0 commit comments

Comments
 (0)