Skip to content

Commit d0ad9a4

Browse files
committed
fiw what's new, slight changes UT
1 parent 0290bd2 commit d0ad9a4

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

doc/changes/0.5.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
Version 0.5 (in progress)
44
-------------------------
5-
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso> (PR: :gh:`298`)
5+
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def test_PDCD_WS(with_dual_init):
8181
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
8282

8383

84-
def test_lasso_vs_sqrt_lasso_with_intercept():
85-
84+
@pytest.mark.parametrize("fit_intercept", [True, False])
85+
def test_lasso_sqrt_lasso_equivalence(fit_intercept):
8686
n_samples, n_features = 50, 10
8787
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
8888

8989
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
9090
alpha = alpha_max / 10
9191

92-
lasso = Lasso(alpha=alpha, fit_intercept=True, tol=1e-8).fit(X, y)
93-
w_lasso = lasso.coef_
92+
lasso = Lasso(alpha=alpha, fit_intercept=fit_intercept, tol=1e-8).fit(X, y)
9493

9594
scal = n_samples / norm(y - lasso.predict(X))
96-
sqrt = SqrtLasso(alpha=alpha * scal, fit_intercept=True, tol=1e-8).fit(X, y)
95+
sqrt = SqrtLasso(
96+
alpha=alpha * scal, fit_intercept=fit_intercept, tol=1e-8).fit(X, y)
9797

98-
np.testing.assert_allclose(w_lasso, sqrt.coef_, rtol=1e-6)
98+
np.testing.assert_allclose(sqrt.coef_, lasso.coef_, rtol=1e-6)
9999

100100

101101
if __name__ == '__main__':

0 commit comments

Comments
 (0)