Skip to content

Commit 1225970

Browse files
ENH - Adds support for L1 + L2 regularization in SparseLogisticRegression (#278)
Co-authored-by: mathurinm <[email protected]>
1 parent 495333b commit 1225970

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

skglm/estimators.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,12 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
967967
alpha : float, default=1.0
968968
Regularization strength; must be a positive float.
969969
970+
l1_ratio : float, default=1.0
971+
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
972+
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
973+
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
974+
combination of L1 and L2.
975+
970976
tol : float, optional
971977
Stopping criterion for the optimization.
972978
@@ -1003,10 +1009,11 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
10031009
Number of subproblems solved to reach the specified tolerance.
10041010
"""
10051011

1006-
def __init__(self, alpha=1.0, tol=1e-4, max_iter=20, max_epochs=1_000, verbose=0,
1007-
fit_intercept=True, warm_start=False):
1012+
def __init__(self, alpha=1.0, l1_ratio=1.0, tol=1e-4, max_iter=20, max_epochs=1_000,
1013+
verbose=0, fit_intercept=True, warm_start=False):
10081014
super().__init__()
10091015
self.alpha = alpha
1016+
self.l1_ratio = l1_ratio
10101017
self.tol = tol
10111018
self.max_iter = max_iter
10121019
self.max_epochs = max_epochs
@@ -1035,7 +1042,8 @@ def fit(self, X, y):
10351042
max_iter=self.max_iter, max_pn_iter=self.max_epochs, tol=self.tol,
10361043
fit_intercept=self.fit_intercept, warm_start=self.warm_start,
10371044
verbose=self.verbose)
1038-
return _glm_fit(X, y, self, Logistic(), L1(self.alpha), solver)
1045+
return _glm_fit(X, y, self, Logistic(), L1_plus_L2(self.alpha, self.l1_ratio),
1046+
solver)
10391047

10401048
def predict_proba(self, X):
10411049
"""Probability estimates.

skglm/tests/test_estimators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,5 +600,25 @@ def test_GroupLasso_estimator_sparse_vs_dense(positive):
600600
np.testing.assert_allclose(coef_sparse, coef_dense, atol=1e-7, rtol=1e-5)
601601

602602

603+
@pytest.mark.parametrize("X, l1_ratio", product([X, X_sparse], [1., 0.7, 0.]))
604+
def test_SparseLogReg_elasticnet(X, l1_ratio):
605+
606+
estimator_sk = clone(dict_estimators_sk['LogisticRegression'])
607+
estimator_ours = clone(dict_estimators_ours['LogisticRegression'])
608+
estimator_sk.set_params(fit_intercept=True, solver='saga',
609+
penalty='elasticnet', l1_ratio=l1_ratio, max_iter=10_000)
610+
estimator_ours.set_params(fit_intercept=True, l1_ratio=l1_ratio, max_iter=10_000)
611+
612+
estimator_sk.fit(X, y)
613+
estimator_ours.fit(X, y)
614+
coef_sk = estimator_sk.coef_
615+
coef_ours = estimator_ours.coef_
616+
617+
np.testing.assert_array_less(1e-5, norm(coef_ours))
618+
np.testing.assert_allclose(coef_ours, coef_sk, atol=1e-6)
619+
np.testing.assert_allclose(
620+
estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4)
621+
622+
603623
if __name__ == "__main__":
604624
pass

0 commit comments

Comments
 (0)