|
| 1 | +import warnings |
| 2 | +import numpy as np |
| 3 | +from numba import njit |
| 4 | +from scipy.sparse import issparse |
| 5 | + |
| 6 | +from skglm.utils import AndersonAcceleration |
| 7 | + |
| 8 | + |
| 9 | +def gram_cd_solver(X, y, penalty, max_iter=100, w_init=None, |
| 10 | + use_acc=True, greedy_cd=True, tol=1e-4, verbose=False): |
| 11 | + r"""Run coordinate descent while keeping the gradients up-to-date with Gram updates. |
| 12 | +
|
| 13 | + This solver should be used when n_features < n_samples, and computes the |
| 14 | + (n_features, n_features) Gram matrix which comes with an overhead. It is only |
| 15 | + suited to Quadratic datafits. |
| 16 | +
|
| 17 | + It minimizes:: |
| 18 | + 1 / (2*n_samples) * norm(y - Xw)**2 + penalty(w) |
| 19 | +
|
| 20 | + which can be rewritten as:: |
| 21 | + w.T @ Q @ w / (2*n_samples) - q.T @ w / n_samples + penalty(w) |
| 22 | +
|
| 23 | + where:: |
| 24 | + Q = X.T @ X (gram matrix), and q = X.T @ y |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + X : array or sparse CSC matrix, shape (n_samples, n_features) |
| 29 | + Design matrix. |
| 30 | +
|
| 31 | + y : array, shape (n_samples,) |
| 32 | + Target vector. |
| 33 | +
|
| 34 | + penalty : instance of BasePenalty |
| 35 | + Penalty object. |
| 36 | +
|
| 37 | + max_iter : int, default 100 |
| 38 | + Maximum number of iterations. |
| 39 | +
|
| 40 | + w_init : array, shape (n_features,), default None |
| 41 | + Initial value of coefficients. |
| 42 | + If set to None, a zero vector is used instead. |
| 43 | +
|
| 44 | + use_acc : bool, default True |
| 45 | + Extrapolate the iterates based on the past 5 iterates if set to True. |
| 46 | +
|
| 47 | + greedy_cd : bool, default True |
| 48 | + Use a greedy strategy to select features to update in coordinate descent epochs |
| 49 | + if set to True. A cyclic strategy is used otherwise. |
| 50 | +
|
| 51 | + tol : float, default 1e-4 |
| 52 | + Tolerance for convergence. |
| 53 | +
|
| 54 | + verbose : bool, default False |
| 55 | + Amount of verbosity. 0/False is silent. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + w : array, shape (n_features,) |
| 60 | + Solution that minimizes the problem defined by datafit and penalty. |
| 61 | +
|
| 62 | + objs_out : array, shape (n_iter,) |
| 63 | + The objective values at every outer iteration. |
| 64 | +
|
| 65 | + stop_crit : float |
| 66 | + The value of the stopping criterion when the solver stops. |
| 67 | + """ |
| 68 | + n_samples, n_features = X.shape |
| 69 | + |
| 70 | + if issparse(X): |
| 71 | + scaled_gram = X.T.dot(X) |
| 72 | + scaled_gram = scaled_gram.toarray() / n_samples |
| 73 | + scaled_Xty = X.T.dot(y) / n_samples |
| 74 | + else: |
| 75 | + scaled_gram = X.T @ X / n_samples |
| 76 | + scaled_Xty = X.T @ y / n_samples |
| 77 | + # TODO potential improvement: allow to pass scaled_gram (e.g. for path computation) |
| 78 | + |
| 79 | + scaled_y_norm2 = np.linalg.norm(y)**2 / (2*n_samples) |
| 80 | + |
| 81 | + all_features = np.arange(n_features) |
| 82 | + stop_crit = np.inf # prevent ref before assign |
| 83 | + p_objs_out = [] |
| 84 | + |
| 85 | + w = np.zeros(n_features) if w_init is None else w_init |
| 86 | + grad = - scaled_Xty if w_init is None else scaled_gram @ w_init - scaled_Xty |
| 87 | + opt = penalty.subdiff_distance(w, grad, all_features) |
| 88 | + |
| 89 | + if use_acc: |
| 90 | + if greedy_cd: |
| 91 | + warnings.warn( |
| 92 | + "Anderson acceleration does not work with greedy_cd, set use_acc=False", |
| 93 | + UserWarning) |
| 94 | + accelerator = AndersonAcceleration(K=5) |
| 95 | + w_acc = np.zeros(n_features) |
| 96 | + grad_acc = np.zeros(n_features) |
| 97 | + |
| 98 | + for t in range(max_iter): |
| 99 | + # check convergences |
| 100 | + stop_crit = np.max(opt) |
| 101 | + if verbose: |
| 102 | + p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + |
| 103 | + scaled_y_norm2 + penalty.value(w)) |
| 104 | + print( |
| 105 | + f"Iteration {t+1}: {p_obj:.10f}, " |
| 106 | + f"stopping crit: {stop_crit:.2e}" |
| 107 | + ) |
| 108 | + |
| 109 | + if stop_crit <= tol: |
| 110 | + if verbose: |
| 111 | + print(f"Stopping criterion max violation: {stop_crit:.2e}") |
| 112 | + break |
| 113 | + |
| 114 | + # inplace update of w, grad |
| 115 | + opt = _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd) |
| 116 | + |
| 117 | + # perform Anderson extrapolation |
| 118 | + if use_acc: |
| 119 | + w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad) |
| 120 | + |
| 121 | + if is_extrapolated: |
| 122 | + # omit constant term for comparison |
| 123 | + p_obj_acc = (0.5 * w_acc @ (scaled_gram @ w_acc) - scaled_Xty @ w_acc + |
| 124 | + penalty.value(w_acc)) |
| 125 | + p_obj = 0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + penalty.value(w) |
| 126 | + if p_obj_acc < p_obj: |
| 127 | + w[:] = w_acc |
| 128 | + grad[:] = grad_acc |
| 129 | + |
| 130 | + # store p_obj |
| 131 | + p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + scaled_y_norm2 + |
| 132 | + penalty.value(w)) |
| 133 | + p_objs_out.append(p_obj) |
| 134 | + return w, np.array(p_objs_out), stop_crit |
| 135 | + |
| 136 | + |
| 137 | +@njit |
| 138 | +def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): |
| 139 | + all_features = np.arange(len(w)) |
| 140 | + for cd_iter in all_features: |
| 141 | + # select feature j |
| 142 | + if greedy_cd: |
| 143 | + opt = penalty.subdiff_distance(w, grad, all_features) |
| 144 | + j = np.argmax(opt) |
| 145 | + else: # cyclic |
| 146 | + j = cd_iter |
| 147 | + |
| 148 | + # update w_j |
| 149 | + old_w_j = w[j] |
| 150 | + step = 1 / scaled_gram[j, j] # 1 / lipschitz_j |
| 151 | + w[j] = penalty.prox_1d(old_w_j - step * grad[j], step, j) |
| 152 | + |
| 153 | + # gradient update with Gram matrix |
| 154 | + if w[j] != old_w_j: |
| 155 | + grad += (w[j] - old_w_j) * scaled_gram[:, j] |
| 156 | + |
| 157 | + return penalty.subdiff_distance(w, grad, all_features) |
0 commit comments