Skip to content

Commit a549bba

Browse files
Merge branch 'fix_intercept_SqrtLasso' of https://github.com/PascalCarrivain/skglm into fix_intercept_SqrtLasso
2 parents 6f68666 + 1975371 commit a549bba

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_vs_statsmodels():
3131
n_alphas = 3
3232
alphas = alpha_max * np.geomspace(1, 1e-2, n_alphas+1)[1:]
3333

34-
sqrt_lasso = SqrtLasso(tol=1e-9)
34+
sqrt_lasso = SqrtLasso(tol=1e-9, fit_intercept=False)
3535
coefs_skglm = sqrt_lasso.path(X, y, alphas)[1]
3636

3737
coefs_statsmodels = np.zeros((len(alphas), n_features))
@@ -54,7 +54,7 @@ def test_prox_newton_cp():
5454

5555
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
5656
alpha = alpha_max / 10
57-
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
57+
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
5858
w, _, _ = _chambolle_pock_sqrt(X, y, alpha, max_iter=1000)
5959
np.testing.assert_allclose(clf.coef_, w)
6060

@@ -70,7 +70,7 @@ def test_PDCD_WS(with_dual_init):
7070
dual_init = y / norm(y) if with_dual_init else None
7171

7272
w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0]
73-
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
73+
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
7474
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
7575

7676

0 commit comments

Comments
 (0)