Skip to content

Commit 8a89980

Browse files
committed
Added L1 plus L2 regularisation to SparseLogisticRegression
1 parent 495333b commit 8a89980

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

skglm/estimators.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,12 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
967967
alpha : float, default=1.0
968968
Regularization strength; must be a positive float.
969969
970+
l1_ratio : float, default=1.0
971+
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
972+
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
973+
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
974+
combination of L1 and L2.
975+
970976
tol : float, optional
971977
Stopping criterion for the optimization.
972978
@@ -1003,10 +1009,11 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
10031009
Number of subproblems solved to reach the specified tolerance.
10041010
"""
10051011

1006-
def __init__(self, alpha=1.0, tol=1e-4, max_iter=20, max_epochs=1_000, verbose=0,
1007-
fit_intercept=True, warm_start=False):
1012+
def __init__(self, alpha=1.0, l1_ratio=1.0, tol=1e-4, max_iter=20, max_epochs=1_000,
1013+
verbose=0, fit_intercept=True, warm_start=False):
10081014
super().__init__()
10091015
self.alpha = alpha
1016+
self.l1_ratio = l1_ratio
10101017
self.tol = tol
10111018
self.max_iter = max_iter
10121019
self.max_epochs = max_epochs
@@ -1035,7 +1042,8 @@ def fit(self, X, y):
10351042
max_iter=self.max_iter, max_pn_iter=self.max_epochs, tol=self.tol,
10361043
fit_intercept=self.fit_intercept, warm_start=self.warm_start,
10371044
verbose=self.verbose)
1038-
return _glm_fit(X, y, self, Logistic(), L1(self.alpha), solver)
1045+
return _glm_fit(X, y, self, Logistic(), L1_plus_L2(self.alpha, self.l1_ratio),
1046+
solver)
10391047

10401048
def predict_proba(self, X):
10411049
"""Probability estimates.

0 commit comments

Comments
 (0)