|
| 1 | +""" |
| 2 | +Cross-validation for Generalized Linear Models |
| 3 | +============================================ |
| 4 | +
|
| 5 | +This example demonstrates how to use cross-validation to select the optimal |
| 6 | +regularization parameters for different types of generalized linear models. |
| 7 | +
|
| 8 | +We cover: |
| 9 | +1. Lasso regression (L1 penalty) |
| 10 | +2. Elastic Net regression (L1 + L2 penalty) |
| 11 | +3. Logistic regression with L1 penalty |
| 12 | +4. Logistic regression with Elastic Net penalty |
| 13 | +
|
| 14 | +
|
| 15 | +Understanding Cross-Validation |
| 16 | +---------------------------- |
| 17 | +Cross-validation (CV) is a technique to evaluate how well a model will perform |
| 18 | +on unseen data. In this example, we use K-fold CV (K=5 by default) to: |
| 19 | +1. Split the data into K folds |
| 20 | +2. Train the model K times, each time using K-1 folds for training |
| 21 | +3. Evaluate the model on the remaining fold |
| 22 | +4. Average the results to get a robust estimate of model performance |
| 23 | +
|
| 24 | +The Process |
| 25 | +---------- |
| 26 | +For each model type, we: |
| 27 | +1. Generate synthetic data (or use real data) |
| 28 | +2. Split it into training and test sets |
| 29 | +3. Use CV to find the best regularization parameters |
| 30 | +4. Train the final model with the best parameters |
| 31 | +5. Evaluate on the test set |
| 32 | +
|
| 33 | +References |
| 34 | +---------- |
| 35 | +[1] scikit-learn. (n.d.). Cross-validation: evaluating estimator performance. |
| 36 | + https://scikit-learn.org/stable/modules/cross_validation.html |
| 37 | +""" |
| 38 | + |
| 39 | +import numpy as np |
| 40 | +import matplotlib.pyplot as plt |
| 41 | +from sklearn.datasets import make_regression, make_classification |
| 42 | +from sklearn.model_selection import train_test_split |
| 43 | +from sklearn.metrics import mean_squared_error, accuracy_score |
| 44 | + |
| 45 | +from skglm.estimators import GeneralizedLinearEstimator |
| 46 | +from skglm.datafits import Quadratic, Logistic |
| 47 | +from skglm.penalties import L1, L1_plus_L2 |
| 48 | +from skglm.solvers import AndersonCD |
| 49 | +from sklearn.preprocessing import StandardScaler |
| 50 | + |
| 51 | +# Set random seed for reproducibility |
| 52 | +np.random.seed(42) |
| 53 | + |
| 54 | +# 1. Lasso Regression Example |
| 55 | +# -------------------------- |
| 56 | +print("1. Lasso Regression Example") |
| 57 | +print("-" * 30) |
| 58 | + |
| 59 | +# Generate synthetic data |
| 60 | +X, y = make_regression(n_samples=100, n_features=20, noise=0.1) |
| 61 | + |
| 62 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
| 63 | +scaler = StandardScaler() |
| 64 | +X_train = scaler.fit_transform(X_train) |
| 65 | +X_test = scaler.transform(X_test) |
| 66 | + |
| 67 | +# Create and fit Lasso with CV |
| 68 | +lasso = GeneralizedLinearEstimator( |
| 69 | + datafit=Quadratic(), |
| 70 | + penalty=L1(alpha=1.0), |
| 71 | + solver=AndersonCD() |
| 72 | +) |
| 73 | +lasso.cross_validate(X_train, y_train, alphas='auto', cv=5, |
| 74 | + scoring='neg_mean_squared_error') |
| 75 | + |
| 76 | +# Evaluate on test set |
| 77 | +y_pred = lasso.predict(X_test) |
| 78 | +mse = mean_squared_error(y_test, y_pred) |
| 79 | +print(f"Best alpha: {lasso.best_alpha_:.3f}") |
| 80 | +print(f"Test MSE: {mse:.3f}") |
| 81 | + |
| 82 | +# Plot CV scores |
| 83 | +plt.figure(figsize=(10, 6)) |
| 84 | +plt.semilogx(lasso.alphas_, lasso.cv_scores_[None].mean(axis=0)) |
| 85 | +plt.axvline(lasso.best_alpha_, color='r', linestyle='--', |
| 86 | + label=f'Best alpha: {lasso.best_alpha_:.3f}') |
| 87 | +plt.xlabel('Alpha') |
| 88 | +plt.ylabel('CV Score') |
| 89 | +plt.title('Lasso CV Scores') |
| 90 | +plt.legend() |
| 91 | +plt.grid(True) |
| 92 | + |
| 93 | +# 2. Elastic Net Regression Example |
| 94 | +# -------------------------------- |
| 95 | +print("\n2. Elastic Net Regression Example") |
| 96 | +print("-" * 30) |
| 97 | + |
| 98 | +# Create and fit Elastic Net with CV |
| 99 | +enet = GeneralizedLinearEstimator( |
| 100 | + datafit=Quadratic(), |
| 101 | + penalty=L1_plus_L2(alpha=1.0, l1_ratio=0.5), |
| 102 | + solver=AndersonCD() |
| 103 | +) |
| 104 | +enet.cross_validate(X_train, y_train, alphas='auto', |
| 105 | + l1_ratios=[0.1, 0.5, 0.9], cv=5, scoring='neg_mean_squared_error') |
| 106 | + |
| 107 | +# Evaluate on test set |
| 108 | +y_pred = enet.predict(X_test) |
| 109 | +mse = mean_squared_error(y_test, y_pred) |
| 110 | +print(f"Best alpha: {enet.best_alpha_:.3f}") |
| 111 | +print(f"Best l1_ratio: {enet.best_l1_ratio_:.3f}") |
| 112 | +print(f"Test MSE: {mse:.3f}") |
| 113 | + |
| 114 | +# Plot CV scores for different l1_ratios |
| 115 | +plt.figure(figsize=(10, 6)) |
| 116 | +for ratio in enet.cv_scores_: |
| 117 | + plt.semilogx(enet.alphas_, enet.cv_scores_[ratio].mean(axis=0), |
| 118 | + label=f'l1_ratio={ratio}') |
| 119 | +plt.axvline(enet.best_alpha_, color='r', linestyle='--', |
| 120 | + label=f'Best alpha: {enet.best_alpha_:.3f}') |
| 121 | +plt.xlabel('Alpha') |
| 122 | +plt.ylabel('CV Score') |
| 123 | +plt.title('Elastic Net CV Scores') |
| 124 | +plt.legend() |
| 125 | +plt.grid(True) |
| 126 | + |
| 127 | +# 3. Logistic Regression with L1 Penalty |
| 128 | +# ------------------------------------- |
| 129 | +print("\n3. Logistic Regression with L1 Penalty") |
| 130 | +print("-" * 30) |
| 131 | + |
| 132 | +# Generate synthetic classification data |
| 133 | +X, y = make_classification(n_samples=100, n_features=20, n_classes=2) |
| 134 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
| 135 | +scaler = StandardScaler() |
| 136 | +X_train = scaler.fit_transform(X_train) |
| 137 | +X_test = scaler.transform(X_test) |
| 138 | + |
| 139 | +# Create and fit Logistic Regression with CV |
| 140 | +logreg = GeneralizedLinearEstimator( |
| 141 | + datafit=Logistic(), |
| 142 | + penalty=L1(alpha=1.0), |
| 143 | + solver=AndersonCD() |
| 144 | +) |
| 145 | +logreg.cross_validate(X_train, y_train, alphas='auto', cv=5) |
| 146 | + |
| 147 | +# Evaluate on test set |
| 148 | +y_pred = logreg.predict(X_test) |
| 149 | +accuracy = accuracy_score(y_test, y_pred) |
| 150 | +print(f"Best alpha: {logreg.best_alpha_:.3f}") |
| 151 | +print(f"Test Accuracy: {accuracy:.3f}") |
| 152 | + |
| 153 | +# Plot CV scores |
| 154 | +plt.figure(figsize=(10, 6)) |
| 155 | +plt.semilogx(logreg.alphas_, logreg.cv_scores_[None].mean(axis=0)) |
| 156 | +plt.axvline(logreg.best_alpha_, color='r', linestyle='--', |
| 157 | + label=f'Best alpha: {logreg.best_alpha_:.3f}') |
| 158 | +plt.xlabel('Alpha') |
| 159 | +plt.ylabel('CV Score') |
| 160 | +plt.title('Logistic Regression CV Scores') |
| 161 | +plt.legend() |
| 162 | +plt.grid(True) |
| 163 | + |
| 164 | +# 4. Logistic Regression with Elastic Net Penalty |
| 165 | +# --------------------------------------------- |
| 166 | +print("\n4. Logistic Regression with Elastic Net Penalty") |
| 167 | +print("-" * 30) |
| 168 | + |
| 169 | +# Create and fit Logistic Regression with Elastic Net penalty |
| 170 | +logreg_enet = GeneralizedLinearEstimator( |
| 171 | + datafit=Logistic(), |
| 172 | + penalty=L1_plus_L2(alpha=1.0, l1_ratio=0.5), |
| 173 | + solver=AndersonCD() |
| 174 | +) |
| 175 | +logreg_enet.cross_validate(X_train, y_train, alphas='auto', |
| 176 | + l1_ratios=[0.1, 0.5, 0.9], cv=5) |
| 177 | + |
| 178 | +# Evaluate on test set |
| 179 | +y_pred = logreg_enet.predict(X_test) |
| 180 | +accuracy = accuracy_score(y_test, y_pred) |
| 181 | +print(f"Best alpha: {logreg_enet.best_alpha_:.3f}") |
| 182 | +print(f"Best l1_ratio: {logreg_enet.best_l1_ratio_:.3f}") |
| 183 | +print(f"Test Accuracy: {accuracy:.3f}") |
| 184 | + |
| 185 | +# Plot CV scores for different l1_ratios |
| 186 | +plt.figure(figsize=(10, 6)) |
| 187 | +for ratio in logreg_enet.cv_scores_: |
| 188 | + plt.semilogx(logreg_enet.alphas_, logreg_enet.cv_scores_[ratio].mean(axis=0), |
| 189 | + label=f'l1_ratio={ratio}') |
| 190 | +plt.axvline(logreg_enet.best_alpha_, color='r', linestyle='--', |
| 191 | + label=f'Best alpha: {logreg_enet.best_alpha_:.3f}') |
| 192 | +plt.xlabel('Alpha') |
| 193 | +plt.ylabel('CV Score') |
| 194 | +plt.title('Logistic Regression with Elastic Net CV Scores') |
| 195 | +plt.legend() |
| 196 | +plt.grid(True) |
| 197 | + |
| 198 | +plt.show() |
0 commit comments