|
| 1 | +import warnings |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from numpy.linalg import norm |
| 5 | +from scipy.sparse import issparse |
| 6 | + |
| 7 | +from numba import njit |
| 8 | +from skglm.utils.jit_compilation import compiled_clone |
| 9 | +from sklearn.exceptions import ConvergenceWarning |
| 10 | + |
| 11 | + |
| 12 | +class PDCD_WS: |
| 13 | + r"""Primal-Dual Coordinate Descent solver with working sets. |
| 14 | +
|
| 15 | + It solves:: |
| 16 | +
|
| 17 | + \min_w F(Xw) + G(w) |
| 18 | +
|
| 19 | + using a primal-dual method on the saddle point problem:: |
| 20 | +
|
| 21 | + \min_w \max_z <Xw, z> + G(w) - F^*(z) |
| 22 | +
|
| 23 | + where :math:`F` is the datafit term (:math:`F^*` its Fenchel conjugate) |
| 24 | + and :math:`G` is the penalty term. |
| 25 | +
|
| 26 | + The datafit is required to be convex and proximable. Also, the penalty |
| 27 | + is required to be convex, separable, and proximable. |
| 28 | +
|
| 29 | + The solver is an adaptation of algorithm [1] to working sets [2]. |
| 30 | + The working sets are built using a fixed point distance strategy |
| 31 | + where each feature is assigned a score based how much its coefficient varies |
| 32 | + when performing a primal update:: |
| 33 | +
|
| 34 | + \text{score}_j = \abs{w_j - prox_{\tau_j, G_j}(w_j - \tau_j <X_j, z>)} |
| 35 | +
|
| 36 | + where :maths:`\tau_j` is the primal step associated with the j-th feature. |
| 37 | +
|
| 38 | + Parameters |
| 39 | + ---------- |
| 40 | + max_iter : int, optional |
| 41 | + The maximum number of iterations or equivalently the |
| 42 | + the maximum number of solved subproblems. |
| 43 | +
|
| 44 | + max_epochs : int, optional |
| 45 | + Maximum number of primal CD epochs on each subproblem. |
| 46 | +
|
| 47 | + dual_init : array, shape (n_samples,) default None |
| 48 | + The initialization of dual variables. |
| 49 | + If None, they are initialized as the 0 vector ``np.zeros(n_samples)``. |
| 50 | +
|
| 51 | + p0 : int, optional |
| 52 | + First working set size. |
| 53 | +
|
| 54 | + tol : float, optional |
| 55 | + The tolerance for the optimization. |
| 56 | +
|
| 57 | + verbose : bool or int, default False |
| 58 | + Amount of verbosity. 0/False is silent. |
| 59 | +
|
| 60 | + References |
| 61 | + ---------- |
| 62 | + .. [1] Olivier Fercoq and Pascal Bianchi, |
| 63 | + "A Coordinate-Descent Primal-Dual Algorithm with Large Step Size and Possibly |
| 64 | + Nonseparable Functions", SIAM Journal on Optimization, 2020, |
| 65 | + https://epubs.siam.org/doi/10.1137/18M1168480, |
| 66 | + code: https://github.com/Badr-MOUFAD/Fercoq-Bianchi-solver |
| 67 | +
|
| 68 | + .. [2] Bertrand, Q. and Klopfenstein, Q. and Bannier, P.-A. and Gidel, G. |
| 69 | + and Massias, M. |
| 70 | + "Beyond L1: Faster and Better Sparse Models with skglm", NeurIPS, 2022 |
| 71 | + https://arxiv.org/abs/2204.07826 |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, |
| 75 | + p0=100, tol=1e-6, verbose=False): |
| 76 | + self.max_iter = max_iter |
| 77 | + self.max_epochs = max_epochs |
| 78 | + self.dual_init = dual_init |
| 79 | + self.p0 = p0 |
| 80 | + self.tol = tol |
| 81 | + self.verbose = verbose |
| 82 | + |
| 83 | + def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): |
| 84 | + if issparse(X): |
| 85 | + raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.") |
| 86 | + |
| 87 | + datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_) |
| 88 | + n_samples, n_features = X.shape |
| 89 | + |
| 90 | + # init steps |
| 91 | + # Despite violating the conditions mentioned in [1] |
| 92 | + # this choice of steps yield in practice a convergent algorithm |
| 93 | + # with better speed of convergence |
| 94 | + dual_step = 1 / norm(X, ord=2) |
| 95 | + primal_steps = 1 / norm(X, axis=0, ord=2) |
| 96 | + |
| 97 | + # primal vars |
| 98 | + w = np.zeros(n_features) if w_init is None else w_init |
| 99 | + Xw = np.zeros(n_samples) if Xw_init is None else Xw_init |
| 100 | + |
| 101 | + # dual vars |
| 102 | + if self.dual_init is None: |
| 103 | + z = np.zeros(n_samples) |
| 104 | + z_bar = np.zeros(n_samples) |
| 105 | + else: |
| 106 | + z = self.dual_init.copy() |
| 107 | + z_bar = self.dual_init.copy() |
| 108 | + |
| 109 | + p_objs = [] |
| 110 | + stop_crit = 0. |
| 111 | + all_features = np.arange(n_features) |
| 112 | + |
| 113 | + for iteration in range(self.max_iter): |
| 114 | + |
| 115 | + # check convergence using fixed-point criteria on both dual and primal |
| 116 | + opts_primal = _scores_primal(X, w, z, penalty, primal_steps, all_features) |
| 117 | + opt_dual = _score_dual(y, z, Xw, datafit, dual_step) |
| 118 | + |
| 119 | + stop_crit = max(max(opts_primal), opt_dual) |
| 120 | + |
| 121 | + if self.verbose: |
| 122 | + current_p_obj = datafit.value(y, w, Xw) + penalty.value(w) |
| 123 | + print( |
| 124 | + f"Iteration {iteration+1}: {current_p_obj:.10f}, " |
| 125 | + f"stopping crit: {stop_crit:.2e}") |
| 126 | + |
| 127 | + if stop_crit <= self.tol: |
| 128 | + break |
| 129 | + |
| 130 | + # build ws |
| 131 | + gsupp_size = (w != 0).sum() |
| 132 | + ws_size = max(min(self.p0, n_features), |
| 133 | + min(n_features, 2 * gsupp_size)) |
| 134 | + |
| 135 | + # similar to np.argsort()[-ws_size:] but without full sort |
| 136 | + ws = np.argpartition(opts_primal, -ws_size)[-ws_size:] |
| 137 | + |
| 138 | + # solve sub problem |
| 139 | + # inplace update of w, Xw, z, z_bar |
| 140 | + PDCD_WS._solve_subproblem( |
| 141 | + y, X, w, Xw, z, z_bar, datafit, penalty, |
| 142 | + primal_steps, dual_step, ws, self.max_epochs, tol_in=0.3*stop_crit) |
| 143 | + |
| 144 | + current_p_obj = datafit.value(y, w, Xw) + penalty.value(w) |
| 145 | + p_objs.append(current_p_obj) |
| 146 | + else: |
| 147 | + warnings.warn( |
| 148 | + f"PDCD_WS did not converge for tol={self.tol:.3e} " |
| 149 | + f"and max_iter={self.max_iter}.\n" |
| 150 | + "Considering increasing `max_iter` or `tol`.", |
| 151 | + category=ConvergenceWarning |
| 152 | + ) |
| 153 | + |
| 154 | + return w, np.asarray(p_objs), stop_crit |
| 155 | + |
| 156 | + @staticmethod |
| 157 | + @njit |
| 158 | + def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, |
| 159 | + primal_steps, dual_step, ws, max_epochs, tol_in): |
| 160 | + n_features = X.shape[1] |
| 161 | + |
| 162 | + for epoch in range(max_epochs): |
| 163 | + |
| 164 | + for j in ws: |
| 165 | + # update primal |
| 166 | + old_w_j = w[j] |
| 167 | + pseudo_grad = X[:, j] @ (2 * z_bar - z) |
| 168 | + w[j] = penalty.prox_1d( |
| 169 | + old_w_j - primal_steps[j] * pseudo_grad, |
| 170 | + primal_steps[j], j) |
| 171 | + |
| 172 | + # keep Xw syncr with X @ w |
| 173 | + delta_w_j = w[j] - old_w_j |
| 174 | + if delta_w_j: |
| 175 | + Xw += delta_w_j * X[:, j] |
| 176 | + |
| 177 | + # update dual |
| 178 | + z_bar[:] = datafit.prox_conjugate(z + dual_step * Xw, |
| 179 | + dual_step, y) |
| 180 | + z += (z_bar - z) / n_features |
| 181 | + |
| 182 | + # check convergence using fixed-point criteria on both dual and primal |
| 183 | + if epoch % 10 == 0: |
| 184 | + opts_primal_in = _scores_primal(X, w, z, penalty, primal_steps, ws) |
| 185 | + opt_dual_in = _score_dual(y, z, Xw, datafit, dual_step) |
| 186 | + |
| 187 | + stop_crit_in = max(max(opts_primal_in), opt_dual_in) |
| 188 | + |
| 189 | + if stop_crit_in <= tol_in: |
| 190 | + break |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def _validate_init(datafit_, penalty_): |
| 194 | + # validate datafit |
| 195 | + missing_attrs = [] |
| 196 | + for attr in ('prox_conjugate', 'subdiff_distance'): |
| 197 | + if not hasattr(datafit_, attr): |
| 198 | + missing_attrs.append(f"`{attr}`") |
| 199 | + |
| 200 | + if len(missing_attrs): |
| 201 | + raise AttributeError( |
| 202 | + "Datafit is not compatible with PDCD_WS solver.\n" |
| 203 | + "Datafit must implement `prox_conjugate` and `subdiff_distance`.\n" |
| 204 | + f"Missing {' and '.join(missing_attrs)}." |
| 205 | + ) |
| 206 | + |
| 207 | + # jit compile classes |
| 208 | + compiled_datafit = compiled_clone(datafit_) |
| 209 | + compiled_penalty = compiled_clone(penalty_) |
| 210 | + |
| 211 | + return compiled_datafit, compiled_penalty |
| 212 | + |
| 213 | + |
| 214 | +@njit |
| 215 | +def _scores_primal(X, w, z, penalty, primal_steps, ws): |
| 216 | + scores_ws = np.zeros(len(ws)) |
| 217 | + |
| 218 | + for idx, j in enumerate(ws): |
| 219 | + next_w_j = penalty.prox_1d(w[j] - primal_steps[j] * X[:, j] @ z, |
| 220 | + primal_steps[j], j) |
| 221 | + scores_ws[idx] = abs(w[j] - next_w_j) |
| 222 | + |
| 223 | + return scores_ws |
| 224 | + |
| 225 | + |
| 226 | +@njit |
| 227 | +def _score_dual(y, z, Xw, datafit, dual_step): |
| 228 | + next_z = datafit.prox_conjugate(z + dual_step * Xw, |
| 229 | + dual_step, y) |
| 230 | + return norm(z - next_z, ord=np.inf) |
0 commit comments