11use crate :: { generate:: * , inner:: * , norm:: Norm , types:: * } ;
22use ndarray:: * ;
3+ use num_traits:: Zero ;
34
45#[ derive( Debug , Clone ) ]
56pub struct MGS < A > {
@@ -8,6 +9,11 @@ pub struct MGS<A> {
89 r : Vec < Array1 < A > > ,
910}
1011
12+ pub enum Dependence < A : Scalar > {
13+ Orthogonal ( A :: Real ) ,
14+ Coefficient ( Array1 < A > ) ,
15+ }
16+
1117impl < A : Scalar + Lapack > MGS < A > {
1218 pub fn new ( dim : usize ) -> Self {
1319 Self {
@@ -25,7 +31,16 @@ impl<A: Scalar + Lapack> MGS<A> {
2531 self . q . len ( )
2632 }
2733
34+ /// Add new vector, return residual norm
2835 pub fn append < S > ( & mut self , a : ArrayBase < S , Ix1 > ) -> A :: Real
36+ where
37+ S : Data < Elem = A > ,
38+ {
39+ self . append_if ( a, A :: Real :: zero ( ) ) . unwrap ( )
40+ }
41+
42+ /// Add new vector if the residual is larger than relative tolerance
43+ pub fn append_if < S > ( & mut self , a : ArrayBase < S , Ix1 > , rtol : A :: Real ) -> Option < A :: Real >
2944 where
3045 S : Data < Elem = A > ,
3146 {
@@ -39,11 +54,14 @@ impl<A: Scalar + Lapack> MGS<A> {
3954 coef[ i] = c;
4055 }
4156 let nrm = a. norm_l2 ( ) ;
57+ if nrm < rtol {
58+ return None ;
59+ }
4260 coef[ self . len ( ) ] = A :: from_real ( nrm) ;
4361 self . r . push ( coef) ;
4462 azip ! ( mut a in { * a = * a / A :: from_real( nrm) } ) ;
4563 self . q . push ( a) ;
46- nrm
64+ Some ( nrm)
4765 }
4866
4967 /// Get orthogonal basis as Q matrix
@@ -78,7 +96,7 @@ mod tests {
7896 assert_eq ! ( mgs. len( ) , 0 ) ;
7997 }
8098
81- fn test < A : Scalar + Lapack > ( rtol : A :: Real ) {
99+ fn test_append < A : Scalar + Lapack > ( rtol : A :: Real ) {
82100 let mut mgs: MGS < A > = MGS :: new ( N ) ;
83101 let a: Array2 < A > = crate :: generate:: random ( ( N , 3 ) ) ;
84102 dbg ! ( & a) ;
@@ -100,22 +118,43 @@ mod tests {
100118 }
101119
102120 #[ test]
103- fn test_f32 ( ) {
104- test :: < f32 > ( 1e-5 ) ;
121+ fn append ( ) {
122+ test_append :: < f32 > ( 1e-5 ) ;
123+ test_append :: < c32 > ( 1e-5 ) ;
124+ test_append :: < f64 > ( 1e-9 ) ;
125+ test_append :: < c64 > ( 1e-9 ) ;
105126 }
106127
107- #[ test]
108- fn test_c32 ( ) {
109- test :: < c32 > ( 1e-5 ) ;
110- }
128+ fn test_append_if < A : Scalar + Lapack > ( rtol : A :: Real ) {
129+ let mut mgs: MGS < A > = MGS :: new ( N ) ;
130+ let a: Array2 < A > = crate :: generate:: random ( ( N , 8 ) ) ;
131+ dbg ! ( & a) ;
132+ for col in a. axis_iter ( Axis ( 1 ) ) {
133+ match mgs. append_if ( col, rtol) {
134+ Some ( res) => {
135+ dbg ! ( res) ;
136+ }
137+ None => break ,
138+ }
139+ }
140+ let q = mgs. get_q ( ) ;
141+ dbg ! ( & q) ;
142+ let r = mgs. get_r ( ) ;
143+ dbg ! ( & r) ;
111144
112- #[ test]
113- fn test_f64 ( ) {
114- test :: < f64 > ( 1e-9 ) ;
145+ dbg ! ( q. dot( & r) ) ;
146+ close_l2 ( & q. dot ( & r) , & a. slice ( s ! [ .., 0 ..N ] ) , rtol) . unwrap ( ) ;
147+
148+ let qt: Array2 < _ > = conjugate ( & q) ;
149+ dbg ! ( qt. dot( & q) ) ;
150+ close_l2 ( & qt. dot ( & q) , & Array2 :: eye ( N ) , rtol) . unwrap ( ) ;
115151 }
116152
117153 #[ test]
118- fn test_c64 ( ) {
119- test :: < c64 > ( 1e-9 ) ;
154+ fn append_if ( ) {
155+ test_append_if :: < f32 > ( 1e-5 ) ;
156+ test_append_if :: < c32 > ( 1e-5 ) ;
157+ test_append_if :: < f64 > ( 1e-9 ) ;
158+ test_append_if :: < c64 > ( 1e-9 ) ;
120159 }
121160}
0 commit comments