Skip to content

Commit 3e3df79

Browse files
committed
add Reweighted GLasso regularization path example
1 parent 5696d40 commit 3e3df79

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import numpy as np
2+
from numpy.linalg import norm
3+
import matplotlib.pyplot as plt
4+
from sklearn.metrics import f1_score
5+
from sklearn.datasets import make_sparse_spd_matrix
6+
from sklearn.utils import check_random_state
7+
8+
from skglm.estimators import GraphicalLasso
9+
from skglm.estimators import AdaptiveGraphicalLasso
10+
11+
# Data
12+
p = 100
13+
n = 1000
14+
rng = check_random_state(0)
15+
Theta_true = make_sparse_spd_matrix(
16+
p,
17+
alpha=0.9,
18+
random_state=rng)
19+
20+
Theta_true += 0.1*np.eye(p)
21+
Sigma_true = np.linalg.pinv(Theta_true, hermitian=True)
22+
X = rng.multivariate_normal(
23+
mean=np.zeros(p),
24+
cov=Sigma_true,
25+
size=n,
26+
)
27+
28+
S = np.cov(X, bias=True, rowvar=False)
29+
S_cpy = np.copy(S)
30+
np.fill_diagonal(S_cpy, 0.)
31+
alpha_max = np.max(np.abs(S_cpy))
32+
33+
alphas = alpha_max*np.geomspace(1, 1e-4, num=10)
34+
35+
36+
penalties = [
37+
"L1",
38+
"R-L1 (log)",
39+
"R-L1 (L0.5)",
40+
"R-L1 (MCP)",
41+
]
42+
n_reweights = 5
43+
models_tol = 1e-4
44+
models = [
45+
GraphicalLasso(algo="primal",
46+
warm_start=True,
47+
tol=models_tol),
48+
AdaptiveGraphicalLasso(warm_start=True,
49+
strategy="log",
50+
n_reweights=n_reweights,
51+
tol=models_tol),
52+
AdaptiveGraphicalLasso(warm_start=True,
53+
strategy="sqrt",
54+
n_reweights=n_reweights,
55+
tol=models_tol),
56+
AdaptiveGraphicalLasso(warm_start=True,
57+
strategy="mcp",
58+
n_reweights=n_reweights,
59+
tol=models_tol),
60+
]
61+
62+
my_glasso_nmses = {penalty: [] for penalty in penalties}
63+
my_glasso_f1_scores = {penalty: [] for penalty in penalties}
64+
65+
sk_glasso_nmses = []
66+
sk_glasso_f1_scores = []
67+
68+
69+
for i, (penalty, model) in enumerate(zip(penalties, models)):
70+
print(penalty)
71+
for alpha_idx, alpha in enumerate(alphas):
72+
print(f"======= alpha {alpha_idx+1}/{len(alphas)} =======")
73+
model.alpha = alpha
74+
model.fit(S)
75+
Theta = model.precision_
76+
77+
my_nmse = norm(Theta - Theta_true)**2 / norm(Theta_true)**2
78+
79+
my_f1_score = f1_score(Theta.flatten() != 0.,
80+
Theta_true.flatten() != 0.)
81+
print(f"NMSE: {my_nmse:.3f}")
82+
print(f"F1 : {my_f1_score:.3f}")
83+
84+
my_glasso_nmses[penalty].append(my_nmse)
85+
my_glasso_f1_scores[penalty].append(my_f1_score)
86+
87+
88+
plt.close('all')
89+
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(
90+
[6.11, 3.91]), layout="constrained")
91+
cmap = plt.get_cmap("tab10")
92+
for i, penalty in enumerate(penalties):
93+
94+
ax[0].semilogx(alphas/alpha_max,
95+
my_glasso_nmses[penalty],
96+
color=cmap(i),
97+
linewidth=2.,
98+
label=penalty)
99+
min_nmse = np.argmin(my_glasso_nmses[penalty])
100+
ax[0].vlines(
101+
x=alphas[min_nmse] / alphas[0],
102+
ymin=0,
103+
ymax=np.min(my_glasso_nmses[penalty]),
104+
linestyle='--',
105+
color=cmap(i))
106+
line0 = ax[0].plot(
107+
[alphas[min_nmse] / alphas[0]],
108+
0,
109+
clip_on=False,
110+
marker='X',
111+
color=cmap(i),
112+
markersize=12)
113+
114+
ax[1].semilogx(alphas/alpha_max,
115+
my_glasso_f1_scores[penalty],
116+
linewidth=2.,
117+
color=cmap(i))
118+
max_f1 = np.argmax(my_glasso_f1_scores[penalty])
119+
ax[1].vlines(
120+
x=alphas[max_f1] / alphas[0],
121+
ymin=0,
122+
ymax=np.max(my_glasso_f1_scores[penalty]),
123+
linestyle='--',
124+
color=cmap(i))
125+
line1 = ax[1].plot(
126+
[alphas[max_f1] / alphas[0]],
127+
0,
128+
clip_on=False,
129+
marker='X',
130+
markersize=12,
131+
color=cmap(i))
132+
133+
134+
ax[0].set_title(f"{p=},{n=}", fontsize=18)
135+
ax[0].set_ylabel("NMSE", fontsize=18)
136+
ax[1].set_ylabel("F1 score", fontsize=18)
137+
ax[1].set_xlabel(f"$\lambda / \lambda_\mathrm{{max}}$", fontsize=18)
138+
139+
ax[0].legend(fontsize=14)
140+
ax[0].grid(which='both', alpha=0.9)
141+
ax[1].grid(which='both', alpha=0.9)
142+
# plt.savefig(f"./non_convex_p{p}_n{n}.pdf")
143+
plt.show(block=False)

0 commit comments

Comments
 (0)