4
4
from sklearn .utils import check_array
5
5
from skglm .solvers .common import construct_grad , construct_grad_sparse , dist_fix_point
6
6
7
+ from skglm .utils import AndersonAcceleration
8
+
7
9
8
10
def cd_solver_path (X , y , datafit , penalty , alphas = None ,
9
11
coef_init = None , max_iter = 20 , max_epochs = 50_000 ,
10
- p0 = 10 , tol = 1e-4 , use_acc = True , return_n_iter = False ,
12
+ p0 = 10 , tol = 1e-4 , return_n_iter = False ,
11
13
ws_strategy = "subdiff" , verbose = 0 ):
12
14
r"""Compute optimization path with Anderson accelerated coordinate descent.
13
15
@@ -47,9 +49,6 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
47
49
tol : float, optional
48
50
The tolerance for the optimization.
49
51
50
- use_acc : bool, optional
51
- Usage of Anderson acceleration for faster convergence.
52
-
53
52
return_n_iter : bool, optional
54
53
If True, number of iterations along the path are returned.
55
54
@@ -148,7 +147,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
148
147
sol = cd_solver (
149
148
X , y , datafit , penalty , w , Xw ,
150
149
max_iter = max_iter , max_epochs = max_epochs , p0 = p0 , tol = tol ,
151
- use_acc = use_acc , verbose = verbose , ws_strategy = ws_strategy )
150
+ verbose = verbose , ws_strategy = ws_strategy )
152
151
153
152
coefs [:, t ] = w
154
153
stop_crits [t ] = sol [- 1 ]
@@ -165,7 +164,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
165
164
166
165
def cd_solver (
167
166
X , y , datafit , penalty , w , Xw , max_iter = 50 , max_epochs = 50_000 , p0 = 10 ,
168
- tol = 1e-4 , use_acc = True , K = 5 , ws_strategy = "subdiff" , verbose = 0 ):
167
+ tol = 1e-4 , ws_strategy = "subdiff" , verbose = 0 ):
169
168
r"""Run a coordinate descent solver.
170
169
171
170
Parameters
@@ -201,12 +200,6 @@ def cd_solver(
201
200
tol : float, optional
202
201
The tolerance for the optimization.
203
202
204
- use_acc : bool, optional
205
- Usage of Anderson acceleration for faster convergence.
206
-
207
- K : int, optional
208
- The number of past primal iterates used to build an extrapolated point.
209
-
210
203
ws_strategy : ('subdiff'|'fixpoint'), optional
211
204
The score used to build the working set.
212
205
@@ -226,13 +219,14 @@ def cd_solver(
226
219
"""
227
220
if ws_strategy not in ("subdiff" , "fixpoint" ):
228
221
raise ValueError (f'Unsupported value for ws_strategy: { ws_strategy } ' )
229
- n_features = X .shape [ 1 ]
222
+ n_samples , n_features = X .shape
230
223
pen = penalty .is_penalized (n_features )
231
224
unpen = ~ pen
232
225
n_unpen = unpen .sum ()
233
226
obj_out = []
234
227
all_feats = np .arange (n_features )
235
228
stop_crit = np .inf # initialize for case n_iter=0
229
+ w_acc , Xw_acc = np .zeros (n_features ), np .zeros (n_samples )
236
230
237
231
is_sparse = sparse .issparse (X )
238
232
for t in range (max_iter ):
@@ -259,14 +253,12 @@ def cd_solver(
259
253
opt [unpen ] = np .inf # always include unpenalized features
260
254
opt [penalty .generalized_support (w )] = np .inf
261
255
262
- # here use topk instead of sorting the full array
263
- # ie the following line
256
+ # here use topk instead of np.argsort(opt)[-ws_size:]
264
257
ws = np .argpartition (opt , - ws_size )[- ws_size :]
265
- # is equivalent to ws = np.argsort(opt)[-ws_size:]
266
258
267
- if use_acc :
268
- last_K_w = np . zeros ([ K + 1 , ws_size ] )
269
- U = np . zeros ([ K , ws_size ])
259
+ # re init AA at every iter to consider ws
260
+ accelerator = AndersonAcceleration ( K = 5 )
261
+ w_acc [:] = 0.
270
262
271
263
if verbose :
272
264
print (f'Iteration { t + 1 } , { ws_size } feats in subpb.' )
@@ -283,45 +275,18 @@ def cd_solver(
283
275
284
276
# 3) do Anderson acceleration on smaller problem
285
277
# TODO optimize computation using ws
286
- if use_acc :
287
- last_K_w [epoch % (K + 1 )] = w [ws ]
288
-
289
- if epoch % (K + 1 ) == K :
290
- for k in range (K ):
291
- U [k ] = last_K_w [k + 1 ] - last_K_w [k ]
292
- C = np .dot (U , U .T )
293
-
294
- try :
295
- z = np .linalg .solve (C , np .ones (K ))
296
- # When C is ill-conditioned, z can take very large finite
297
- # positive and negative values (1e35 and -1e35), which leads
298
- # to z.sum() being null.
299
- if z .sum () == 0 :
300
- raise np .linalg .LinAlgError
301
- except np .linalg .LinAlgError :
302
- if max (verbose - 1 , 0 ):
303
- print ("----------Linalg error" )
304
- else :
305
- c = z / z .sum ()
306
- w_acc = np .zeros (n_features )
307
- w_acc [ws ] = np .sum (
308
- last_K_w [:- 1 ] * c [:, None ], axis = 0 )
309
- # TODO create a p_obj function ?
310
- # TODO : managed penalty.value(w[ws])
311
- p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
312
- # p_obj = datafit.value(y, w, Xw) +penalty.value(w[ws])
313
- Xw_acc = X [:, ws ] @ w_acc [ws ]
314
- # TODO : managed penalty.value(w[ws])
315
- p_obj_acc = datafit .value (
316
- y , w_acc , Xw_acc ) + penalty .value (w_acc )
317
- if p_obj_acc < p_obj :
318
- w [:] = w_acc
319
- Xw [:] = Xw_acc
278
+ w_acc [ws ], Xw_acc [:], is_extrapolated = accelerator .extrapolate (w [ws ], Xw )
320
279
321
- if epoch % 10 == 0 :
280
+ if is_extrapolated : # avoid computing p_obj for un-extrapolated w, Xw
322
281
# TODO : manage penalty.value(w, ws) for weighted Lasso
323
- p_obj = datafit .value (y , w [ws ], Xw ) + penalty .value (w )
282
+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
283
+ p_obj_acc = datafit .value (y , w_acc , Xw_acc ) + penalty .value (w_acc )
324
284
285
+ if p_obj_acc < p_obj :
286
+ w [:], Xw [:] = w_acc , Xw_acc
287
+ p_obj = p_obj_acc
288
+
289
+ if epoch % 10 == 0 :
325
290
if is_sparse :
326
291
grad_ws = construct_grad_sparse (
327
292
X .data , X .indptr , X .indices , y , w , Xw , datafit , ws )
@@ -334,6 +299,7 @@ def cd_solver(
334
299
335
300
stop_crit_in = np .max (opt_ws )
336
301
if max (verbose - 1 , 0 ):
302
+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
337
303
print (f"Epoch { epoch + 1 } , objective { p_obj :.10f} , "
338
304
f"stopping crit { stop_crit_in :.2e} " )
339
305
if ws_size == n_features :
@@ -344,6 +310,7 @@ def cd_solver(
344
310
if max (verbose - 1 , 0 ):
345
311
print ("Early exit" )
346
312
break
313
+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
347
314
obj_out .append (p_obj )
348
315
return w , np .array (obj_out ), stop_crit
349
316
0 commit comments