Skip to content

Commit de7af7c

Browse files
ENH add single task group solver (#26)
Co-authored-by: mathurinm <[email protected]>
1 parent 18db9b1 commit de7af7c

File tree

7 files changed

+435
-2
lines changed

7 files changed

+435
-2
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
- name: Install dependencies
2121
run: |
2222
python -m pip install --upgrade pip
23+
python -m pip install git+https://github.com/mathurinm/celer.git
2324
pip install pytest
2425
pip install numpydoc
2526
pip install .

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Penalties
4040
L2_3
4141
MCPenalty
4242
WeightedL1
43+
WeightedGroupL2
4344

4445

4546
Datafits
@@ -50,6 +51,7 @@ Datafits
5051
.. autosummary::
5152
:toctree: generated/
5253

54+
GroupQuadratic
5355
Logistic
5456
Quadratic
5557
QuadraticSVC

skglm/datafits/group.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
from numpy.linalg import norm
3+
from numba.experimental import jitclass
4+
from numba import int32, float64
5+
6+
from skglm.datafits.base import BaseDatafit
7+
8+
9+
spec_QuadraticGroup = [
10+
('grp_ptr', int32[:]),
11+
('grp_indices', int32[:]),
12+
('lipschitz', float64[:])
13+
]
14+
15+
16+
@jitclass(spec_QuadraticGroup)
17+
class QuadraticGroup(BaseDatafit):
18+
"""Quadratic datafit used with group penalties.
19+
20+
The datafit reads::
21+
22+
(1 / (2 * n_samples)) * ||y - X w||^2_2
23+
24+
Attributes
25+
----------
26+
grp_indices : array, shape (n_features,)
27+
The group indices stacked contiguously
28+
(e.g. [grp1_indices, grp2_indices, ...]).
29+
30+
grp_ptr : array, shape (n_groups + 1,)
31+
The group pointers such that two consecutive elements delimit
32+
the indices of a group in ``grp_indices``.
33+
34+
lipschitz : array, shape (n_groups,)
35+
The lipschitz constants for each group.
36+
"""
37+
38+
def __init__(self, grp_ptr, grp_indices):
39+
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
40+
41+
def initialize(self, X, y):
42+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
43+
n_groups = len(grp_ptr) - 1
44+
45+
lipschitz = np.zeros(n_groups)
46+
for g in range(n_groups):
47+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
48+
X_g = X[:, grp_g_indices]
49+
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
50+
51+
self.lipschitz = lipschitz
52+
53+
def value(self, y, w, Xw):
54+
return norm(y - Xw) ** 2 / (2 * len(y))
55+
56+
def gradient_g(self, X, y, w, Xw, g):
57+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
58+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
59+
60+
grad_g = np.zeros(len(grp_g_indices))
61+
for idx, j in enumerate(grp_g_indices):
62+
grad_g[idx] = self.gradient_scalar(X, y, w, Xw, j)
63+
64+
return grad_g
65+
66+
def gradient_scalar(self, X, y, w, Xw, j):
67+
return X[:, j] @ (Xw - y) / len(y)

skglm/penalties/block_separable.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
2-
from numpy.linalg.linalg import norm
3-
from numba import float64
2+
from numpy.linalg import norm
3+
4+
from numba import float64, int32
45
from numba.experimental import jitclass
56
from numba.types import bool_
67

@@ -153,3 +154,81 @@ def subdiff_distance(self, W, grad, ws):
153154
def is_penalized(self, n_features):
154155
"""Return a binary mask with the penalized features."""
155156
return np.ones(n_features, bool_)
157+
158+
159+
spec_WeightedGroupL2 = [
160+
('alpha', float64),
161+
('weights', float64[:]),
162+
('grp_ptr', int32[:]),
163+
('grp_indices', int32[:]),
164+
]
165+
166+
167+
@jitclass(spec_WeightedGroupL2)
168+
class WeightedGroupL2(BasePenalty):
169+
r"""Weighted Group L2 penalty.
170+
171+
The penalty reads::
172+
173+
\sum_{g} weights[g] * ||w_g||_2
174+
175+
Attributes
176+
----------
177+
alpha : float
178+
The regularization parameter.
179+
180+
weights : array, shape (n_groups,)
181+
The weights of the groups.
182+
183+
grp_indices : array, shape (n_features,)
184+
The group indices stacked contiguously
185+
(e.g. [grp1_indices, grp2_indices, ...]).
186+
187+
grp_ptr : array, shape (n_groups + 1,)
188+
The group pointers such that two consecutive elements delimit
189+
the indices of a group in ``grp_indices``.
190+
"""
191+
192+
def __init__(self, alpha, weights, grp_ptr, grp_indices):
193+
self.alpha, self.weights = alpha, weights
194+
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
195+
196+
def value(self, w):
197+
"""Value of penalty at vector ``w``."""
198+
alpha, weights = self.alpha, self.weights
199+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
200+
n_grp = len(grp_ptr) - 1
201+
202+
sum_weighted_L2 = 0.
203+
for g in range(n_grp):
204+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
205+
w_g = w[grp_g_indices]
206+
207+
sum_weighted_L2 += alpha * weights[g] * norm(w_g)
208+
209+
return sum_weighted_L2
210+
211+
def prox_1group(self, value, stepsize, g):
212+
"""Compute the proximal operator of group ``g``."""
213+
return BST(value, self.alpha * stepsize * self.weights[g])
214+
215+
def subdiff_distance(self, w, grad, ws):
216+
"""Compute distance of negative gradient to the subdifferential at ``w``."""
217+
alpha, weights = self.alpha, self.weights
218+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
219+
220+
scores = np.zeros(len(ws))
221+
for idx, g in enumerate(ws):
222+
grad_g = grad[idx]
223+
224+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
225+
w_g = w[grp_g_indices]
226+
norm_w_g = norm(w_g)
227+
228+
if norm_w_g == 0:
229+
scores[idx] = max(0, norm(grad_g) - alpha * weights[g])
230+
else:
231+
subdiff = alpha * weights[g] * w_g / norm_w_g
232+
scores[idx] = norm(grad_g - subdiff)
233+
234+
return scores

skglm/solvers/group_bcd_solver.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import numpy as np
2+
from numba import njit
3+
4+
5+
def bcd_solver(X, y, datafit, penalty, w_init=None,
6+
max_iter=1000, max_epochs=100, tol=1e-7, verbose=False):
7+
"""Run a group BCD solver.
8+
9+
Parameters
10+
----------
11+
X : array, shape (n_samples, n_features)
12+
Design matrix.
13+
14+
y : array, shape (n_samples,)
15+
Target vector.
16+
17+
datafit : instance of BaseDatafit
18+
Datafit object.
19+
20+
penalty : instance of BasePenalty
21+
Penalty object.
22+
23+
w_init : array, shape (n_features,), default None
24+
Initial value of coefficients.
25+
If set to None, a zero vector is used instead.
26+
27+
max_iter : int, default 1000
28+
Maximum number of iterations.
29+
30+
max_epochs : int, default 100
31+
Maximum number of epochs.
32+
33+
tol : float, default 1e-6
34+
Tolerance for convergence.
35+
36+
verbose : bool, default False
37+
Amount of verbosity. 0/False is silent.
38+
39+
Returns
40+
-------
41+
w : array, shape (n_features,)
42+
Solution that minimizes the problem defined by datafit and penalty.
43+
44+
p_objs_out: array (max_iter,)
45+
The objective values at every outer iteration.
46+
47+
stop_crit: float
48+
The value of the stop criterion.
49+
"""
50+
n_features = X.shape[1]
51+
n_groups = len(penalty.grp_ptr) - 1
52+
53+
# init
54+
w = np.zeros(n_features) if w_init is None else w_init
55+
Xw = X @ w
56+
datafit.initialize(X, y)
57+
all_groups = np.arange(n_groups)
58+
p_objs_out = np.zeros(max_iter)
59+
60+
for t in range(max_iter):
61+
if t == 0: # avoid computing p_obj twice
62+
prev_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
63+
64+
for epoch in range(max_epochs):
65+
_bcd_epoch(X, y, w, Xw, datafit, penalty, all_groups)
66+
67+
if epoch % 10 == 0:
68+
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
69+
stop_crit_in = prev_p_obj - current_p_obj
70+
71+
if max(verbose - 1, 0):
72+
print(
73+
f"Epoch {epoch+1}: {current_p_obj:.10f} "
74+
f"obj. variation: {stop_crit_in:.2e}"
75+
)
76+
77+
if stop_crit_in <= tol:
78+
print("Early exit")
79+
break
80+
prev_p_obj = current_p_obj
81+
82+
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
83+
stop_crit = prev_p_obj - current_p_obj
84+
85+
if max(verbose, 0):
86+
print(
87+
f"Iteration {t+1}: {current_p_obj:.10f}, "
88+
f"stopping crit: {stop_crit:.2f}"
89+
)
90+
91+
if stop_crit <= tol:
92+
print("Outer solver: Early exit")
93+
break
94+
95+
prev_p_obj = current_p_obj
96+
p_objs_out[t] = current_p_obj
97+
98+
return w, p_objs_out, stop_crit
99+
100+
101+
@njit
102+
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
103+
"""Perform a single BCD epoch on groups in ws."""
104+
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
105+
106+
for g in ws:
107+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
108+
old_w_g = w[grp_g_indices].copy()
109+
110+
lipschitz_g = datafit.lipschitz[g]
111+
grad_g = datafit.gradient_g(X, y, w, Xw, g)
112+
113+
w[grp_g_indices] = penalty.prox_1group(
114+
old_w_g - grad_g / lipschitz_g,
115+
1 / lipschitz_g, g
116+
)
117+
118+
for idx, j in enumerate(grp_g_indices):
119+
if old_w_g[idx] != w[j]:
120+
Xw += (w[j] - old_w_g[idx]) * X[:, j]
121+
return

0 commit comments

Comments
 (0)