Skip to content

Commit 9929e71

Browse files
authored
ENH add working set strategy to group_bcd_solver (#28)
1 parent de7af7c commit 9929e71

File tree

5 files changed

+119
-26
lines changed

5 files changed

+119
-26
lines changed

skglm/penalties/block_separable.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,23 @@ def prox_1group(self, value, stepsize, g):
212212
"""Compute the proximal operator of group ``g``."""
213213
return BST(value, self.alpha * stepsize * self.weights[g])
214214

215-
def subdiff_distance(self, w, grad, ws):
216-
"""Compute distance of negative gradient to the subdifferential at ``w``."""
215+
def subdiff_distance(self, w, grad_ws, ws):
216+
"""Compute distance to the subdifferential at ``w`` of negative gradient.
217+
218+
Note: ``grad_ws`` is a stacked array of ``-``gradients.
219+
([-grad_ws_1, -grad_ws_2, ...])
220+
"""
217221
alpha, weights = self.alpha, self.weights
218222
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
219223

220224
scores = np.zeros(len(ws))
225+
grad_ptr = 0
221226
for idx, g in enumerate(ws):
222-
grad_g = grad[idx]
223-
224227
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
228+
229+
grad_g = grad_ws[grad_ptr: grad_ptr + len(grp_g_indices)]
230+
grad_ptr += len(grp_g_indices)
231+
225232
w_g = w[grp_g_indices]
226233
norm_w_g = norm(w_g)
227234

@@ -232,3 +239,23 @@ def subdiff_distance(self, w, grad, ws):
232239
scores[idx] = norm(grad_g - subdiff)
233240

234241
return scores
242+
243+
def is_penalized(self, n_groups):
244+
return np.ones(n_groups, dtype=np.bool_)
245+
246+
def generalized_support(self, w):
247+
grp_indices, grp_ptr = self.grp_indices, self.grp_ptr
248+
n_groups = len(grp_ptr) - 1
249+
is_penalized = self.is_penalized(n_groups)
250+
251+
gsupp = np.zeros(n_groups, dtype=np.bool_)
252+
for g in range(n_groups):
253+
if not is_penalized[g]:
254+
gsupp[g] = True
255+
continue
256+
257+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
258+
if np.any(w[grp_g_indices]):
259+
gsupp[g] = True
260+
261+
return gsupp

skglm/solvers/cd_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
9292
# X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X)
9393

9494
if alphas is None:
95-
raise ValueError('alphas should be passed explicitely')
95+
raise ValueError('alphas should be passed explicitly')
9696
# if hasattr(penalty, "alpha_max"):
9797
# if sparse.issparse(X):
9898
# grad0 = construct_grad_sparse(

skglm/solvers/group_bcd_solver.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
22
from numba import njit
33

4+
from skglm.utils import check_group_compatible
45

5-
def bcd_solver(X, y, datafit, penalty, w_init=None,
6-
max_iter=1000, max_epochs=100, tol=1e-7, verbose=False):
6+
7+
def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
8+
max_iter=1000, max_epochs=100, tol=1e-4, verbose=False):
79
"""Run a group BCD solver.
810
911
Parameters
@@ -24,13 +26,16 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
2426
Initial value of coefficients.
2527
If set to None, a zero vector is used instead.
2628
29+
p0 : int, default 10
30+
Minimum number of groups to be included in the working set.
31+
2732
max_iter : int, default 1000
2833
Maximum number of iterations.
2934
3035
max_epochs : int, default 100
3136
Maximum number of epochs.
3237
33-
tol : float, default 1e-6
38+
tol : float, default 1e-4
3439
Tolerance for convergence.
3540
3641
verbose : bool, default False
@@ -47,6 +52,9 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
4752
stop_crit: float
4853
The value of the stop criterion.
4954
"""
55+
check_group_compatible(datafit)
56+
check_group_compatible(penalty)
57+
5058
n_features = X.shape[1]
5159
n_groups = len(penalty.grp_ptr) - 1
5260

@@ -56,51 +64,62 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
5664
datafit.initialize(X, y)
5765
all_groups = np.arange(n_groups)
5866
p_objs_out = np.zeros(max_iter)
67+
stop_crit = 0. # prevent ref before assign when max_iter == 0
5968

6069
for t in range(max_iter):
61-
if t == 0: # avoid computing p_obj twice
62-
prev_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
70+
if t == 0: # avoid computing grad and opt twice
71+
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
72+
opt = penalty.subdiff_distance(w, grad, all_groups)
73+
stop_crit = np.max(opt)
74+
75+
if stop_crit <= tol:
76+
break
77+
78+
gsupp_size = penalty.generalized_support(w).sum()
79+
ws_size = max(min(p0, n_groups),
80+
min(n_groups, 2 * gsupp_size))
81+
ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort)
6382

