Skip to content

Commit 92e1266

Browse files
authored
MAINT - use AndersonAcceleration class in cd_solver (#33)
1 parent fe3bedd commit 92e1266

File tree

1 file changed

+22
-55
lines changed

1 file changed

+22
-55
lines changed

skglm/solvers/cd_solver.py

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from sklearn.utils import check_array
55
from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point
66

7+
from skglm.utils import AndersonAcceleration
8+
79

810
def cd_solver_path(X, y, datafit, penalty, alphas=None,
911
coef_init=None, max_iter=20, max_epochs=50_000,
10-
p0=10, tol=1e-4, use_acc=True, return_n_iter=False,
12+
p0=10, tol=1e-4, return_n_iter=False,
1113
ws_strategy="subdiff", verbose=0):
1214
r"""Compute optimization path with Anderson accelerated coordinate descent.
1315
@@ -47,9 +49,6 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
4749
tol : float, optional
4850
The tolerance for the optimization.
4951
50-
use_acc : bool, optional
51-
Usage of Anderson acceleration for faster convergence.
52-
5352
return_n_iter : bool, optional
5453
If True, number of iterations along the path are returned.
5554
@@ -148,7 +147,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
148147
sol = cd_solver(
149148
X, y, datafit, penalty, w, Xw,
150149
max_iter=max_iter, max_epochs=max_epochs, p0=p0, tol=tol,
151-
use_acc=use_acc, verbose=verbose, ws_strategy=ws_strategy)
150+
verbose=verbose, ws_strategy=ws_strategy)
152151

153152
coefs[:, t] = w
154153
stop_crits[t] = sol[-1]
@@ -165,7 +164,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
165164

166165
def cd_solver(
167166
X, y, datafit, penalty, w, Xw, max_iter=50, max_epochs=50_000, p0=10,
168-
tol=1e-4, use_acc=True, K=5, ws_strategy="subdiff", verbose=0):
167+
tol=1e-4, ws_strategy="subdiff", verbose=0):
169168
r"""Run a coordinate descent solver.
170169
171170
Parameters
@@ -201,12 +200,6 @@ def cd_solver(
201200
tol : float, optional
202201
The tolerance for the optimization.
203202
204-
use_acc : bool, optional
205-
Usage of Anderson acceleration for faster convergence.
206-
207-
K : int, optional
208-
The number of past primal iterates used to build an extrapolated point.
209-
210203
ws_strategy : ('subdiff'|'fixpoint'), optional
211204
The score used to build the working set.
212205
@@ -226,13 +219,14 @@ def cd_solver(
226219
"""
227220
if ws_strategy not in ("subdiff", "fixpoint"):
228221
raise ValueError(f'Unsupported value for ws_strategy: {ws_strategy}')
229-
n_features = X.shape[1]
222+
n_samples, n_features = X.shape
230223
pen = penalty.is_penalized(n_features)
231224
unpen = ~pen
232225
n_unpen = unpen.sum()
233226
obj_out = []
234227
all_feats = np.arange(n_features)
235228
stop_crit = np.inf # initialize for case n_iter=0
229+
w_acc, Xw_acc = np.zeros(n_features), np.zeros(n_samples)
236230

237231
is_sparse = sparse.issparse(X)
238232
for t in range(max_iter):
@@ -259,14 +253,12 @@ def cd_solver(
259253
opt[unpen] = np.inf # always include unpenalized features
260254
opt[penalty.generalized_support(w)] = np.inf
261255

262-
# here use topk instead of sorting the full array
263-
# ie the following line
256+
# here use topk instead of np.argsort(opt)[-ws_size:]
264257
ws = np.argpartition(opt, -ws_size)[-ws_size:]
265-
# is equivalent to ws = np.argsort(opt)[-ws_size:]
266258

267-
if use_acc:
268-
last_K_w = np.zeros([K + 1, ws_size])
269-
U = np.zeros([K, ws_size])
259+
# re init AA at every iter to consider ws
260+
accelerator = AndersonAcceleration(K=5)
261+
w_acc[:] = 0.
270262

271263
if verbose:
272264
print(f'Iteration {t + 1}, {ws_size} feats in subpb.')
@@ -283,45 +275,18 @@ def cd_solver(
283275

284276
# 3) do Anderson acceleration on smaller problem
285277
# TODO optimize computation using ws
286-
if use_acc:
287-
last_K_w[epoch % (K + 1)] = w[ws]
288-
289-
if epoch % (K + 1) == K:
290-
for k in range(K):
291-
U[k] = last_K_w[k + 1] - last_K_w[k]
292-
C = np.dot(U, U.T)
293-
294-
try:
295-
z = np.linalg.solve(C, np.ones(K))
296-
# When C is ill-conditioned, z can take very large finite
297-
# positive and negative values (1e35 and -1e35), which leads
298-
# to z.sum() being null.
299-
if z.sum() == 0:
300-
raise np.linalg.LinAlgError
301-
except np.linalg.LinAlgError:
302-
if max(verbose - 1, 0):
303-
print("----------Linalg error")
304-
else:
305-
c = z / z.sum()
306-
w_acc = np.zeros(n_features)
307-
w_acc[ws] = np.sum(
308-
last_K_w[:-1] * c[:, None], axis=0)
309-
# TODO create a p_obj function ?
310-
# TODO : managed penalty.value(w[ws])
311-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
312-
# p_obj = datafit.value(y, w, Xw) +penalty.value(w[ws])
313-
Xw_acc = X[:, ws] @ w_acc[ws]
314-
# TODO : managed penalty.value(w[ws])
315-
p_obj_acc = datafit.value(
316-
y, w_acc, Xw_acc) + penalty.value(w_acc)
317-
if p_obj_acc < p_obj:
318-
w[:] = w_acc
319-
Xw[:] = Xw_acc
278+
w_acc[ws], Xw_acc[:], is_extrapolated = accelerator.extrapolate(w[ws], Xw)
320279

321-
if epoch % 10 == 0:
280+
if is_extrapolated: # avoid computing p_obj for un-extrapolated w, Xw
322281
# TODO : manage penalty.value(w, ws) for weighted Lasso
323-
p_obj = datafit.value(y, w[ws], Xw) + penalty.value(w)
282+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
283+
p_obj_acc = datafit.value(y, w_acc, Xw_acc) + penalty.value(w_acc)
324284

285+
if p_obj_acc < p_obj:
286+
w[:], Xw[:] = w_acc, Xw_acc
287+
p_obj = p_obj_acc
288+
289+
if epoch % 10 == 0:
325290
if is_sparse:
326291
grad_ws = construct_grad_sparse(
327292
X.data, X.indptr, X.indices, y, w, Xw, datafit, ws)
@@ -334,6 +299,7 @@ def cd_solver(
334299

335300
stop_crit_in = np.max(opt_ws)
336301
if max(verbose - 1, 0):
302+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
337303
print(f"Epoch {epoch + 1}, objective {p_obj:.10f}, "
338304
f"stopping crit {stop_crit_in:.2e}")
339305
if ws_size == n_features:
@@ -344,6 +310,7 @@ def cd_solver(
344310
if max(verbose - 1, 0):
345311
print("Early exit")
346312
break
313+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
347314
obj_out.append(p_obj)
348315
return w, np.array(obj_out), stop_crit
349316

0 commit comments

Comments
 (0)