Skip to content

Commit cb43489

Browse files
authored
ENH - add Primal-Dual Coordinate Descent solver (#131)
1 parent cb715d2 commit cb43489

File tree

3 files changed

+268
-2
lines changed

3 files changed

+268
-2
lines changed

skglm/experimental/pdcd_ws.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import warnings
2+
3+
import numpy as np
4+
from numpy.linalg import norm
5+
from scipy.sparse import issparse
6+
7+
from numba import njit
8+
from skglm.utils.jit_compilation import compiled_clone
9+
from sklearn.exceptions import ConvergenceWarning
10+
11+
12+
class PDCD_WS:
13+
r"""Primal-Dual Coordinate Descent solver with working sets.
14+
15+
It solves::
16+
17+
\min_w F(Xw) + G(w)
18+
19+
using a primal-dual method on the saddle point problem::
20+
21+
\min_w \max_z <Xw, z> + G(w) - F^*(z)
22+
23+
where :math:`F` is the datafit term (:math:`F^*` its Fenchel conjugate)
24+
and :math:`G` is the penalty term.
25+
26+
The datafit is required to be convex and proximable. Also, the penalty
27+
is required to be convex, separable, and proximable.
28+
29+
The solver is an adaptation of algorithm [1] to working sets [2].
30+
The working sets are built using a fixed point distance strategy
31+
where each feature is assigned a score based how much its coefficient varies
32+
when performing a primal update::
33+
34+
\text{score}_j = \abs{w_j - prox_{\tau_j, G_j}(w_j - \tau_j <X_j, z>)}
35+
36+
where :maths:`\tau_j` is the primal step associated with the j-th feature.
37+
38+
Parameters
39+
----------
40+
max_iter : int, optional
41+
The maximum number of iterations or equivalently the
42+
the maximum number of solved subproblems.
43+
44+
max_epochs : int, optional
45+
Maximum number of primal CD epochs on each subproblem.
46+
47+
dual_init : array, shape (n_samples,) default None
48+
The initialization of dual variables.
49+
If None, they are initialized as the 0 vector ``np.zeros(n_samples)``.
50+
51+
p0 : int, optional
52+
First working set size.
53+
54+
tol : float, optional
55+
The tolerance for the optimization.
56+
57+
verbose : bool or int, default False
58+
Amount of verbosity. 0/False is silent.
59+
60+
References
61+
----------
62+
.. [1] Olivier Fercoq and Pascal Bianchi,
63+
"A Coordinate-Descent Primal-Dual Algorithm with Large Step Size and Possibly
64+
Nonseparable Functions", SIAM Journal on Optimization, 2020,
65+
https://epubs.siam.org/doi/10.1137/18M1168480,
66+
code: https://github.com/Badr-MOUFAD/Fercoq-Bianchi-solver
67+
68+
.. [2] Bertrand, Q. and Klopfenstein, Q. and Bannier, P.-A. and Gidel, G.
69+
and Massias, M.
70+
"Beyond L1: Faster and Better Sparse Models with skglm", NeurIPS, 2022
71+
https://arxiv.org/abs/2204.07826
72+
"""
73+
74+
def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
75+
p0=100, tol=1e-6, verbose=False):
76+
self.max_iter = max_iter
77+
self.max_epochs = max_epochs
78+
self.dual_init = dual_init
79+
self.p0 = p0
80+
self.tol = tol
81+
self.verbose = verbose
82+
83+
def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None):
84+
if issparse(X):
85+
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")
86+
87+
datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_)
88+
n_samples, n_features = X.shape
89+
90+
# init steps
91+
# Despite violating the conditions mentioned in [1]
92+
# this choice of steps yield in practice a convergent algorithm
93+
# with better speed of convergence
94+
dual_step = 1 / norm(X, ord=2)
95+
primal_steps = 1 / norm(X, axis=0, ord=2)
96+
97+
# primal vars
98+
w = np.zeros(n_features) if w_init is None else w_init
99+
Xw = np.zeros(n_samples) if Xw_init is None else Xw_init
100+
101+
# dual vars
102+
if self.dual_init is None:
103+
z = np.zeros(n_samples)
104+
z_bar = np.zeros(n_samples)
105+
else:
106+
z = self.dual_init.copy()
107+
z_bar = self.dual_init.copy()
108+
109+
p_objs = []
110+
stop_crit = 0.
111+
all_features = np.arange(n_features)
112+
113+
for iteration in range(self.max_iter):
114+
115+
# check convergence using fixed-point criteria on both dual and primal
116+
opts_primal = _scores_primal(X, w, z, penalty, primal_steps, all_features)
117+
opt_dual = _score_dual(y, z, Xw, datafit, dual_step)
118+
119+
stop_crit = max(max(opts_primal), opt_dual)
120+
121+
if self.verbose:
122+
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
123+
print(
124+
f"Iteration {iteration+1}: {current_p_obj:.10f}, "
125+
f"stopping crit: {stop_crit:.2e}")
126+
127+
if stop_crit <= self.tol:
128+
break
129+
130+
# build ws
131+
gsupp_size = (w != 0).sum()
132+
ws_size = max(min(self.p0, n_features),
133+
min(n_features, 2 * gsupp_size))
134+
135+
# similar to np.argsort()[-ws_size:] but without full sort
136+
ws = np.argpartition(opts_primal, -ws_size)[-ws_size:]
137+
138+
# solve sub problem
139+
# inplace update of w, Xw, z, z_bar
140+
PDCD_WS._solve_subproblem(
141+
y, X, w, Xw, z, z_bar, datafit, penalty,
142+
primal_steps, dual_step, ws, self.max_epochs, tol_in=0.3*stop_crit)
143+
144+
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
145+
p_objs.append(current_p_obj)
146+
else:
147+
warnings.warn(
148+
f"PDCD_WS did not converge for tol={self.tol:.3e} "
149+
f"and max_iter={self.max_iter}.\n"
150+
"Considering increasing `max_iter` or `tol`.",
151+
category=ConvergenceWarning
152+
)
153+
154+
return w, np.asarray(p_objs), stop_crit
155+
156+
@staticmethod
157+
@njit
158+
def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
159+
primal_steps, dual_step, ws, max_epochs, tol_in):
160+
n_features = X.shape[1]
161+
162+
for epoch in range(max_epochs):
163+
164+
for j in ws:
165+
# update primal
166+
old_w_j = w[j]
167+
pseudo_grad = X[:, j] @ (2 * z_bar - z)
168+
w[j] = penalty.prox_1d(
169+
old_w_j - primal_steps[j] * pseudo_grad,
170+
primal_steps[j], j)
171+
172+
# keep Xw syncr with X @ w
173+
delta_w_j = w[j] - old_w_j
174+
if delta_w_j:
175+
Xw += delta_w_j * X[:, j]
176+
177+
# update dual
178+
z_bar[:] = datafit.prox_conjugate(z + dual_step * Xw,
179+
dual_step, y)
180+
z += (z_bar - z) / n_features
181+
182+
# check convergence using fixed-point criteria on both dual and primal
183+
if epoch % 10 == 0:
184+
opts_primal_in = _scores_primal(X, w, z, penalty, primal_steps, ws)
185+
opt_dual_in = _score_dual(y, z, Xw, datafit, dual_step)
186+
187+
stop_crit_in = max(max(opts_primal_in), opt_dual_in)
188+
189+
if stop_crit_in <= tol_in:
190+
break
191+
192+
@staticmethod
193+
def _validate_init(datafit_, penalty_):
194+
# validate datafit
195+
missing_attrs = []
196+
for attr in ('prox_conjugate', 'subdiff_distance'):
197+
if not hasattr(datafit_, attr):
198+
missing_attrs.append(f"`{attr}`")
199+
200+
if len(missing_attrs):
201+
raise AttributeError(
202+
"Datafit is not compatible with PDCD_WS solver.\n"
203+
"Datafit must implement `prox_conjugate` and `subdiff_distance`.\n"
204+
f"Missing {' and '.join(missing_attrs)}."
205+
)
206+
207+
# jit compile classes
208+
compiled_datafit = compiled_clone(datafit_)
209+
compiled_penalty = compiled_clone(penalty_)
210+
211+
return compiled_datafit, compiled_penalty
212+
213+
214+
@njit
215+
def _scores_primal(X, w, z, penalty, primal_steps, ws):
216+
scores_ws = np.zeros(len(ws))
217+
218+
for idx, j in enumerate(ws):
219+
next_w_j = penalty.prox_1d(w[j] - primal_steps[j] * X[:, j] @ z,
220+
primal_steps[j], j)
221+
scores_ws[idx] = abs(w[j] - next_w_j)
222+
223+
return scores_ws
224+
225+
226+
@njit
227+
def _score_dual(y, z, Xw, datafit, dual_step):
228+
next_z = datafit.prox_conjugate(z + dual_step * Xw,
229+
dual_step, y)
230+
return norm(z - next_z, ord=np.inf)

