Skip to content

Commit aabbf91

Browse files
fix add support for intercept in SqrtLasso
1 parent 1b699a3 commit aabbf91

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,13 @@ class SqrtLasso(LinearModel, RegressorMixin):
101101
102102
verbose : bool, default False
103103
Amount of verbosity. 0/False is silent.
104+
105+
fit_intercept: bool, default True
106+
xxx
104107
"""
105108

106109
def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
107-
tol=1e-4, verbose=0):
110+
tol=1e-4, verbose=0, fit_intercept=True):
108111
super().__init__()
109112
self.alpha = alpha
110113
self.max_iter = max_iter
@@ -113,6 +116,7 @@ def __init__(self, alpha=1., max_iter=100, max_pn_iter=100, p0=10,
113116
self.p0 = p0
114117
self.tol = tol
115118
self.verbose = verbose
119+
self.fit_intercept = fit_intercept
116120

117121
def fit(self, X, y):
118122
"""Fit the model according to the given training data.
@@ -132,7 +136,10 @@ def fit(self, X, y):
132136
Fitted estimator.
133137
"""
134138
self.coef_ = self.path(X, y, alphas=[self.alpha])[1][0]
135-
self.intercept_ = 0. # TODO handle fit_intercept
139+
if self.fit_intercept:
140+
self.intercept_ = self.coef_[-1]
141+
else:
142+
self.intercept_ = 0.
136143
return self
137144

138145
def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
@@ -182,7 +189,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
182189
sqrt_quadratic = compiled_clone(SqrtQuadratic())
183190
l1_penalty = compiled_clone(L1(1.)) # alpha is set along the path
184191

185-
coefs = np.zeros((n_alphas, n_features))
192+
coefs = np.zeros((n_alphas, n_features + self.fit_intercept))
186193

187194
for i in range(n_alphas):
188195
if self.verbose:
@@ -193,12 +200,17 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
193200

194201
l1_penalty.alpha = alphas[i]
195202
# no warm start for the first alpha
196-
coef_init = coefs[i].copy() if i else np.zeros(n_features)
203+
coef_init = coefs[i].copy() if i else np.zeros(n_features + self.fit_intercept)
197204

198205
try:
199-
coef, _, _ = self.solver_.solve(
200-
X, y, sqrt_quadratic, l1_penalty,
201-
w_init=coef_init, Xw_init=X @ coef_init)
206+
if self.fit_intercept:
207+
coef, _, _ = self.solver_.solve(
208+
X, y, sqrt_quadratic, l1_penalty,
209+
w_init=coef_init, Xw_init=X @ coef_init)
210+
else:
211+
coef, _, _ = self.solver_.solve(
212+
X, y, sqrt_quadratic, l1_penalty,
213+
w_init=coef_init, Xw_init=X @ coef_init[:-1] + coef_init[-1])
202214
coefs[i] = coef
203215
except ValueError as val_exception:
204216
# make sure to catch residual error
@@ -209,7 +221,10 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
209221
# save coef despite not converging
210222
# coef_init holds a ref to coef
211223
coef = coef_init
212-
res_norm = norm(y - X @ coef)
224+
if self.fit_intercept:
225+
res_norm = norm(y - X @ coef[:-1] - coef[-1])
226+
else:
227+
res_norm = norm(y - X @ coef)
213228
warnings.warn(
214229
f"Small residuals prevented the solver from converging "
215230
f"at alpha={alphas[i]:.2e} (residuals' norm: {res_norm:.4e}). "

0 commit comments

Comments
 (0)