2
2
3
3
use std:: cmp:: min;
4
4
use ndarray:: prelude:: * ;
5
+ use ndarray:: DataMut ;
5
6
use lapack:: c:: Layout ;
6
7
7
8
use error:: { LinalgError , StrideError } ;
@@ -43,6 +44,35 @@ pub trait Matrix: Sized {
43
44
}
44
45
}
45
46
47
+ fn check_layout ( strides : & [ Ixs ] ) -> Result < Layout , StrideError > {
48
+ if min ( strides[ 0 ] , strides[ 1 ] ) != 1 {
49
+ return Err ( StrideError {
50
+ s0 : strides[ 0 ] ,
51
+ s1 : strides[ 1 ] ,
52
+ } ) ; ;
53
+ }
54
+ if strides[ 0 ] < strides[ 1 ] {
55
+ Ok ( Layout :: ColumnMajor )
56
+ } else {
57
+ Ok ( Layout :: RowMajor )
58
+ }
59
+ }
60
+
61
+ fn permutate < A : NdFloat , S > ( mut a : & mut ArrayBase < S , Ix2 > , ipiv : & Vec < i32 > )
62
+ where S : DataMut < Elem = A >
63
+ {
64
+ let m = a. cols ( ) ;
65
+ for ( i, j_) in ipiv. iter ( ) . enumerate ( ) . rev ( ) {
66
+ let j = ( j_ - 1 ) as usize ;
67
+ if i == j {
68
+ continue ;
69
+ }
70
+ for k in 0 ..m {
71
+ a. swap ( ( i, k) , ( j, k) ) ;
72
+ }
73
+ }
74
+ }
75
+
46
76
impl < A : MFloat > Matrix for Array < A , Ix2 > {
47
77
type Scalar = A ;
48
78
type Vector = Array < A , Ix1 > ;
@@ -52,18 +82,7 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
52
82
( self . rows ( ) , self . cols ( ) )
53
83
}
54
84
fn layout ( & self ) -> Result < Layout , StrideError > {
55
- let strides = self . strides ( ) ;
56
- if min ( strides[ 0 ] , strides[ 1 ] ) != 1 {
57
- return Err ( StrideError {
58
- s0 : strides[ 0 ] ,
59
- s1 : strides[ 1 ] ,
60
- } ) ; ;
61
- }
62
- if strides[ 0 ] < strides[ 1 ] {
63
- Ok ( Layout :: ColumnMajor )
64
- } else {
65
- Ok ( Layout :: RowMajor )
66
- }
85
+ check_layout ( self . strides ( ) )
67
86
}
68
87
fn norm_1 ( & self ) -> Self :: Scalar {
69
88
let ( m, n) = self . size ( ) ;
@@ -159,15 +178,45 @@ impl<A: MFloat> Matrix for Array<A, Ix2> {
159
178
Ok ( ( p, lm, am) )
160
179
}
161
180
fn permutate ( & mut self , ipiv : & Self :: Permutator ) {
162
- let ( _, m) = self . size ( ) ;
163
- for ( i, j_) in ipiv. iter ( ) . enumerate ( ) . rev ( ) {
164
- let j = ( j_ - 1 ) as usize ;
165
- if i == j {
166
- continue ;
167
- }
168
- for k in 0 ..m {
169
- self . swap ( ( i, k) , ( j, k) ) ;
170
- }
171
- }
181
+ permutate ( self , ipiv) ;
182
+ }
183
+ }
184
+
185
+ impl < A : MFloat > Matrix for RcArray < A , Ix2 > {
186
+ type Scalar = A ;
187
+ type Vector = RcArray < A , Ix1 > ;
188
+ type Permutator = Vec < i32 > ;
189
+ fn size ( & self ) -> ( usize , usize ) {
190
+ ( self . rows ( ) , self . cols ( ) )
191
+ }
192
+ fn layout ( & self ) -> Result < Layout , StrideError > {
193
+ check_layout ( self . strides ( ) )
194
+ }
195
+ fn norm_1 ( & self ) -> Self :: Scalar {
196
+ // TODO remove clone by into_owned()
197
+ self . to_owned ( ) . norm_1 ( )
198
+ }
199
+ fn norm_i ( & self ) -> Self :: Scalar {
200
+ // TODO remove clone by into_owned()
201
+ self . to_owned ( ) . norm_i ( )
202
+ }
203
+ fn norm_f ( & self ) -> Self :: Scalar {
204
+ // TODO remove clone by into_owned()
205
+ self . to_owned ( ) . norm_f ( )
206
+ }
207
+ fn svd ( self ) -> Result < ( Self , Self :: Vector , Self ) , LinalgError > {
208
+ let ( u, s, v) = self . to_owned ( ) . svd ( ) ?;
209
+ Ok ( ( u. into_shared ( ) , s. into_shared ( ) , v. into_shared ( ) ) )
210
+ }
211
+ fn qr ( self ) -> Result < ( Self , Self ) , LinalgError > {
212
+ let ( q, r) = self . to_owned ( ) . qr ( ) ?;
213
+ Ok ( ( q. into_shared ( ) , r. into_shared ( ) ) )
214
+ }
215
+ fn lu ( self ) -> Result < ( Self :: Permutator , Self , Self ) , LinalgError > {
216
+ let ( p, l, u) = self . to_owned ( ) . lu ( ) ?;
217
+ Ok ( ( p, l. into_shared ( ) , u. into_shared ( ) ) )
218
+ }
219
+ fn permutate ( & mut self , ipiv : & Self :: Permutator ) {
220
+ permutate ( self , ipiv) ;
172
221
}
173
222
}
0 commit comments