Skip to content

Commit a9f88f6

Browse files
QB3Badr-MOUFAD
andauthored
ENH - Add weights and positivity constraint to MCP (#184)
Co-authored-by: Badr MOUFAD <[email protected]> Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 562f42b commit a9f88f6

File tree

8 files changed

+184
-39
lines changed

8 files changed

+184
-39
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ Penalties
4444
PositiveConstraint
4545
WeightedL1
4646
WeightedGroupL2
47+
WeightedMCPenalty
4748
SCAD
4849
BlockSCAD
4950
SLOPE

doc/changes/0.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
.. _changes_0_4:
2+
3+
Version 0.4 (in progress)
4+
---------------------------
5+
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)

doc/changes/whats_new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ What's new
55

66
.. currentmodule:: skglm
77

8+
.. include:: 0.4.rst
9+
810
.. include:: 0.3.rst
911

1012
.. include:: 0.2.rst
1113

1214
.. include:: 0.1.rst
13-

skglm/estimators.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from skglm.solvers import AndersonCD, MultiTaskBCD
2323
from skglm.datafits import Cox, Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask
2424
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2,
25-
MCPenalty, IndicatorBox, L2_1)
25+
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
2626

2727

2828
def _glm_fit(X, y, model, datafit, penalty, solver):
@@ -792,6 +792,10 @@ class MCPRegression(LinearModel, RegressorMixin):
792792
If ``gamma = np.inf`` it is a soft thresholding.
793793
Should be larger than (or equal to) 1.
794794
795+
weights : array, shape (n_features,), optional (default=None)
796+
Positive weights used in the L1 penalty part of the Lasso
797+
objective. If ``None``, weights equal to 1 are used.
798+
795799
max_iter : int, optional
796800
The maximum number of iterations (subproblem definitions).
797801
@@ -807,6 +811,9 @@ class MCPRegression(LinearModel, RegressorMixin):
807811
tol : float, optional
808812
Stopping criterion for the optimization.
809813
814+
positive : bool, optional
815+
When set to ``True``, forces the coefficient vector to be positive.
816+
810817
fit_intercept : bool, optional (default=True)
811818
Whether or not to fit an intercept.
812819
@@ -836,20 +843,22 @@ class MCPRegression(LinearModel, RegressorMixin):
836843
Lasso : Lasso regularization.
837844
"""
838845

839-
def __init__(self, alpha=1., gamma=3, max_iter=50, max_epochs=50_000, p0=10,
840-
verbose=0, tol=1e-4, fit_intercept=True, warm_start=False,
841-
ws_strategy="subdiff"):
846+
def __init__(self, alpha=1., gamma=3, weights=None, max_iter=50, max_epochs=50_000,
847+
p0=10, verbose=0, tol=1e-4, positive=False, fit_intercept=True,
848+
warm_start=False, ws_strategy="subdiff"):
842849
super().__init__()
843850
self.alpha = alpha
844851
self.gamma = gamma
845-
self.tol = tol
852+
self.weights = weights
846853
self.max_iter = max_iter
847854
self.max_epochs = max_epochs
848855
self.p0 = p0
849-
self.ws_strategy = ws_strategy
856+
self.verbose = verbose
857+
self.tol = tol
858+
self.positive = positive
850859
self.fit_intercept = fit_intercept
851860
self.warm_start = warm_start
852-
self.verbose = verbose
861+
self.ws_strategy = ws_strategy
853862

854863
def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
855864
"""Compute MCPRegression path.
@@ -890,7 +899,19 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
890899
The number of iterations along the path. If return_n_iter is set to
891900
``True``.
892901
"""
893-
penalty = compiled_clone(MCPenalty(self.alpha, self.gamma))
902+
if self.weights is None:
903+
penalty = compiled_clone(
904+
MCPenalty(self.alpha, self.gamma, self.positive)
905+
)
906+
else:
907+
if X.shape[1] != len(self.weights):
908+
raise ValueError(
909+
"The number of weights must match the number of features. "
910+
f"Got {len(self.weights)}, expected {X.shape[1]}."
911+
)
912+
penalty = compiled_clone(
913+
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
914+
)
894915
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
895916
solver = AndersonCD(
896917
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
@@ -914,12 +935,21 @@ def fit(self, X, y):
914935
self :
915936
Fitted estimator.
916937
"""
938+
if self.weights is None:
939+
penalty = MCPenalty(self.alpha, self.gamma, self.positive)
940+
else:
941+
if X.shape[1] != len(self.weights):
942+
raise ValueError(
943+
"The number of weights must match the number of features. "
944+
f"Got {len(self.weights)}, expected {X.shape[1]}."
945+
)
946+
penalty = WeightedMCPenalty(
947+
self.alpha, self.gamma, self.weights, self.positive)
917948
solver = AndersonCD(
918949
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
919950
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
920951
warm_start=self.warm_start, verbose=self.verbose)
921-
return _glm_fit(X, y, self, Quadratic(), MCPenalty(self.alpha, self.gamma),
922-
solver)
952+
return _glm_fit(X, y, self, Quadratic(), penalty, solver)
923953

