Skip to content

Commit fe2aaf7

Browse files
committed
FIX - add alpha to SLOPE penalty
Add `alpha` to make `repr()` work on objects created with the SLOPE penalty. Closes #315.
1 parent d57e361 commit fe2aaf7

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

skglm/penalties/non_separable.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class SLOPE(BasePenalty):
1414
Contain regularization levels for every feature.
1515
When ``alphas`` contain a single unique value, ``SLOPE``
1616
is equivalent to the ``L1``penalty.
17+
alpha : float, default=1.0
18+
Scaling factor for the penalty. `alphas` is multiplied by this value.
1719
1820
References
1921
----------
@@ -23,24 +25,26 @@ class SLOPE(BasePenalty):
2325
https://doi.org/10.1214/15-AOAS842
2426
"""
2527

26-
def __init__(self, alphas):
28+
def __init__(self, alphas, alpha=1):
2729
self.alphas = alphas
30+
self.alpha = alpha
2831

2932
def get_spec(self):
3033
spec = (
34+
('alpha', float64),
3135
('alphas', float64[:]),
3236
)
3337
return spec
3438

3539
def params_to_dict(self):
36-
return dict(alphas=self.alphas)
40+
return dict(alphas=self.alphas, alpha=self.alpha)
3741

3842
def value(self, w):
3943
"""Compute the value of SLOPE at w."""
40-
return np.sum(np.sort(np.abs(w)) * self.alphas[::-1])
44+
return np.sum(np.sort(np.abs(w)) * self.alphas[::-1] * self.alpha)
4145

4246
def prox_vec(self, x, stepsize):
43-
alphas = self.alphas
47+
alphas = self.alphas * self.alpha
4448
prox = np.zeros_like(x)
4549

4650
abs_x = np.abs(x)

skglm/tests/test_estimators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,12 @@ def test_SparseLogReg_elasticnet(X, l1_ratio):
621621
np.testing.assert_allclose(
622622
estimator_sk.intercept_, estimator_ours.intercept_, rtol=1e-4)
623623

624+
def test_SLOPE_printing():
625+
alphas = [0.5, 0.1]
626+
model = GeneralizedLinearEstimator(penalty = SLOPE(alphas))
627+
res = repr(model)
628+
assert isinstance(res, str)
629+
624630

625631
if __name__ == "__main__":
626632
pass

0 commit comments

Comments
 (0)