Skip to content

Commit c38c00d

Browse files
add fit_intercept to test_alpha_max and to self.solver_
1 parent cdc21ea commit c38c00d

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

skglm/experimental/sqrt_lasso.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
176176
if not hasattr(self, "solver_"):
177177
self.solver_ = ProxNewton(
178178
tol=self.tol, max_iter=self.max_iter, verbose=self.verbose,
179-
fit_intercept=False)
179+
fit_intercept=self.fit_intercept)
180180
# build path
181181
if alphas is None:
182182
alpha_max = norm(X.T @ y, ord=np.inf) / (np.sqrt(len(y)) * norm(y))

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def test_alpha_max():
1616

1717
sqrt_lasso = SqrtLasso(alpha=alpha_max).fit(X, y)
1818

19-
np.testing.assert_equal(sqrt_lasso.coef_, 0)
19+
if sqrt_lasso.fit_intercept:
20+
np.testing.assert_equal(sqrt_lasso.coef_[:-1], 0)
21+
else:
22+
np.testing.assert_equal(sqrt_lasso.coef_, 0)
2023

2124

2225
def test_vs_statsmodels():

0 commit comments

Comments
 (0)