Skip to content

Commit d6b2bef

Browse files
committed
Merge branch 'solve_tri'
2 parents fc31024 + 959d5ac commit d6b2bef

File tree

7 files changed

+150
-39
lines changed

7 files changed

+150
-39
lines changed

src/error.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use std::error;
44
use std::fmt;
5-
use ndarray::Ixs;
5+
use ndarray::{Ixs, ShapeError};
66

77
#[derive(Debug)]
88
pub struct LapackError {
@@ -68,6 +68,7 @@ pub enum LinalgError {
6868
NotSquare(NotSquareError),
6969
Lapack(LapackError),
7070
Stride(StrideError),
71+
Shape(ShapeError),
7172
}
7273

7374
impl fmt::Display for LinalgError {
@@ -76,6 +77,7 @@ impl fmt::Display for LinalgError {
7677
LinalgError::NotSquare(ref err) => err.fmt(f),
7778
LinalgError::Lapack(ref err) => err.fmt(f),
7879
LinalgError::Stride(ref err) => err.fmt(f),
80+
LinalgError::Shape(ref err) => err.fmt(f),
7981
}
8082
}
8183
}
@@ -86,6 +88,7 @@ impl error::Error for LinalgError {
8688
LinalgError::NotSquare(ref err) => err.description(),
8789
LinalgError::Lapack(ref err) => err.description(),
8890
LinalgError::Stride(ref err) => err.description(),
91+
LinalgError::Shape(ref err) => err.description(),
8992
}
9093
}
9194
}
@@ -107,3 +110,8 @@ impl From<StrideError> for LinalgError {
107110
LinalgError::Stride(err)
108111
}
109112
}
113+
impl From<ShapeError> for LinalgError {
114+
fn from(err: ShapeError) -> LinalgError {
115+
LinalgError::Shape(err)
116+
}
117+
}

src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//!
44
//! They are implemented as traits,
55
//! [Matrix](matrix/trait.Matrix.html), [SquareMatrix](square/trait.SquareMatrix.html),
6-
//! [TriangularMatrix](triangular/trait.TriangularMatrix.html), and
6+
//! [SolveTriangular](triangular/trait.SolveTriangular.html), and
77
//! [HermiteMatrix](hermite/trait.HermiteMatrix.html)
88
//!
99
//! Matrix
@@ -21,10 +21,10 @@
2121
//! - [trace of matrix](square/trait.SquareMatrix.html#tymethod.trace)
2222
//! - [WIP] eigenvalue
2323
//!
24-
//! TriangularMatrix
24+
//! SolveTriangular
2525
//! ------------------
26-
//! - [solve linear problem with upper triangular matrix](triangular/trait.TriangularMatrix.html#tymethod.solve_upper)
27-
//! - [solve linear problem with lower triangular matrix](triangular/trait.TriangularMatrix.html#tymethod.solve_lower)
26+
//! - [solve linear problem with upper triangular matrix](triangular/trait.SolveTriangular.html#tymethod.solve_upper)
27+
//! - [solve linear problem with lower triangular matrix](triangular/trait.SolveTriangular.html#tymethod.solve_lower)
2828
//!
2929
//! HermiteMatrix
3030
//! --------------

src/prelude.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ pub use vector::Vector;
33
pub use matrix::Matrix;
44
pub use square::SquareMatrix;
55
pub use hermite::HermiteMatrix;
6-
pub use triangular::{TriangularMatrix, drop_lower, drop_upper};
6+
pub use triangular::{SolveTriangular, drop_lower, drop_upper};
77
pub use util::{all_close_l1, all_close_l2, all_close_max};

src/solve.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ pub trait ImplSolve: Sized {
1818
b: Vec<Self>)
1919
-> Result<Vec<Self>, LapackError>;
2020
/// solve triangular linear problem
21-
fn solve_triangle(layout: Layout,
22-
uplo: u8,
23-
size: usize,
24-
a: &[Self],
25-
b: Vec<Self>)
26-
-> Result<Vec<Self>, LapackError>;
21+
fn solve_triangle<'a, 'b>(layout: Layout,
22+
uplo: u8,
23+
size: usize,
24+
a: &'a [Self],
25+
b: &'b mut [Self],
26+
nrhs: i32)
27+
-> Result<&'b mut [Self], LapackError>;
2728
}
2829

