Skip to content

Commit 3d1f524

Browse files
ENH add Square root Lasso (#57)
Co-authored-by: mathurinm <[email protected]>
1 parent 182cae8 commit 3d1f524

File tree

8 files changed

+353
-2
lines changed

8 files changed

+353
-2
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ jobs:
2424
pip install pytest
2525
pip install numpydoc
2626
pip install .
27+
pip install statsmodels cvxopt
2728
- name: Test with pytest
2829
run: pytest -v skglm/

doc/api.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,14 @@ Solvers
7373
GroupBCD
7474
MultiTaskBCD
7575
ProxNewton
76-
76+
77+
78+
Experimental
79+
============
80+
81+
.. currentmodule:: skglm.experimental
82+
83+
.. autosummary::
84+
:toctree: generated/
85+
86+
SqrtLasso

doc/changes/0.2.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
.. _changes_0_2:
22

33
Version 0.2 (in progress)
4-
------------------------
4+
-------------------------
5+
6+
- Experimental :ref:`Square root Lasso <skglm.experimental.SqrtLasso>` class with ProxNewton or Chambolle-Pock solver (PR :gh:`57`)
57

68
- Accelerated block coordinate descent solver :ref:`GroupBCD <skglm.solvers.GroupBCD>` with working sets for problems with group penalties (PR :gh:`29`, :gh:`28`, and :gh:`26`)
79

skglm/experimental/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .sqrt_lasso import SqrtLasso
2+
3+
__all__ = [
4+
SqrtLasso,
5+
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
import numpy as np
3+
from numpy.linalg import norm
4+
import matplotlib.pyplot as plt
5+
from skglm.utils import make_correlated_data
6+
from skglm.experimental.sqrt_lasso import SqrtLasso, _chambolle_pock_sqrt
7+
8+
X, y, _ = make_correlated_data(n_samples=200, n_features=100, random_state=24)
9+
10+
n_samples, n_features = X.shape
11+
alpha_max = norm(X.T @ y, ord=np.inf) / (norm(y) * np.sqrt(n_samples))
12+
13+
alpha = alpha_max / 10
14+
15+
16+
max_iter = 1000
17+
obj_freq = 10
18+
w, _, objs = _chambolle_pock_sqrt(X, y, alpha, max_iter=max_iter, obj_freq=obj_freq)
19+
20+
21+
# no convergence issue if n_features < n_samples, can use ProxNewton
22+
# clf = SqrtLasso(alpha=alpha / np.sqrt(n_samples), verbose=2, tol=1e-10)
23+
clf = SqrtLasso(alpha=alpha, verbose=2, tol=1e-10)
24+
clf.fit(X, y)
25+
26+
# consider that our solver has converged
27+
w_star = clf.coef_
28+
p_star = norm(X @ w_star - y) / np.sqrt(n_samples) + alpha * norm(w_star, ord=1)
29+
30+
plt.close("all")
31+
plt.semilogy(np.arange(1, max_iter+1, obj_freq), np.array(objs) - p_star)
32+
plt.xlabel("CP iteration")
33+
plt.ylabel("$F(x) - F(x^*)$")
34+
plt.show(block=False)

skglm/experimental/sqrt_lasso.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import warnings
2+
import numpy as np
3+
from numpy.linalg import norm
4+
from sklearn.exceptions import ConvergenceWarning
5+
from sklearn.linear_model._base import LinearModel, RegressorMixin
6+
7+
from skglm.penalties import L1
8+
from skglm.utils import compiled_clone, ST_vec, proj_L2ball
9+
from skglm.datafits.base import BaseDatafit
10+
from skglm.solvers.prox_newton import ProxNewton
11+
12+
13+
class SqrtQuadratic(BaseDatafit):
14+
"""Square root quadratic datafit.
15+
16+
The datafit reads::
17+
||y - Xw||_2 / sqrt(n_samples)
18+
"""
19+
20+
def __init__(self):
21+
pass
22+
23+
def get_spec(self):
24+
spec = ()
25+
return spec
26+
27+
def params_to_dict(self):
28+
return dict()
29+
30+
def value(self, y, w, Xw):
31+
return np.linalg.norm(y - Xw) / np.sqrt(len(y))
32+
33+
def raw_grad(self, y, Xw):
34+
"""Compute gradient of datafit w.r.t ``Xw``.
35+
36+
Raises
37+
------
38+
Exception
39+
if value of residuals is less than ``1e-2 * ||y||``.
40+
"""
41+
minus_residual = Xw - y
42+
norm_residuals = norm(minus_residual)
43+
44+
if norm_residuals < 1e-2 * norm(y):
45+
raise ValueError("SmallResidualException")
46+
47+
return minus_residual / (norm_residuals * np.sqrt(len(y)))
48+
49+
def raw_hessian(self, y, Xw):
50+
"""Diagonal matrix upper bounding the Hessian."""
51+
n_samples = len(y)
52+
fill_value = 1 / (np.sqrt(n_samples) * norm(y - Xw))
53+
return np.full(n_samples, fill_value)
54+
55+
56+
class SqrtLasso(LinearModel, RegressorMixin):
57+
"""Square root Lasso estimator based on Prox Newton solver.
58+
59+
The optimization objective for square root Lasso is::
60+
61+
|y - X w||_2 / sqrt(n_samples) + alpha * ||w||_1
62+
63+
Parameters
64+
----------
65+
alpha : float, default 1
66+
Penalty strength.
67+
68+
max_iter : int, default 20
69+
Maximum number of outer iterations.
70+
71+
max_pn_iter : int, default 1000
72+
Maximum number of prox Newton iterations on each subproblem.
73+
74+
p0 : int, default 10
75+
Minimum number of features to be included in the working set.
76+
77+
tol : float, default 1e-4
78+
Tolerance for convergence.
79+
80+
verbose : bool, default False
81+
Amount of verbosity. 0/False is silent.
82+
"""
83+
84+
def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
85+
tol=1e-4, verbose=0):
86+
super().__init__()
87+
self.alpha = alpha
88+
self.max_iter = max_iter
89+
self.max_pn_iter = max_pn_iter
90+
91+
self.p0 = p0
92+
self.tol = tol
93+
self.verbose = verbose
94+
95+
def fit(self, X, y):
96+
"""Fit the model according to the given training data.
97+
98+
Parameters
99+
----------
100+
X : array or sparse CSC matrix, shape (n_samples, n_features)
101+
Training data, where n_samples is the number of samples and
102+
n_features is the number of features.
103+
104+
y : array-like, shape (n_samples,)
105+
Target vector relative to X.
106+
107+
Returns
108+
-------
109+
self :
110+
Fitted estimator.
111+
"""
112+
self.coef_ = self.path(X, y, alphas=[self.alpha])[1][0]
113+
self.intercept_ = 0. # TODO handle fit_intercept
114+
return self
115+
116+
def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
117+
"""Compute Lasso path.
118+
119+
Parameters
120+
----------
121+
X : array, shape (n_samples, n_features)
122+
Design matrix.
123+
124+
y : array, shape (n_samples,)
125+
Target vector.
126+
127+
alphas : array, shape (n_alphas,) default None
128+
Grid of alpha. If None a path is constructed from
129+
(0, alpha_max] with a length ``eps``.
130+
131+
eps: float, default 1e-2
132+
Length of the path. ``eps=1e-3`` means that
133+
``alpha_min = 1e-3 * alpha_max``.
134+
135+
n_alphas: int, default 10
136+
Number of alphas along the path. This argument is
137+
ignored if ``alphas`` was provided.
138+
139+
Returns
140+
-------
141+
alphas : array, shape (n_alphas,)
142+
The alphas along the path where models are computed.
143+
144+
coefs : array, shape (n_features, n_alphas)
145+
Coefficients along the path.
146+
"""
147+
if not hasattr(self, "solver_"):
148+
self.solver_ = ProxNewton(
149+
tol=self.tol, max_iter=self.max_iter, verbose=self.verbose)
150+
# build path
151+
if alphas is None:
152+
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(len(y)) * norm(y))
153+
alphas = alpha_max * np.geomspace(1, eps, n_alphas)
154+
else:
155+
n_alphas = len(alphas)
156+
alphas = np.sort(alphas)[::-1]
157+
158+
n_features = X.shape[1]
159+
sqrt_quadratic = compiled_clone(SqrtQuadratic())
160+
l1_penalty = compiled_clone(L1(1.)) # alpha is set along the path
161+
162+
coefs = np.zeros((n_alphas, n_features))
163+
164+
for i in range(n_alphas):
165+
if self.verbose:
166+
to_print = "##### Computing alpha %d/%d" % (i + 1, n_alphas)
167+
print("#" * len(to_print))
168+
print(to_print)
169+
print("#" * len(to_print))
170+
171+
l1_penalty.alpha = alphas[i]
172+
# no warm start for the first alpha
173+
coef_init = coefs[i].copy() if i else np.zeros(n_features)
174+
175+
try:
176+
coef, _, _ = self.solver_.solve(
177+
X, y, sqrt_quadratic, l1_penalty,
178+
w_init=coef_init, Xw_init=X @ coef_init)
179+
coefs[i] = coef
180+
except ValueError as val_exception:
181+
# make sure to catch residual error
182+
# it's implemented this way as Numba doesn't support custom Exception
183+
if not str(val_exception) == "SmallResidualException":
184+
raise
185+
186+
# save coef despite not converging
187+
# coef_init holds a ref to coef
188+
coef = coef_init
189+
res_norm = norm(y - X @ coef)
190+
warnings.warn(
191+
f"Small residuals prevented the solver from converging "
192+
f"at alpha={alphas[i]:.2e} (residuals' norm: {res_norm:.4e}). "
193+
"Consider fitting with higher alpha.",
194+
ConvergenceWarning
195+
)
196+
coefs[i] = coef
197+
break
198+
199+
return alphas, coefs
200+
201+
202+
def _chambolle_pock_sqrt(X, y, alpha, max_iter=1000, obj_freq=10, verbose=False):
203+
"""Apply Chambolle-Pock algorithm to solve square-root Lasso.
204+
205+
The objective function is:
206+
min_w ||Xw - y||_2/sqrt(n_samples) + alpha * ||w||_1.
207+
"""
208+
n_samples, n_features = X.shape
209+
# dual variable is z, primal is w
210+
z_old = np.zeros(n_samples)
211+
z = z_old.copy()
212+
w = np.zeros(n_features)
213+
214+
objs = []
215+
216+
L = norm(X, ord=2)
217+
# take primal and dual stepsizes equal
218+
tau = 0.99 / L
219+
sigma = 0.99 / L
220+
221+
for t in range(max_iter):
222+
w = ST_vec(w - tau * X.T @ (2 * z - z_old), alpha * np.sqrt(n_samples) * tau)
223+
z_old = z.copy()
224+
z[:] = proj_L2ball(z + sigma * (X @ w - y))
225+
226+
if t % obj_freq == 0:
227+
objs.append(norm(X @ w - y) / np.sqrt(n_samples) + alpha * norm(w, ord=1))
228+
if verbose:
229+
print(f"Iter {t}, obj {objs[-1]: .10f}")
230+
231+
return w, z, objs
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import numpy as np
3+
from numpy.linalg import norm
4+
5+
from skglm.utils import make_correlated_data
6+
from skglm.experimental.sqrt_lasso import SqrtLasso, _chambolle_pock_sqrt
7+
8+
9+
def test_alpha_max():
10+
n_samples, n_features = 50, 10
11+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
12+
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
13+
14+
sqrt_lasso = SqrtLasso(alpha=alpha_max).fit(X, y)
15+
16+
np.testing.assert_equal(sqrt_lasso.coef_, 0)
17+
18+
19+
def test_vs_statsmodels():
20+
try:
21+
from statsmodels.regression import linear_model # noqa
22+
except ImportError:
23+
pytest.xfail("This test requires statsmodels to run.")
24+
n_samples, n_features = 50, 10
25+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
26+
27+
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
28+
n_alphas = 3
29+
alphas = alpha_max * np.geomspace(1, 1e-2, n_alphas+1)[1:]
30+
31+
sqrt_lasso = SqrtLasso(tol=1e-9)
32+
coefs_skglm = sqrt_lasso.path(X, y, alphas)[1]
33+
34+
coefs_statsmodels = np.zeros((len(alphas), n_features))
35+
36+
# fit statsmodels on path
37+
for i in range(n_alphas):
38+
alpha = alphas[i]
39+
model = linear_model.OLS(y, X)
40+
model = model.fit_regularized(method='sqrt_lasso', L1_wt=1.,
41+
alpha=n_samples * alpha)
42+
coefs_statsmodels[i] = model.params
43+
44+
np.testing.assert_almost_equal(coefs_skglm, coefs_statsmodels, decimal=4)
45+
46+
47+
def test_prox_newton_cp():
48+
n_samples, n_features = 50, 10
49+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
50+
51+
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(n_samples) * norm(y))
52+
alpha = alpha_max / 10
53+
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
54+
w, _, _ = _chambolle_pock_sqrt(X, y, alpha, max_iter=1000)
55+
np.testing.assert_allclose(clf.coef_, w)
56+
57+
58+
if __name__ == '__main__':
59+
pass

skglm/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ def ST_vec(x, u):
105105
return np.sign(x) * np.maximum(0., np.abs(x) - u)
106106

107107

108+
@njit
109+
def proj_L2ball(u):
110+
"""Project input on L2 unit ball."""
111+
norm_u = norm(u)
112+
if norm_u <= 1:
113+
return u
114+
return u / norm_u
115+
116+
108117
@njit
109118
def BST(x, u):
110119
"""Block soft-thresholding of vector x at level u."""

0 commit comments

Comments
 (0)