Skip to content

Commit 7622936

Browse files
PABanniermathurinm
andauthored
ENH Add IterativeRewieghtedL1 (#87)
Co-authored-by: mathurinm <[email protected]>
1 parent 359f4da commit 7622936

File tree

5 files changed

+222
-0
lines changed

5 files changed

+222
-0
lines changed

examples/plot_reweighted_l1.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
=================================================================
3+
Timing comparison between direct prox computation and reweighting
4+
=================================================================
5+
Compare time and objective value of L0_5-regularized problem with
6+
direct proximal computation and iterative reweighting.
7+
"""
8+
# Author: Pierre-Antoine Bannier <[email protected]>
9+
10+
import time
11+
import numpy as np
12+
import pandas as pd
13+
from numpy.linalg import norm
14+
import matplotlib.pyplot as plt
15+
16+
from skglm.penalties.separable import L0_5
17+
from skglm.utils import make_correlated_data
18+
from skglm.estimators import GeneralizedLinearEstimator
19+
from skglm.experimental import IterativeReweightedL1
20+
from skglm.solvers import AndersonCD
21+
22+
23+
n_samples, n_features = 200, 500
24+
X, y, w_true = make_correlated_data(
25+
n_samples=n_samples, n_features=n_features, random_state=24)
26+
27+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
28+
alphas = [alpha_max / 10, alpha_max / 100, alpha_max / 1000]
29+
tol = 1e-10
30+
31+
32+
def _obj(w):
33+
return (np.sum((y - X @ w) ** 2) / (2 * n_samples)
34+
+ alpha * np.sum(np.sqrt(np.abs(w))))
35+
36+
37+
def fit_l05(alpha):
38+
start = time.time()
39+
iterative_l05 = IterativeReweightedL1(
40+
penalty=L0_5(alpha),
41+
solver=AndersonCD(tol=tol, fit_intercept=False)).fit(X, y)
42+
iterative_time = time.time() - start
43+
44+
# `subdiff` strategy for WS is uninformative for L0_5
45+
start = time.time()
46+
direct_l05 = GeneralizedLinearEstimator(
47+
penalty=L0_5(alpha),
48+
solver=AndersonCD(tol=tol, fit_intercept=False,
49+
ws_strategy="fixpoint")).fit(X, y)
50+
direct_time = time.time() - start
51+
52+
results = {
53+
"iterative": (iterative_l05, iterative_time),
54+
"direct": (direct_l05, direct_time),
55+
}
56+
return results
57+
58+
59+
# caching Numba compilation
60+
fit_l05(alpha_max/10)
61+
62+
time_results = np.zeros((2, len(alphas)))
63+
obj_results = np.zeros((2, len(alphas)))
64+
65+
# actual run
66+
for i, alpha in enumerate(alphas):
67+
results = fit_l05(alpha=alpha)
68+
iterative_l05, iterative_time = results["iterative"]
69+
direct_l05, direct_time = results["direct"]
70+
71+
iterative_obj = _obj(iterative_l05.coef_)
72+
direct_obj = _obj(direct_l05.coef_)
73+
74+
obj_results[:, i] = np.array([iterative_obj, direct_obj])
75+
time_results[:, i] = np.array([iterative_time, direct_time])
76+
77+
time_df = pd.DataFrame(time_results.T, columns=["Iterative", "Direct"])
78+
obj_df = pd.DataFrame(obj_results.T, columns=["Iterative", "Direct"])
79+
80+
time_df.index = [1e-1, 1e-2, 1e-3]
81+
obj_df.index = [1e-1, 1e-2, 1e-3]
82+
83+
fig, axarr = plt.subplots(1, 2, figsize=(8, 3.5), constrained_layout=True)
84+
ax = axarr[0]
85+
time_df.plot.bar(rot=0, ax=ax)
86+
ax.set_xlabel(r"$\lambda/\lambda_{max}$")
87+
ax.set_ylabel("time (in s)")
88+
ax.set_title("Time to fit")
89+
90+
ax = axarr[1]
91+
obj_df.plot.bar(rot=0, ax=ax)
92+
ax.set_xlabel(r"$\lambda/\lambda_{max}$")
93+
ax.set_ylabel("obj. value")
94+
ax.set_title("Objective at solution")
95+
plt.show(block=False)

skglm/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from .reweighted import IterativeReweightedL1
12
from .sqrt_lasso import SqrtLasso
23

34
__all__ = [
45
SqrtLasso,
6+
IterativeReweightedL1,
57
]

skglm/experimental/reweighted.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
from skglm.datafits import Quadratic
3+
from skglm.estimators import GeneralizedLinearEstimator
4+
from skglm.penalties import WeightedL1, L0_5
5+
from skglm.utils import compiled_clone
6+
7+
8+
class IterativeReweightedL1(GeneralizedLinearEstimator):
9+
r"""Reweighted L1-norm estimator.
10+
11+
This estimator solves a non-convex problems by iteratively solving
12+
convex surrogates involving weighted L1 norms.
13+
14+
Parameters
15+
----------
16+
alpha : float, optional
17+
Penalty strength.
18+
19+
datafit : instance of BaseDatafit, optional
20+
Datafit. If None, ``datafit`` is initialized as a ``Quadratic`` datafit.
21+
``datafit`` is replaced by a JIT-compiled instance when calling fit.
22+
23+
solver : instance of BaseSolver, optional
24+
Solver. If None, `solver` is initialized as an `AndersonCD` solver.
25+
26+
n_reweights : int, optional
27+
Number of reweighting performed (convex surrogates solved).
28+
29+
Attributes
30+
----------
31+
coef_ : array, shape (n_features,)
32+
Parameter vector (w in the cost function formula).
33+
34+
loss_history_ : list
35+
Objective history after every reweighting.
36+
37+
References
38+
----------
39+
.. [1] Candès et al. (2007), Enhancing sparsity by reweighted l1 minimization
40+
https://web.stanford.edu/~boyd/papers/pdf/rwl1.pdf
41+
"""
42+
43+
def __init__(self, datafit=Quadratic(), penalty=L0_5(1.), solver=None,
44+
n_reweights=5):
45+
super().__init__(datafit=datafit, penalty=penalty, solver=solver)
46+
self.n_reweights = n_reweights
47+
48+
def fit(self, X, y):
49+
"""Fit the model according to the given training data.
50+
51+
Parameters
52+
----------
53+
X : array-like, shape (n_samples, n_features)
54+
Training data, where n_samples is the number of samples and
55+
n_features is the number of features.
56+
57+
y : array-like, shape (n_samples,)
58+
Target vector relative to X.
59+
60+
Returns
61+
-------
62+
self :
63+
Fitted estimator.
64+
"""
65+
if not hasattr(self.penalty, "derivative"):
66+
raise ValueError(
67+
"Missing `derivative` method. Reweighting is not implemented for " +
68+
f"penalty {self.penalty.__class__.__name__}")
69+
70+
n_features = X.shape[1]
71+
_penalty = compiled_clone(WeightedL1(self.penalty.alpha, np.ones(n_features)))
72+
self.datafit = compiled_clone(self.datafit)
73+
self.penalty = compiled_clone(self.penalty)
74+
75+
self.loss_history_ = []
76+
77+
for iter_reweight in range(self.n_reweights):
78+
coef_ = self.solver.solve(X, y, self.datafit, _penalty)[0]
79+
_penalty.weights = self.penalty.derivative(coef_)
80+
81+
loss = (self.datafit.value(y, coef_, X @ coef_)
82+
+ self.penalty.value(coef_))
83+
self.loss_history_.append(loss)
84+
85+
if self.solver.verbose:
86+
print(f"Reweight {iter_reweight}/{self.n_reweights}, objective {loss}")
87+
88+
self.coef_ = coef_
89+
90+
return self
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
from numpy.linalg import norm
3+
4+
from skglm.penalties.separable import L0_5
5+
from skglm.utils import make_correlated_data
6+
from skglm.experimental import IterativeReweightedL1
7+
from skglm.solvers import AndersonCD
8+
9+
10+
n_samples, n_features = 20, 50
11+
X, y, w_true = make_correlated_data(
12+
n_samples=n_samples, n_features=n_features, random_state=24)
13+
14+
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
15+
alpha = alpha_max / 100
16+
tol = 1e-10
17+
18+
19+
def test_decreasing_loss():
20+
# reweighting can't increase the L0.5 objective
21+
iterative_l05 = IterativeReweightedL1(
22+
penalty=L0_5(alpha),
23+
solver=AndersonCD(tol=tol, fit_intercept=False)).fit(X, y)
24+
np.testing.assert_array_less(
25+
iterative_l05.loss_history_[-1], iterative_l05.loss_history_[0])
26+
diffs = np.diff(iterative_l05.loss_history_)
27+
np.testing.assert_array_less(diffs, 1e-5)

skglm/penalties/separable.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ def value(self, w):
380380
"""Compute the value of L0_5 at w."""
381381
return self.alpha * np.sum(np.abs(w) ** 0.5)
382382

383+
def derivative(self, w):
384+
"""Compute the element-wise derivative."""
385+
return 1. / (2. * np.sqrt(np.abs(w)) + 1e-12)
386+
383387
def prox_1d(self, value, stepsize, j):
384388
"""Compute the proximal operator of L0_5."""
385389
return prox_05(value, self.alpha * stepsize)
@@ -429,6 +433,10 @@ def value(self, w):
429433
"""Compute the value of the L2_3 norm at w."""
430434
return self.alpha * np.sum(np.abs(w) ** (2/3))
431435

436+
def derivative(self, w):
437+
"""Compute the element-wise derivative."""
438+
return 2 / (3 * np.abs(w) ** (1/3) + 1e-12)
439+
432440
def prox_1d(self, value, stepsize, j):
433441
"""Compute the proximal operator of the L2_3 norm."""
434442
return prox_2_3(value, self.alpha * stepsize)

0 commit comments

Comments
 (0)