Skip to content

Commit ca6960f

Browse files
committed
fix issues in glasso reg path example
1 parent c9da9d4 commit ca6960f

File tree

2 files changed

+58
-194
lines changed

2 files changed

+58
-194
lines changed

examples/plot_graphical_lasso.py

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 58 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,33 @@
1+
# Authors: Can Pouliquen
2+
# Mathurin Massias
3+
"""
4+
=======================================================================
5+
Regularization paths for the Graphical Lasso and its Adaptive variation
6+
=======================================================================
7+
Highlight the importance of using non-convex regularization for improved performance,
8+
solved using the reweighting strategy.
9+
"""
10+
111
import numpy as np
212
from numpy.linalg import norm
313
import matplotlib.pyplot as plt
414
from sklearn.metrics import f1_score
5-
from sklearn.datasets import make_sparse_spd_matrix
6-
from sklearn.utils import check_random_state
715

16+
from skglm.utils.data import generate_GraphicalLasso_data
817
from skglm.estimators import GraphicalLasso
918
from skglm.estimators import AdaptiveGraphicalLasso
1019

11-
# Data
20+
1221
p = 100
1322
n = 1000
14-
rng = check_random_state(0)
15-
Theta_true = make_sparse_spd_matrix(
16-
p,
17-
alpha=0.9,
18-
random_state=rng)
19-
20-
Theta_true += 0.1*np.eye(p)
21-
Sigma_true = np.linalg.pinv(Theta_true, hermitian=True)
22-
X = rng.multivariate_normal(
23-
mean=np.zeros(p),
24-
cov=Sigma_true,
25-
size=n,
26-
)
27-
28-
S = np.cov(X, bias=True, rowvar=False)
29-
S_cpy = np.copy(S)
30-
np.fill_diagonal(S_cpy, 0.)
31-
alpha_max = np.max(np.abs(S_cpy))
32-
23+
S, Theta_true, alpha_max = generate_GraphicalLasso_data(n, p)
3324
alphas = alpha_max*np.geomspace(1, 1e-4, num=10)
3425

35-
3626
penalties = [
3727
"L1",
38-
"R-L1 (log)",
39-
"R-L1 (L0.5)",
40-
"R-L1 (MCP)",
28+
"Log",
29+
"L0.5",
30+
"MCP",
4131
]
4232
n_reweights = 5
4333
models_tol = 1e-4
@@ -67,9 +57,8 @@
6757

6858

6959
for i, (penalty, model) in enumerate(zip(penalties, models)):
70-
print(penalty)
7160
for alpha_idx, alpha in enumerate(alphas):
72-
print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======")
61+
print(f"======= {penalty} penalty, alpha {alpha_idx+1}/{len(alphas)} =======")
7362
model.alpha = alpha
7463
model.fit(S)
7564
Theta = model.precision_
@@ -78,66 +67,54 @@
7867

7968
my_f1_score = f1_score(Theta.flatten() != 0.,
8069
Theta_true.flatten() != 0.)
81-
print(f"NMSE: {my_nmse:.3f}")
82-
print(f"F1 : {my_f1_score:.3f}")
8370

8471
my_glasso_nmses[penalty].append(my_nmse)
8572
my_glasso_f1_scores[penalty].append(my_f1_score)
8673

8774

8875
plt.close('all')
89-
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(
90-
[6.11, 3.91]), layout="constrained")
76+
fig, axarr = plt.subplots(2, 1, sharex=True, figsize=([6.11, 3.91]),
77+
layout="constrained")
9178
cmap = plt.get_cmap("tab10")
9279
for i, penalty in enumerate(penalties):
9380

94-
ax[0].semilogx(alphas/alpha_max,
95-
my_glasso_nmses[penalty],
96-
color=cmap(i),
97-
linewidth=2.,
98-
label=penalty)
99-
min_nmse = np.argmin(my_glasso_nmses[penalty])
100-
ax[0].vlines(
101-
x=alphas[min_nmse] / alphas[0],
102-
ymin=0,
103-
ymax=np.min(my_glasso_nmses[penalty]),
104-
linestyle='--',
105-
color=cmap(i))
106-
line0 = ax[0].plot(
107-
[alphas[min_nmse] / alphas[0]],
108-
0,
109-
clip_on=False,
110-
marker='X',
111-
color=cmap(i),
112-
markersize=12)
113-
114-
ax[1].semilogx(alphas/alpha_max,
115-
my_glasso_f1_scores[penalty],
116-
linewidth=2.,
117-
color=cmap(i))
118-
max_f1 = np.argmax(my_glasso_f1_scores[penalty])
119-
ax[1].vlines(
120-
x=alphas[max_f1] / alphas[0],
121-
ymin=0,
122-
ymax=np.max(my_glasso_f1_scores[penalty]),
123-
linestyle='--',
124-
color=cmap(i))
125-
line1 = ax[1].plot(
126-
[alphas[max_f1] / alphas[0]],
127-
0,
128-
clip_on=False,
129-
marker='X',
130-
markersize=12,
131-
color=cmap(i))
132-
133-
134-
ax[0].set_title(f"{p=},{n=}", fontsize=18)
135-
ax[0].set_ylabel("NMSE", fontsize=18)
136-
ax[1].set_ylabel("F1 score", fontsize=18)
137-
ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18)
138-
139-
ax[0].legend(fontsize=14)
140-
ax[0].grid(which='both', alpha=0.9)
141-
ax[1].grid(which='both', alpha=0.9)
142-
# plt.savefig(f"./non_convex_p{p}_n{n}.pdf")
81+
for j, ax in enumerate(axarr):
82+
83+
if j == 0:
84+
metric = my_glasso_nmses
85+
best_idx = np.argmin(metric[penalty])
86+
ystop = np.min(metric[penalty])
87+
else:
88+
metric = my_glasso_f1_scores
89+
best_idx = np.argmax(metric[penalty])
90+
ystop = np.max(metric[penalty])
91+
92+
ax.semilogx(alphas/alpha_max,
93+
metric[penalty],
94+
color=cmap(i),
95+
linewidth=2.,
96+
label=penalty)
97+
98+
ax.vlines(
99+
x=alphas[best_idx] / alphas[0],
100+
ymin=0,
101+
ymax=ystop,
102+
linestyle='--',
103+
color=cmap(i))
104+
line = ax.plot(
105+
[alphas[best_idx] / alphas[0]],
106+
0,
107+
clip_on=False,
108+
marker='X',
109+
color=cmap(i),
110+
markersize=12)
111+
112+
ax.grid(which='both', alpha=0.9)
113+
114+
axarr[0].legend(fontsize=14)
115+
axarr[0].set_title(f"{p=},{n=}", fontsize=18)
116+
axarr[0].set_ylabel("NMSE", fontsize=18)
117+
axarr[1].set_ylabel("F1 score", fontsize=18)
118+
axarr[1].set_xlabel(r"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18)
119+
143120
plt.show(block=False)

0 commit comments

Comments
 (0)