Skip to content

Commit 8a581b0

Browse files
PABannierKlopfemathurinmQB3
authored
Add SCAD and BlockSCAD penalties
Co-authored-by: Klopfe <[email protected]> Co-authored-by: mathurinm <[email protected]> Co-authored-by: mathurinm <[email protected]> Co-authored-by: QB3 <[email protected]>
1 parent 9b7f3ef commit 8a581b0

File tree

7 files changed

+282
-46
lines changed

7 files changed

+282
-46
lines changed

examples/plot_pen_prox.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99
import numpy as np
1010
import matplotlib.pyplot as plt
1111

12-
from skglm.penalties import WeightedL1, L1, L1_plus_L2, MCPenalty
12+
from skglm.penalties import WeightedL1, L1, L1_plus_L2, MCPenalty, SCAD, L0_5, L2_3
1313

1414

15-
penalties = [WeightedL1(alpha=1, weights=np.array([2.])),
16-
L1(alpha=1),
17-
L1_plus_L2(alpha=1, l1_ratio=0.7),
18-
MCPenalty(alpha=1, gamma=3.),
19-
]
15+
penalties = [
16+
WeightedL1(alpha=1, weights=np.array([2.])),
17+
L1(alpha=1),
18+
L1_plus_L2(alpha=1, l1_ratio=0.7),
19+
MCPenalty(alpha=1, gamma=3.),
20+
SCAD(alpha=1, gamma=3.),
21+
L0_5(alpha=1),
22+
L2_3(alpha=1),
23+
]
2024

2125

22-
x_range = np.linspace(-5, 5, num=300)
26+
x_range = np.linspace(-4, 4, num=300)
2327

2428
fig, axarr = plt.subplots(1, 2, figsize=(8, 3), constrained_layout=True)
2529

@@ -32,7 +36,6 @@
3236
label=pen.__class__.__name__)
3337

3438
axarr[0].legend()
35-
axarr[1].legend()
3639
axarr[0].set_title("Penalty value")
3740
axarr[1].set_title("Proximal operator of penalty")
3841
plt.show(block=False)

examples/plot_sparse_recovery.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from skglm.utils import make_correlated_data
1919
from skglm.solvers import cd_solver_path
2020
from skglm.datafits import Quadratic
21-
from skglm.penalties import L1, MCPenalty, L0_5, L2_3
21+
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD
2222

2323
cmap = plt.get_cmap('tab10')
2424

@@ -51,14 +51,16 @@
5151
penalties = {}
5252
penalties['lasso'] = L1(alpha=1)
5353
penalties['mcp'] = MCPenalty(alpha=1, gamma=3)
54+
penalties['scad'] = SCAD(alpha=1, gamma=3)
5455
penalties['l05'] = L0_5(alpha=1)
5556
penalties['l23'] = L2_3(alpha=1)
5657

5758
colors = {}
5859
colors['lasso'] = cmap(0)
5960
colors['mcp'] = cmap(1)
60-
colors['l05'] = cmap(2)
61-
colors['l23'] = cmap(3)
61+
colors['scad'] = cmap(2)
62+
colors['l05'] = cmap(3)
63+
colors['l23'] = cmap(4)
6264

6365
f1 = {}
6466
estimation_error = {}
@@ -83,12 +85,14 @@
8385

8486
name_estimators = {'lasso': "Lasso"}
8587
name_estimators['mcp'] = r"MCP, $\gamma=%s$" % 3
88+
name_estimators['scad'] = r"SCAD, $\gamma=%s$" % 3
8689
name_estimators['l05'] = r"$\ell_{1/2}$"
8790
name_estimators['l23'] = r"$\ell_{2/3}$"
8891

8992

9093
plt.close('all')
91-
fig, axarr = plt.subplots(2, 1, sharex=True, sharey=False, figsize=[8.2, 5.7])
94+
fig, axarr = plt.subplots(2, 1, sharex=True, sharey=False, figsize=[
95+
6.3, 3.7], constrained_layout=True)
9296

9397
for idx, estimator in enumerate(penalties.keys()):
9498

@@ -127,5 +131,6 @@
127131
axarr[1].set_ylabel("pred. RMSE left-out")
128132
axarr[0].legend(
129133
bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left",
130-
mode="expand", borderaxespad=0, ncol=4)
131-
plt.show(block=False)
134+
mode="expand", borderaxespad=0, ncol=1)
135+
136+
plt.show(block=False)

skglm/penalties/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .base import BasePenalty # noqa F401
22