924954

925955
class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):

skglm/penalties/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .base import BasePenalty
22
from .separable import (
3-
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
4-
PositiveConstraint
3+
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD,
4+
WeightedL1, IndicatorBox, PositiveConstraint
55
)
66
from .block_separable import (
77
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
@@ -12,6 +12,7 @@
1212

1313
__all__ = [
1414
BasePenalty,
15-
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
16-
PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
15+
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD, WeightedL1,
16+
IndicatorBox, PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD,
17+
WeightedGroupL2, SLOPE
1718
]

skglm/penalties/separable.py

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from skglm.penalties.base import BasePenalty
66
from skglm.utils.prox_funcs import (
7-
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
7+
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP,
8+
value_MCP, value_weighted_MCP)
89

910

1011
class L1(BasePenalty):
@@ -216,48 +217,57 @@ class MCPenalty(BasePenalty):
216217
With :math:`x >= 0`:
217218
218219
.. math::
219-
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x =< alpha gamma),
220+
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
220221
(gamma alpha^2 / 2 , if x > alpha gamma):}
221222
.. math::
222223
"value" = sum_(j=1)^(n_"features") "pen"(abs(w_j))
223224
"""
224225

225-
def __init__(self, alpha, gamma):
226+
def __init__(self, alpha, gamma, positive=False):
226227
self.alpha = alpha
227228
self.gamma = gamma
229+
self.positive = positive
228230

229231
def get_spec(self):
230232
spec = (
231233
('alpha', float64),
232234
('gamma', float64),
235+
('positive', bool_)
233236
)
234237
return spec
235238

236239
def params_to_dict(self):
237240
return dict(alpha=self.alpha,
238-
gamma=self.gamma)
241+
gamma=self.gamma,
242+
positive=self.positive)
239243

240244
def value(self, w):
241245
return value_MCP(w, self.alpha, self.gamma)
242246

243247
def prox_1d(self, value, stepsize, j):
244248
"""Compute the proximal operator of MCP."""
245-
return prox_MCP(value, stepsize, self.alpha, self.gamma)
249+
return prox_MCP(value, stepsize, self.alpha, self.gamma, self.positive)
246250

247251
def subdiff_distance(self, w, grad, ws):
248252
"""Compute distance of negative gradient to the subdifferential at w."""
249253
subdiff_dist = np.zeros_like(grad)
250254
for idx, j in enumerate(ws):
251-
if w[j] == 0:
252-
# distance of -grad to alpha * [-1, 1]
253-
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
254-
elif np.abs(w[j]) < self.alpha * self.gamma:
255-
# distance of -grad_j to (alpha * sign(w[j]) - w[j] / gamma)
256-
subdiff_dist[idx] = np.abs(
257-
grad[idx] + self.alpha * np.sign(w[j]) - w[j] / self.gamma)
255+
if self.positive and w[j] < 0:
256+
subdiff_dist[idx] = np.inf
257+
elif self.positive and w[j] == 0:
258+
# distance of -grad to (-infty, alpha]
259+
subdiff_dist[idx] = max(0, - grad[idx] - self.alpha)
258260
else:
259-
# distance of grad to 0
260-
subdiff_dist[idx] = np.abs(grad[idx])
261+
if w[j] == 0:
262+
# distance of -grad to [-alpha, alpha]
263+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
264+
elif np.abs(w[j]) < self.alpha * self.gamma:
265+
# distance of -grad to {alpha * sign(w[j]) - w[j] / gamma}
266+
subdiff_dist[idx] = np.abs(
267+
grad[idx] + self.alpha * np.sign(w[j]) - w[j] / self.gamma)
268+
else:
269+
# distance of grad to 0
270+
subdiff_dist[idx] = np.abs(grad[idx])
261271
return subdiff_dist
262272

263273
def is_penalized(self, n_features):
@@ -273,6 +283,89 @@ def alpha_max(self, gradient0):
273283
return np.max(np.abs(gradient0))
274284

275285

286+
class WeightedMCPenalty(BasePenalty):
287+
"""Weighted Minimax Concave Penalty (MCP), a non-convex sparse penalty.
288+
289+
Notes
290+
-----
291+
With :math:`x >= 0`:
292+
293+
.. math::
294+
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
295+
(gamma alpha^2 / 2 , if x > alpha gamma):}
296+
.. math::
297+
"value" = sum_(j=1)^(n_"features") "weights"_j xx "pen"(abs(w_j))
298+
"""
299+
300+
def __init__(self, alpha, gamma, weights, positive=False):
301+
self.alpha = alpha
302+
self.gamma = gamma
303+
self.weights = weights.astype(np.float64)
304+
self.positive = positive
305+
306+
def get_spec(self):
307+
spec = (
308+
('alpha', float64),
309+
('gamma', float64),
310+
('weights', float64[:]),
311+
('positive', bool_)
312+
)
313+
return spec
314+
315+
def params_to_dict(self):
316+
return dict(alpha=self.alpha,
317+
gamma=self.gamma,
318+
weights=self.weights,
319+
positive=self.positive)
320+
321+
def value(self, w):
322+
return value_weighted_MCP(w, self.alpha, self.gamma, self.weights)
323+
324+
def prox_1d(self, value, stepsize, j):
325+
"""Compute the proximal operator of the weighted MCP."""
326+
return prox_MCP(
327+
value, stepsize, self.alpha, self.gamma, self.positive, self.weights[j])
328+
329+
def subdiff_distance(self, w, grad, ws):
330+
"""Compute distance of negative gradient to the subdifferential at w."""
331+
subdiff_dist = np.zeros_like(grad)
332+
for idx, j in enumerate(ws):
333+
if self.positive and w[j] < 0:
334+
subdiff_dist[idx] = np.inf
335+
elif self.positive and w[j] == 0:
336+
# distance of -grad to (-infty, alpha * weights[j]]
337+
subdiff_dist[idx] = max(
338+
0, - grad[idx] - self.alpha * self.weights[j])
339+
else:
340+
if w[j] == 0:
341+
# distance of -grad to weights[j] * [-alpha, alpha]
342+
subdiff_dist[idx] = max(
343+
0, np.abs(grad[idx]) - self.alpha * self.weights[j])
344+
elif np.abs(w[j]) < self.alpha * self.gamma:
345+
# distance of -grad to
346+
# {weights[j] * alpha * sign(w[j]) - w[j] / gamma}
347+
subdiff_dist[idx] = np.abs(
348+
grad[idx] + self.alpha * self.weights[j] * np.sign(w[j])
349+
- self.weights[j] * w[j] / self.gamma)
350+
else:
351+
# distance of grad to 0
352+
subdiff_dist[idx] = np.abs(grad[idx])
353+
return subdiff_dist
354+
355+
def is_penalized(self, n_features):
356+
"""Return a binary mask with the penalized features."""
357+
return np.ones(n_features, bool_)
358+
359+
def generalized_support(self, w):
360+
"""Return a mask with non-zero coefficients."""
361+
return w != 0
362+
363+
def alpha_max(self, gradient0):
364+
"""Return penalization value for which 0 is solution."""
365+
nnz_weights = self.weights != 0
366+
return np.max(np.abs(gradient0[nnz_weights] / self.weights[nnz_weights]))
367+
368+
276369
class SCAD(BasePenalty):
277370
r"""Smoothly Clipped Absolute Deviation.
278371

