Skip to content

Commit 212e60d

Browse files
ENH add modular Prox Newton solver (#51)
Co-authored-by: mathurinm <[email protected]>
1 parent 99d0a06 commit 212e60d

File tree

4 files changed

+456
-1
lines changed

4 files changed

+456
-1
lines changed

skglm/datafits/single_task.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def get_spec(self):
126126
def params_to_dict(self):
127127
return dict()
128128

129+
def raw_grad(self, y, Xw):
130+
"""Compute gradient of datafit w.r.t ``Xw``."""
131+
return -y / (1 + np.exp(y * Xw)) / len(y)
132+
133+
def raw_hessian(self, y, Xw):
134+
"""Compute Hessian of datafit w.r.t ``Xw``."""
135+
exp_minus_yXw = np.exp(-y * Xw)
136+
return exp_minus_yXw / (1 + exp_minus_yXw) ** 2 / len(y)
137+
129138
def initialize(self, X, y):
130139
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
131140

skglm/solvers/prox_newton.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
import numpy as np
2+
from numba import njit
3+
from scipy.sparse import issparse
4+
5+
6+
EPS_TOL = 0.3
7+
MAX_CD_ITER = 20
8+
MAX_BACKTRACK_ITER = 20
9+
10+
11+
def prox_newton(X, y, datafit, penalty, w_init=None, p0=10,
12+
max_iter=20, max_pn_iter=1000, tol=1e-4, verbose=0):
13+
"""Run a Prox Newton solver combined with working sets.
14+
15+
Parameters
16+
----------
17+
X : array or sparse CSC matrix, shape (n_samples, n_features)
18+
Design matrix.
19+
20+
y : array, shape (n_samples,)
21+
Target vector.
22+
23+
datafit : instance of BaseDatafit
24+
Datafit object.
25+
26+
penalty : instance of BasePenalty
27+
Penalty object.
28+
29+
w_init : array, shape (n_features,), default None
30+
Initial value of coefficients.
31+
If set to None, a zero vector is used instead.
32+
33+
p0 : int, default 10
34+
Minimum number of features to be included in the working set.
35+
36+
max_iter : int, default 20
37+
Maximum number of outer iterations.
38+
39+
max_pn_iter : int, default 1000
40+
Maximum number of prox Newton iterations on each subproblem.
41+
42+
tol : float, default 1e-4
43+
Tolerance for convergence.
44+
45+
verbose : bool, default False
46+
Amount of verbosity. 0/False is silent.
47+
48+
Returns
49+
-------
50+
w : array, shape (n_features,)
51+
Solution that minimizes the problem defined by datafit and penalty.
52+
53+
objs_out : array, shape (n_iter,)
54+
The objective values at every outer iteration.
55+
56+
stop_crit : float
57+
The value of the stopping criterion when the solver stops.
58+
59+
References
60+
----------
61+
.. [1] Massias, M. and Vaiter, S. and Gramfort, A. and Salmon, J.
62+
"Dual Extrapolation for Sparse Generalized Linear Models", JMLR, 2020,
63+
https://arxiv.org/abs/1907.05830
64+
code: https://github.com/mathurinm/celer
65+
66+
.. [2] Johnson, T. B. and Guestrin, C.
67+
"Blitz: A principled meta-algorithm for scaling sparse optimization",
68+
ICML, 2015.
69+
https://proceedings.mlr.press/v37/johnson15.html
70+
code: https://github.com/tbjohns/BlitzL1
71+
"""
72+
n_samples, n_features = X.shape
73+
w = np.zeros(n_features) if w_init is None else w_init
74+
Xw = np.zeros(n_samples) if w_init is None else X @ w_init
75+
all_features = np.arange(n_features)
76+
stop_crit = 0.
77+
p_objs_out = []
78+
79+
is_sparse = issparse(X)
80+
if is_sparse:
81+
X_bundles = (X.data, X.indptr, X.indices)
82+
83+
for t in range(max_iter):
84+
# compute scores
85+
if is_sparse:
86+
grad = _construct_grad_sparse(*X_bundles, y, w, Xw, datafit, all_features)
87+
else:
88+
grad = _construct_grad(X, y, w, Xw, datafit, all_features)
89+
90+
opt = penalty.subdiff_distance(w, grad, all_features)
91+
92+
# check convergences
93+
stop_crit = np.max(opt)
94+
if verbose:
95+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
96+
print(
97+
f"Iteration {t+1}: {p_obj:.10f}, "
98+
f"stopping crit: {stop_crit:.2e}"
99+
)
100+
101+
if stop_crit <= tol:
102+
if verbose:
103+
print(f"Stopping criterion max violation: {stop_crit:.2e}")
104+
break
105+
106+
# build working set
107+
gsupp_size = penalty.generalized_support(w).sum()
108+
ws_size = max(min(p0, n_features),
109+
min(n_features, 2 * gsupp_size))
110+
# similar to np.argsort()[-ws_size:] but without sorting
111+
ws = np.argpartition(opt, -ws_size)[-ws_size:]
112+
113+
grad_ws = grad[ws]
114+
tol_in = EPS_TOL * stop_crit
115+
116+
for pn_iter in range(max_pn_iter):
117+
# find descent direction
118+
if is_sparse:
119+
delta_w_ws, X_delta_w_ws = _descent_direction_s(
120+
*X_bundles, y, w, Xw, grad_ws, datafit,
121+
penalty, ws, tol=EPS_TOL*tol_in)
122+
else:
123+
delta_w_ws, X_delta_w_ws = _descent_direction(
124+
X, y, w, Xw, grad_ws, datafit, penalty, ws, tol=EPS_TOL*tol_in)
125+
126+
# backtracking line search with inplace update of w, Xw
127+
if is_sparse:
128+
grad_ws[:] = _backtrack_line_search_s(
129+
*X_bundles, y, w, Xw, datafit, penalty, delta_w_ws,
130+
X_delta_w_ws, ws)
131+
else:
132+
grad_ws[:] = _backtrack_line_search(
133+
X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws)
134+
135+
# check convergence
136+
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
137+
stop_crit_in = np.max(opt_in)
138+
139+
if max(verbose-1, 0):
140+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
141+
print(
142+
f"PN iteration {pn_iter+1}: {p_obj:.10f}, "
143+
f"stopping crit in: {stop_crit_in:.2e}"
144+
)
145+
146+
if stop_crit_in <= tol_in:
147+
if max(verbose-1, 0):
148+
print("Early exit")
149+
break
150+
151+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
152+
p_objs_out.append(p_obj)
153+
return w, np.asarray(p_objs_out), stop_crit
154+
155+
156+
@njit
157+
def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
158+
penalty, ws, tol):
159+
# Given:
160+
# 1) b = \nabla F(X w_epoch)
161+
# 2) D = \nabla^2 F(X w_epoch) <------> raw_hess
162+
# Minimize quadratic approximation for delta_w = w - w_epoch:
163+
# b.T @ X @ delta_w + \
164+
# 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w)
165+
raw_hess = datafit.raw_hessian(y, Xw_epoch)
166+
167+
lipschitz = np.zeros(len(ws))
168+
for idx, j in enumerate(ws):
169+
lipschitz[idx] = raw_hess @ X[:, j] ** 2
170+
171+
# for a less costly stopping criterion, we do no compute the exact gradient,
172+
# but store each coordinate-wise gradient every time we upate one coordinate:
173+
past_grads = np.zeros(len(ws))
174+
X_delta_w_ws = np.zeros(X.shape[0])
175+
w_ws = w_epoch[ws]
176+
177+
for cd_iter in range(MAX_CD_ITER):
178+
for idx, j in enumerate(ws):
179+
# skip when X[:, j] == 0
180+
if lipschitz[idx] == 0:
181+
continue
182+
183+
past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws)
184+
old_w_idx = w_ws[idx]
185+
stepsize = 1 / lipschitz[idx]
186+
187+
w_ws[idx] = penalty.prox_1d(
188+
old_w_idx - stepsize * past_grads[idx], stepsize, j)
189+
190+
if w_ws[idx] != old_w_idx:
191+
X_delta_w_ws += (w_ws[idx] - old_w_idx) * X[:, j]
192+
193+
if cd_iter % 5 == 0:
194+
# TODO: can be improved by passing in w_ws but breaks for WeightedL1
195+
current_w = w_epoch.copy()
196+
current_w[ws] = w_ws
197+
opt = penalty.subdiff_distance(current_w, past_grads, ws)
198+
if np.max(opt) <= tol:
199+
break
200+
201+
# descent direction
202+
return w_ws - w_epoch[ws], X_delta_w_ws
203+
204+
205+
# sparse version of _compute_descent_direction
206+
@njit
207+
def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
208+
Xw_epoch, grad_ws, datafit, penalty, ws, tol):
209+
raw_hess = datafit.raw_hessian(y, Xw_epoch)
210+
211+
lipschitz = np.zeros(len(ws))
212+
for idx, j in enumerate(ws):
213+
# equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
214+
lipschitz[idx] = _sparse_squared_weighted_norm(
215+
X_data, X_indptr, X_indices, j, raw_hess)
216+
217+
# see _descent_direction() comment
218+
past_grads = np.zeros(len(ws))
219+
X_delta_w_ws = np.zeros(len(y))
220+
w_ws = w_epoch[ws]
221+
222+
for cd_iter in range(MAX_CD_ITER):
223+
for idx, j in enumerate(ws):
224+
# skip when X[:, j] == 0
225+
if lipschitz[idx] == 0:
226+
continue
227+
228+
past_grads[idx] = grad_ws[idx]
229+
# equivalent to cached_grads[idx] += X[:, j] @ (raw_hess * X_delta_w_ws)
230+
past_grads[idx] += _sparse_weighted_dot(
231+
X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess)
232+
233+
old_w_idx = w_ws[idx]
234+
stepsize = 1 / lipschitz[idx]
235+
236+
w_ws[idx] = penalty.prox_1d(
237+
old_w_idx - stepsize * past_grads[idx], stepsize, j)
238+
239+
if w_ws[idx] != old_w_idx:
240+
_update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w_ws,
241+
w_ws[idx] - old_w_idx, j)
242+
243+
if cd_iter % 5 == 0:
244+
# TODO: could be improved by passing in w_ws
245+
current_w = w_epoch.copy()
246+
current_w[ws] = w_ws
247+
opt = penalty.subdiff_distance(current_w, past_grads, ws)
248+
if np.max(opt) <= tol:
249+
break
250+
251+
# descent direction
252+
return w_ws - w_epoch[ws], X_delta_w_ws
253+
254+
255+
@njit
256+
def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws,
257+
X_delta_w_ws, ws):
258+
# 1) find step in [0, 1] such that:
259+
# penalty(w + step * delta_w) - penalty(w) +
260+
# step * \nabla datafit(w + step * delta_w) @ delta_w < 0
261+
# ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf
262+
# 2) inplace update of w and Xw and return grad_ws of the last w and Xw
263+
step, prev_step = 1., 0.
264+
# TODO: could be improved by passing in w[ws]
265+
old_penalty_val = penalty.value(w)
266+
267+
# try step = 1, 1/2, 1/4, ...
268+
for _ in range(MAX_BACKTRACK_ITER):
269+
w[ws] += (step - prev_step) * delta_w_ws
270+
Xw += (step - prev_step) * X_delta_w_ws
271+
272+
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
273+
# TODO: could be improved by passing in w[ws]
274+
stop_crit = penalty.value(w) - old_penalty_val
275+
stop_crit += step * grad_ws @ delta_w_ws
276+
277+
if stop_crit < 0:
278+
break
279+
else:
280+
prev_step = step
281+
step /= 2
282+
else:
283+
pass
284+
# TODO this case is not handled yet
285+
286+
return grad_ws
287+
288+
289+
# sparse version of _backtrack_line_search
290+
@njit
291+
def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, datafit,
292+
penalty, delta_w_ws, X_delta_w_ws, ws):
293+
step, prev_step = 1., 0.
294+
# TODO: could be improved by passing in w[ws]
295+
old_penalty_val = penalty.value(w)
296+
297+
for _ in range(MAX_BACKTRACK_ITER):
298+
w[ws] += (step - prev_step) * delta_w_ws
299+
Xw += (step - prev_step) * X_delta_w_ws
300+
301+
grad_ws = _construct_grad_sparse(X_data, X_indptr, X_indices,
302+
y, w, Xw, datafit, ws)
303+
# TODO: could be improved by passing in w[ws]
304+
stop_crit = penalty.value(w) - old_penalty_val
305+
stop_crit += step * grad_ws.T @ delta_w_ws
306+
307+
if stop_crit < 0:
308+
break
309+
else:
310+
prev_step = step
311+
step /= 2
312+
else:
313+
pass # TODO
314+
315+
return grad_ws
316+
317+
318+
@njit
319+
def _construct_grad(X, y, w, Xw, datafit, ws):
320+
# Compute grad of datafit restricted to ws. This function avoids
321+
# recomputing raw_grad for every j, which is costly for logreg
322+
raw_grad = datafit.raw_grad(y, Xw)
323+
grad = np.zeros(len(ws))
324+
for idx, j in enumerate(ws):
325+
grad[idx] = X[:, j] @ raw_grad
326+
return grad
327+
328+
329+
@njit
330+
def _construct_grad_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, ws):
331+
# Compute grad of datafit restricted to ws in case X sparse
332+
raw_grad = datafit.raw_grad(y, Xw)
333+
grad = np.zeros(len(ws))
334+
for idx, j in enumerate(ws):
335+
grad[idx] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad)
336+
return grad
337+
338+
339+
@njit(fastmath=True)
340+
def _sparse_xj_dot(X_data, X_indptr, X_indices, j, other):
341+
# Compute X[:, j] @ other in case X sparse
342+
res = 0.
343+
for i in range(X_indptr[j], X_indptr[j+1]):
344+
res += X_data[i] * other[X_indices[i]]
345+
return res
346+
347+
348+
@njit(fastmath=True)
349+
def _sparse_weighted_dot(X_data, X_indptr, X_indices, j, other, weights):
350+
# Compute X[:, j] @ (weights * other) in case X sparse
351+
res = 0.
352+
for i in range(X_indptr[j], X_indptr[j+1]):
353+
res += X_data[i] * other[X_indices[i]] * weights[X_indices[i]]
354+
return res
355+
356+
357+
@njit(fastmath=True)
358+
def _sparse_squared_weighted_norm(X_data, X_indptr, X_indices, j, weights):
359+
# Compute weights @ X[:, j]**2 in case X sparse
360+
res = 0.
361+
for i in range(X_indptr[j], X_indptr[j+1]):
362+
res += weights[X_indices[i]] * X_data[i]**2
363+
return res
364+
365+
366+
@njit(fastmath=True)
367+
def _update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w, diff, j):
368+
# Compute X_delta_w += diff * X[:, j] in case of X sparse
369+
for i in range(X_indptr[j], X_indptr[j+1]):
370+
X_delta_w[X_indices[i]] += diff * X_data[i]

0 commit comments

Comments
 (0)