Skip to content

Commit 4538002

Browse files
committed
fix glasso solver issues, move estimator to own file, create dedicated tests file
1 parent ca6960f commit 4538002

File tree

7 files changed

+363
-372
lines changed

7 files changed

+363
-372
lines changed

examples/plot_reweighted_glasso_reg_path.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
import matplotlib.pyplot as plt
1414
from sklearn.metrics import f1_score
1515

16-
from skglm.utils.data import generate_GraphicalLasso_data
17-
from skglm.estimators import GraphicalLasso
18-
from skglm.estimators import AdaptiveGraphicalLasso
16+
from skglm.covariance import GraphicalLasso, AdaptiveGraphicalLasso
17+
from skglm.utils.data import make_dummy_covariance_data
1918

2019

2120
p = 100
2221
n = 1000
23-
S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p)
22+
S, Theta_true, alpha_max = make_dummy_covariance_data(n, p)
2423
alphas = alpha_max*np.geomspace(1, 1e-4, num=10)
2524

2625
penalties = [

skglm/covariance.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# License: BSD 3 clause
2+
3+
import numpy as np
4+
from scipy.linalg import pinvh
5+
6+
from skglm.solvers.gram_cd import barebones_cd_gram
7+
8+
9+
class GraphicalLasso():
10+
""" A first-order BCD Graphical Lasso solver implementing the GLasso algorithm
11+
described in Friedman et al., 2008 and the P-GLasso algorithm described in
12+
Mazumder et al., 2012."""
13+
14+
def __init__(self,
15+
alpha=1.,
16+
weights=None,
17+
algo="dual",
18+
max_iter=100,
19+
tol=1e-8,
20+
warm_start=False,
21+
inner_tol=1e-4,
22+
verbose=False
23+
):
24+
self.alpha = alpha
25+
self.weights = weights
26+
self.algo = algo
27+
self.max_iter = max_iter
28+
self.tol = tol
29+
self.warm_start = warm_start
30+
self.inner_tol = inner_tol
31+
self.verbose = verbose
32+
33+
def fit(self, S):
34+
p = S.shape[-1]
35+
indices = np.arange(p)
36+
37+
if self.weights is None:
38+
Weights = np.ones((p, p))
39+
else:
40+
Weights = self.weights
41+
if not np.allclose(Weights, Weights.T):
42+
raise ValueError("Weights should be symmetric.")
43+
44+
if self.warm_start and hasattr(self, "precision_"):
45+
if self.algo == "dual":
46+
raise ValueError(
47+
"dual does not support warm start for now.")
48+
Theta = self.precision_
49+
W = self.covariance_
50+
else:
51+
W = S.copy()
52+
W *= 0.95
53+
diagonal = S.flat[:: p + 1]
54+
W.flat[:: p + 1] = diagonal
55+
Theta = pinvh(W)
56+
57+
W_11 = np.copy(W[1:, 1:], order="C")
58+
eps = np.finfo(np.float64).eps
59+
it = 0
60+
Theta_old = Theta.copy()
61+
62+
for it in range(self.max_iter):
63+
Theta_old = Theta.copy()
64+
65+
for col in range(p):
66+
if self.algo == "primal":
67+
indices_minus_col = np.concatenate(
68+
[indices[:col], indices[col + 1:]])
69+
_11 = indices_minus_col[:, None], indices_minus_col[None]
70+
_12 = indices_minus_col, col
71+
_22 = col, col
72+
73+
elif self.algo == "dual":
74+
if col > 0:
75+
di = col - 1
76+
W_11[di] = W[di][indices != col]
77+
W_11[:, di] = W[:, di][indices != col]
78+
else:
79+
W_11[:] = W[1:, 1:]
80+
81+
s_12 = S[col, indices != col]
82+
83+
if self.algo == "dual":
84+
beta_init = (Theta[indices != col, col] /
85+
(Theta[col, col] + 1000 * eps))
86+
Q = W_11
87+
88+
elif self.algo == "primal":
89+
inv_Theta_11 = (W[_11] -
90+
np.outer(W[_12],
91+
W[_12])/W[_22])
92+
Q = inv_Theta_11
93+
beta_init = Theta[indices != col, col] * S[col, col]
94+
else:
95+
raise ValueError(f"Unsupported algo {self.algo}")
96+
97+
beta = barebones_cd_gram(
98+
Q,
99+
s_12,
100+
x=beta_init,
101+
alpha=self.alpha,
102+
weights=Weights[indices != col, col],
103+
tol=self.inner_tol,
104+
max_iter=self.max_iter,
105+
)
106+
107+
if self.algo == "dual":
108+
w_12 = -np.dot(W_11, beta)
109+
W[col, indices != col] = w_12
110+
W[indices != col, col] = w_12
111+
112+
Theta[col, col] = 1 / \
113+
(W[col, col] + np.dot(beta, w_12))
114+
Theta[indices != col, col] = beta*Theta[col, col]
115+
Theta[col, indices != col] = beta*Theta[col, col]
116+
117+
else: # primal
118+
s_22 = S[col, col]
119+
120+
# Updating Theta
121+
theta_12 = beta / s_22
122+
Theta[indices != col, col] = theta_12
123+
Theta[col, indices != col] = theta_12
124+
Theta[col, col] = (1/s_22 +
125+
theta_12 @
126+
inv_Theta_11 @
127+
theta_12)
128+
theta_22 = Theta[col, col]
129+
130+
# Updating W
131+
W[col, col] = (1/(theta_22 -
132+
theta_12 @
133+
inv_Theta_11 @
134+
theta_12))
135+
w_22 = W[col, col]
136+
137+
w_12 = (-w_22 * inv_Theta_11 @ theta_12)
138+
W[indices != col, col] = w_12
139+
W[col, indices != col] = w_12
140+
141+
# Maybe W_11 can be done smarter ?
142+
W[_11] = (inv_Theta_11 +
143+
np.outer(w_12,
144+
w_12)/w_22)
145+
146+
if np.linalg.norm(Theta - Theta_old) < self.tol:
147+
if self.verbose:
148+
print(f"Weighted Glasso converged at CD epoch {it + 1}")
149+
break
150+
else:
151+
if self.verbose:
152+
print(
153+
f"Not converged at epoch {it + 1}, "
154+
f"diff={np.linalg.norm(Theta - Theta_old):.2e}"
155+
)
156+
self.precision_, self.covariance_ = Theta, W
157+
self.n_iter_ = it + 1
158+
159+
return self
160+
161+
162+
class AdaptiveGraphicalLasso():
163+
""" An adaptive version of the Graphical Lasso that solves non-convex penalty
164+
variations using the reweighting strategy from Candès et al., 2007."""
165+
166+
def __init__(
167+
self,
168+
alpha=1.,
169+
strategy="log",
170+
n_reweights=5,
171+
max_iter=1000,
172+
tol=1e-8,
173+
warm_start=False,
174+
):
175+
self.alpha = alpha
176+
self.strategy = strategy
177+
self.n_reweights = n_reweights
178+
self.max_iter = max_iter
179+
self.tol = tol
180+
self.warm_start = warm_start
181+
182+
def fit(self, S):
183+
glasso = GraphicalLasso(
184+
alpha=self.alpha,
185+
algo="primal",
186+
max_iter=self.max_iter,
187+
tol=self.tol,
188+
warm_start=True)
189+
Weights = np.ones(S.shape)
190+
self.n_iter_ = []
191+
for it in range(self.n_reweights):
192+
glasso.weights = Weights
193+
glasso.fit(S)
194+
Theta = glasso.precision_
195+
Weights = update_weights(Theta, self.alpha, strategy=self.strategy)
196+
self.n_iter_.append(glasso.n_iter_)
197+
# TODO print losses for original problem?
198+
glasso.covariance_ = np.linalg.pinv(Theta, hermitian=True)
199+
self.precision_ = glasso.precision_
200+
self.covariance_ = glasso.covariance_
201+
return self
202+
203+
204+
def update_weights(Theta, alpha, strategy="log"):
205+
if strategy == "log":
206+
return 1/(np.abs(Theta) + 1e-10)
207+
elif strategy == "sqrt":
208+
return 1/(2*np.sqrt(np.abs(Theta)) + 1e-10)
209+
elif strategy == "mcp":
210+
gamma = 3.
211+
Weights = np.zeros_like(Theta)
212+
Weights[np.abs(Theta)
213+
< gamma*alpha] = (alpha -
214+
np.abs(Theta[np.abs(Theta)
215+
< gamma*alpha])/gamma)
216+
return Weights
217+
else:
218+
raise ValueError(f"Unknown strategy {strategy}")

0 commit comments

Comments
 (0)