skglm/tests/test_estimators.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@
7474
dict_estimators_ours["MCP"] = MCPRegression(
7575
alpha=alpha, gamma=np.inf, tol=tol)
7676

77+
dict_estimators_sk["wMCP"] = Lasso_sklearn(
78+
alpha=alpha, tol=tol)
79+
dict_estimators_ours["wMCP"] = MCPRegression(
80+
alpha=alpha, gamma=np.inf, tol=tol, weights=np.ones(n_features))
81+
7782
dict_estimators_sk["LogisticRegression"] = LogReg_sklearn(
7883
C=1/(alpha * n_samples), tol=tol, penalty='l1',
7984
solver='liblinear')
@@ -88,7 +93,7 @@
8893

8994
@pytest.mark.parametrize(
9095
"estimator_name",
91-
["Lasso", "wLasso", "ElasticNet", "MCP", "LogisticRegression", "SVC"])
96+
["Lasso", "wLasso", "ElasticNet", "MCP", "wMCP", "LogisticRegression", "SVC"])
9297
def test_check_estimator(estimator_name):
9398
if estimator_name == "SVC":
9499
pytest.xfail("SVC check_estimator is too slow due to bug.")
@@ -97,7 +102,7 @@ def test_check_estimator(estimator_name):
97102
pytest.xfail("ProxNewton does not yet support intercept fitting")
98103
clf = clone(dict_estimators_ours[estimator_name])
99104
clf.tol = 1e-6 # failure in float32 computation otherwise
100-
if isinstance(clf, WeightedLasso):
105+
if isinstance(clf, (WeightedLasso, MCPRegression)):
101106
clf.weights = None
102107
check_estimator(clf)
103108

