@@ -81,21 +81,21 @@ def test_PDCD_WS(with_dual_init):
8181 np .testing .assert_allclose (clf .coef_ , w , atol = 1e-6 )
8282
8383
84- def test_lasso_vs_sqrt_lasso_with_intercept ():
85-
84+ @ pytest . mark . parametrize ( "fit_intercept" , [ True , False ])
85+ def test_lasso_sqrt_lasso_equivalence ( fit_intercept ):
8686 n_samples , n_features = 50 , 10
8787 X , y , _ = make_correlated_data (n_samples , n_features , random_state = 0 )
8888
8989 alpha_max = norm (X .T @ y , ord = np .inf ) / norm (y )
9090 alpha = alpha_max / 10
9191
92- lasso = Lasso (alpha = alpha , fit_intercept = True , tol = 1e-8 ).fit (X , y )
93- w_lasso = lasso .coef_
92+ lasso = Lasso (alpha = alpha , fit_intercept = fit_intercept , tol = 1e-8 ).fit (X , y )
9493
9594 scal = n_samples / norm (y - lasso .predict (X ))
96- sqrt = SqrtLasso (alpha = alpha * scal , fit_intercept = True , tol = 1e-8 ).fit (X , y )
95+ sqrt = SqrtLasso (
96+ alpha = alpha * scal , fit_intercept = fit_intercept , tol = 1e-8 ).fit (X , y )
9797
98- np .testing .assert_allclose (w_lasso , sqrt .coef_ , rtol = 1e-6 )
98+ np .testing .assert_allclose (sqrt . coef_ , lasso .coef_ , rtol = 1e-6 )
9999
100100
101101if __name__ == '__main__' :
0 commit comments