Skip to content

Commit c9da9d4

Browse files
committed
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm into graphical_lasso
2 parents 3e3df79 + 41792a0 commit c9da9d4

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

doc/changes/0.4.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Version 0.4 (in progress)
66
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
77
- Check compatibility with datafit and penalty in solver (PR :gh:`137`)
88
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)
9-
9+
- Add support for ElasticNet regularization (`penalty="l1_plus_l2"`) to :ref:`SparseLogisticRegression <skglm.SparseLogisticRegression>` (PR: :gh:`244`)
1010

1111
Version 0.3.1 (2023/12/21)
1212
--------------------------

skglm/estimators.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,19 +969,27 @@ class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstim
969969
970970
The optimization objective for sparse Logistic regression is:
971971
972-
.. math:: 1 / n_"samples" sum_(i=1)^(n_"samples") log(1 + exp(-y_i x_i^T w))
973-
+ alpha ||w||_1
972+
.. math::
973+
1 / n_"samples" \sum_{i=1}^{n_"samples"} log(1 + exp(-y_i x_i^T w))
974+
+ tt"l1_ratio" xx alpha ||w||_1
975+
+ (1 - tt"l1_ratio") xx alpha/2 ||w||_2 ^ 2
976+
977+
By default, ``l1_ratio=1.0`` corresponds to Lasso (pure L1 penalty).
978+
When ``0 < l1_ratio < 1``, the penalty is a convex combination of L1 and L2
979+
(i.e., ElasticNet). ``l1_ratio=0.0`` corresponds to Ridge (pure L2), but note
980+
that pure Ridge is not typically used with this class.
974981
975982
Parameters
976983
----------
977984
alpha : float, default=1.0
978985
Regularization strength; must be a positive float.
979986
980987
l1_ratio : float, default=1.0
981-
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
982-
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
983-
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
984-
combination of L1 and L2.
988+
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``.
989+
Only used when ``penalty="l1_plus_l2"``.
990+
For ``l1_ratio = 0`` the penalty is an L2 penalty.
991+
``For l1_ratio = 1`` it is an L1 penalty.
992+
For ``0 < l1_ratio < 1``, the penalty is a combination of L1 and L2.
985993
986994
tol : float, optional
987995
Stopping criterion for the optimization.

0 commit comments

Comments
 (0)