Skip to content

Commit e3efa30

Browse files
authored
ENH add Anderson acceleration to group_bcd_solver (#29)
1 parent 77a6c0a commit e3efa30

File tree

4 files changed

+112
-29
lines changed

4 files changed

+112
-29
lines changed

skglm/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
940940
Target vector relative to X.
941941
942942
alphas : array
943-
Values of regularization strenghts for which solutions are
943+
Values of regularization strengths for which solutions are
944944
computed.
945945
946946
coef_init : array, shape (n_features,), optional

skglm/solvers/group_bcd_solver.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from numba import njit
33

4-
from skglm.utils import check_group_compatible
4+
from skglm.utils import AndersonAcceleration, check_group_compatible
55

66

77
def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
@@ -65,53 +65,54 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
6565
all_groups = np.arange(n_groups)
6666
p_objs_out = np.zeros(max_iter)
6767
stop_crit = 0. # prevent ref before assign when max_iter == 0
68+
accelerator = AndersonAcceleration(K=5)
6869

6970
for t in range(max_iter):
70-
if t == 0: # avoid computing grad and opt twice
71-
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
72-
opt = penalty.subdiff_distance(w, grad, all_groups)
73-
stop_crit = np.max(opt)
71+
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
72+
opt = penalty.subdiff_distance(w, grad, all_groups)
73+
stop_crit = np.max(opt)
74+
75+
if verbose:
76+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
77+
print(
78+
f"Iteration {t+1}: {p_obj:.10f}, "
79+
f"stopping crit: {stop_crit:.2e}"
80+
)
7481

75-
if stop_crit <= tol:
76-
break
82+
if stop_crit <= tol:
83+
break
7784

7885
gsupp_size = penalty.generalized_support(w).sum()
7986
ws_size = max(min(p0, n_groups),
8087
min(n_groups, 2 * gsupp_size))
8188
ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort)
8289

8390
for epoch in range(max_epochs):
91+
# inplace update of w and Xw
8492
_bcd_epoch(X, y, w, Xw, datafit, penalty, ws)
8593

94+
w_acc, Xw_acc = accelerator.extrapolate(w, Xw)
95+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
96+
p_obj_acc = datafit.value(y, w_acc, Xw_acc) + penalty.value(w_acc)
97+
98+
if p_obj_acc < p_obj:
99+
w, Xw = w_acc, Xw_acc
100+
p_obj = p_obj_acc
101+
102+
# check sub-optimality every 10 epochs
86103
if epoch % 10 == 0:
87104
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
88105
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
89106
stop_crit_in = np.max(opt_in)
90107

91108
if max(verbose - 1, 0):
92-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
93109
print(
94110
f"Epoch {epoch+1}: {p_obj:.10f} "
95111
f"obj. variation: {stop_crit_in:.2e}"
96112
)
97113

98114
if stop_crit_in <= 0.3 * stop_crit:
99115
break
100-
101-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
102-
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
103-
opt = penalty.subdiff_distance(w, grad, all_groups)
104-
stop_crit = np.max(opt)
105-
106-
if verbose:
107-
print(
108-
f"Iteration {t+1}: {p_obj:.10f}, "
109-
f"stopping crit: {stop_crit:.2e}"
110-
)
111-
112-
if stop_crit <= tol:
113-
break
114-
115116
p_objs_out[t] = p_obj
116117

117118
return w, p_objs_out, stop_crit
@@ -137,7 +138,6 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
137138
for idx, j in enumerate(grp_g_indices):
138139
if old_w_g[idx] != w[j]:
139140
Xw += (w[j] - old_w_g[idx]) * X[:, j]
140-
return
141141

142142

143143
@njit

skglm/tests/test_group.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from skglm.datafits.group import QuadraticGroup
99
from skglm.solvers.group_bcd_solver import bcd_solver
1010

11-
from skglm.utils import grp_converter, make_correlated_data
11+
from skglm.utils import grp_converter, make_correlated_data, AndersonAcceleration
1212
from celer import GroupLasso, Lasso
1313

1414

@@ -60,8 +60,7 @@ def test_alpha_max(n_groups, n_features, shuffle):
6060
alpha=alpha_max, grp_ptr=grp_ptr,
6161
grp_indices=grp_indices, weights=weights)
6262

