Skip to content

Commit 7054fd0

Browse files
DOC add ucurve example (#239)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent ef88450 commit 7054fd0

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

examples/plot_ucurve.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
==============================
3+
Show U-curve of regularization
4+
==============================
5+
Illustrate the sweet spot of regularization: not too much, not too little.
6+
We showcase that for the Lasso estimator on the ``rcv1.binary`` dataset.
7+
"""
8+
9+
import numpy as np
10+
from numpy.linalg import norm
11+
import matplotlib.pyplot as plt
12+
from libsvmdata import fetch_libsvm
13+
14+
from sklearn.model_selection import train_test_split
15+
from sklearn.metrics import mean_squared_error
16+
17+
from skglm import Lasso
18+
19+
# %%
20+
# First, we load the dataset and keep 2000 features.
21+
# We also retrain 2000 samples in training dataset.
22+
X, y = fetch_libsvm("rcv1.binary")
23+
24+
X = X[:, :2000]
25+
X_train, X_test, y_train, y_test = train_test_split(X, y)
26+
X_train, y_train = X_train[:2000], y_train[:2000]
27+
28+
# %%
29+
# Next, we define the regularization path.
30+
# For Lasso, it is well know that there is an ``alpha_max`` above which the optimal solution is the zero vector.
31+
alpha_max = norm(X_train.T @ y_train, ord=np.inf) / len(y_train)
32+
alphas = alpha_max * np.geomspace(1, 1e-4)
33+
34+
# %%
35+
# Let's train the estimator along the regularization path and then compute the MSE on train and test data.
36+
mse_train = []
37+
mse_test = []
38+
39+
clf = Lasso(fit_intercept=False, tol=1e-8, warm_start=True)
40+
for idx, alpha in enumerate(alphas):
41+
clf.alpha = alpha
42+
clf.fit(X_train, y_train)
43+
44+
mse_train.append(mean_squared_error(y_train, clf.predict(X_train)))
45+
mse_test.append(mean_squared_error(y_test, clf.predict(X_test)))
46+
47+
# %%
48+
# Finally, we can plot the train and test MSE.
49+
# Notice the "sweet spot" at around ``1e-4``, which sits at the boundary between underfitting and overfitting.
50+
plt.close('all')
51+
plt.semilogx(alphas, mse_train, label='train MSE')
52+
plt.semilogx(alphas, mse_test, label='test MSE')
53+
plt.legend()
54+
plt.title("Mean squared error")
55+
plt.xlabel(r"Lasso regularization strength $\lambda$")
56+
plt.show(block=False)

0 commit comments

Comments
 (0)