Skip to content

Commit aaa1aa3

Browse files
committed
Added test for L1_plus_L2 penalty in SparseLogisticRegression
1 parent 8a89980 commit aaa1aa3

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

skglm/tests/test_estimators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,5 +600,25 @@ def test_GroupLasso_estimator_sparse_vs_dense(positive):
600600
np.testing.assert_allclose(coef_sparse, coef_dense, atol=1e-7, rtol=1e-5)
601601

602602

603+
@pytest.mark.parametrize("X, l1_ratio", product([X, X_sparse], [1., 0.7, 0.]))
604+
def test_SparseLogReg_elasticnet(X, l1_ratio):
605+
606+
estimator_sk = clone(dict_estimators_sk['LogisticRegression'])
607+
estimator_ours = clone(dict_estimators_ours['LogisticRegression'])
608+
estimator_sk.set_params(fit_intercept=False, solver='saga',
609+
penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000)
610+
estimator_ours.set_params(fit_intercept=False, l1_ratio=l1_ratio, max_iter=10_000)
611+
612+
estimator_sk.fit(X, y)
613+
estimator_ours.fit(X, y)
614+
coef_sk = estimator_sk.coef_
615+
coef_ours = estimator_ours.coef_
616+
617+
np.testing.assert_array_less(1e-5, norm(coef_ours))
618+
np.testing.assert_allclose(coef_ours, coef_sk, atol=1e-6)
619+
np.testing.assert_allclose(
620+
estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4)
621+
622+
603623
if __name__ == "__main__":
604624
pass

0 commit comments

Comments
 (0)