Skip to content

Commit 34da16f

Browse files
authored
ENH - Add modular Group Prox Newton solver (#103)
1 parent 2cef37c commit 34da16f

File tree

4 files changed

+379
-10
lines changed

4 files changed

+379
-10
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Solvers
7474
FISTA
7575
GramCD
7676
GroupBCD
77+
GroupProxNewton
7778
MultiTaskBCD
7879
ProxNewton
7980

skglm/solvers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from .group_bcd import GroupBCD
66
from .multitask_bcd import MultiTaskBCD
77
from .prox_newton import ProxNewton
8+
from .group_prox_newton import GroupProxNewton
89

910

10-
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
11+
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton,
12+
GroupProxNewton]

skglm/solvers/group_prox_newton.py

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
import numpy as np
2+
from numba import njit
3+
from numpy.linalg import norm
4+
from skglm.solvers.base import BaseSolver
5+
from skglm.utils import check_group_compatible
6+
7+
EPS_TOL = 0.3
8+
MAX_CD_ITER = 20
9+
MAX_BACKTRACK_ITER = 20
10+
11+
12+
class GroupProxNewton(BaseSolver):
13+
"""Group Prox Newton solver combined with working sets.
14+
15+
p0 : int, default 10
16+
Minimum number of features to be included in the working set.
17+
18+
max_iter : int, default 20
19+
Maximum number of outer iterations.
20+
21+
max_pn_iter : int, default 1000
22+
Maximum number of prox Newton iterations on each subproblem.
23+
24+
tol : float, default 1e-4
25+
Tolerance for convergence.
26+
27+
verbose : bool, default False
28+
Amount of verbosity. 0/False is silent.
29+
30+
References
31+
----------
32+
.. [1] Massias, M. and Vaiter, S. and Gramfort, A. and Salmon, J.
33+
"Dual Extrapolation for Sparse Generalized Linear Models", JMLR, 2020,
34+
https://arxiv.org/abs/1907.05830
35+
code: https://github.com/mathurinm/celer
36+
37+
.. [2] Johnson, T. B. and Guestrin, C.
38+
"Blitz: A principled meta-algorithm for scaling sparse optimization",
39+
ICML, 2015.
40+
https://proceedings.mlr.press/v37/johnson15.html
41+
code: https://github.com/tbjohns/BlitzL1
42+
"""
43+
44+
def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
45+
fit_intercept=False, warm_start=False, verbose=0):
46+
self.p0 = p0
47+
self.max_iter = max_iter
48+
self.max_pn_iter = max_pn_iter
49+
self.tol = tol
50+
self.fit_intercept = fit_intercept
51+
self.warm_start = warm_start
52+
self.verbose = verbose
53+
54+
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
55+
check_group_compatible(datafit)
56+
check_group_compatible(penalty)
57+
58+
fit_intercept = self.fit_intercept
59+
n_samples, n_features = X.shape
60+
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
61+
n_groups = len(grp_ptr) - 1
62+
63+
w = np.zeros(n_features + fit_intercept) if w_init is None else w_init
64+
Xw = np.zeros(n_samples) if Xw_init is None else Xw_init
65+
all_groups = np.arange(n_groups)
66+
stop_crit = 0.
67+
p_objs_out = []
68+
69+
for iter in range(self.max_iter):
70+
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
71+
72+
# check convergence
73+
opt = penalty.subdiff_distance(w, grad, all_groups)
74+
stop_crit = np.max(opt)
75+
76+
# optimality of intercept
77+
if fit_intercept:
78+
# gradient w.r.t. intercept (constant features of ones)
79+
intercept_opt = np.abs(np.sum(datafit.raw_grad(y, Xw)))
80+
else:
81+
intercept_opt = 0.
82+
83+
stop_crit = max(stop_crit, intercept_opt)
84+
85+
if self.verbose:
86+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
87+
print(
88+
f"Iteration {iter+1}: {p_obj:.10f}, "
89+
f"stopping crit: {stop_crit:.2e}"
90+
)
91+
92+
if stop_crit <= self.tol:
93+
break
94+
95+
# build working set ws
96+
gsupp_size = penalty.generalized_support(w).sum()
97+
ws_size = max(min(self.p0, n_groups),
98+
min(n_groups, 2 * gsupp_size))
99+
ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort)
100+
101+
grad_ws = _slice_array(grad, ws, grp_ptr, grp_indices)
102+
tol_in = EPS_TOL * stop_crit
103+
104+
# solve subproblem restricted to ws
105+
for pn_iter in range(self.max_pn_iter):
106+
# find descent direction
107+
delta_w_ws, X_delta_w_ws = _descent_direction(
108+
X, y, w, Xw, fit_intercept, grad_ws, datafit, penalty,
109+
ws, tol=EPS_TOL*tol_in)
110+
111+
# find a suitable step size and in-place update w, Xw
112+
grad_ws[:] = _backtrack_line_search(
113+
X, y, w, Xw, fit_intercept, datafit, penalty,
114+
delta_w_ws, X_delta_w_ws, ws)
115+
116+
# check convergence
117+
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
118+
stop_crit_in = np.max(opt_in)
119+
120+
# optimality of intercept
121+
if fit_intercept:
122+
# gradient w.r.t. intercept (constant features of ones)
123+
intercept_opt_in = np.abs(np.sum(datafit.raw_grad(y, Xw)))
124+
else:
125+
intercept_opt_in = 0.
126+
127+
stop_crit_in = max(stop_crit_in, intercept_opt_in)
128+
129+
if max(self.verbose-1, 0):
130+
p_obj = datafit.value(y, w, Xw) + penalty.value(w[:n_features])
131+
print(
132+
f"PN iteration {pn_iter+1}: {p_obj:.10f}, "
133+
f"stopping crit in: {stop_crit_in:.2e}"
134+
)
135+
136+
if stop_crit_in <= tol_in:
137+
if max(self.verbose-1, 0):
138+
print("Early exit")
139+
break
140+
141+
p_obj = datafit.value(y, w, Xw) + penalty.value(w[:n_features])
142+
p_objs_out.append(p_obj)
143+
return w, np.asarray(p_objs_out), stop_crit
144+
145+
146+
@njit
147+
def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
148+
penalty, ws, tol):
149+
# given:
150+
# 1) b = \nabla F(X w_epoch)
151+
# 2) D = \nabla^2 F(X w_epoch) <------> raw_hess
152+
# minimize quadratic approximation for delta_w = w - w_epoch:
153+
# b.T @ X @ delta_w + \
154+
# 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w)
155+
# In BCD, we leverage inequality:
156+
# penalty_g(w_g) + 1/2 ||delta_w_g||^2_H <= \
157+
# penalty_g(w_g) + 1/2 * || H ||^2 * ||delta_w_g||^2
158+
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
159+
n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws])
160+
raw_hess = datafit.raw_hessian(y, Xw_epoch)
161+
162+
lipchitz = np.zeros(len(ws))
163+
for idx, g in enumerate(ws):
164+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
165+
# compute efficiently (few multiplications and avoid copying the cols of X)
166+
# norm(X[:, grp_g_indices].T @ np.diag(raw_hess) @ X[:, grp_g_indices], ord=2)
167+
lipchitz[idx] = norm(_diag_times_X_g(
168+
np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2
169+
170+
if fit_intercept:
171+
lipchitz_intercept = np.sum(raw_hess)
172+
grad_intercept = np.sum(datafit.raw_grad(y, Xw_epoch))
173+
174+
# for a less costly stopping criterion, we do no compute the exact gradient,
175+
# but store each coordinate-wise gradient every time we update one coordinate:
176+
past_grads = np.zeros(n_features_ws)
177+
X_delta_w_ws = np.zeros(X.shape[0])
178+
w_ws = _slice_array(w_epoch, ws, grp_ptr, grp_indices, fit_intercept)
179+
180+
for cd_iter in range(MAX_CD_ITER):
181+
ptr = 0
182+
for idx, g in enumerate(ws):
183+
# skip when X[:, grp_g_indices] == 0
184+
if lipchitz[idx] == 0.:
185+
continue
186+
187+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
188+
range_grp_g = slice(ptr, ptr + len(grp_g_indices))
189+
190+
past_grads[range_grp_g] = grad_ws[range_grp_g]
191+
# += X[:, grp_g_indices].T @ (raw_hess * X_delta_w_ws)
192+
past_grads[range_grp_g] += _X_g_T_dot_vec(
193+
X, raw_hess * X_delta_w_ws, grp_g_indices)
194+
195+
old_w_ws_g = w_ws[range_grp_g].copy()
196+
stepsize = 1 / lipchitz[idx]
197+
198+
w_ws[range_grp_g] = penalty.prox_1group(
199+
old_w_ws_g - stepsize * past_grads[range_grp_g], stepsize, g)
200+
201+
# update X_delta_w_ws without copying the cols of X
202+
# X_delta_w_ws += X[:, grp_g_indices] @ (w_ws[range_grp_g] - old_w_ws_g)
203+
_update_X_delta_w_ws(X, X_delta_w_ws, w_ws[range_grp_g], old_w_ws_g,
204+
grp_g_indices)
205+
206+
ptr += len(grp_g_indices)
207+
208+
# intercept update
209+
if fit_intercept:
210+
past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws
211+
old_intercept = w_ws[-1]
212+
w_ws[-1] -= past_grads_intercept / lipchitz_intercept
213+
214+
if w_ws[-1] != old_intercept:
215+
X_delta_w_ws += w_ws[-1] - old_intercept
216+
217+
if cd_iter % 5 == 0:
218+
# TODO: can be improved by passing in w_ws
219+
current_w = w_epoch.copy()
220+
221+
# for g in ws: current_w[ws_g] = w_ws_g
222+
ptr = 0
223+
for g in ws:
224+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
225+
current_w[grp_g_indices] = w_ws[ptr:ptr+len(grp_g_indices)]
226+
ptr += len(grp_g_indices)
227+
228+
opt = penalty.subdiff_distance(current_w, past_grads, ws)
229+
stop_crit = np.max(opt)
230+
if fit_intercept:
231+
stop_crit = max(stop_crit, np.abs(past_grads_intercept))
232+
233+
if stop_crit <= tol:
234+
break
235+
236+
# descent direction
237+
delta_w_ws = w_ws - _slice_array(w_epoch, ws, grp_ptr, grp_indices, fit_intercept)
238+
return delta_w_ws, X_delta_w_ws
239+
240+
241+
@njit
242+
def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w_ws,
243+
X_delta_w_ws, ws):
244+
# 1) find step in [0, 1] such that:
245+
# penalty(w + step * delta_w) - penalty(w) +
246+
# step * \nabla datafit(w + step * delta_w) @ delta_w < 0
247+
# ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf
248+
# 2) inplace update of w and Xw and return grad_ws of the last w and Xw
249+
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
250+
step, prev_step = 1., 0.
251+
n_features = X.shape[1]
252+
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])
253+
254+
# TODO: could be improved by passing in w[ws]
255+
old_penalty_val = penalty.value(w)
256+
257+
# try step = 1, 1/2, 1/4, ...
258+
for _ in range(MAX_BACKTRACK_ITER):
259+
# for g in ws: w[ws_g] += (step - prev_step) * delta_w_ws_g
260+
ptr = 0
261+
for g in ws:
262+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
263+
w[grp_g_indices] += ((step - prev_step) *
264+
delta_w_ws[ptr:ptr+len(grp_g_indices)])
265+
ptr += len(grp_g_indices)
266+
267+
if fit_intercept:
268+
w[-1] += (step - prev_step) * delta_w_ws[-1]
269+
270+
Xw += (step - prev_step) * X_delta_w_ws
271+
grad_ws = _construct_grad(X, y, w[:n_features], Xw, datafit, ws)
272+
273+
# TODO: could be improved by passing in w[ws]
274+
stop_crit = penalty.value(w[:-1]) - old_penalty_val
275+
stop_crit += step * grad_ws @ delta_w_ws[:n_features_ws]
276+
277+
if fit_intercept:
278+
stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw))
279+
280+
if stop_crit < 0:
281+
break
282+
else:
283+
prev_step = step
284+
step /= 2
285+
else:
286+
pass
287+
# TODO this case is not handled yet
288+
289+
return grad_ws
290+
291+
292+
@njit
293+
def _construct_grad(X, y, w, Xw, datafit, ws):
294+
# compute grad of datafit restricted to ws. This function avoids
295+
# recomputing raw_grad for every j, which is costly for logreg
296+
grp_ptr, grp_indices = datafit.grp_ptr, datafit.grp_indices
297+
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])
298+
299+
raw_grad = datafit.raw_grad(y, Xw)
300+
grad = np.zeros(n_features_ws)
301+
302+
ptr = 0
303+
for g in ws:
304+
# compute grad_g
305+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
306+
for j in grp_g_indices:
307+
grad[ptr] = X[:, j] @ raw_grad
308+
ptr += 1
309+
310+
return grad
311+
312+
313+
@njit
314+
def _slice_array(arr, ws, grp_ptr, grp_indices, fit_intercept=False):
315+
# returns h stacked (arr[ws_1], arr[ws_2], ...)
316+
# include last element when fit_intercept=True
317+
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])
318+
sliced_arr = np.zeros(n_features_ws + fit_intercept)
319+
320+
ptr = 0
321+
for g in ws:
322+
grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]]
323+
sliced_arr[ptr: ptr+len(grp_g_indices)] = arr[grp_g_indices]
324+
ptr += len(grp_g_indices)
325+
326+
if fit_intercept:
327+
sliced_arr[-1] = arr[-1]
328+
329+
return sliced_arr
330+
331+
332+
@njit
333+
def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices):
334+
# X_delta_w_ws += X[:, grp_g_indices] @ (w_ws_g - old_w_ws_g)
335+
# but without copying the cols of X
336+
for idx, j in enumerate(grp_g_indices):
337+
delta_w_j = w_ws_g[idx] - old_w_ws_g[idx]
338+
if w_ws_g[idx] != old_w_ws_g[idx]:
339+
X_delta_w_ws += delta_w_j * X[:, j]
340+
341+
342+
@njit
343+
def _X_g_T_dot_vec(X, vec, grp_g_indices):
344+
# X[:, grp_g_indices].T @ vec
345+
# but without copying the cols of X
346+
result = np.zeros(len(grp_g_indices))
347+
for idx, j in enumerate(grp_g_indices):
348+
result[idx] = X[:, j] @ vec
349+
return result
350+
351+
352+
@njit
353+
def _diag_times_X_g(diag, X, grp_g_indices):
354+
# np.diag(dig) @ X[:, grp_g_indices]
355+
# but without copying the cols of X
356+
result = np.zeros((len(diag), len(grp_g_indices)))
357+
for idx, j in enumerate(grp_g_indices):
358+
result[:, idx] = diag * X[:, j]
359+
return result

0 commit comments

Comments
 (0)