skglm/experimental/sqrt_lasso.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.linear_model._base import LinearModel, RegressorMixin
66

77
from skglm.penalties import L1
8-
from skglm.utils.prox_funcs import ST_vec, proj_L2ball
8+
from skglm.utils.prox_funcs import ST_vec, proj_L2ball, BST
99
from skglm.utils.jit_compilation import compiled_clone
1010
from skglm.datafits.base import BaseDatafit
1111
from skglm.solvers.prox_newton import ProxNewton
@@ -54,6 +54,24 @@ def raw_hessian(self, y, Xw):
5454
fill_value = 1 / norm(y - Xw)
5555
return np.full(n_samples, fill_value)
5656

57+
def prox(self, w, step, y):
58+
"""Prox of ``step * ||y - . ||``."""
59+
return y - BST(y - w, step)
60+
61+
def prox_conjugate(self, z, step, y):
62+
"""Prox of ``step * ||y - . ||^*``."""
63+
return proj_L2ball(z - step * y)
64+
65+
def subdiff_distance(self, Xw, z, y):
66+
"""Distance of ``z`` to subdiff of ||y - . || at ``Xw``."""
67+
# computation note: \partial ||y - . ||(Xw) = - \partial || . ||(y - Xw)
68+
y_minus_Xw = y - Xw
69+
70+
if np.any(y_minus_Xw):
71+
return norm(z + y_minus_Xw / norm(y_minus_Xw))
72+
73+
return norm(z - proj_L2ball(z))
74+
5775

5876
class SqrtLasso(LinearModel, RegressorMixin):
5977
"""Square root Lasso estimator based on Prox Newton solver.

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import numpy as np
33
from numpy.linalg import norm
44

5+
from skglm.penalties import L1
56
from skglm.utils.data import make_correlated_data
6-
from skglm.experimental.sqrt_lasso import SqrtLasso, _chambolle_pock_sqrt
7+
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
8+
_chambolle_pock_sqrt)
9+
from skglm.experimental.pdcd_ws import PDCD_WS
710

811

912
def test_alpha_max():
@@ -56,5 +59,20 @@ def test_prox_newton_cp():
5659
np.testing.assert_allclose(clf.coef_, w)
5760

5861

62+
@pytest.mark.parametrize('with_dual_init', [True, False])
63+
def test_PDCD_WS(with_dual_init):
64+
n_samples, n_features = 50, 10
65+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
66+
67+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
68+
alpha = alpha_max / 10
69+
70+
dual_init = y / norm(y) if with_dual_init else None
71+
72+
w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0]
73+
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
74+
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
75+
76+
5977
if __name__ == '__main__':
6078
pass

0 commit comments

Comments
 (0)