2930
macro_rules! impl_solve {
@@ -65,14 +66,14 @@ impl ImplSolve for $scalar {
6566
Err(From::from(info))
6667
}
6768
}
68-
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &[Self], mut b: Vec<Self>) -> Result<Vec<Self>, LapackError> {
69+
fn solve_triangle<'a, 'b>(layout: Layout, uplo: u8, size: usize, a: &'a [Self], mut b: &'b mut [Self], nrhs: i32) -> Result<&'b mut [Self], LapackError> {
6970
let n = size as i32;
7071
let lda = n;
7172
let ldb = match layout {
7273
Layout::ColumnMajor => n,
7374
Layout::RowMajor => 1,
7475
};
75-
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, 1, a, lda, &mut b, ldb);
76+
let info = $trtrs(layout, uplo, 'N' as u8, 'N' as u8, n, nrhs, a, lda, &mut b, ldb);
7677
if info == 0 {
7778
Ok(b)
7879
} else {

src/square.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ pub trait SquareMatrix: Matrix {
3030
})
3131
}
3232
}
33+
/// test matrix is square and return its size
34+
fn square_size(&self) -> Result<usize, NotSquareError> {
35+
self.check_square()?;
36+
let (n, _) = self.size();
37+
Ok(n)
38+
}
3339
}
3440

3541
fn trace<A: MFloat, S>(a: &ArrayBase<S, Ix2>) -> A

src/triangular.rs

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,87 @@
11

2-
use ndarray::{Ix2, Array, RcArray, NdFloat, ArrayBase, DataMut};
2+
use ndarray::*;
3+
use super::matrix::{Matrix, MFloat};
4+
use super::square::SquareMatrix;
5+
use super::error::LinalgError;
6+
use super::solve::ImplSolve;
7+
use super::util::hstack;
38

4-
use matrix::{Matrix, MFloat};
5-
use square::SquareMatrix;
6-
use error::LinalgError;
7-
use solve::ImplSolve;
8-
9-
pub trait TriangularMatrix: Matrix + SquareMatrix {
9+
pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {
10+
type Output;
1011
/// solve a triangular system with upper triangular matrix
11-
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
12+
fn solve_upper(&self, Rhs) -> Result<Self::Output, LinalgError>;
1213
/// solve a triangular system with lower triangular matrix
13-
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
14+
fn solve_lower(&self, Rhs) -> Result<Self::Output, LinalgError>;
1415
}
1516

16-
impl<A: MFloat> TriangularMatrix for Array<A, Ix2> {
17-
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
18-
self.check_square()?;
19-
let (n, _) = self.size();
17+
impl<A, S1, S2> SolveTriangular<ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix2>
18+
where A: MFloat,
19+
S1: Data<Elem = A>,
20+
S2: DataMut<Elem = A>,
21+
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
22+
{
23+
type Output = ArrayBase<S2, Ix1>;
24+
fn solve_upper(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
25+
let n = self.square_size()?;
2026
let layout = self.layout()?;
2127
let a = self.as_slice_memory_order().unwrap();
22-
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec())?;
23-
Ok(Array::from_vec(x))
28+
ImplSolve::solve_triangle(layout,
29+
'U' as u8,
30+
n,
31+
a,
32+
b.as_slice_memory_order_mut().unwrap(),
33+
1)?;
34+
Ok(b)
2435
}
25-
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
26-
self.check_square()?;
27-
let (n, _) = self.size();
36+
fn solve_lower(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
37+
let n = self.square_size()?;
2838
let layout = self.layout()?;
2939
let a = self.as_slice_memory_order().unwrap();
30-
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.into_raw_vec())?;
31-
Ok(Array::from_vec(x))
40+
ImplSolve::solve_triangle(layout,
41+
'L' as u8,
42+
n,
43+
a,
44+
b.as_slice_memory_order_mut().unwrap(),
45+
1)?;
46+
Ok(b)
47+
}
48+
}
49+
50+
impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase<S2, Ix2>> for ArrayBase<S1, Ix2>
51+
where A: MFloat,
52+
S1: Data<Elem = A>,
53+
S2: Data<Elem = A>,
54+
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
55+
{
56+
type Output = Array<A, Ix2>;
57+
fn solve_upper(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
58+
let mut xs = Vec::new();
59+
for b in bs.axis_iter(Axis(1)) {
60+
let x = self.solve_upper(b.to_owned())?;
61+
xs.push(x);
62+
}
63+
hstack(&xs).map_err(|e| e.into())
64+
}
65+
fn solve_lower(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
66+
let mut xs = Vec::new();
67+
for b in bs.axis_iter(Axis(1)) {
68+
let x = self.solve_lower(b.to_owned())?;
69+
xs.push(x);
70+
}
71+
hstack(&xs).map_err(|e| e.into())
3272
}
3373
}
3474

