Skip to content

Commit cca6d48

Browse files
authored
FEAT Add positivity penalty (#126)
1 parent 5762352 commit cca6d48

File tree

3 files changed

+66
-3
lines changed

3 files changed

+66
-3
lines changed

skglm/penalties/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .base import BasePenalty
22
from .separable import (
3-
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox
3+
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
4+
PositiveConstraint
45
)
56
from .block_separable import (
67
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
@@ -12,5 +13,5 @@
1213
__all__ = [
1314
BasePenalty,
1415
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
15-
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
16+
PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
1617
]

skglm/penalties/separable.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,48 @@ def is_penalized(self, n_features):
466466
def generalized_support(self, w):
467467
"""Return a mask with non-zero coefficients."""
468468
return w != 0
469+
470+
471+
class PositiveConstraint(BasePenalty):
472+
"""Positivity constraint penalty."""
473+
474+
def __init__(self):
475+
pass
476+
477+
def get_spec(self):
478+
return ()
479+
480+
def params_to_dict(self):
481+
return dict()
482+
483+
def value(self, w):
484+
"""Compute the value of the PositiveConstraint penalty at w."""
485+
return np.inf if (w < 0).any() else 0.
486+
487+
def prox_1d(self, value, stepsize, j):
488+
"""Compute the proximal operator of the PositiveConstraint."""
489+
return max(0., value)
490+
491+
def subdiff_distance(self, w, grad, ws):
492+
"""Compute distance of negative gradient to the subdifferential at w."""
493+
subdiff_dist = np.zeros_like(grad)
494+
for idx, j in enumerate(ws):
495+
if w[j] == 0:
496+
# distance of - grad_j to ]-infty, 0]
497+
subdiff_dist[idx] = max(0, -grad[idx])
498+
elif w[j] > 0:
499+
# distance of - grad_j to 0
500+
subdiff_dist[idx] = abs(-grad[idx])
501+
else:
502+
# subdiff is empty, distance is infinite
503+
subdiff_dist[idx] = np.inf
504+
505+
return subdiff_dist
506+
507+
def is_penalized(self, n_features):
508+
"""Return a binary mask with the penalized features."""
509+
return np.ones(n_features, bool_)
510+
511+
def generalized_support(self, w):
512+
"""Return a mask with non-zero coefficients."""
513+
return w != 0

skglm/tests/test_penalties.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from numpy.linalg import norm
55
from numpy.testing import assert_array_less
66

7+
from sklearn.linear_model import LinearRegression
8+
79
from skglm.datafits import Quadratic, QuadraticMultiTask
810
from skglm.penalties import (
911
L1, L1_plus_L2, WeightedL1, MCPenalty, SCAD, IndicatorBox, L0_5, L2_3, SLOPE,
10-
L2_1, L2_05, BlockMCPenalty, BlockSCAD)
12+
PositiveConstraint, L2_1, L2_05, BlockMCPenalty, BlockSCAD)
1113
from skglm import GeneralizedLinearEstimator, Lasso
1214
from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA
1315
from skglm.utils.data import make_correlated_data
@@ -101,5 +103,20 @@ def test_slope():
101103
np.testing.assert_allclose(ours.coef_, pyslope_out["beta"], rtol=1e-5)
102104

103105

106+
@pytest.mark.parametrize("fit_intercept", [True, False])
107+
def test_nnls(fit_intercept):
108+
# compare solutions with sklearn's LinearRegression, note that n_samples >=
109+
# n_features for the design matrix to be injective, hence the solution unique
110+
clf = GeneralizedLinearEstimator(
111+
datafit=Quadratic(),
112+
penalty=PositiveConstraint(),
113+
solver=AndersonCD(tol=tol, fit_intercept=fit_intercept),
114+
).fit(X, y)
115+
reg_nnls = LinearRegression(positive=True, fit_intercept=fit_intercept).fit(X, y)
116+
117+
np.testing.assert_allclose(clf.coef_, reg_nnls.coef_)
118+
np.testing.assert_allclose(clf.intercept_, reg_nnls.intercept_)
119+
120+
104121
if __name__ == "__main__":
105122
pass

0 commit comments

Comments
 (0)