|
2 | 2 |
|
3 | 3 | use ndarray::*;
|
4 | 4 | use num_traits::Zero;
|
5 |
| -use super::impl2::UPLO; |
6 | 5 |
|
7 |
| -use super::matrix::{Matrix, MFloat}; |
8 |
| -use super::square::SquareMatrix; |
9 |
| -use super::error::LinalgError; |
10 |
| -use super::util::hstack; |
11 |
| -use super::impls::solve::ImplSolve; |
| 6 | +use super::layout::*; |
| 7 | +use super::error::*; |
| 8 | +use super::impl2::*; |
12 | 9 |
|
13 |
| -pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix { |
| 10 | +/// solve a triangular system with upper triangular matrix |
| 11 | +pub trait SolveTriangular<Rhs> { |
14 | 12 | type Output;
|
15 |
| - /// solve a triangular system with upper triangular matrix |
16 |
| - fn solve_upper(&self, Rhs) -> Result<Self::Output, LinalgError>; |
17 |
| - /// solve a triangular system with lower triangular matrix |
18 |
| - fn solve_lower(&self, Rhs) -> Result<Self::Output, LinalgError>; |
| 13 | + fn solve_triangular(&self, UPLO, Diag, Rhs) -> Result<Self::Output>; |
19 | 14 | }
|
20 | 15 |
|
21 |
| -impl<A, S1, S2> SolveTriangular<ArrayBase<S2, Ix1>> for ArrayBase<S1, Ix2> |
22 |
| - where A: MFloat, |
23 |
| - S1: Data<Elem = A>, |
24 |
| - S2: DataMut<Elem = A>, |
25 |
| - ArrayBase<S1, Ix2>: Matrix + SquareMatrix |
| 16 | +impl<A, S, V> SolveTriangular<V> for ArrayBase<S, Ix2> |
| 17 | + where A: LapackScalar, |
| 18 | + S: Data<Elem = A>, |
| 19 | + V: AllocatedArrayMut<Elem = A> |
26 | 20 | {
|
27 |
| - type Output = ArrayBase<S2, Ix1>; |
28 |
| - fn solve_upper(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> { |
29 |
| - let n = self.square_size()?; |
30 |
| - let layout = self.layout()?; |
31 |
| - let a = self.as_slice_memory_order().unwrap(); |
32 |
| - ImplSolve::solve_triangle(layout, |
33 |
| - 'U' as u8, |
34 |
| - n, |
35 |
| - a, |
36 |
| - b.as_slice_memory_order_mut().unwrap(), |
37 |
| - 1)?; |
38 |
| - Ok(b) |
39 |
| - } |
40 |
| - fn solve_lower(&self, mut b: ArrayBase<S2, Ix1>) -> Result<Self::Output, LinalgError> { |
41 |
| - let n = self.square_size()?; |
42 |
| - let layout = self.layout()?; |
43 |
| - let a = self.as_slice_memory_order().unwrap(); |
44 |
| - ImplSolve::solve_triangle(layout, |
45 |
| - 'L' as u8, |
46 |
| - n, |
47 |
| - a, |
48 |
| - b.as_slice_memory_order_mut().unwrap(), |
49 |
| - 1)?; |
50 |
| - Ok(b) |
51 |
| - } |
52 |
| -} |
53 |
| - |
54 |
| -impl<'a, S1, S2, A> SolveTriangular<&'a ArrayBase<S2, Ix2>> for ArrayBase<S1, Ix2> |
55 |
| - where A: MFloat, |
56 |
| - S1: Data<Elem = A>, |
57 |
| - S2: Data<Elem = A>, |
58 |
| - ArrayBase<S1, Ix2>: Matrix + SquareMatrix |
59 |
| -{ |
60 |
| - type Output = Array<A, Ix2>; |
61 |
| - fn solve_upper(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> { |
62 |
| - let mut xs = Vec::new(); |
63 |
| - for b in bs.axis_iter(Axis(1)) { |
64 |
| - let x = self.solve_upper(b.to_owned())?; |
65 |
| - xs.push(x); |
66 |
| - } |
67 |
| - hstack(&xs).map_err(|e| e.into()) |
68 |
| - } |
69 |
| - fn solve_lower(&self, bs: &ArrayBase<S2, Ix2>) -> Result<Self::Output, LinalgError> { |
70 |
| - let mut xs = Vec::new(); |
71 |
| - for b in bs.axis_iter(Axis(1)) { |
72 |
| - let x = self.solve_lower(b.to_owned())?; |
73 |
| - xs.push(x); |
74 |
| - } |
75 |
| - hstack(&xs).map_err(|e| e.into()) |
76 |
| - } |
77 |
| -} |
| 21 | + type Output = V; |
78 | 22 |
|
79 |
| -impl<A: MFloat> SolveTriangular<RcArray<A, Ix2>> for RcArray<A, Ix2> { |
80 |
| - type Output = RcArray<A, Ix2>; |
81 |
| - fn solve_upper(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> { |
82 |
| - // XXX unnecessary clone |
83 |
| - let x = self.to_owned().solve_upper(&b)?; |
84 |
| - Ok(x.into_shared()) |
85 |
| - } |
86 |
| - fn solve_lower(&self, b: RcArray<A, Ix2>) -> Result<Self::Output, LinalgError> { |
87 |
| - // XXX unnecessary clone |
88 |
| - let x = self.to_owned().solve_lower(&b)?; |
89 |
| - Ok(x.into_shared()) |
| 23 | + fn solve_triangular(&self, uplo: UPLO, diag: Diag, mut b: V) -> Result<Self::Output> { |
| 24 | + let la = self.layout()?; |
| 25 | + let lb = b.layout()?; |
| 26 | + let a_ = self.as_allocated()?; |
| 27 | + A::solve_triangular(la, lb, uplo, diag, a_, b.as_allocated_mut()?)?; |
| 28 | + Ok(b) |
90 | 29 | }
|
91 | 30 | }
|
92 | 31 |
|
93 |
| -pub fn drop_upper<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2> |
| 32 | +pub fn drop_upper<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2> |
94 | 33 | where S: DataMut<Elem = A>
|
95 | 34 | {
|
96 |
| - for ((i, j), val) in a.indexed_iter_mut() { |
97 |
| - if i < j { |
98 |
| - *val = A::zero(); |
99 |
| - } |
100 |
| - } |
101 |
| - a |
| 35 | + a.into_triangular(UPLO::Lower) |
102 | 36 | }
|
103 | 37 |
|
104 |
| -pub fn drop_lower<A: Zero, S>(mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2> |
| 38 | +pub fn drop_lower<A: Zero, S>(a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2> |
105 | 39 | where S: DataMut<Elem = A>
|
106 | 40 | {
|
107 |
| - for ((i, j), val) in a.indexed_iter_mut() { |
108 |
| - if i > j { |
109 |
| - *val = A::zero(); |
110 |
| - } |
111 |
| - } |
112 |
| - a |
| 41 | + a.into_triangular(UPLO::Upper) |
113 | 42 | }
|
114 | 43 |
|
115 | 44 | pub trait IntoTriangular<T> {
|
|
0 commit comments