@@ -113,7 +118,8 @@ def test_estimator(estimator_name, X, fit_intercept, positive):
113118
pytest.xfail("sklearn LogisticRegression does not support intercept.")
114119
if fit_intercept and estimator_name == "SVC":
115120
pytest.xfail("Intercept is not supported for SVC.")
116-
if positive and estimator_name not in ("Lasso", "ElasticNet", "WeightedLasso"):
121+
if positive and estimator_name not in (
122+
"Lasso", "ElasticNet", "wLasso", "MCP", "wMCP"):
117123
pytest.xfail("`positive` option is only supported by L1, L1_plus_L2 and wL1.")
118124

119125
estimator_sk = clone(dict_estimators_sk[estimator_name])

skglm/utils/prox_funcs.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,23 @@ def value_MCP(w, alpha, gamma):
6060

6161

6262
@njit
63-
def prox_MCP(value, stepsize, alpha, gamma):
64-
"""Compute the proximal operator of stepsize * MCP penalty."""
65-
tau = alpha * stepsize
66-
g = gamma / stepsize # what does g stand for ?
67-
if np.abs(value) <= tau:
63+
def value_weighted_MCP(w, alpha, gamma, weights):
64+
"""Compute the value of the weighted MCP."""
65+
s0 = np.abs(w) < gamma * alpha
66+
value = np.full_like(w, gamma * alpha ** 2 / 2.)
67+
value[s0] = alpha * np.abs(w[s0]) - w[s0]**2 / (2 * gamma)
68+
return np.sum(weights * value)
69+
70+
71+
@njit
72+
def prox_MCP(value, stepsize, alpha, gamma, positive=False, weight=1.):
73+
"""Compute the proximal operator of stepsize * weight MCP penalty."""
74+
wstepsize = weight * stepsize # weighted stepsize
75+
if (np.abs(value) <= alpha * wstepsize) or (positive and value <= 0.):
6876
return 0.
69-
if np.abs(value) > g * tau:
77+
if np.abs(value) > alpha * gamma:
7078
return value
71-
return np.sign(value) * (np.abs(value) - tau) / (1. - 1./g)
79+
return np.sign(value) * (np.abs(value) - alpha * wstepsize) / (1. - wstepsize/gamma)
7280

7381

7482
@njit

0 commit comments

Comments
 (0)