6483
for epoch in range(max_epochs):
65-
_bcd_epoch(X, y, w, Xw, datafit, penalty, all_groups)
84+
_bcd_epoch(X, y, w, Xw, datafit, penalty, ws)
6685

6786
if epoch % 10 == 0:
68-
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
69-
stop_crit_in = prev_p_obj - current_p_obj
87+
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
88+
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
89+
stop_crit_in = np.max(opt_in)
7090

7191
if max(verbose - 1, 0):
92+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
7293
print(
73-
f"Epoch {epoch+1}: {current_p_obj:.10f} "
94+
f"Epoch {epoch+1}: {p_obj:.10f} "
7495
f"obj. variation: {stop_crit_in:.2e}"
7596
)
7697

77-
if stop_crit_in <= tol:
78-
print("Early exit")
98+
if stop_crit_in <= 0.3 * stop_crit:
7999
break
80-
prev_p_obj = current_p_obj
81100

82-
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
83-
stop_crit = prev_p_obj - current_p_obj
101+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
102+
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
103+
opt = penalty.subdiff_distance(w, grad, all_groups)
104+
stop_crit = np.max(opt)
84105

85-
if max(verbose, 0):
106+
if verbose:
86107
print(
87-
f"Iteration {t+1}: {current_p_obj:.10f}, "
88-
f"stopping crit: {stop_crit:.2f}"
108+
f"Iteration {t+1}: {p_obj:.10f}, "
109+
f"stopping crit: {stop_crit:.2e}"
89110
)
90111

91112
if stop_crit <= tol:
92-
print("Outer solver: Early exit")
93113
break
94114

95-
prev_p_obj = current_p_obj
96-
p_objs_out[t] = current_p_obj
115+
p_objs_out[t] = p_obj
97116

98117
return w, p_objs_out, stop_crit
99118

100119

101120
@njit
102121
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
103-
"""Perform a single BCD epoch on groups in ws."""
122+
# perform a single BCD epoch on groups in ws
104123
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices
105124

106125
for g in ws:
@@ -119,3 +138,19 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
119138
if old_w_g[idx] != w[j]:
120139
Xw += (w[j] - old_w_g[idx]) * X[:, j]
121140
return
141+
142+
143+
@njit
144+
def _construct_grad(X, y, w, Xw, datafit, ws):
145+
# compute the -gradient according to each group in ws
146+
# note: -gradients are stacked in a 1d array ([-grad_ws_1, -grad_ws_2, ...])
147+
grp_ptr = datafit.grp_ptr
148+
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])
149+
150+
grads = np.zeros(n_features_ws)
151+
grad_ptr = 0
152+
for g in ws:
153+
grad_g = datafit.gradient_g(X, y, w, Xw, g)
154+
grads[grad_ptr: grad_ptr+len(grad_g)] = -grad_g
155+
grad_ptr += len(grad_g)
156+
return grads

skglm/tests/test_group.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
from numpy.linalg import norm
44

5+
from skglm.penalties import L1
6+
from skglm.datafits import Quadratic
57
from skglm.penalties.block_separable import WeightedGroupL2
68
from skglm.datafits.group import QuadraticGroup
79
from skglm.solvers.group_bcd_solver import bcd_solver
@@ -26,6 +28,15 @@ def _generate_random_grp(n_groups, n_features, shuffle=True):
2628
return grp_indices, splits, groups
2729

2830

31+
def test_check_group_compatible():
32+
l1_penalty = L1(1e-3)
33+
quad_datafit = Quadratic()
34+
X, y = np.random.randn(5, 5), np.random.randn(5)
35+
36+
with np.testing.assert_raises(Exception):
37+
bcd_solver(X, y, quad_datafit, l1_penalty)
38+
39+
2940
@pytest.mark.parametrize("n_groups, n_features, shuffle",
3041
[[10, 50, True], [10, 50, False], [17, 53, False]])
3142
def test_alpha_max(n_groups, n_features, shuffle):

skglm/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,23 @@ def grp_converter(groups, n_features):
234234
else:
235235
raise ValueError("Unsupported group format.")
236236
return grp_indices.astype(np.int32), grp_ptr.astype(np.int32)
237+
238+
239+
def check_group_compatible(obj):
240+
"""Check whether ``obj`` is compatible with ``bcd_solver``.
241+
242+
Parameters
243+
----------
244+
obj : instance of BaseDatafit or BasePenalty
245+
Object to check.
246+
"""
247+
obj_name = obj.__class__.__name__
248+
group_attrs = ('grp_ptr', 'grp_indices')
249+
250+
for attr in group_attrs:
251+
if not hasattr(obj, attr):
252+
raise Exception(
253+
f"datafit and penalty must be compatible with 'bcd_solver'.\n"
254+
f"'{obj_name}' is not block-separable. "
255+
f"Missing '{attr}' attribute."
256+
)

0 commit comments

Comments
 (0)