Skip to content

Commit 494221c

Browse files
add example for webpage,add cv to api
1 parent 299fdb0 commit 494221c

File tree

3 files changed

+121
-9
lines changed

3 files changed

+121
-9
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Estimators
1818
:toctree: generated/
1919

2020
GeneralizedLinearEstimator
21+
GeneralizedLinearEstimatorCV
2122
CoxEstimator
2223
ElasticNet
2324
GroupLasso
Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,138 @@
11
"""
22
===================================
3-
Cross Validation for Generalized Linear Estimator
3+
Cross-Validation for Generalized Linear Models
44
===================================
5+
6+
This example shows how to use cross-validation to automatically select
7+
the optimal regularization parameter for generalized linear models.
58
"""
9+
10+
# Author: Florian Kozikowski
11+
612
import numpy as np
7-
from sklearn.datasets import make_regression
13+
import matplotlib.pyplot as plt
14+
15+
from skglm.utils.data import make_correlated_data
816
from skglm.cv import GeneralizedLinearEstimatorCV
17+
from skglm.estimators import GeneralizedLinearEstimator
918
from skglm.datafits import Quadratic
1019
from skglm.penalties import L1_plus_L2
1120
from skglm.solvers import AndersonCD
1221

22+
# %%
23+
# Generate correlated data with sparse ground truth
24+
# --------------------------------------------------
25+
X, y, true_coef = make_correlated_data(
26+
n_samples=150, n_features=300, random_state=42
27+
)
1328

14-
X, y = make_regression(n_samples=100, n_features=20, noise=0.1, random_state=42)
15-
29+
# %%
30+
# Fit model using cross-validation
31+
# --------------------------------
32+
# The CV estimator automatically finds the best regularization strength
1633
estimator = GeneralizedLinearEstimatorCV(
1734
datafit=Quadratic(),
1835
penalty=L1_plus_L2(alpha=1.0, l1_ratio=0.5),
19-
solver=AndersonCD(max_iter=50, tol=1e-4),
20-
cv=6,
36+
solver=AndersonCD(max_iter=100),
37+
cv=5,
38+
n_alphas=50,
2139
)
2240
estimator.fit(X, y)
2341

2442
print(f"Best alpha: {estimator.alpha_:.3f}")
25-
print(f"L1 ratio: {estimator.penalty.l1_ratio:.3f}")
26-
print(f"Number of non-zero coefficients: {np.sum(estimator.coef_ != 0)}")
43+
n_nonzero = np.sum(estimator.coef_ != 0)
44+
n_true_nonzero = np.sum(true_coef != 0)
45+
print(f"Non-zero coefficients: {n_nonzero} (true: {n_true_nonzero})")
46+
47+
# %%
48+
# Visualize the cross-validation path
49+
# -----------------------------------
50+
# Plot shows how CV balances model complexity with prediction performance
51+
52+
# Get mean CV scores
53+
mean_scores = np.mean(estimator.scores_path_, axis=1)
54+
std_scores = np.std(estimator.scores_path_, axis=1)
55+
best_idx = np.argmax(mean_scores)
56+
best_alpha = estimator.alphas_[best_idx]
57+
58+
# Compute coefficient paths
59+
coef_paths = []
60+
for alpha in estimator.alphas_:
61+
est_temp = GeneralizedLinearEstimator(
62+
datafit=Quadratic(),
63+
penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5),
64+
solver=AndersonCD(max_iter=100)
65+
)
66+
est_temp.fit(X, y)
67+
coef_paths.append(est_temp.coef_)
68+
coef_paths = np.array(coef_paths)
69+
70+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)
71+
72+
ax1.semilogx(estimator.alphas_, -mean_scores, 'b-', linewidth=2, label='MSE')
73+
ax1.fill_between(estimator.alphas_,
74+
-mean_scores - std_scores,
75+
-mean_scores + std_scores,
76+
alpha=0.2, label='±1 std. dev.')
77+
ax1.axvline(best_alpha, color='red', linestyle='--',
78+
label=f'Best alpha = {best_alpha:.2e}')
79+
ax1.set_ylabel('MSE')
80+
ax1.set_title('Cross-Validation Score vs. Regularization')
81+
ax1.legend(loc='best')
82+
ax1.grid(True, alpha=0.3)
83+
ax1.set_xlabel('alpha')
84+
85+
for j in range(coef_paths.shape[1]):
86+
ax2.semilogx(estimator.alphas_, coef_paths[:, j], lw=1, alpha=0.3)
87+
ax2.axvline(best_alpha, color='red', linestyle='--')
88+
ax2.set_xlabel('alpha')
89+
ax2.set_ylabel('Coefficient value')
90+
ax2.set_title('Regularization Path of Coefficients')
91+
ax2.grid(True, alpha=0.3)
92+
93+
plt.tight_layout()
94+
plt.show()
95+
96+
# %% [markdown]
97+
# Top panel: Mean CV MSE shows U-shape, minimized at chosen alpha for optimal
98+
# bias-variance tradeoff.
99+
#
100+
# Bottom panel: At this alpha, most coefficients are shrunk (many near zero),
101+
# highlighting a sparse subset of key predictors.
102+
103+
104+
# %%
105+
# Visualize distance to true coefficients
106+
# ----------------------------------------
107+
# Compute how well different regularization strengths recover the true coefficients
108+
109+
distances = []
110+
for alpha in estimator.alphas_:
111+
est_temp = GeneralizedLinearEstimator(
112+
datafit=Quadratic(),
113+
penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5),
114+
solver=AndersonCD(max_iter=100)
115+
)
116+
est_temp.fit(X, y)
117+
distances.append(np.linalg.norm(est_temp.coef_ - true_coef, ord=1))
118+
119+
plt.figure(figsize=(8, 5))
120+
plt.loglog(estimator.alphas_, distances, 'b-', linewidth=2)
121+
plt.axvline(estimator.alpha_, color='red', linestyle='--',
122+
label=f'CV-selected alpha = {estimator.alpha_:.3f}')
123+
plt.xlabel('Alpha (regularization strength)')
124+
plt.ylabel('L1 distance to true coefficients')
125+
plt.title('Recovery of True Coefficients')
126+
plt.legend()
127+
plt.grid(True, alpha=0.3)
128+
plt.show()
129+
130+
print(
131+
f"Distance at CV-selected alpha: "
132+
f"{np.linalg.norm(estimator.coef_ - true_coef, ord=1):.3f}")
27133

28-
# TODO: add plot, test with other penalties and datafits
134+
# %% [markdown]
135+
# The U-shaped curve shows two failure modes: small alpha doesn't induce
136+
# enough sparsity (keeping noisy/irrelevant features), while large alpha
137+
# overshrinks all coefficients including the true signals. Cross-validation
138+
# finds a good balance without needing access to the ground truth.

skglm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,
55
SparseLogisticRegression, GeneralizedLinearEstimator, CoxEstimator, GroupLasso,
66
)
7+
from .cv import GeneralizedLinearEstimatorCV # noqa F401

0 commit comments

Comments
 (0)