Skip to content

Commit 16dbe4e

Browse files
fix linters
1 parent 264faed commit 16dbe4e

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

skglm/covariance.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import matplotlib.colors as mcolors
66
from skglm.penalties.separable import LogSumPenalty
77
from sklearn.datasets import make_sparse_spd_matrix
8-
from skglm.utils.data import make_dummy_covariance_data
98
import matplotlib.pyplot as plt
109
import numpy as np
1110
from scipy.linalg import pinvh
@@ -15,9 +14,11 @@
1514

1615

1716
class GraphicalLasso():
18-
""" A first-order BCD Graphical Lasso solver implementing the GLasso algorithm
19-
described in Friedman et al., 2008 and the P-GLasso algorithm described in
20-
Mazumder et al., 2012."""
17+
"""A first-order BCD Graphical Lasso solver.
18+
19+
Implementing the GLasso algorithm described in Friedman et al., 2008 and
20+
the P-GLasso algorithm described in Mazumder et al., 2012.
21+
"""
2122

2223
def __init__(self,
2324
alpha=1.,
@@ -168,8 +169,26 @@ def fit(self, S):
168169

169170

170171
class AdaptiveGraphicalLassoPenalty():
171-
""" An adaptive version of the Graphical Lasso that solves non-convex penalty
172-
variations using the reweighting strategy from Candès et al., 2007."""
172+
"""An adaptive version of the Graphical Lasso with non-convex penalties.
173+
174+
Solves non-convex penalty variations using the reweighting strategy
175+
from Candès et al., 2007.
176+
177+
Parameters
178+
----------
179+
alpha : float, default=1.0
180+
Regularization parameter controlling sparsity.
181+
n_reweights : int, default=5
182+
Number of reweighting iterations.
183+
max_iter : int, default=1000
184+
Maximum iterations for inner solver.
185+
tol : float, default=1e-8
186+
Convergence tolerance.
187+
warm_start : bool, default=False
188+
Whether to use warm start.
189+
penalty : Penalty object, default=L0_5(1.)
190+
Non-convex penalty function to use for reweighting.
191+
"""
173192

174193
def __init__(
175194
self,
@@ -190,7 +209,7 @@ def __init__(
190209
self.penalty = penalty
191210

192211
def fit(self, S):
193-
""" Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S."""
212+
"""Fit the AdaptiveGraphicalLasso model on the empirical covariance matrix S."""
194213
glasso = GraphicalLasso(
195214
alpha=self.alpha,
196215
algo="primal",
@@ -212,7 +231,8 @@ def fit(self, S):
212231
)
213232

214233
print(
215-
f"Min/Max Weights after penalty derivative: {Weights.min():.2e}, {Weights.max():.2e}")
234+
f"Min/Max Weights after penalty derivative: "
235+
f"{Weights.min():.2e}, {Weights.max():.2e}")
216236

217237
self.n_iter_.append(glasso.n_iter_)
218238
# TODO print losses for original problem?
@@ -222,15 +242,34 @@ def fit(self, S):
222242
self.covariance_ = glasso.covariance_
223243
if not np.isclose(self.alpha, self.penalty.alpha):
224244
print(
225-
f"Alpha mismatch: GLasso alpha = {self.alpha}, Penalty alpha = {self.penalty.alpha}")
245+
f"Alpha mismatch: GLasso alpha = {self.alpha}, "
246+
f"Penalty alpha = {self.penalty.alpha}")
226247
else:
227248
print(f"Alpha values match: {self.alpha}")
228249
return self
229250

230251

231252
class AdaptiveGraphicalLasso():
232-
""" An adaptive version of the Graphical Lasso that solves non-convex penalty
233-
variations using the reweighting strategy from Candès et al., 2007."""
253+
"""An adaptive version of the Graphical Lasso with non-convex penalties.
254+
255+
Solves non-convex penalty variations using the reweighting strategy
256+
from Candès et al., 2007.
257+
258+
Parameters
259+
----------
260+
alpha : float, default=1.0
261+
Regularization parameter controlling sparsity.
262+
strategy : str, default="log"
263+
Reweighting strategy: "log", "sqrt", or "mcp".
264+
n_reweights : int, default=5
265+
Number of reweighting iterations.
266+
max_iter : int, default=1000
267+
Maximum iterations for inner solver.
268+
tol : float, default=1e-8
269+
Convergence tolerance.
270+
warm_start : bool, default=False
271+
Whether to use warm start.
272+
"""
234273

235274
def __init__(
236275
self,
@@ -271,6 +310,7 @@ def fit(self, S):
271310

272311

273312
def update_weights(Theta, alpha, strategy="log"):
313+
"""Update weights for adaptive graphical lasso based on strategy."""
274314
if strategy == "log":
275315
return 1/(np.abs(Theta) + 1e-10)
276316
elif strategy == "sqrt":
@@ -345,7 +385,8 @@ def generate_problem(dim=20, n_samples=100, seed=42):
345385
# Compare the two estimated models
346386
rel_diff_between_models = frobenius_norm_diff(Theta_penalty, Theta_strategy)
347387
print(
348-
f"\n Frobenius norm relative difference between models: {rel_diff_between_models:.2e}")
388+
f"\n Frobenius norm relative difference between models: "
389+
f"{rel_diff_between_models:.2e}")
349390
print(" Matrices are close?", np.allclose(
350391
Theta_penalty, Theta_strategy, atol=1e-4))
351392

skglm/solvers/gram_cd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd):
170170

171171
@njit
172172
def barebones_cd_gram(H, q, x, alpha, weights, max_iter=100, tol=1e-4):
173-
"""
174-
Solve min .5 * x.T H x + q.T @ x + alpha * norm(x, 1).
173+
"""Solve min .5 * x.T H x + q.T @ x + alpha * norm(x, 1).
175174
176175
H must be symmetric.
177176
"""

skglm/utils/data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ def _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights):
256256

257257

258258
def make_dummy_covariance_data(n_samples, n_features):
259+
"""Generate dummy data for covariance estimation problems.
260+
261+
Returns empirical covariance matrix, true precision matrix, and max
262+
off-diagonal value.
263+
"""
259264
rng = check_random_state(0)
260265
Theta_true = make_sparse_spd_matrix(
261266
n_features, alpha=0.9, random_state=rng)

0 commit comments

Comments
 (0)