|
| 1 | +# License: BSD 3 clause |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from scipy.linalg import pinvh |
| 5 | + |
| 6 | +from skglm.solvers.gram_cd import barebones_cd_gram |
| 7 | + |
| 8 | + |
| 9 | +class GraphicalLasso(): |
| 10 | + """ A first-order BCD Graphical Lasso solver implementing the GLasso algorithm |
| 11 | + described in Friedman et al., 2008 and the P-GLasso algorithm described in |
| 12 | + Mazumder et al., 2012.""" |
| 13 | + |
| 14 | + def __init__(self, |
| 15 | + alpha=1., |
| 16 | + weights=None, |
| 17 | + algo="dual", |
| 18 | + max_iter=100, |
| 19 | + tol=1e-8, |
| 20 | + warm_start=False, |
| 21 | + inner_tol=1e-4, |
| 22 | + verbose=False |
| 23 | + ): |
| 24 | + self.alpha = alpha |
| 25 | + self.weights = weights |
| 26 | + self.algo = algo |
| 27 | + self.max_iter = max_iter |
| 28 | + self.tol = tol |
| 29 | + self.warm_start = warm_start |
| 30 | + self.inner_tol = inner_tol |
| 31 | + self.verbose = verbose |
| 32 | + |
| 33 | + def fit(self, S): |
| 34 | + p = S.shape[-1] |
| 35 | + indices = np.arange(p) |
| 36 | + |
| 37 | + if self.weights is None: |
| 38 | + Weights = np.ones((p, p)) |
| 39 | + else: |
| 40 | + Weights = self.weights |
| 41 | + if not np.allclose(Weights, Weights.T): |
| 42 | + raise ValueError("Weights should be symmetric.") |
| 43 | + |
| 44 | + if self.warm_start and hasattr(self, "precision_"): |
| 45 | + if self.algo == "dual": |
| 46 | + raise ValueError( |
| 47 | + "dual does not support warm start for now.") |
| 48 | + Theta = self.precision_ |
| 49 | + W = self.covariance_ |
| 50 | + else: |
| 51 | + W = S.copy() |
| 52 | + W *= 0.95 |
| 53 | + diagonal = S.flat[:: p + 1] |
| 54 | + W.flat[:: p + 1] = diagonal |
| 55 | + Theta = pinvh(W) |
| 56 | + |
| 57 | + W_11 = np.copy(W[1:, 1:], order="C") |
| 58 | + eps = np.finfo(np.float64).eps |
| 59 | + it = 0 |
| 60 | + Theta_old = Theta.copy() |
| 61 | + |
| 62 | + for it in range(self.max_iter): |
| 63 | + Theta_old = Theta.copy() |
| 64 | + |
| 65 | + for col in range(p): |
| 66 | + if self.algo == "primal": |
| 67 | + indices_minus_col = np.concatenate( |
| 68 | + [indices[:col], indices[col + 1:]]) |
| 69 | + _11 = indices_minus_col[:, None], indices_minus_col[None] |
| 70 | + _12 = indices_minus_col, col |
| 71 | + _22 = col, col |
| 72 | + |
| 73 | + elif self.algo == "dual": |
| 74 | + if col > 0: |
| 75 | + di = col - 1 |
| 76 | + W_11[di] = W[di][indices != col] |
| 77 | + W_11[:, di] = W[:, di][indices != col] |
| 78 | + else: |
| 79 | + W_11[:] = W[1:, 1:] |
| 80 | + |
| 81 | + s_12 = S[col, indices != col] |
| 82 | + |
| 83 | + if self.algo == "dual": |
| 84 | + beta_init = (Theta[indices != col, col] / |
| 85 | + (Theta[col, col] + 1000 * eps)) |
| 86 | + Q = W_11 |
| 87 | + |
| 88 | + elif self.algo == "primal": |
| 89 | + inv_Theta_11 = (W[_11] - |
| 90 | + np.outer(W[_12], |
| 91 | + W[_12])/W[_22]) |
| 92 | + Q = inv_Theta_11 |
| 93 | + beta_init = Theta[indices != col, col] * S[col, col] |
| 94 | + else: |
| 95 | + raise ValueError(f"Unsupported algo {self.algo}") |
| 96 | + |
| 97 | + beta = barebones_cd_gram( |
| 98 | + Q, |
| 99 | + s_12, |
| 100 | + x=beta_init, |
| 101 | + alpha=self.alpha, |
| 102 | + weights=Weights[indices != col, col], |
| 103 | + tol=self.inner_tol, |
| 104 | + max_iter=self.max_iter, |
| 105 | + ) |
| 106 | + |
| 107 | + if self.algo == "dual": |
| 108 | + w_12 = -np.dot(W_11, beta) |
| 109 | + W[col, indices != col] = w_12 |
| 110 | + W[indices != col, col] = w_12 |
| 111 | + |
| 112 | + Theta[col, col] = 1 / \ |
| 113 | + (W[col, col] + np.dot(beta, w_12)) |
| 114 | + Theta[indices != col, col] = beta*Theta[col, col] |
| 115 | + Theta[col, indices != col] = beta*Theta[col, col] |
| 116 | + |
| 117 | + else: # primal |
| 118 | + s_22 = S[col, col] |
| 119 | + |
| 120 | + # Updating Theta |
| 121 | + theta_12 = beta / s_22 |
| 122 | + Theta[indices != col, col] = theta_12 |
| 123 | + Theta[col, indices != col] = theta_12 |
| 124 | + Theta[col, col] = (1/s_22 + |
| 125 | + theta_12 @ |
| 126 | + inv_Theta_11 @ |
| 127 | + theta_12) |
| 128 | + theta_22 = Theta[col, col] |
| 129 | + |
| 130 | + # Updating W |
| 131 | + W[col, col] = (1/(theta_22 - |
| 132 | + theta_12 @ |
| 133 | + inv_Theta_11 @ |
| 134 | + theta_12)) |
| 135 | + w_22 = W[col, col] |
| 136 | + |
| 137 | + w_12 = (-w_22 * inv_Theta_11 @ theta_12) |
| 138 | + W[indices != col, col] = w_12 |
| 139 | + W[col, indices != col] = w_12 |
| 140 | + |
| 141 | + # Maybe W_11 can be done smarter ? |
| 142 | + W[_11] = (inv_Theta_11 + |
| 143 | + np.outer(w_12, |
| 144 | + w_12)/w_22) |
| 145 | + |
| 146 | + if np.linalg.norm(Theta - Theta_old) < self.tol: |
| 147 | + if self.verbose: |
| 148 | + print(f"Weighted Glasso converged at CD epoch {it + 1}") |
| 149 | + break |
| 150 | + else: |
| 151 | + if self.verbose: |
| 152 | + print( |
| 153 | + f"Not converged at epoch {it + 1}, " |
| 154 | + f"diff={np.linalg.norm(Theta - Theta_old):.2e}" |
| 155 | + ) |
| 156 | + self.precision_, self.covariance_ = Theta, W |
| 157 | + self.n_iter_ = it + 1 |
| 158 | + |
| 159 | + return self |
| 160 | + |
| 161 | + |
| 162 | +class AdaptiveGraphicalLasso(): |
| 163 | + """ An adaptive version of the Graphical Lasso that solves non-convex penalty |
| 164 | + variations using the reweighting strategy from Candès et al., 2007.""" |
| 165 | + |
| 166 | + def __init__( |
| 167 | + self, |
| 168 | + alpha=1., |
| 169 | + strategy="log", |
| 170 | + n_reweights=5, |
| 171 | + max_iter=1000, |
| 172 | + tol=1e-8, |
| 173 | + warm_start=False, |
| 174 | + ): |
| 175 | + self.alpha = alpha |
| 176 | + self.strategy = strategy |
| 177 | + self.n_reweights = n_reweights |
| 178 | + self.max_iter = max_iter |
| 179 | + self.tol = tol |
| 180 | + self.warm_start = warm_start |
| 181 | + |
| 182 | + def fit(self, S): |
| 183 | + glasso = GraphicalLasso( |
| 184 | + alpha=self.alpha, |
| 185 | + algo="primal", |
| 186 | + max_iter=self.max_iter, |
| 187 | + tol=self.tol, |
| 188 | + warm_start=True) |
| 189 | + Weights = np.ones(S.shape) |
| 190 | + self.n_iter_ = [] |
| 191 | + for it in range(self.n_reweights): |
| 192 | + glasso.weights = Weights |
| 193 | + glasso.fit(S) |
| 194 | + Theta = glasso.precision_ |
| 195 | + Weights = update_weights(Theta, self.alpha, strategy=self.strategy) |
| 196 | + self.n_iter_.append(glasso.n_iter_) |
| 197 | + # TODO print losses for original problem? |
| 198 | + glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True) |
| 199 | + self.precision_ = glasso.precision_ |
| 200 | + self.covariance_ = glasso.covariance_ |
| 201 | + return self |
| 202 | + |
| 203 | + |
| 204 | +def update_weights(Theta, alpha, strategy="log"): |
| 205 | + if strategy == "log": |
| 206 | + return 1/(np.abs(Theta) + 1e-10) |
| 207 | + elif strategy == "sqrt": |
| 208 | + return 1/(2*np.sqrt(np.abs(Theta)) + 1e-10) |
| 209 | + elif strategy == "mcp": |
| 210 | + gamma = 3. |
| 211 | + Weights = np.zeros_like(Theta) |
| 212 | + Weights[np.abs(Theta) |
| 213 | + < gamma*alpha] = (alpha - |
| 214 | + np.abs(Theta[np.abs(Theta) |
| 215 | + < gamma*alpha])/gamma) |
| 216 | + return Weights |
| 217 | + else: |
| 218 | + raise ValueError(f"Unknown strategy {strategy}") |
0 commit comments