Skip to content

Commit 6f68666

Browse files
fix inverting the cases (fit_intercept)
1 parent a44c11f commit 6f68666

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,11 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
207207
if self.fit_intercept:
208208
coef, _, _ = self.solver_.solve(
209209
X, y, sqrt_quadratic, l1_penalty,
210-
w_init=coef_init, Xw_init=X @ coef_init)
210+
w_init=coef_init, Xw_init=X @ coef_init[:-1] + coef_init[-1])
211211
else:
212212
coef, _, _ = self.solver_.solve(
213213
X, y, sqrt_quadratic, l1_penalty,
214-
w_init=coef_init, Xw_init=X @ coef_init[:-1] + coef_init[-1])
214+
w_init=coef_init, Xw_init=X @ coef_init)
215215
coefs[i] = coef
216216
except ValueError as val_exception:
217217
# make sure to catch residual error
@@ -222,10 +222,8 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
222222
# save coef despite not converging
223223
# coef_init holds a ref to coef
224224
coef = coef_init
225-
if self.fit_intercept:
226-
res_norm = norm(y - X @ coef[:-1] - coef[-1])
227-
else:
228-
res_norm = norm(y - X @ coef)
225+
X_coef = X @ coef[:-1] + coef[-1] if self.fit_intercept else X @ coef
226+
res_norm = norm(y - X_coeff)
229227
warnings.warn(
230228
f"Small residuals prevented the solver from converging "
231229
f"at alpha={alphas[i]:.2e} (residuals' norm: {res_norm:.4e}). "

0 commit comments

Comments
 (0)