Skip to content

Commit 37c492c

Browse files
committed
Re-implement SolveTriangular
1 parent 369e632 commit 37c492c

File tree

3 files changed

+73
-29
lines changed

3 files changed

+73
-29
lines changed

src/error.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use std::error;
44
use std::fmt;
5-
use ndarray::Ixs;
5+
use ndarray::{Ixs, ShapeError};
66

77
#[derive(Debug)]
88
pub struct LapackError {
@@ -68,6 +68,7 @@ pub enum LinalgError {
6868
NotSquare(NotSquareError),
6969
Lapack(LapackError),
7070
Stride(StrideError),
71+
Shape(ShapeError),
7172
}
7273

7374
impl fmt::Display for LinalgError {
@@ -76,6 +77,7 @@ impl fmt::Display for LinalgError {
7677
LinalgError::NotSquare(ref err) => err.fmt(f),
7778
LinalgError::Lapack(ref err) => err.fmt(f),
7879
LinalgError::Stride(ref err) => err.fmt(f),
80+
LinalgError::Shape(ref err) => err.fmt(f),
7981
}
8082
}
8183
}
@@ -86,6 +88,7 @@ impl error::Error for LinalgError {
8688
LinalgError::NotSquare(ref err) => err.description(),
8789
LinalgError::Lapack(ref err) => err.description(),
8890
LinalgError::Stride(ref err) => err.description(),
91+
LinalgError::Shape(ref err) => err.description(),
8992
}
9093
}
9194
}
@@ -107,3 +110,8 @@ impl From<StrideError> for LinalgError {
107110
LinalgError::Stride(err)
108111
}
109112
}
113+
impl From<ShapeError> for LinalgError {
114+
fn from(err: ShapeError) -> LinalgError {
115+
LinalgError::Shape(err)
116+
}
117+
}

src/solve.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ pub trait ImplSolve: Sized {
1818
b: Vec<Self>)
1919
-> Result<Vec<Self>, LapackError>;
2020
/// solve triangular linear problem
21-
fn solve_triangle(layout: Layout,
22-
uplo: u8,
23-
size: usize,
24-
a: &[Self],
25-
b: Vec<Self>,
26-
nrhs: i32)
27-
-> Result<Vec<Self>, LapackError>;
21+
fn solve_triangle<'a, 'b>(layout: Layout,
22+
uplo: u8,
23+
size: usize,
24+
a: &'a [Self],
25+
b: &'b mut [Self],
26+
nrhs: i32)
27+
-> Result<&'b mut [Self], LapackError>;
2828
}
2929

3030
macro_rules! impl_solve {
@@ -66,7 +66,7 @@ impl ImplSolve for $scalar {
6666
Err(From::from(info))
6767
}
6868
}
69-
fn solve_triangle(layout: Layout, uplo: u8, size: usize, a: &[Self], mut b: Vec<Self>, nrhs: i32) -> Result<Vec<Self>, LapackError> {
69+
fn solve_triangle<'a, 'b>(layout: Layout, uplo: u8, size: usize, a: &'a [Self], mut b: &'b mut [Self], nrhs: i32) -> Result<&'b mut [Self], LapackError> {
7070
let n = size as i32;
7171
let lda = n;
7272
let ldb = match layout {

src/triangular.rs

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

2-
use ndarray::{Ix1, Ix2, Array, RcArray, NdFloat, ArrayBase, DataMut};
3-
4-
use matrix::{Matrix, MFloat};
5-
use square::SquareMatrix;
6-
use error::LinalgError;
7-
use solve::ImplSolve;
2+
use ndarray::*;
3+
use super::matrix::{Matrix, MFloat};
4+
use super::square::SquareMatrix;
5+
use super::error::LinalgError;
6+
use super::solve::ImplSolve;
7+
use super::util::hstack;
88

99
pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {
1010
type Output;
@@ -14,34 +14,70 @@ pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {
1414
fn solve_lower(&self, Rhs) -> Result<Self::Output, LinalgError>;
1515
}
1616

17-
impl<A: MFloat> SolveTriangular<Array<A, Ix1>> for Array<A, Ix2> {
18-
type Output = Array<A, Ix1>;
19-
fn solve_upper(&self, b: Array<A, Ix1>) -> Result<Self::Output, LinalgError> {
17+
impl<A, S1, S2> SolveTriangular<ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix2>
18+
where A: MFloat,
19+
S1: Data<Elem = A>,
20+
S2: DataMut<Elem = A>,
21+
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
22+
{
23+
type Output = ArrayBase<S2, Ix1>;
24+
fn solve_upper(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
2025
let n = self.square_size()?;
2126
let layout = self.layout()?;
2227
let a = self.as_slice_memory_order().unwrap();
23-
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec(), 1)?;
24-
Ok(Array::from_vec(x))
28+
{
29+
let b_ = b.as_slice_memory_order_mut().unwrap();
30+
ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b_, 1)?;
31+
}
32+
Ok(b)
2533
}
26-
fn solve_lower(&self, b: Array<A, Ix1>) -> Result<Self::Output, LinalgError> {
34+
fn solve_lower(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
2735
let n = self.square_size()?;
2836
let layout = self.layout()?;
2937
let a = self.as_slice_memory_order().unwrap();
30-
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.into_raw_vec(), 1)?;
31-
Ok(Array::from_vec(x))
38+
{
39+
let b_ = b.as_slice_memory_order_mut().unwrap();
40+
ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b_, 1)?;
41+
}
42+
Ok(b)
43+
}
44+
}
45+
46+
impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase<S2, Ix2>> for ArrayBase<S1, Ix2>
47+
where A: MFloat,
48+
S1: Data<Elem = A>,
49+
S2: Data<Elem = A>,
50+
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
51+
{
52+
type Output = Array<A, Ix2>;
53+
fn solve_upper(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
54+
let mut xs = Vec::new();
55+
for b in bs.axis_iter(Axis(1)) {
56+
let x = self.solve_upper(b.to_owned())?;
57+
xs.push(x);
58+
}
59+
hstack(&xs).map_err(|e| e.into())
60+
}
61+
fn solve_lower(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
62+
let mut xs = Vec::new();
63+
for b in bs.axis_iter(Axis(1)) {
64+
let x = self.solve_lower(b.to_owned())?;
65+
xs.push(x);
66+
}
67+
hstack(&xs).map_err(|e| e.into())
3268
}
3369
}
3470

35-
impl<A: MFloat> SolveTriangular<RcArray<A, Ix1>> for RcArray<A, Ix2> {
36-
type Output = RcArray<A, Ix1>;
37-
fn solve_upper(&self, b: RcArray<A, Ix1>) -> Result<Self::Output, LinalgError> {
71+
impl<A: MFloat> SolveTriangular<RcArray<A, Ix2>> for RcArray<A, Ix2> {
72+
type Output = RcArray<A, Ix2>;
73+
fn solve_upper(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
3874
// XXX unnecessary clone
39-
let x = self.to_owned().solve_upper(b.into_owned())?;
75+
let x = self.to_owned().solve_upper(&b)?;
4076
Ok(x.into_shared())
4177
}
42-
fn solve_lower(&self, b: RcArray<A, Ix1>) -> Result<Self::Output, LinalgError> {
78+
fn solve_lower(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
4379
// XXX unnecessary clone
44-
let x = self.to_owned().solve_lower(b.into_owned())?;
80+
let x = self.to_owned().solve_lower(&b)?;
4581
Ok(x.into_shared())
4682
}
4783
}

0 commit comments

Comments
 (0)