35-
impl<A: MFloat> TriangularMatrix for RcArray<A, Ix2> {
36-
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
75+
impl<A: MFloat> SolveTriangular<RcArray<A, Ix2>> for RcArray<A, Ix2> {
76+
type Output = RcArray<A, Ix2>;
77+
fn solve_upper(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
3778
// XXX unnecessary clone
38-
let x = self.to_owned().solve_upper(b.to_owned())?;
79+
let x = self.to_owned().solve_upper(&b)?;
3980
Ok(x.into_shared())
4081
}
41-
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
82+
fn solve_lower(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
4283
// XXX unnecessary clone
43-
let x = self.to_owned().solve_lower(b.to_owned())?;
84+
let x = self.to_owned().solve_lower(&b)?;
4485
Ok(x.into_shared())
4586
}
4687
}

tests/triangular.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,58 @@ mod $modname {
6363

6464
impl_test!(owned, Array<f64, _>::random);
6565
impl_test!(shared, RcArray<f64, _>::random);
66+
67+
macro_rules! impl_test_2d {
68+
($modname:ident, $drop:path, $solve:ident) => {
69+
mod $modname {
70+
use super::random_owned;
71+
use ndarray_linalg::prelude::*;
72+
#[test]
73+
fn solve_tt() {
74+
let a = $drop(random_owned(3, 3, true));
75+
println!("a = \n{:?}", &a);
76+
let b = random_owned(3, 2, true);
77+
println!("b = \n{:?}", &b);
78+
let x = a.$solve(&b).unwrap();
79+
println!("x = \n{:?}", &x);
80+
println!("Ax = \n{:?}", a.dot(&x));
81+
all_close_l2(&a.dot(&x), &b, 1e-7).unwrap();
82+
}
83+
#[test]
84+
fn solve_tf() {
85+
let a = $drop(random_owned(3, 3, true));
86+
println!("a = \n{:?}", &a);
87+
let b = random_owned(3, 2, false);
88+
println!("b = \n{:?}", &b);
89+
let x = a.$solve(&b).unwrap();
90+
println!("x = \n{:?}", &x);
91+
println!("Ax = \n{:?}", a.dot(&x));
92+
all_close_l2(&a.dot(&x), &b, 1e-7).unwrap();
93+
}
94+
#[test]
95+
fn solve_ft() {
96+
let a = $drop(random_owned(3, 3, false));
97+
println!("a = \n{:?}", &a);
98+
let b = random_owned(3, 2, true);
99+
println!("b = \n{:?}", &b);
100+
let x = a.$solve(&b).unwrap();
101+
println!("x = \n{:?}", &x);
102+
println!("Ax = \n{:?}", a.dot(&x));
103+
all_close_l2(&a.dot(&x), &b, 1e-7).unwrap();
104+
}
105+
#[test]
106+
fn solve_ff() {
107+
let a = $drop(random_owned(3, 3, false));
108+
println!("a = \n{:?}", &a);
109+
let b = random_owned(3, 2, false);
110+
println!("b = \n{:?}", &b);
111+
let x = a.$solve(&b).unwrap();
112+
println!("x = \n{:?}", &x);
113+
println!("Ax = \n{:?}", a.dot(&x));
114+
all_close_l2(&a.dot(&x), &b, 1e-7).unwrap();
115+
}
116+
}
117+
}} // impl_test_2d
118+
119+
impl_test_2d!(lower2d, drop_upper, solve_lower);
120+
impl_test_2d!(upper2d, drop_lower, solve_upper);

0 commit comments

Comments
 (0)