|
| 1 | +""" |
| 2 | +============================== |
| 3 | +Show U-curve of regularization |
| 4 | +============================== |
| 5 | +Illustrate the sweet spot of regularization: not too much, not too little. |
| 6 | +We showcase that for the Lasso estimator on the ``rcv1.binary`` dataset. |
| 7 | +""" |
| 8 | + |
| 9 | +import numpy as np |
| 10 | +from numpy.linalg import norm |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +from libsvmdata import fetch_libsvm |
| 13 | + |
| 14 | +from sklearn.model_selection import train_test_split |
| 15 | +from sklearn.metrics import mean_squared_error |
| 16 | + |
| 17 | +from skglm import Lasso |
| 18 | + |
| 19 | +# %% |
| 20 | +# First, we load the dataset and keep 2000 features. |
| 21 | +# We also retrain 2000 samples in training dataset. |
| 22 | +X, y = fetch_libsvm("rcv1.binary") |
| 23 | + |
| 24 | +X = X[:, :2000] |
| 25 | +X_train, X_test, y_train, y_test = train_test_split(X, y) |
| 26 | +X_train, y_train = X_train[:2000], y_train[:2000] |
| 27 | + |
| 28 | +# %% |
| 29 | +# Next, we define the regularization path. |
| 30 | +# For Lasso, it is well know that there is an ``alpha_max`` above which the optimal solution is the zero vector. |
| 31 | +alpha_max = norm(X_train.T @ y_train, ord=np.inf) / len(y_train) |
| 32 | +alphas = alpha_max * np.geomspace(1, 1e-4) |
| 33 | + |
| 34 | +# %% |
| 35 | +# Let's train the estimator along the regularization path and then compute the MSE on train and test data. |
| 36 | +mse_train = [] |
| 37 | +mse_test = [] |
| 38 | + |
| 39 | +clf = Lasso(fit_intercept=False, tol=1e-8, warm_start=True) |
| 40 | +for idx, alpha in enumerate(alphas): |
| 41 | + clf.alpha = alpha |
| 42 | + clf.fit(X_train, y_train) |
| 43 | + |
| 44 | + mse_train.append(mean_squared_error(y_train, clf.predict(X_train))) |
| 45 | + mse_test.append(mean_squared_error(y_test, clf.predict(X_test))) |
| 46 | + |
| 47 | +# %% |
| 48 | +# Finally, we can plot the train and test MSE. |
| 49 | +# Notice the "sweet spot" at around ``1e-4``, which sits at the boundary between underfitting and overfitting. |
| 50 | +plt.close('all') |
| 51 | +plt.semilogx(alphas, mse_train, label='train MSE') |
| 52 | +plt.semilogx(alphas, mse_test, label='test MSE') |
| 53 | +plt.legend() |
| 54 | +plt.title("Mean squared error") |
| 55 | +plt.xlabel(r"Lasso regularization strength $\lambda$") |
| 56 | +plt.show(block=False) |
0 commit comments