63-
w = bcd_solver(
64-
X, y, quad_group, group_penalty, max_iter=10000, tol=0)[0]
63+
w = bcd_solver(X, y, quad_group, group_penalty, tol=1e-12)[0]
6564

6665
np.testing.assert_allclose(norm(w), 0, atol=1e-14)
6766

@@ -82,7 +81,7 @@ def test_equivalence_lasso():
8281
alpha=alpha, grp_ptr=grp_ptr,
8382
grp_indices=grp_indices, weights=weights)
8483

85-
w = bcd_solver(X, y, quad_group, group_penalty, max_iter=10000, tol=1e-12)[0]
84+
w = bcd_solver(X, y, quad_group, group_penalty, tol=1e-12)[0]
8685

8786
celer_lasso = Lasso(
8887
alpha=alpha, fit_intercept=False, tol=1e-12, weights=weights).fit(X, y)
@@ -123,5 +122,45 @@ def test_vs_celer_grouplasso(n_groups, n_features, shuffle):
123122
np.testing.assert_allclose(model.coef_, w, atol=1e-5)
124123

125124

125+
def test_anderson_acceleration():
126+
# VAR: w = rho * w + 1 with |rho| < 1
127+
# converges to w_star = 1 / (1 - rho)
128+
max_iter, tol = 1000, 1e-9
129+
n_features = 2
130+
rho = np.array([0.5, 0.8])
131+
w_star = 1 / (1 - rho)
132+
X = np.diag([2, 5])
133+
134+
# with acceleration
135+
acc = AndersonAcceleration(K=5)
136+
n_iter_acc = 0
137+
w = np.ones(n_features)
138+
Xw = X @ w
139+
for i in range(max_iter):
140+
w, Xw = acc.extrapolate(w, Xw)
141+
w = rho * w + 1
142+
Xw = X @ w
143+
144+
if norm(w - w_star, ord=np.inf) < tol:
145+
n_iter_acc = i
146+
break
147+
148+
# without acceleration
149+
n_iter = 0
150+
w = np.ones(n_features)
151+
for i in range(max_iter):
152+
w = rho * w + 1
153+
154+
if norm(w - w_star, ord=np.inf) < tol:
155+
n_iter = i
156+
break
157+
158+
np.testing.assert_allclose(w, w_star)
159+
np.testing.assert_allclose(Xw, X @ w_star)
160+
161+
np.testing.assert_array_equal(n_iter_acc, 13)
162+
np.testing.assert_array_equal(n_iter, 99)
163+
164+
126165
if __name__ == '__main__':
127166
pass

skglm/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,47 @@ def check_group_compatible(obj):
254254
f"'{obj_name}' is not block-separable. "
255255
f"Missing '{attr}' attribute."
256256
)
257+
258+
259+
class AndersonAcceleration:
260+
"""Abstraction of Anderson Acceleration.
261+
262+
Extrapolate the asymptotic VAR ``w`` and ``Xw``
263+
based on ``K`` previous iterations.
264+
265+
Parameters
266+
----------
267+
K : int
268+
Number of previous iterates to consider for extrapolation.
269+
"""
270+
271+
def __init__(self, K):
272+
self.K, self.current_iter = K, 0
273+
self.arr_w_, self.arr_Xw_ = None, None
274+
275+
def extrapolate(self, w, Xw):
276+
"""Return ``w`` and ``Xw`` extrapolated."""
277+
if self.arr_w_ is None or self.arr_Xw_ is None:
278+
self.arr_w_ = np.zeros((w.shape[0], self.K+1))
279+
self.arr_Xw_ = np.zeros((Xw.shape[0], self.K+1))
280+
281+
if self.current_iter <= self.K:
282+
self.arr_w_[:, self.current_iter] = w
283+
self.arr_Xw_[:, self.current_iter] = Xw
284+
self.current_iter += 1
285+
return w, Xw
286+
287+
U = np.diff(self.arr_w_, axis=1) # compute residuals
288+
289+
# compute extrapolation coefs
290+
try:
291+
inv_UTU_ones = np.linalg.solve(U.T @ U, np.ones(self.K))
292+
except np.linalg.LinAlgError:
293+
return w, Xw
294+
finally:
295+
self.current_iter = 0
296+
297+
# extrapolate
298+
C = inv_UTU_ones / np.sum(inv_UTU_ones)
299+
# floating point errors may cause w and Xw to disagree
300+
return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C

0 commit comments

Comments
 (0)