Skip to content

Commit cb715d2

Browse files
ENH Add optional positivity constraint in L1, WeightedL1 and L1_plus_L2 (#110)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent d0536dc commit cb715d2

File tree

4 files changed

+107
-49
lines changed

4 files changed

+107
-49
lines changed

skglm/estimators.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,9 @@ class Lasso(LinearModel, RegressorMixin):
314314
tol : float, optional
315315
Stopping criterion for the optimization.
316316
317+
positive : bool, optional
318+
When set to ``True``, forces the coefficient vector to be positive.
319+
317320
fit_intercept : bool, optional (default=True)
318321
Whether or not to fit an intercept.
319322
@@ -345,14 +348,16 @@ class Lasso(LinearModel, RegressorMixin):
345348
"""
346349

347350
def __init__(self, alpha=1., max_iter=50, max_epochs=50_000, p0=10, verbose=0,
348-
tol=1e-4, fit_intercept=True, warm_start=False, ws_strategy="subdiff"):
351+
tol=1e-4, positive=False, fit_intercept=True, warm_start=False,
352+
ws_strategy="subdiff"):
349353
super().__init__()
350354
self.alpha = alpha
351355
self.tol = tol
352356
self.max_iter = max_iter
353357
self.max_epochs = max_epochs
354358
self.p0 = p0
355359
self.ws_strategy = ws_strategy
360+
self.positive = positive
356361
self.fit_intercept = fit_intercept
357362
self.warm_start = warm_start
358363
self.verbose = verbose
@@ -378,7 +383,7 @@ def fit(self, X, y):
378383
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
379384
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
380385
warm_start=self.warm_start, verbose=self.verbose)
381-
return _glm_fit(X, y, self, Quadratic(), L1(self.alpha), solver)
386+
return _glm_fit(X, y, self, Quadratic(), L1(self.alpha, self.positive), solver)
382387

383388
def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
384389
"""Compute Lasso path.
@@ -417,7 +422,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
417422
n_iters : array, shape (n_alphas,), optional
418423
The number of iterations along the path. If return_n_iter is set to `True`.
419424
"""
420-
penalty = compiled_clone(L1(self.alpha))
425+
penalty = compiled_clone(L1(self.alpha, self.positive))
421426
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
422427
solver = AndersonCD(
423428
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
@@ -457,6 +462,9 @@ class WeightedLasso(LinearModel, RegressorMixin):
457462
tol : float, optional
458463
Stopping criterion for the optimization.
459464
465+
positive : bool, optional
466+
When set to ``True``, forces the coefficient vector to be positive.
467+
460468
fit_intercept : bool, optional (default=True)
461469
Whether or not to fit an intercept.
462470
@@ -492,8 +500,8 @@ class WeightedLasso(LinearModel, RegressorMixin):
492500
"""
493501

494502
def __init__(self, alpha=1., weights=None, max_iter=50, max_epochs=50_000, p0=10,
495-
verbose=0, tol=1e-4, fit_intercept=True, warm_start=False,
496-
ws_strategy="subdiff"):
503+
verbose=0, tol=1e-4, positive=False, fit_intercept=True,
504+
warm_start=False, ws_strategy="subdiff"):
497505
super().__init__()
498506
self.alpha = alpha
499507
self.weights = weights
@@ -502,6 +510,7 @@ def __init__(self, alpha=1., weights=None, max_iter=50, max_epochs=50_000, p0=10
502510
self.max_epochs = max_epochs
503511
self.p0 = p0
504512
self.ws_strategy = ws_strategy
513+
self.positive = positive
505514
self.fit_intercept = fit_intercept
506515
self.warm_start = warm_start
507516
self.verbose = verbose
@@ -548,7 +557,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
548557
raise ValueError("The number of weights must match the number of \
549558
features. Got %s, expected %s." % (
550559
len(weights), X.shape[1]))
551-
penalty = compiled_clone(WeightedL1(self.alpha, weights))
560+
penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive))
552561
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
553562
solver = AndersonCD(
554563
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
@@ -574,9 +583,9 @@ def fit(self, X, y):
574583
"""
575584
if self.weights is None:
576585
warnings.warn('Weights are not provided, fitting with Lasso penalty')
577-
penalty = L1(self.alpha)
586+
penalty = L1(self.alpha, self.positive)
578587
else:
579-
penalty = WeightedL1(self.alpha, self.weights)
588+
penalty = WeightedL1(self.alpha, self.weights, self.positive)
580589
solver = AndersonCD(
581590
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
582591
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
@@ -618,6 +627,9 @@ class ElasticNet(LinearModel, RegressorMixin):
618627
tol : float, optional
619628
Stopping criterion for the optimization.
620629
630+
positive : bool, optional
631+
When set to ``True``, forces the coefficient vector to be positive.
632+
621633
fit_intercept : bool, optional (default=True)
622634
Whether or not to fit an intercept.
623635
@@ -648,8 +660,8 @@ class ElasticNet(LinearModel, RegressorMixin):
648660
"""
649661

650662
def __init__(self, alpha=1., l1_ratio=0.5, max_iter=50, max_epochs=50_000, p0=10,
651-
verbose=0, tol=1e-4, fit_intercept=True, warm_start=False,
652-
ws_strategy="subdiff"):
663+
verbose=0, tol=1e-4, positive=False, fit_intercept=True,
664+
warm_start=False, ws_strategy="subdiff"):
653665
super().__init__()
654666
self.alpha = alpha
655667
self.l1_ratio = l1_ratio
@@ -659,6 +671,7 @@ def __init__(self, alpha=1., l1_ratio=0.5, max_iter=50, max_epochs=50_000, p0=10
659671
self.p0 = p0
660672
self.ws_strategy = ws_strategy
661673
self.fit_intercept = fit_intercept
674+
self.positive = positive
662675
self.warm_start = warm_start
663676
self.verbose = verbose
664677

@@ -699,7 +712,7 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
699712
n_iters : array, shape (n_alphas,), optional
700713
The number of iterations along the path. If return_n_iter is set to `True`.
701714
"""
702-
penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio))
715+
penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive))
703716
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
704717
solver = AndersonCD(
705718
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
@@ -728,7 +741,7 @@ def fit(self, X, y):
728741
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
729742
warm_start=self.warm_start, verbose=self.verbose)
730743
return _glm_fit(X, y, self, Quadratic(),
731-
L1_plus_L2(self.alpha, self.l1_ratio), solver)
744+
L1_plus_L2(self.alpha, self.l1_ratio, self.positive), solver)
732745

733746

734747
class MCPRegression(LinearModel, RegressorMixin):

skglm/penalties/separable.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,48 @@
1010
class L1(BasePenalty):
1111
"""L1 penalty."""
1212

13-
def __init__(self, alpha):
13+
def __init__(self, alpha, positive=False):
1414
self.alpha = alpha
15+
self.positive = positive
1516

1617
def get_spec(self):
1718
spec = (
1819
('alpha', float64),
20+
('positive', bool_),
1921
)
2022
return spec
2123

2224
def params_to_dict(self):
23-
return dict(alpha=self.alpha)
25+
return dict(alpha=self.alpha, positive=self.positive)
2426

2527
def value(self, w):
2628
"""Compute L1 penalty value."""
2729
return self.alpha * np.sum(np.abs(w))
2830

2931
def prox_1d(self, value, stepsize, j):
3032
"""Compute proximal operator of the L1 penalty (soft-thresholding operator)."""
31-
return ST(value, self.alpha * stepsize)
33+
return ST(value, self.alpha * stepsize, self.positive)
3234

3335
def subdiff_distance(self, w, grad, ws):
3436
"""Compute distance of negative gradient to the subdifferential at w."""
3537
subdiff_dist = np.zeros_like(grad)
3638
for idx, j in enumerate(ws):
37-
if w[j] == 0:
38-
# distance of - grad_j to [-alpha, alpha]
39-
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
39+
if self.positive:
40+
if w[j] < 0:
41+
subdiff_dist[idx] = np.inf
42+
elif w[j] == 0:
43+
# distance of -grad_j to (-infty, alpha]
44+
subdiff_dist[idx] = max(0, -grad[idx] - self.alpha)
45+
else:
46+
# distance of -grad_j to {alpha}
47+
subdiff_dist[idx] = np.abs(grad[idx] + self.alpha)
4048
else:
41-
# distance of - grad_j to alpha * sign(w[j])
42-
subdiff_dist[idx] = np.abs(
43-
- grad[idx] - np.sign(w[j]) * self.alpha)
49+
if w[j] == 0:
50+
# distance of -grad_j to [-alpha, alpha]
51+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
52+
else:
53+
# distance of -grad_j to {alpha * sign(w[j])}
54+
subdiff_dist[idx] = np.abs(grad[idx] + np.sign(w[j]) * self.alpha)
4455
return subdiff_dist
4556

4657
def is_penalized(self, n_features):
@@ -59,20 +70,21 @@ def alpha_max(self, gradient0):
5970
class L1_plus_L2(BasePenalty):
6071
"""L1 + L2 penalty (aka ElasticNet penalty)."""
6172

62-
def __init__(self, alpha, l1_ratio):
73+
def __init__(self, alpha, l1_ratio, positive=False):
6374
self.alpha = alpha
6475
self.l1_ratio = l1_ratio
76+
self.positive = positive
6577

6678
def get_spec(self):
6779
spec = (
6880
('alpha', float64),
6981
('l1_ratio', float64),
82+
('positive', bool_),
7083
)
7184
return spec
7285

7386
def params_to_dict(self):
74-
return dict(alpha=self.alpha,
75-
l1_ratio=self.l1_ratio)
87+
return dict(alpha=self.alpha, l1_ratio=self.l1_ratio, positive=self.positive)
7688

7789
def value(self, w):
7890
"""Compute the L1 + L2 penalty value."""
@@ -82,25 +94,38 @@ def value(self, w):
8294

8395
def prox_1d(self, value, stepsize, j):
8496
"""Compute the proximal operator (scaled soft-thresholding)."""
85-
prox = ST(value, self.l1_ratio * self.alpha * stepsize)
97+
prox = ST(value, self.l1_ratio * self.alpha * stepsize, self.positive)
8698
prox /= (1 + stepsize * (1 - self.l1_ratio) * self.alpha)
8799
return prox
88100

89101
def subdiff_distance(self, w, grad, ws):
90102
"""Compute distance of negative gradient to the subdifferential at w."""
91103
subdiff_dist = np.zeros_like(grad)
104+
alpha = self.alpha
105+
l1_ratio = self.l1_ratio
106+
92107
for idx, j in enumerate(ws):
93-
if w[j] == 0:
94-
# distance of - grad_j to alpha * l1_ratio * [-1, 1]
95-
subdiff_dist[idx] = max(
96-
0, np.abs(grad[idx]) - self.alpha * self.l1_ratio)
108+
if self.positive:
109+
if w[j] < 0:
110+
subdiff_dist[idx] = np.inf
111+
elif w[j] == 0:
112+
# distance of -grad_j to (-infty, alpha * l1_ratio]
113+
subdiff_dist[idx] = max(0, -grad[idx] - alpha * l1_ratio)
114+
else:
115+
# distance of -grad_j to alpha * {l1_ratio + (1 - l1_ratio) * w[j]}
116+
subdiff_dist[idx] = np.abs(
117+
grad[idx] + alpha * (l1_ratio
118+
+ (1 - l1_ratio) * w[j]))
97119
else:
98-
# distance of - grad_j to alpha * l_1 ratio * sign(w[j]) +
99-
# alpha * (1 - l1_ratio) * w[j]
100-
subdiff_dist[idx] = np.abs(
101-
- grad[idx] -
102-
self.alpha * (self.l1_ratio *
103-
np.sign(w[j]) + (1 - self.l1_ratio) * w[j]))
120+
if w[j] == 0:
121+
# distance of -grad_j to alpha * l1_ratio * [-1, 1]
122+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - alpha * l1_ratio)
123+
else:
124+
# distance of -grad_j to
125+
# {alpha * (l1 ratio * sign(w[j]) + (1 - l1_ratio) * w[j])}
126+
subdiff_dist[idx] = np.abs(
127+
grad[idx] + alpha * (l1_ratio * np.sign(w[j])
128+
+ (1 - l1_ratio) * w[j]))
104129
return subdiff_dist
105130

106131
def is_penalized(self, n_features):
@@ -119,41 +144,54 @@ def alpha_max(self, gradient0):
119144
class WeightedL1(BasePenalty):
120145
"""Weighted L1 penalty."""
121146

122-
def __init__(self, alpha, weights):
147+
def __init__(self, alpha, weights, positive=False):
123148
self.alpha = alpha
124149
self.weights = weights.astype(np.float64)
150+
self.positive = positive
125151

126152
def get_spec(self):
127153
spec = (
128154
('alpha', float64),
129155
('weights', float64[:]),
156+
('positive', bool_),
130157
)
131158
return spec
132159

133160
def params_to_dict(self):
134-
return dict(alpha=self.alpha,
135-
weights=self.weights)
161+
return dict(alpha=self.alpha, weights=self.weights, positive=self.positive)
136162

137163
def value(self, w):
138164
"""Compute the weighted L1 penalty."""
139165
return self.alpha * np.sum(np.abs(w) * self.weights)
140166

141167
def prox_1d(self, value, stepsize, j):
142168
"""Compute the proximal operator of weighted L1 (weighted soft-thresholding)."""
143-
return ST(value, self.alpha * stepsize * self.weights[j])
169+
return ST(value, self.alpha * stepsize * self.weights[j], self.positive)
144170

145171
def subdiff_distance(self, w, grad, ws):
146172
"""Compute distance of negative gradient to the subdifferential at w."""
147173
subdiff_dist = np.zeros_like(grad)
174+
alpha = self.alpha
175+
weights = self.weights
176+
148177
for idx, j in enumerate(ws):
149-
if w[j] == 0:
150-
# distance of - grad_j to alpha * weights[j] * [-1, 1]
151-
subdiff_dist[idx] = max(
152-
0, np.abs(grad[idx]) - self.alpha * self.weights[j])
178+
if self.positive:
179+
if w[j] < 0:
180+
subdiff_dist[idx] = np.inf
181+
elif w[j] == 0:
182+
# distance of -grad_j to (-infty, alpha * weights[j]]
183+
subdiff_dist[idx] = max(0, -grad[idx] - alpha * weights[j])
184+
else:
185+
# distance of -grad_j to {alpha * weights[j]}
186+
subdiff_dist[idx] = np.abs(grad[idx] + alpha * weights[j])
153187
else:
154-
# distance of - grad_j to alpha * weights[j] * sign(w[j])
155-
subdiff_dist[idx] = np.abs(
156-
- grad[idx] - self.alpha * self.weights[j] * np.sign(w[j]))
188+
if w[j] == 0:
189+
# distance of -grad_j to alpha * weights[j] * [-1, 1]
190+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - alpha * weights[j])
191+
else:
192+
# distance of -grad_j to {alpha * weights[j] * sign(w[j])}
193+
subdiff_dist[idx] = np.abs(
194+
grad[idx] + alpha * weights[j] * np.sign(w[j]))
157195
return subdiff_dist
158196

159197
def is_penalized(self, n_features):

skglm/tests/test_estimators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,27 @@ def test_check_estimator(estimator_name):
9999
@pytest.mark.parametrize("estimator_name", dict_estimators_ours.keys())
100100
@pytest.mark.parametrize('X', [X, X_sparse])
101101
@pytest.mark.parametrize('fit_intercept', [True, False])
102-
def test_estimator(estimator_name, X, fit_intercept):
102+
@pytest.mark.parametrize('positive', [True, False])
103+
def test_estimator(estimator_name, X, fit_intercept, positive):
103104
if estimator_name == "GeneralizedLinearEstimator":
104105
pytest.skip()
105106
if fit_intercept and estimator_name == "LogisticRegression":
106107
pytest.xfail("sklearn LogisticRegression does not support intercept.")
107108
if fit_intercept and estimator_name == "SVC":
108109
pytest.xfail("Intercept is not supported for SVC.")
110+
if positive and estimator_name not in ("Lasso", "ElasticNet", "WeightedLasso"):
111+
pytest.xfail("`positive` option is only supported by L1, L1_plus_L2 and wL1.")
109112

110113
estimator_sk = clone(dict_estimators_sk[estimator_name])
111114
estimator_ours = clone(dict_estimators_ours[estimator_name])
112115

113116
estimator_sk.set_params(fit_intercept=fit_intercept)
114117
estimator_ours.set_params(fit_intercept=fit_intercept)
115118

119+
if positive:
120+
estimator_sk.set_params(positive=positive)
121+
estimator_ours.set_params(positive=positive)
122+
116123
estimator_sk.fit(X, y)
117124
estimator_ours.fit(X, y)
118125
coef_sk = estimator_sk.coef_

skglm/utils/prox_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55

66
@njit
7-
def ST(x, u):
7+
def ST(x, u, positive=False):
88
"""Soft-thresholding of scalar x at level u."""
99
if x > u:
1010
return x - u
11-
elif x < - u:
11+
elif x < - u and not positive:
1212
return x + u
1313
else:
1414
return 0.

0 commit comments

Comments
 (0)