Skip to content

Commit c3ec314

Browse files
committed
impl AllocatedArrayMut for ArrayBase1 to use triangular
1 parent e996503 commit c3ec314

File tree

4 files changed

+64
-113
lines changed

4 files changed

+64
-113
lines changed

src/impl2/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ pub use self::svd::*;
1313
pub use self::solve::*;
1414
pub use self::cholesky::*;
1515
pub use self::eigh::*;
16+
pub use self::triangular::*;
1617

1718
use super::error::*;
1819

19-
pub trait LapackScalar: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {}
20-
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ {}
20+
pub trait LapackScalar
21+
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ {
22+
}
23+
impl<A> LapackScalar for A where A: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ {}
2124

2225
pub fn into_result<T>(info: i32, val: T) -> Result<T> {
2326
if info == 0 {

src/impl2/triangular.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,25 @@ use super::{UPLO, Transpose, into_result};
88

99
#[derive(Debug, Clone, Copy)]
1010
#[repr(u8)]
11-
pub enum TriangularDiag {
11+
pub enum Diag {
1212
Unit = b'U',
1313
NonUnit = b'N',
1414
}
1515

1616
pub trait Triangular_: Sized {
17-
fn inv_triangular(l: Layout, UPLO, TriangularDiag, a: &mut [Self]) -> Result<()>;
18-
fn solve_triangular(al: Layout, bl: Layout, UPLO, TriangularDiag, a: &[Self], b: &mut [Self]) -> Result<()>;
17+
fn inv_triangular(l: Layout, UPLO, Diag, a: &mut [Self]) -> Result<()>;
18+
fn solve_triangular(al: Layout, bl: Layout, UPLO, Diag, a: &[Self], b: &mut [Self]) -> Result<()>;
1919
}
2020

2121
impl Triangular_ for f64 {
22-
fn inv_triangular(l: Layout, uplo: UPLO, diag: TriangularDiag, a: &mut [Self]) -> Result<()> {
22+
fn inv_triangular(l: Layout, uplo: UPLO, diag: Diag, a: &mut [Self]) -> Result<()> {
2323
let (n, _) = l.size();
2424
let lda = l.lda();
2525
let info = c::dtrtri(l.lapacke_layout(), uplo as u8, diag as u8, n, a, lda);
2626
into_result(info, ())
2727
}
2828

29-
fn solve_triangular(al: Layout,
30-
bl: Layout,
31-
uplo: UPLO,
32-
diag: TriangularDiag,
33-
a: &[Self],
34-
mut b: &mut [Self])
35-
-> Result<()> {
29+
fn solve_triangular(al: Layout, bl: Layout, uplo: UPLO, diag: Diag, a: &[Self], mut b: &mut [Self]) -> Result<()> {
3630
let (n, _) = al.size();
3731
let lda = al.lda();
3832
let nrhs = bl.len();

src/layout.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ impl Layout {
5353
}
5454

5555
pub trait AllocatedArray {
56-
type Scalar;
56+
type Elem;
5757
fn layout(&self) -> Result<Layout>;
5858
fn square_layout(&self) -> Result<Layout>;
59-
fn as_allocated(&self) -> Result<&[Self::Scalar]>;
59+
fn as_allocated(&self) -> Result<&[Self::Elem]>;
6060
}
6161

6262
pub trait AllocatedArrayMut: AllocatedArray {
63-
fn as_allocated_mut(&mut self) -> Result<&mut [Self::Scalar]>;
63+
fn as_allocated_mut(&mut self) -> Result<&mut [Self::Elem]>;
6464
}
6565

6666
impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
6767
where S: Data<Elem = A>
6868
{
69-
type Scalar = A;
69+
type Elem = A;
7070

7171
fn layout(&self) -> Result<Layout> {
7272
let strides = self.strides();
@@ -91,20 +91,45 @@ impl<A, S> AllocatedArray for ArrayBase<S, Ix2>
9191
}
9292

9393
fn as_allocated(&self) -> Result<&[A]> {
94-
let slice = self.as_slice_memory_order().ok_or(MemoryContError::new())?;
95-
Ok(slice)
94+
Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?)
9695
}
9796
}
9897

9998
impl<A, S> AllocatedArrayMut for ArrayBase<S, Ix2>
10099
where S: DataMut<Elem = A>
101100
{
102101
fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
103-
let slice = self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?;
104-
Ok(slice)
102+
Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?)
105103
}
106104
}
107105

106+
impl<A, S> AllocatedArray for ArrayBase<S, Ix1>
107+
where S: Data<Elem = A>
108+
{
109+
type Elem = A;
110+
111+
fn layout(&self) -> Result<Layout> {
112+
Ok(Layout::F((self.len() as i32, 1)))
113+
}
114+
115+
fn square_layout(&self) -> Result<Layout> {
116+
Err(NotSquareError::new(self.len() as i32, 1).into())
117+
}
118+
119+
fn as_allocated(&self) -> Result<&[A]> {
120+
Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?)
121+
}
122+
}
123+
124+
impl<A, S> AllocatedArrayMut for ArrayBase<S, Ix1>
125+
where S: DataMut<Elem = A>
126+
{
127+
fn as_allocated_mut(&mut self) -> Result<&mut [A]> {
128+
Ok(self.as_slice_memory_order_mut().ok_or(MemoryContError::new())?)
129+
}
130+
}
131+
132+
108133
pub fn reconstruct<A, S>(l: Layout, a: Vec<A>) -> Result<ArrayBase<S, Ix2>>
109134
where S: DataOwned<Elem = A>
110135
{

src/triangular.rs

Lines changed: 21 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,114 +2,43 @@
22
33
use ndarray::*;
44
use num_traits::Zero;
5-
use super::impl2::UPLO;
65

7-
use super::matrix::{Matrix, MFloat};
8-
use super::square::SquareMatrix;
9-
use super::error::LinalgError;
10-
use super::util::hstack;
11-
use super::impls::solve::ImplSolve;
6+
use super::layout::*;
7+
use super::error::*;
8+
use super::impl2::*;
129

13-
pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {
10+
/// solve a triangular system with upper triangular matrix
11+
pub trait SolveTriangular<Rhs> {
1412
type Output;
15-
/// solve a triangular system with upper triangular matrix
16-
fn solve_upper(&self, Rhs) -> Result<Self::Output, LinalgError>;
17-
/// solve a triangular system with lower triangular matrix
18-
fn solve_lower(&self, Rhs) -> Result<Self::Output, LinalgError>;
13+
fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result<Self::Output>;
1914
}
2015

21-
impl<A, S1, S2> SolveTriangular<ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix2>
22-
where A: MFloat,
23-
S1: Data<Elem = A>,
24-
S2: DataMut<Elem = A>,
25-
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
16+
impl<A, S, V> SolveTriangular<V> for ArrayBase<S, Ix2>
17+
where A: LapackScalar,
18+
S: Data<Elem = A>,
19+
V: AllocatedArrayMut<Elem = A>
2620
{
27-
type Output = ArrayBase<S2, Ix1>;
28-
fn solve_upper(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
29-
let n = self.square_size()?;
30-
let layout = self.layout()?;
31-
let a = self.as_slice_memory_order().unwrap();
32-
ImplSolve::solve_triangle(layout,
33-
'U' as u8,
34-
n,
35-
a,
36-
b.as_slice_memory_order_mut().unwrap(),
37-
1)?;
38-
Ok(b)
39-
}
40-
fn solve_lower(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> {
41-
let n = self.square_size()?;
42-
let layout = self.layout()?;
43-
let a = self.as_slice_memory_order().unwrap();
44-
ImplSolve::solve_triangle(layout,
45-
'L' as u8,
46-
n,
47-
a,
48-
b.as_slice_memory_order_mut().unwrap(),
49-
1)?;
50-
Ok(b)
51-
}
52-
}
53-
54-
impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase<S2, Ix2>> for ArrayBase<S1, Ix2>
55-
where A: MFloat,
56-
S1: Data<Elem = A>,
57-
S2: Data<Elem = A>,
58-
ArrayBase<S1, Ix2>: Matrix + SquareMatrix
59-
{
60-
type Output = Array<A, Ix2>;
61-
fn solve_upper(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
62-
let mut xs = Vec::new();
63-
for b in bs.axis_iter(Axis(1)) {
64-
let x = self.solve_upper(b.to_owned())?;
65-
xs.push(x);
66-
}
67-
hstack(&xs).map_err(|e| e.into())
68-
}
69-
fn solve_lower(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> {
70-
let mut xs = Vec::new();
71-
for b in bs.axis_iter(Axis(1)) {
72-
let x = self.solve_lower(b.to_owned())?;
73-
xs.push(x);
74-
}
75-
hstack(&xs).map_err(|e| e.into())
76-
}
77-
}
21+
type Output = V;
7822

79-
impl<A: MFloat> SolveTriangular<RcArray<A, Ix2>> for RcArray<A, Ix2> {
80-
type Output = RcArray<A, Ix2>;
81-
fn solve_upper(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
82-
// XXX unnecessary clone
83-
let x = self.to_owned().solve_upper(&b)?;
84-
Ok(x.into_shared())
85-
}
86-
fn solve_lower(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> {
87-
// XXX unnecessary clone
88-
let x = self.to_owned().solve_lower(&b)?;
89-
Ok(x.into_shared())
23+
fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: V) -> Result<Self::Output> {
24+
let la = self.layout()?;
25+
let lb = b.layout()?;
26+
let a_ = self.as_allocated()?;
27+
A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?;
28+
Ok(b)
9029
}
9130
}
9231

93-
pub fn drop_upper<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
32+
pub fn drop_upper<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
9433
where S: DataMut<Elem = A>
9534
{
96-
for ((i, j), val) in a.indexed_iter_mut() {
97-
if i < j {
98-
*val = A::zero();
99-
}
100-
}
101-
a
35+
a.into_triangular(UPLO::Lower)
10236
}
10337

104-
pub fn drop_lower<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
38+
pub fn drop_lower<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
10539
where S: DataMut<Elem = A>
10640
{
107-
for ((i, j), val) in a.indexed_iter_mut() {
108-
if i > j {
109-
*val = A::zero();
110-
}
111-
}
112-
a
41+
a.into_triangular(UPLO::Upper)
11342
}
11443

11544
pub trait IntoTriangular<T> {

0 commit comments

Comments
 (0)