55import matplotlib .colors as mcolors
66from skglm .penalties .separable import LogSumPenalty
77from sklearn .datasets import make_sparse_spd_matrix
8- from skglm .utils .data import make_dummy_covariance_data
98import matplotlib .pyplot as plt
109import numpy as np
1110from scipy .linalg import pinvh
1514
1615
1716class GraphicalLasso ():
18- """ A first-order BCD Graphical Lasso solver implementing the GLasso algorithm
19- described in Friedman et al., 2008 and the P-GLasso algorithm described in
20- Mazumder et al., 2012."""
17+ """A first-order BCD Graphical Lasso solver.
18+
19+ Implementing the GLasso algorithm described in Friedman et al., 2008 and
20+ the P-GLasso algorithm described in Mazumder et al., 2012.
21+ """
2122
2223 def __init__ (self ,
2324 alpha = 1. ,
@@ -168,8 +169,26 @@ def fit(self, S):
168169
169170
170171class AdaptiveGraphicalLassoPenalty ():
171- """ An adaptive version of the Graphical Lasso that solves non-convex penalty
172- variations using the reweighting strategy from Candès et al., 2007."""
172+ """An adaptive version of the Graphical Lasso with non-convex penalties.
173+
174+ Solves non-convex penalty variations using the reweighting strategy
175+ from Candès et al., 2007.
176+
177+ Parameters
178+ ----------
179+ alpha : float, default=1.0
180+ Regularization parameter controlling sparsity.
181+ n_reweights : int, default=5
182+ Number of reweighting iterations.
183+ max_iter : int, default=1000
184+ Maximum iterations for inner solver.
185+ tol : float, default=1e-8
186+ Convergence tolerance.
187+ warm_start : bool, default=False
188+ Whether to use warm start.
189+ penalty : Penalty object, default=L0_5(1.)
190+ Non-convex penalty function to use for reweighting.
191+ """
173192
174193 def __init__ (
175194 self ,
@@ -190,7 +209,7 @@ def __init__(
190209 self .penalty = penalty
191210
192211 def fit (self , S ):
193- """ Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S."""
212+ """Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S."""
194213 glasso = GraphicalLasso (
195214 alpha = self .alpha ,
196215 algo = "primal" ,
@@ -212,7 +231,8 @@ def fit(self, S):
212231 )
213232
214233 print (
215- f"Min/Max Weights after penalty derivative: { Weights .min ():.2e} , { Weights .max ():.2e} " )
234+ f"Min/Max Weights after penalty derivative: "
235+ f"{ Weights .min ():.2e} , { Weights .max ():.2e} " )
216236
217237 self .n_iter_ .append (glasso .n_iter_ )
218238 # TODO print losses for original problem?
@@ -222,15 +242,34 @@ def fit(self, S):
222242 self .covariance_ = glasso .covariance_
223243 if not np .isclose (self .alpha , self .penalty .alpha ):
224244 print (
225- f"Alpha mismatch: GLasso alpha = { self .alpha } , Penalty alpha = { self .penalty .alpha } " )
245+ f"Alpha mismatch: GLasso alpha = { self .alpha } , "
246+ f"Penalty alpha = { self .penalty .alpha } " )
226247 else :
227248 print (f"Alpha values match: { self .alpha } " )
228249 return self
229250
230251
231252class AdaptiveGraphicalLasso ():
232- """ An adaptive version of the Graphical Lasso that solves non-convex penalty
233- variations using the reweighting strategy from Candès et al., 2007."""
253+ """An adaptive version of the Graphical Lasso with non-convex penalties.
254+
255+ Solves non-convex penalty variations using the reweighting strategy
256+ from Candès et al., 2007.
257+
258+ Parameters
259+ ----------
260+ alpha : float, default=1.0
261+ Regularization parameter controlling sparsity.
262+ strategy : str, default="log"
263+ Reweighting strategy: "log", "sqrt", or "mcp".
264+ n_reweights : int, default=5
265+ Number of reweighting iterations.
266+ max_iter : int, default=1000
267+ Maximum iterations for inner solver.
268+ tol : float, default=1e-8
269+ Convergence tolerance.
270+ warm_start : bool, default=False
271+ Whether to use warm start.
272+ """
234273
235274 def __init__ (
236275 self ,
@@ -271,6 +310,7 @@ def fit(self, S):
271310
272311
273312def update_weights (Theta , alpha , strategy = "log" ):
313+ """Update weights for adaptive graphical lasso based on strategy."""
274314 if strategy == "log" :
275315 return 1 / (np .abs (Theta ) + 1e-10 )
276316 elif strategy == "sqrt" :
@@ -345,7 +385,8 @@ def generate_problem(dim=20, n_samples=100, seed=42):
345385 # Compare the two estimated models
346386 rel_diff_between_models = frobenius_norm_diff (Theta_penalty , Theta_strategy )
347387 print (
348- f"\n Frobenius norm relative difference between models: { rel_diff_between_models :.2e} " )
388+ f"\n Frobenius norm relative difference between models: "
389+ f"{ rel_diff_between_models :.2e} " )
349390 print (" Matrices are close?" , np .allclose (
350391 Theta_penalty , Theta_strategy , atol = 1e-4 ))
351392
0 commit comments