Skip to content

Commit bd81345

Browse files
add unit test for sqrt lasso with intercept
1 parent 9eb19ae commit bd81345

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

skglm/experimental/tests/test_sqrt_lasso.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
88
_chambolle_pock_sqrt)
99
from skglm.experimental.pdcd_ws import PDCD_WS
10+
from skglm import Lasso
1011

1112

1213
def test_alpha_max():
@@ -80,5 +81,22 @@ def test_PDCD_WS(with_dual_init):
8081
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
8182

8283

84+
def test_lasso_vs_sqrt_lasso_with_intercept():
85+
86+
n_samples, n_features = 50, 10
87+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
88+
89+
alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
90+
alpha = alpha_max / 10
91+
92+
lasso = Lasso(alpha=alpha, fit_intercept=True, tol=1e-8).fit(X, y)
93+
w_lasso = lasso.coef_
94+
95+
scal = n_samples / norm(y - lasso.predict(X))
96+
sqrt = SqrtLasso(alpha=alpha * scal, fit_intercept=True, tol=1e-8).fit(X, y)
97+
98+
np.testing.assert_allclose(w_lasso, sqrt.coef_, rtol=1e-6)
99+
100+
83101
if __name__ == '__main__':
84102
pass

0 commit comments

Comments
 (0)