Skip to content

Commit d87638c

Browse files
PascalCarrivainBadr-MOUFADmathurinmfloriankozikowski
authored
ENH add support for intercept in SqrtLasso (#298)
Co-authored-by: Badr-MOUFAD <[email protected]> Co-authored-by: mathurinm <[email protected]> Co-authored-by: floriankozikowski <[email protected]>
1 parent 994c5fc commit d87638c

File tree

4 files changed

+45
-12
lines changed

4 files changed

+45
-12
lines changed

doc/changes/0.4.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _changes_0_4:
22

3-
Version 0.4 (2023/04/08)
3+
Version 0.4 (2025/04/08)
44
-------------------------
55
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
66
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)

doc/changes/0.5.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
Version 0.5 (in progress)
44
-------------------------
5+
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)

skglm/experimental/sqrt_lasso.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,13 @@ class SqrtLasso(LinearModel, RegressorMixin):
100100
101101
verbose : bool, default False
102102
Amount of verbosity. 0/False is silent.
103+
104+
fit_intercept: bool, optional (default=True)
105+
Whether or not to fit an intercept.
103106
"""
104107

105108
def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
106-
tol=1e-4, verbose=0):
109+
tol=1e-4, verbose=0, fit_intercept=True):
107110
super().__init__()
108111
self.alpha = alpha
109112
self.max_iter = max_iter
@@ -112,6 +115,7 @@ def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
112115
self.p0 = p0
113116
self.tol = tol
114117
self.verbose = verbose
118+
self.fit_intercept = fit_intercept
115119

116120
def fit(self, X, y):
117121
"""Fit the model according to the given training data.
@@ -131,7 +135,11 @@ def fit(self, X, y):
131135
Fitted estimator.
132136
"""
133137
self.coef_ = self.path(X, y, alphas=[self.alpha])[1][0]
134-
self.intercept_ = 0. # TODO handle fit_intercept
138+
if self.fit_intercept:
139+
self.intercept_ = self.coef_[-1]
140+
self.coef_ = self.coef_[:-1]
141+
else:
142+
self.intercept_ = 0.
135143
return self
136144

137145
def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
@@ -168,7 +176,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
168176
if not hasattr(self, "solver_"):
169177
self.solver_ = ProxNewton(
170178
tol=self.tol, max_iter=self.max_iter, verbose=self.verbose,
171-
fit_intercept=False)
179+
fit_intercept=self.fit_intercept)
172180
# build path
173181
if alphas is None:
174182
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(len(y)) * norm(y))
@@ -181,7 +189,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
181189
sqrt_quadratic = SqrtQuadratic()
182190
l1_penalty = L1(1.) # alpha is set along the path
183191

184-
coefs = np.zeros((n_alphas, n_features))
192+
coefs = np.zeros((n_alphas, n_features + self.fit_intercept))
185193

186194
for i in range(n_alphas):
187195
if self.verbose:
@@ -192,12 +200,14 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
192200

193201
l1_penalty.alpha = alphas[i]
194202
# no warm start for the first alpha
195-
coef_init = coefs[i].copy() if i else np.zeros(n_features)
203+
coef_init = coefs[i].copy() if i else np.zeros(n_features
204+
+ self.fit_intercept)
196205

197206
try:
198207
coef, _, _ = self.solver_.solve(
199208
X, y, sqrt_quadratic, l1_penalty,
200-
w_init=coef_init, Xw_init=X @ coef_init)
209+
w_init=coef_init, Xw_init=X @ coef_init[:-1] + coef_init[-1]
210+
if self.fit_intercept else X @ coef_init)
201211
coefs[i] = coef
202212
except ValueError as val_exception:
203213
# make sure to catch residual error
@@ -208,7 +218,8 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
208218
# save coef despite not converging
209219
# coef_init holds a ref to coef
210220
coef = coef_init
211-
res_norm = norm(y - X @ coef)
221+
X_coef = X @ coef[:-1] + coef[-1] if self.fit_intercept else X @ coef
222+
res_norm = norm(y - X_coef)
212223
warnings.warn(
213224
f"Small residuals prevented the solver from converging "
214225
f"at alpha={alphas[i]:.2e} (residuals' norm: {res_norm:.4e}). "

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
88
_chambolle_pock_sqrt)
99
from skglm.experimental.pdcd_ws import PDCD_WS
10+
from skglm import Lasso
1011

1112

1213
def test_alpha_max():
@@ -16,7 +17,10 @@ def test_alpha_max():
1617

1718
sqrt_lasso = SqrtLasso(alpha=alpha_max).fit(X, y)
1819

19-
np.testing.assert_equal(sqrt_lasso.coef_, 0)
20+
if sqrt_lasso.fit_intercept:
21+
np.testing.assert_equal(sqrt_lasso.coef_[:-1], 0)
22+
else:
23+
np.testing.assert_equal(sqrt_lasso.coef_, 0)
2024

2125

2226
def test_vs_statsmodels():
@@ -31,7 +35,7 @@ def test_vs_statsmodels():
3135
n_alphas = 3
3236
alphas = alpha_max * np.geomspace(1, 1e-2, n_alphas+1)[1:]
3337

34-
sqrt_lasso = SqrtLasso(tol=1e-9)
38+
sqrt_lasso = SqrtLasso(tol=1e-9, fit_intercept=False)
3539
coefs_skglm = sqrt_lasso.path(X, y, alphas)[1]
3640

3741
coefs_statsmodels = np.zeros((len(alphas), n_features))
@@ -54,7 +58,7 @@ def test_prox_newton_cp():
5458

5559
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
5660
alpha = alpha_max / 10
57-
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
61+
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
5862
w, _, _ = _chambolle_pock_sqrt(X, y, alpha, max_iter=1000)
5963
np.testing.assert_allclose(clf.coef_, w)
6064

@@ -73,9 +77,26 @@ def test_PDCD_WS(with_dual_init):
7377
penalty = L1(alpha)
7478

7579
w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
76-
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
80+
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
7781
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
7882

7983

84+
@pytest.mark.parametrize("fit_intercept", [True, False])
85+
def test_lasso_sqrt_lasso_equivalence(fit_intercept):
86+
n_samples, n_features = 50, 10
87+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
88+
89+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
90+
alpha = alpha_max / 10
91+
92+
lasso = Lasso(alpha=alpha, fit_intercept=fit_intercept, tol=1e-8).fit(X, y)
93+
94+
scal = n_samples / norm(y - lasso.predict(X))
95+
sqrt = SqrtLasso(
96+
alpha=alpha * scal, fit_intercept=fit_intercept, tol=1e-8).fit(X, y)
97+
98+
np.testing.assert_allclose(sqrt.coef_, lasso.coef_, rtol=1e-6)
99+
100+
80101
if __name__ == '__main__':
81102
pass

0 commit comments

Comments
 (0)