Skip to content

Commit 964e987

Browse files
committed
Rewrite TriangularMatrix trait
1 parent d087fb7 commit 964e987

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

src/triangular.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,57 @@
11

2-
use ndarray::{Ix2, Array, RcArray, NdFloat, ArrayBase, DataMut};
2+
use ndarray::{Data, Ix1, Ix2, Array, RcArray, NdFloat, ArrayBase, DataMut};
33

44
use matrix::{Matrix, MFloat};
55
use square::SquareMatrix;
66
use error::LinalgError;
77
use solve::ImplSolve;
88

9-
pub trait TriangularMatrix: Matrix + SquareMatrix {
9+
pub trait TriangularMatrix<Rhs>: Matrix + SquareMatrix {
10+
type Output;
1011
/// solve a triangular system with upper triangular matrix
11-
fn solve_upper(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
12+
fn solve_upper(&self, &Rhs) -> Result<Self::Output, LinalgError>;
1213
/// solve a triangular system with lower triangular matrix
13-
fn solve_lower(&self, Self::Vector) -> Result<Self::Vector, LinalgError>;
14+
fn solve_lower(&self, &Rhs) -> Result<Self::Output, LinalgError>;
1415
}
1516

16-
impl<A: MFloat> TriangularMatrix for Array<A, Ix2> {
17-
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
17+
impl<A, S> TriangularMatrix<ArrayBase<S, Ix1>> for Array<A, Ix2>
18+
where A: MFloat,
19+
S: Data<Elem = A>
20+
{
21+
type Output = Array<A, Ix1>;
22+
23+
fn solve_upper(&self, b: &ArrayBase<S, Ix1>) -> Result<Self::Output, LinalgError> {
1824
self.check_square()?;
1925
let (n, _) = self.size();
2026
let layout = self.layout()?;
2127
let a = self.as_slice_memory_order().unwrap();
22-
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.into_raw_vec(), 1)?;
28+
let x = ImplSolve::solve_triangle(layout, 'U' as u8, n, a, b.to_owned().into_raw_vec(), 1)?;
2329
Ok(Array::from_vec(x))
2430
}
25-
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
31+
fn solve_lower(&self, b: &ArrayBase<S, Ix1>) -> Result<Self::Output, LinalgError> {
2632
self.check_square()?;
2733
let (n, _) = self.size();
2834
let layout = self.layout()?;
2935
let a = self.as_slice_memory_order().unwrap();
30-
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.into_raw_vec(), 1)?;
36+
let x = ImplSolve::solve_triangle(layout, 'L' as u8, n, a, b.to_owned().into_raw_vec(), 1)?;
3137
Ok(Array::from_vec(x))
3238
}
3339
}
3440

35-
impl<A: MFloat> TriangularMatrix for RcArray<A, Ix2> {
36-
fn solve_upper(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
41+
impl<A, S> TriangularMatrix<ArrayBase<S, Ix1>> for RcArray<A, Ix2>
42+
where A: MFloat,
43+
S: Data<Elem = A>
44+
{
45+
type Output = RcArray<A, Ix1>;
46+
47+
fn solve_upper(&self, b: &ArrayBase<S, Ix1>) -> Result<Self::Output, LinalgError> {
3748
// XXX unnecessary clone
38-
let x = self.to_owned().solve_upper(b.to_owned())?;
49+
let x = self.to_owned().solve_upper(&b)?;
3950
Ok(x.into_shared())
4051
}
41-
fn solve_lower(&self, b: Self::Vector) -> Result<Self::Vector, LinalgError> {
52+
fn solve_lower(&self, b: &ArrayBase<S, Ix1>) -> Result<Self::Output, LinalgError> {
4253
// XXX unnecessary clone
43-
let x = self.to_owned().solve_lower(b.to_owned())?;
54+
let x = self.to_owned().solve_lower(&b)?;
4455
Ok(x.into_shared())
4556
}
4657
}

0 commit comments

Comments
 (0)