@@ -99,11 +99,22 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
9999        n_samples , n_features  =  X .shape 
100100
101101        # init steps 
102-         # Despite violating the conditions mentioned in [1] 
103-         # this choice of steps yield in practice a convergent algorithm 
104-         # with better speed of convergence 
105-         dual_step  =  1  /  norm (X , ord = 2 )
106-         primal_steps  =  1  /  norm (X , axis = 0 , ord = 2 )
102+         # choose steps to verify condition: Assumption 2.1 e) 
103+         scale  =  np .sqrt (2  *  n_features )
104+         dual_steps  =  1  /  (norm (X , ord = 2 , axis = 1 ) *  scale )
105+         primal_steps  =  1  /  ((dual_steps [:, None ] *  (X  **  2 )).sum (axis = 0 ) *  scale )
106+ 
107+         # NOTE: primal and dual steps verify condition on steps when multiplied/divided 
108+         # by an arbitrary positive constant 
109+         # HACK: balance primal and dual variable: take bigger steps 
110+         # in the space with highest number of variable 
111+         ratio  =  n_samples  /  n_features 
112+         if  n_samples  >  n_features :
113+             dual_steps  *=  ratio 
114+             primal_steps  /=  ratio 
115+         else :
116+             dual_steps  /=  ratio 
117+             primal_steps  *=  ratio 
107118
108119        # primal vars 
109120        w  =  np .zeros (n_features ) if  w_init  is  None  else  w_init 
@@ -125,7 +136,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
125136
126137            # check convergence using fixed-point criteria on both dual and primal 
127138            opts_primal  =  _scores_primal (X , w , z , penalty , primal_steps , all_features )
128-             opt_dual  =  _score_dual (y , z , Xw , datafit , dual_step )
139+             opt_dual  =  _score_dual (y , z , Xw , datafit , dual_steps )
129140
130141            stop_crit  =  max (max (opts_primal ), opt_dual )
131142
@@ -148,13 +159,9 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
148159
149160            # solve sub problem 
150161            # inplace update of w, Xw, z, z_bar 
151-             if  iteration  ==  0 :
152-                 ep  =  500 
153-             else :
154-                 ep  =  self .max_epochs 
155162            PDCD_WS ._solve_subproblem (
156163                y , X , w , Xw , z , z_bar , datafit , penalty ,
157-                 primal_steps , dual_step , ws , ep , tol_in = 0.3 * stop_crit , verbose = self .verbose - 1 )
164+                 primal_steps , dual_steps , ws , self . max_epochs , tol_in = 0.3 * stop_crit , verbose = self .verbose - 1 )
158165
159166            current_p_obj  =  datafit .value (y , w , Xw ) +  penalty .value (w )
160167            p_objs .append (current_p_obj )
@@ -172,7 +179,7 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
172179    @njit  
173180    def  _solve_subproblem (
174181            y , X , w , Xw , z , z_bar , datafit , penalty , primal_steps ,
175-             dual_step , ws , max_epochs , tol_in , verbose ):
182+             dual_steps , ws , max_epochs , tol_in , verbose ):
176183        n_features  =  X .shape [1 ]
177184
178185        for  epoch  in  range (max_epochs ):
@@ -191,20 +198,26 @@ def _solve_subproblem(
191198                    Xw  +=  delta_w_j  *  X [:, j ]
192199
193200                # update dual 
194-                 z_bar [:] =  datafit .prox_conjugate (z  +  dual_step  *  Xw ,
195-                                                   dual_step , y )
201+                 z_bar [:] =  datafit .prox_conjugate (z  +  dual_steps  *  Xw ,
202+                                                   dual_steps , y )
196203                z  +=  (z_bar  -  z ) /  n_features 
197204
198205            # check convergence using fixed-point criteria on both dual and primal 
199206            if  epoch  %  1  ==  0 :
200207                opts_primal_in  =  _scores_primal (X , w , z , penalty , primal_steps , ws )
201-                 opt_dual_in  =  _score_dual (y , z , Xw , datafit , dual_step )
208+                 opt_dual_in  =  _score_dual (y , z , Xw , datafit , dual_steps )
202209
203210                stop_crit_in  =  max (max (opts_primal_in ), opt_dual_in )
204-                 if  verbose :
205-                     print (f'  epoch { epoch }  , inner stopping crit: ' , stop_crit_in )
206-                     print (opt_dual_in )
207-                     print (opts_primal_in )
211+                 # if verbose: 
212+                 #     current_p_obj = datafit.value(y, w, X@w) + penalty.value(w) 
213+                 #     print( 
214+                 #         f"|----- epoch {epoch+1}: {current_p_obj:.10f}, " 
215+                 #         f"opt primal: {max(opts_primal_in):.2e}, opt dual: {opt_dual_in:.2e}") 
216+ 
217+                 # print(f'  epoch {epoch}, inner stopping crit: ', stop_crit_in) 
218+                 # # print(opt_dual_in) 
219+                 # # print(opts_primal_in) 
220+ 
208221                if  stop_crit_in  <=  tol_in :
209222                    break 
210223
@@ -228,7 +241,7 @@ def _scores_primal(X, w, z, penalty, primal_steps, ws):
228241
229242
230243@njit  
231- def  _score_dual (y , z , Xw , datafit , dual_step ):
232-     next_z  =  datafit .prox_conjugate (z  +  dual_step  *  Xw ,
233-                                     dual_step , y )
244+ def  _score_dual (y , z , Xw , datafit , dual_steps ):
245+     next_z  =  datafit .prox_conjugate (z  +  dual_steps  *  Xw ,
246+                                     dual_steps , y )
234247    return  norm (z  -  next_z , ord = np .inf )
0 commit comments