|
1 | 1 | """ |
2 | 2 | =================================== |
3 | | -Cross Validation for Generalized Linear Estimator |
| 3 | +Cross-Validation for Generalized Linear Models |
4 | 4 | =================================== |
| 5 | +
|
| 6 | +This example shows how to use cross-validation to automatically select |
| 7 | +the optimal regularization parameter for generalized linear models. |
5 | 8 | """ |
| 9 | + |
| 10 | +# Author: Florian Kozikowski |
| 11 | + |
6 | 12 | import numpy as np |
7 | | -from sklearn.datasets import make_regression |
| 13 | +import matplotlib.pyplot as plt |
| 14 | + |
| 15 | +from skglm.utils.data import make_correlated_data |
8 | 16 | from skglm.cv import GeneralizedLinearEstimatorCV |
| 17 | +from skglm.estimators import GeneralizedLinearEstimator |
9 | 18 | from skglm.datafits import Quadratic |
10 | 19 | from skglm.penalties import L1_plus_L2 |
11 | 20 | from skglm.solvers import AndersonCD |
12 | 21 |
|
| 22 | +# %% |
| 23 | +# Generate correlated data with sparse ground truth |
| 24 | +# -------------------------------------------------- |
| 25 | +X, y, true_coef = make_correlated_data( |
| 26 | + n_samples=150, n_features=300, random_state=42 |
| 27 | +) |
13 | 28 |
|
14 | | -X, y = make_regression(n_samples=100, n_features=20, noise=0.1, random_state=42) |
15 | | - |
| 29 | +# %% |
| 30 | +# Fit model using cross-validation |
| 31 | +# -------------------------------- |
| 32 | +# The CV estimator automatically finds the best regularization strength |
16 | 33 | estimator = GeneralizedLinearEstimatorCV( |
17 | 34 | datafit=Quadratic(), |
18 | 35 | penalty=L1_plus_L2(alpha=1.0, l1_ratio=0.5), |
19 | | - solver=AndersonCD(max_iter=50, tol=1e-4), |
20 | | - cv=6, |
| 36 | + solver=AndersonCD(max_iter=100), |
| 37 | + cv=5, |
| 38 | + n_alphas=50, |
21 | 39 | ) |
22 | 40 | estimator.fit(X, y) |
23 | 41 |
|
24 | 42 | print(f"Best alpha: {estimator.alpha_:.3f}") |
25 | | -print(f"L1 ratio: {estimator.penalty.l1_ratio:.3f}") |
26 | | -print(f"Number of non-zero coefficients: {np.sum(estimator.coef_ != 0)}") |
| 43 | +n_nonzero = np.sum(estimator.coef_ != 0) |
| 44 | +n_true_nonzero = np.sum(true_coef != 0) |
| 45 | +print(f"Non-zero coefficients: {n_nonzero} (true: {n_true_nonzero})") |
| 46 | + |
| 47 | +# %% |
| 48 | +# Visualize the cross-validation path |
| 49 | +# ----------------------------------- |
| 50 | +# Plot shows how CV balances model complexity with prediction performance |
| 51 | + |
| 52 | +# Get mean CV scores |
| 53 | +mean_scores = np.mean(estimator.scores_path_, axis=1) |
| 54 | +std_scores = np.std(estimator.scores_path_, axis=1) |
| 55 | +best_idx = np.argmax(mean_scores) |
| 56 | +best_alpha = estimator.alphas_[best_idx] |
| 57 | + |
| 58 | +# Compute coefficient paths |
| 59 | +coef_paths = [] |
| 60 | +for alpha in estimator.alphas_: |
| 61 | + est_temp = GeneralizedLinearEstimator( |
| 62 | + datafit=Quadratic(), |
| 63 | + penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5), |
| 64 | + solver=AndersonCD(max_iter=100) |
| 65 | + ) |
| 66 | + est_temp.fit(X, y) |
| 67 | + coef_paths.append(est_temp.coef_) |
| 68 | +coef_paths = np.array(coef_paths) |
| 69 | + |
| 70 | +fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True) |
| 71 | + |
| 72 | +ax1.semilogx(estimator.alphas_, -mean_scores, 'b-', linewidth=2, label='MSE') |
| 73 | +ax1.fill_between(estimator.alphas_, |
| 74 | + -mean_scores - std_scores, |
| 75 | + -mean_scores + std_scores, |
| 76 | + alpha=0.2, label='±1 std. dev.') |
| 77 | +ax1.axvline(best_alpha, color='red', linestyle='--', |
| 78 | + label=f'Best alpha = {best_alpha:.2e}') |
| 79 | +ax1.set_ylabel('MSE') |
| 80 | +ax1.set_title('Cross-Validation Score vs. Regularization') |
| 81 | +ax1.legend(loc='best') |
| 82 | +ax1.grid(True, alpha=0.3) |
| 83 | +ax1.set_xlabel('alpha') |
| 84 | + |
| 85 | +for j in range(coef_paths.shape[1]): |
| 86 | + ax2.semilogx(estimator.alphas_, coef_paths[:, j], lw=1, alpha=0.3) |
| 87 | +ax2.axvline(best_alpha, color='red', linestyle='--') |
| 88 | +ax2.set_xlabel('alpha') |
| 89 | +ax2.set_ylabel('Coefficient value') |
| 90 | +ax2.set_title('Regularization Path of Coefficients') |
| 91 | +ax2.grid(True, alpha=0.3) |
| 92 | + |
| 93 | +plt.tight_layout() |
| 94 | +plt.show() |
| 95 | + |
| 96 | +# %% [markdown] |
| 97 | +# Top panel: Mean CV MSE shows U-shape, minimized at chosen alpha for optimal |
| 98 | +# bias-variance tradeoff. |
| 99 | +# |
| 100 | +# Bottom panel: At this alpha, most coefficients are shrunk (many near zero), |
| 101 | +# highlighting a sparse subset of key predictors. |
| 102 | + |
| 103 | + |
| 104 | +# %% |
| 105 | +# Visualize distance to true coefficients |
| 106 | +# ---------------------------------------- |
| 107 | +# Compute how well different regularization strengths recover the true coefficients |
| 108 | + |
| 109 | +distances = [] |
| 110 | +for alpha in estimator.alphas_: |
| 111 | + est_temp = GeneralizedLinearEstimator( |
| 112 | + datafit=Quadratic(), |
| 113 | + penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5), |
| 114 | + solver=AndersonCD(max_iter=100) |
| 115 | + ) |
| 116 | + est_temp.fit(X, y) |
| 117 | + distances.append(np.linalg.norm(est_temp.coef_ - true_coef, ord=1)) |
| 118 | + |
| 119 | +plt.figure(figsize=(8, 5)) |
| 120 | +plt.loglog(estimator.alphas_, distances, 'b-', linewidth=2) |
| 121 | +plt.axvline(estimator.alpha_, color='red', linestyle='--', |
| 122 | + label=f'CV-selected alpha = {estimator.alpha_:.3f}') |
| 123 | +plt.xlabel('Alpha (regularization strength)') |
| 124 | +plt.ylabel('L1 distance to true coefficients') |
| 125 | +plt.title('Recovery of True Coefficients') |
| 126 | +plt.legend() |
| 127 | +plt.grid(True, alpha=0.3) |
| 128 | +plt.show() |
| 129 | + |
| 130 | +print( |
| 131 | + f"Distance at CV-selected alpha: " |
| 132 | + f"{np.linalg.norm(estimator.coef_ - true_coef, ord=1):.3f}") |
27 | 133 |
|
28 | | -# TODO: add plot, test with other penalties and datafits |
| 134 | +# %% [markdown] |
| 135 | +# The U-shaped curve shows two failure modes: small alpha doesn't induce |
| 136 | +# enough sparsity (keeping noisy/irrelevant features), while large alpha |
| 137 | +# overshrinks all coefficients including the true signals. Cross-validation |
| 138 | +# finds a good balance without needing access to the ground truth. |
0 commit comments