1
1
2
- use ndarray:: { Ix1 , Ix2 , Array , RcArray , NdFloat , ArrayBase , DataMut } ;
3
-
4
- use matrix :: { Matrix , MFloat } ;
5
- use square :: SquareMatrix ;
6
- use error :: LinalgError ;
7
- use solve :: ImplSolve ;
2
+ use ndarray:: * ;
3
+ use super :: matrix :: { Matrix , MFloat } ;
4
+ use super :: square :: SquareMatrix ;
5
+ use super :: error :: LinalgError ;
6
+ use super :: solve :: ImplSolve ;
7
+ use super :: util :: hstack ;
8
8
9
9
pub trait SolveTriangular < Rhs > : Matrix + SquareMatrix {
10
10
type Output ;
@@ -14,34 +14,70 @@ pub trait SolveTriangular<Rhs>: Matrix + SquareMatrix {
14
14
fn solve_lower ( & self , Rhs ) -> Result < Self :: Output , LinalgError > ;
15
15
}
16
16
17
- impl < A : MFloat > SolveTriangular < Array < A , Ix1 > > for Array < A , Ix2 > {
18
- type Output = Array < A , Ix1 > ;
19
- fn solve_upper ( & self , b : Array < A , Ix1 > ) -> Result < Self :: Output , LinalgError > {
17
+ impl < A , S1 , S2 > SolveTriangular < ArrayBase < S2 , Ix1 > > for ArrayBase < S1 , Ix2 >
18
+ where A : MFloat ,
19
+ S1 : Data < Elem = A > ,
20
+ S2 : DataMut < Elem = A > ,
21
+ ArrayBase < S1 , Ix2 > : Matrix + SquareMatrix
22
+ {
23
+ type Output = ArrayBase < S2 , Ix1 > ;
24
+ fn solve_upper ( & self , mut b : ArrayBase < S2 , Ix1 > ) -> Result < Self :: Output , LinalgError > {
20
25
let n = self . square_size ( ) ?;
21
26
let layout = self . layout ( ) ?;
22
27
let a = self . as_slice_memory_order ( ) . unwrap ( ) ;
23
- let x = ImplSolve :: solve_triangle ( layout, 'U' as u8 , n, a, b. into_raw_vec ( ) , 1 ) ?;
24
- Ok ( Array :: from_vec ( x) )
28
+ {
29
+ let b_ = b. as_slice_memory_order_mut ( ) . unwrap ( ) ;
30
+ ImplSolve :: solve_triangle ( layout, 'U' as u8 , n, a, b_, 1 ) ?;
31
+ }
32
+ Ok ( b)
25
33
}
26
- fn solve_lower ( & self , b : Array < A , Ix1 > ) -> Result < Self :: Output , LinalgError > {
34
+ fn solve_lower ( & self , mut b : ArrayBase < S2 , Ix1 > ) -> Result < Self :: Output , LinalgError > {
27
35
let n = self . square_size ( ) ?;
28
36
let layout = self . layout ( ) ?;
29
37
let a = self . as_slice_memory_order ( ) . unwrap ( ) ;
30
- let x = ImplSolve :: solve_triangle ( layout, 'L' as u8 , n, a, b. into_raw_vec ( ) , 1 ) ?;
31
- Ok ( Array :: from_vec ( x) )
38
+ {
39
+ let b_ = b. as_slice_memory_order_mut ( ) . unwrap ( ) ;
40
+ ImplSolve :: solve_triangle ( layout, 'L' as u8 , n, a, b_, 1 ) ?;
41
+ }
42
+ Ok ( b)
43
+ }
44
+ }
45
+
46
+ impl < ' a , S1 , S2 , A > SolveTriangular < & ' a ArrayBase < S2 , Ix2 > > for ArrayBase < S1 , Ix2 >
47
+ where A : MFloat ,
48
+ S1 : Data < Elem = A > ,
49
+ S2 : Data < Elem = A > ,
50
+ ArrayBase < S1 , Ix2 > : Matrix + SquareMatrix
51
+ {
52
+ type Output = Array < A , Ix2 > ;
53
+ fn solve_upper ( & self , bs : & ArrayBase < S2 , Ix2 > ) -> Result < Self :: Output , LinalgError > {
54
+ let mut xs = Vec :: new ( ) ;
55
+ for b in bs. axis_iter ( Axis ( 1 ) ) {
56
+ let x = self . solve_upper ( b. to_owned ( ) ) ?;
57
+ xs. push ( x) ;
58
+ }
59
+ hstack ( & xs) . map_err ( |e| e. into ( ) )
60
+ }
61
+ fn solve_lower ( & self , bs : & ArrayBase < S2 , Ix2 > ) -> Result < Self :: Output , LinalgError > {
62
+ let mut xs = Vec :: new ( ) ;
63
+ for b in bs. axis_iter ( Axis ( 1 ) ) {
64
+ let x = self . solve_lower ( b. to_owned ( ) ) ?;
65
+ xs. push ( x) ;
66
+ }
67
+ hstack ( & xs) . map_err ( |e| e. into ( ) )
32
68
}
33
69
}
34
70
35
- impl < A : MFloat > SolveTriangular < RcArray < A , Ix1 > > for RcArray < A , Ix2 > {
36
- type Output = RcArray < A , Ix1 > ;
37
- fn solve_upper ( & self , b : RcArray < A , Ix1 > ) -> Result < Self :: Output , LinalgError > {
71
+ impl < A : MFloat > SolveTriangular < RcArray < A , Ix2 > > for RcArray < A , Ix2 > {
72
+ type Output = RcArray < A , Ix2 > ;
73
+ fn solve_upper ( & self , b : RcArray < A , Ix2 > ) -> Result < Self :: Output , LinalgError > {
38
74
// XXX unnecessary clone
39
- let x = self . to_owned ( ) . solve_upper ( b . into_owned ( ) ) ?;
75
+ let x = self . to_owned ( ) . solve_upper ( & b ) ?;
40
76
Ok ( x. into_shared ( ) )
41
77
}
42
- fn solve_lower ( & self , b : RcArray < A , Ix1 > ) -> Result < Self :: Output , LinalgError > {
78
+ fn solve_lower ( & self , b : RcArray < A , Ix2 > ) -> Result < Self :: Output , LinalgError > {
43
79
// XXX unnecessary clone
44
- let x = self . to_owned ( ) . solve_lower ( b . into_owned ( ) ) ?;
80
+ let x = self . to_owned ( ) . solve_lower ( & b ) ?;
45
81
Ok ( x. into_shared ( ) )
46
82
}
47
83
}
0 commit comments