33
from .separable import ( # noqa F401
4-
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, WeightedL1, IndicatorBox, BasePenalty
4+
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox, BasePenalty
55
)
66

7-
from .block_separable import L2_05, L2_1, BlockMCPenalty, WeightedGroupL2 # noqa F401
7+
from .block_separable import ( # noqa F401
8+
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
9+
)

skglm/penalties/block_separable.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from numba.types import bool_
77

88
from skglm.penalties.base import BasePenalty
9-
from skglm.utils import BST, prox_block_2_05
9+
from skglm.utils import (
10+
BST, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
1011

1112

1213
spec_L21 = [
@@ -117,21 +118,13 @@ def __init__(self, alpha, gamma):
117118
def value(self, W):
118119
"""Compute the value of BlockMCP at W."""
119120
norm_rows = np.sqrt(np.sum(W ** 2, axis=1))
120-
s0 = norm_rows < self.gamma * self.alpha
121-
value = np.full_like(norm_rows, self.gamma * self.alpha ** 2 / 2.)
122-
value[s0] = self.alpha * norm_rows[s0] - norm_rows[s0]**2 / (2 * self.gamma)
123-
return np.sum(value)
121+
return value_MCP(norm_rows, self.alpha, self.gamma)
124122

125123
def prox_1feat(self, value, stepsize, j):
126124
"""Compute the proximal operator of BlockMCP."""
127-
tau = self.alpha * stepsize
128-
g = self.gamma / stepsize
129-
norm_value = norm(value)
130-
if norm_value <= tau:
131-
return np.zeros_like(value)
132-
if norm_value > g * tau:
133-
return value
134-
return (1 - tau / norm_value) * value / (1. - 1./g)
125+
norm_rows = norm(value)
126+
prox = prox_MCP(norm_rows, stepsize, self.alpha, self.gamma)
127+
return prox * value / norm_rows
135128

136129
def subdiff_distance(self, W, grad, ws):
137130
"""Compute distance of negative gradient to the subdifferential at W."""
@@ -156,6 +149,68 @@ def is_penalized(self, n_features):
156149
return np.ones(n_features, bool_)
157150

158151

152+
spec_BlockSCAD = [
153+
('alpha', float64),
154+
('gamma', float64),
155+
]
156+
157+
158+
@jitclass(spec_BlockSCAD)
159+
class BlockSCAD(BasePenalty):
160+
"""Block Smoothly Clipped Absolute Deviation.
161+
162+
Notes
163+
-----
164+
With W_j the j-th row of W, the penalty is:
165+
pen(||W_j||) = alpha * ||W_j|| if ||W_j|| =< alpha
166+
(2 * gamma * alpha * ||W_j|| - ||W_j|| ** 2 - alpha ** 2) \
167+
/ (2 * (gamma - 1)) if alpha < ||W_j|| < alpha * gamma
168+
(alpha **2 * (gamma + 1)) / 2 if ||W_j|| > gamma * alpha
169+
value = sum_{j=1}^{n_features} pen(||W_j||)
170+
"""
171+
172+
def __init__(self, alpha, gamma):
173+
self.alpha = alpha
174+
self.gamma = gamma
175+
176+
def value(self, W):
177+
"""Compute the value of the SCAD penalty at W."""
178+
norm_rows = np.sqrt(np.sum(W ** 2, axis=1))
179+
return value_SCAD(norm_rows, self.alpha, self.gamma)
180+
181+
def prox_1feat(self, value, stepsize, j):
182+
"""Compute the proximal operator of BlockSCAD."""
183+
norm_value = norm(value)
184+
prox = prox_SCAD(norm_value, stepsize, self.alpha, self.gamma)
185+
return prox * value / norm_value
186+
187+
def subdiff_distance(self, W, grad, ws):
188+
"""Compute distance of negative gradient to the subdifferential at W."""
189+
subdiff_dist = np.zeros_like(ws, dtype=grad.dtype)
190+
for idx, j in enumerate(ws):
191+
norm_Wj = norm(W[j])
192+
if not np.any(W[j]):
193+
# distance of -grad_j to alpha * unit_ball
194+
subdiff_dist[idx] = max(0, norm(grad[idx]) - self.alpha)
195+
elif norm_Wj <= self.alpha:
196+
# distance of -grad_j to alpha * W[j] / ||W[j]||
197+
subdiff_dist[idx] = norm(grad[idx] + self.alpha * W[j] / norm_Wj)
198+
elif norm_Wj <= self.gamma * self.alpha:
199+
# distance of -grad_j to (alpha * gamma - ||W[j]||)
200+
# / ((gamma - 1) * ||W[j]||) * W[j]
201+
subdiff_dist[idx] = norm(grad[idx] + (
202+
(self.alpha * self.gamma - norm_Wj) / (norm_Wj * (self.gamma - 1))
203+
) * W[j])
204+
else:
205+
# distance of -grad_j to 0
206+
subdiff_dist[idx] = norm(grad[idx])
207+
return subdiff_dist
208+
209+
def is_penalized(self, n_features):
210+
"""Return a binary mask with the penalized features."""
211+
return np.ones(n_features, bool_)
212+
213+
159214
spec_WeightedGroupL2 = [
160215
('alpha', float64),
161216
('weights', float64[:]),

skglm/penalties/separable.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from numba.types import bool_
55

66
from skglm.penalties.base import BasePenalty
7-
from skglm.utils import ST, box_proj, prox_05, prox_2_3
7+
from skglm.utils import (
8+
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
89

910

1011
spec_L1 = [
@@ -174,7 +175,7 @@ class MCPenalty(BasePenalty):
174175
With x >= 0
175176
pen(x) =
176177
alpha * x - x^2 / (2 * gamma) if x =< gamma * alpha
177-
gamma * alpha 2 / 2 if x > gamma * alpha
178+
gamma * alpha^2 / 2 if x > gamma * alpha
178179
value = sum_{j=1}^{n_features} pen(abs(w_j))
179180
"""
180181

@@ -183,21 +184,11 @@ def __init__(self, alpha, gamma):
183184
self.gamma = gamma
184185

185186
def value(self, w):
186-
"""Compute the value of MCP."""
187-
s0 = np.abs(w) < self.gamma * self.alpha
188-
value = np.full_like(w, self.gamma * self.alpha ** 2 / 2.)
189-
value[s0] = self.alpha * np.abs(w[s0]) - w[s0]**2 / (2 * self.gamma)
190-
return np.sum(value)
187+
return value_MCP(w, self.alpha, self.gamma)
191188

192189
def prox_1d(self, value, stepsize, j):
193190
"""Compute the proximal operator of MCP."""
194-
tau = self.alpha * stepsize
195-
g = self.gamma / stepsize # what does g stand for ?
196-
if np.abs(value) <= tau:
197-
return 0.
198-
if np.abs(value) > g * tau:
199-
return value
200-
return np.sign(value) * (np.abs(value) - tau) / (1. - 1./g)
191+
return prox_MCP(value, stepsize, self.alpha, self.gamma)
201192

202193
def subdiff_distance(self, w, grad, ws):
203194
"""Compute distance of negative gradient to the subdifferential at w."""
@@ -207,10 +198,9 @@ def subdiff_distance(self, w, grad, ws):
207198
# distance of -grad to alpha * [-1, 1]
208199
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
209200
elif np.abs(w[j]) < self.alpha * self.gamma:
210-
# distance of -grad_j to (alpha - abs(w[j])/gamma) * sign(w[j])
201+
# distance of -grad_j to (alpha * sign(w[j]) - w[j] / gamma)
211202
subdiff_dist[idx] = np.abs(
212-
grad[idx] + self.alpha * np.sign(w[j])
213-
- w[j] / self.gamma)
203+
grad[idx] + self.alpha * np.sign(w[j]) - w[j] / self.gamma)
214204
else:
215205
# distance of grad to 0
216206
subdiff_dist[idx] = np.abs(grad[idx])
@@ -229,6 +219,70 @@ def alpha_max(self, gradient0):
229219
return np.max(np.abs(gradient0))
230220

231221

222+
spec_SCAD = [
223+
('alpha', float64),
224+
('gamma', float64)
225+
]
226+
227+
228+
@jitclass(spec_SCAD)
229+
class SCAD(BasePenalty):
230+
"""Smoothly Clipped Absolute Deviation.
231+
232+
Notes
233+
-----
234+
With x >= 0
235+
pen(x) =
236+
alpha * x if x =< alpha
237+
2 * gamma * alpha * x - x^2 - alpha^2 \
238+
/ 2 * (gamma - 1)) if alpha < x < alpha * gamma
239+
alpha^2 * (gamma + 1) / 2 if x > gamma * alpha
240+
value = sum_{j=1}^{n_features} pen(abs(w_j))
241+
"""
242+
243+
def __init__(self, alpha, gamma):
244+
self.alpha = alpha
245+
self.gamma = gamma
246+
247+
def value(self, w):
248+
"""Compute the value of the SCAD penalty at w."""
249+
return value_SCAD(w, self.alpha, self.gamma)
250+
251+
def prox_1d(self, value, stepsize, j):
252+
"""Compute the proximal operator of SCAD penalty."""
253+
return prox_SCAD(value, stepsize, self.alpha, self.gamma)
254+
255+
def subdiff_distance(self, w, grad, ws):
256+
"""Compute distance of negative gradient to the subdifferential at w."""
257+
subdiff_dist = np.zeros_like(grad)
258+
for idx, j in enumerate(ws):
259+
if w[j] == 0:
260+
# distance of -grad_j to alpha * [-1, 1]
261+
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
262+
elif np.abs(w[j]) <= self.alpha:
263+
# distance of -grad_j to alpha * sgn(w[j])
264+
subdiff_dist[idx] = np.abs(grad[idx] + self.alpha * np.sign(w[j]))
265+
elif np.abs(w[j]) <= self.alpha * self.gamma:
266+
# distance of -grad_j to (alpha * gamma * sign(w[j]) - w[j])
267+
# / (gamma - 1)
268+
subdiff_dist[idx] = np.abs(
269+
grad[idx] +
270+
(np.sign(w[j]) * self.alpha * self.gamma - w[j]) / (self.gamma - 1)
271+
)
272+
else:
273+
# distance of -grad_j to 0
274+
subdiff_dist[idx] = np.abs(grad[idx])
275+
return subdiff_dist
276+
277+
def is_penalized(self, n_features):
278+
"""Return a binary mask with the penalized features."""
279+
return np.ones(n_features, bool_)
280+
281+
def generalized_support(self, w):
282+
"""Return a mask with non-zero coefficients."""
283+
return w != 0
284+
285+
232286
spec_IndicatorBox = [
233287
('alpha', float64)
234288
]

skglm/tests/test_penalties.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
import numpy as np
3+
4+
from numpy.linalg import norm
5+
from numpy.testing import assert_array_less
6+
7+
from skglm.datafits import Quadratic, QuadraticMultiTask
8+
from skglm.penalties import (
9+
L1, L1_plus_L2, WeightedL1, MCPenalty, SCAD, IndicatorBox, L0_5, L2_3,
10+
L2_1, L2_05, BlockMCPenalty, BlockSCAD)
11+
from skglm import GeneralizedLinearEstimator
12+
from skglm.utils import make_correlated_data
13+
14+
15+
n_samples = 20
16+
n_features = 10
17+
n_tasks = 10
18+
X, Y, _ = make_correlated_data(
19+
n_samples=n_samples, n_features=n_features, n_tasks=n_tasks, density=0.1,
20+
random_state=0)
21+
y = Y[:, 0]
22+
23+
n_samples, n_features = X.shape
24+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
25+
alpha = alpha_max / 1000
26+
27+
penalties = [
28+
L1(alpha=alpha),
29+
L1_plus_L2(alpha=alpha, l1_ratio=0.5),
30+
WeightedL1(alpha=1, weights=np.arange(n_features)),
31+
MCPenalty(alpha=alpha, gamma=4),
32+
SCAD(alpha=alpha, gamma=4),
33+
IndicatorBox(alpha=alpha),
34+
L0_5(alpha),
35+
L2_3(alpha)]
36+
37+
block_penalties = [
38+
L2_1(alpha=alpha), L2_05(alpha=alpha),
39+
BlockMCPenalty(alpha=alpha, gamma=4),
40+
BlockSCAD(alpha=alpha, gamma=4)
41+
]
42+
43+
44+
@pytest.mark.parametrize('penalty', penalties)
45+
def test_subdiff_diff(penalty):
46+
est = GeneralizedLinearEstimator(
47+
datafit=Quadratic(),
48+
penalty=penalty,
49+
tol=1e-14,
50+
).fit(X, y)
51+
# assert the stopping criterion is satisfied
52+
assert_array_less(est.stop_crit_, est.tol)
53+
54+
55+
@pytest.mark.parametrize('block_penalty', block_penalties)
56+
def test_subdiff_diff_block(block_penalty):
57+
est = GeneralizedLinearEstimator(
58+
datafit=QuadraticMultiTask(),
59+
penalty=block_penalty,
60+
tol=1e-14,
61+
).fit(X, Y)
62+
# assert the stopping criterion is satisfied
63+
assert_array_less(est.stop_crit_, est.tol)

0 commit comments

Comments
 (0)