1
1
import numpy as np
2
2
from numba import njit
3
3
4
+ from skglm .utils import check_group_compatible
4
5
5
- def bcd_solver (X , y , datafit , penalty , w_init = None ,
6
- max_iter = 1000 , max_epochs = 100 , tol = 1e-7 , verbose = False ):
6
+
7
+ def bcd_solver (X , y , datafit , penalty , w_init = None , p0 = 10 ,
8
+ max_iter = 1000 , max_epochs = 100 , tol = 1e-4 , verbose = False ):
7
9
"""Run a group BCD solver.
8
10
9
11
Parameters
@@ -24,13 +26,16 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
24
26
Initial value of coefficients.
25
27
If set to None, a zero vector is used instead.
26
28
29
+ p0 : int, default 10
30
+ Minimum number of groups to be included in the working set.
31
+
27
32
max_iter : int, default 1000
28
33
Maximum number of iterations.
29
34
30
35
max_epochs : int, default 100
31
36
Maximum number of epochs.
32
37
33
- tol : float, default 1e-6
38
+ tol : float, default 1e-4
34
39
Tolerance for convergence.
35
40
36
41
verbose : bool, default False
@@ -47,6 +52,9 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
47
52
stop_crit: float
48
53
The value of the stop criterion.
49
54
"""
55
+ check_group_compatible (datafit )
56
+ check_group_compatible (penalty )
57
+
50
58
n_features = X .shape [1 ]
51
59
n_groups = len (penalty .grp_ptr ) - 1
52
60
@@ -56,51 +64,62 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
56
64
datafit .initialize (X , y )
57
65
all_groups = np .arange (n_groups )
58
66
p_objs_out = np .zeros (max_iter )
67
+ stop_crit = 0. # prevent ref before assign when max_iter == 0
59
68
60
69
for t in range (max_iter ):
61
- if t == 0 : # avoid computing p_obj twice
62
- prev_p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
70
+ if t == 0 : # avoid computing grad and opt twice
71
+ grad = _construct_grad (X , y , w , Xw , datafit , all_groups )
72
+ opt = penalty .subdiff_distance (w , grad , all_groups )
73
+ stop_crit = np .max (opt )
74
+
75
+ if stop_crit <= tol :
76
+ break
77
+
78
+ gsupp_size = penalty .generalized_support (w ).sum ()
79
+ ws_size = max (min (p0 , n_groups ),
80
+ min (n_groups , 2 * gsupp_size ))
81
+ ws = np .argpartition (opt , - ws_size )[- ws_size :] # k-largest items (no sort)
63
82
64
83
for epoch in range (max_epochs ):
65
- _bcd_epoch (X , y , w , Xw , datafit , penalty , all_groups )
84
+ _bcd_epoch (X , y , w , Xw , datafit , penalty , ws )
66
85
67
86
if epoch % 10 == 0 :
68
- current_p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
69
- stop_crit_in = prev_p_obj - current_p_obj
87
+ grad_ws = _construct_grad (X , y , w , Xw , datafit , ws )
88
+ opt_in = penalty .subdiff_distance (w , grad_ws , ws )
89
+ stop_crit_in = np .max (opt_in )
70
90
71
91
if max (verbose - 1 , 0 ):
92
+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
72
93
print (
73
- f"Epoch { epoch + 1 } : { current_p_obj :.10f} "
94
+ f"Epoch { epoch + 1 } : { p_obj :.10f} "
74
95
f"obj. variation: { stop_crit_in :.2e} "
75
96
)
76
97
77
- if stop_crit_in <= tol :
78
- print ("Early exit" )
98
+ if stop_crit_in <= 0.3 * stop_crit :
79
99
break
80
- prev_p_obj = current_p_obj
81
100
82
- current_p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
83
- stop_crit = prev_p_obj - current_p_obj
101
+ p_obj = datafit .value (y , w , Xw ) + penalty .value (w )
102
+ grad = _construct_grad (X , y , w , Xw , datafit , all_groups )
103
+ opt = penalty .subdiff_distance (w , grad , all_groups )
104
+ stop_crit = np .max (opt )
84
105
85
- if max ( verbose , 0 ) :
106
+ if verbose :
86
107
print (
87
- f"Iteration { t + 1 } : { current_p_obj :.10f} , "
88
- f"stopping crit: { stop_crit :.2f } "
108
+ f"Iteration { t + 1 } : { p_obj :.10f} , "
109
+ f"stopping crit: { stop_crit :.2e } "
89
110
)
90
111
91
112
if stop_crit <= tol :
92
- print ("Outer solver: Early exit" )
93
113
break
94
114
95
- prev_p_obj = current_p_obj
96
- p_objs_out [t ] = current_p_obj
115
+ p_objs_out [t ] = p_obj
97
116
98
117
return w , p_objs_out , stop_crit
99
118
100
119
101
120
@njit
102
121
def _bcd_epoch (X , y , w , Xw , datafit , penalty , ws ):
103
- """Perform a single BCD epoch on groups in ws."""
122
+ # perform a single BCD epoch on groups in ws
104
123
grp_ptr , grp_indices = penalty .grp_ptr , penalty .grp_indices
105
124
106
125
for g in ws :
@@ -119,3 +138,19 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
119
138
if old_w_g [idx ] != w [j ]:
120
139
Xw += (w [j ] - old_w_g [idx ]) * X [:, j ]
121
140
return
141
+
142
+
143
+ @njit
144
+ def _construct_grad (X , y , w , Xw , datafit , ws ):
145
+ # compute the -gradient according to each group in ws
146
+ # note: -gradients are stacked in a 1d array ([-grad_ws_1, -grad_ws_2, ...])
147
+ grp_ptr = datafit .grp_ptr
148
+ n_features_ws = sum ([grp_ptr [g + 1 ] - grp_ptr [g ] for g in ws ])
149
+
150
+ grads = np .zeros (n_features_ws )
151
+ grad_ptr = 0
152
+ for g in ws :
153
+ grad_g = datafit .gradient_g (X , y , w , Xw , g )
154
+ grads [grad_ptr : grad_ptr + len (grad_g )] = - grad_g
155
+ grad_ptr += len (grad_g )
156
+ return grads
0 commit comments