Skip to content

Commit 4b2344c

Browse files
FEAT - GeneralizedLinearEstimatorCV for automatic CV of models with Elastic penalty (#311)
1 parent c40d4dc commit 4b2344c

File tree

6 files changed

+366
-0
lines changed

6 files changed

+366
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Estimators
1818
:toctree: generated/
1919

2020
GeneralizedLinearEstimator
21+
GeneralizedLinearEstimatorCV
2122
CoxEstimator
2223
ElasticNet
2324
GroupLasso

doc/changes/0.5.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ Version 0.5 (in progress)
44
-------------------------
55
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)
66
- Add experimental :ref:`QuantileHuber <skglm.experimental.quantile_huber.QuantileHuber>` and :ref:`SmoothQuantileRegressor <skglm.experimental.quantile_huber.SmoothQuantileRegressor>` for quantile regression, and an example script (PR: :gh:`312`).
7+
- Add :ref:`GeneralizedLinearEstimatorCV <skglm.cv.GeneralizedLinearEstimatorCV>` for cross-validation with automatic parameter selection for L1 and elastic-net penalties (PR: :gh:`299`)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
===================================
3+
Cross-Validation for Generalized Linear Models
4+
===================================
5+
6+
This example shows how to use cross-validation to automatically select
7+
the optimal regularization parameter for generalized linear models.
8+
"""
9+
10+
# Author: Florian Kozikowski
11+
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
15+
from skglm.utils.data import make_correlated_data
16+
from skglm.cv import GeneralizedLinearEstimatorCV
17+
from skglm.estimators import GeneralizedLinearEstimator
18+
from skglm.datafits import Quadratic
19+
from skglm.penalties import L1_plus_L2
20+
from skglm.solvers import AndersonCD
21+
22+
# %%
23+
# Generate correlated data with sparse ground truth
24+
# --------------------------------------------------
25+
X, y, true_coef = make_correlated_data(
26+
n_samples=150, n_features=300, random_state=42
27+
)
28+
29+
# %%
30+
# Fit model using cross-validation
31+
# --------------------------------
32+
# The CV estimator automatically finds the best regularization strength
33+
estimator = GeneralizedLinearEstimatorCV(
34+
datafit=Quadratic(),
35+
penalty=L1_plus_L2(alpha=1.0, l1_ratio=0.5),
36+
solver=AndersonCD(max_iter=100),
37+
cv=5,
38+
n_alphas=50,
39+
)
40+
estimator.fit(X, y)
41+
42+
print(f"Best alpha: {estimator.alpha_:.3f}")
43+
n_nonzero = np.sum(estimator.coef_ != 0)
44+
n_true_nonzero = np.sum(true_coef != 0)
45+
print(f"Non-zero coefficients: {n_nonzero} (true: {n_true_nonzero})")
46+
47+
# %%
48+
# Visualize the cross-validation path
49+
# -----------------------------------
50+
# Plot shows how CV balances model complexity with prediction performance
51+
52+
# Get mean CV scores
53+
mean_scores = np.mean(estimator.scores_path_, axis=1)
54+
std_scores = np.std(estimator.scores_path_, axis=1)
55+
best_idx = np.argmax(mean_scores)
56+
best_alpha = estimator.alphas_[best_idx]
57+
58+
# Compute coefficient paths
59+
coef_paths = []
60+
for alpha in estimator.alphas_:
61+
est_temp = GeneralizedLinearEstimator(
62+
datafit=Quadratic(),
63+
penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5),
64+
solver=AndersonCD(max_iter=100)
65+
)
66+
est_temp.fit(X, y)
67+
coef_paths.append(est_temp.coef_)
68+
coef_paths = np.array(coef_paths)
69+
70+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)
71+
72+
ax1.semilogx(estimator.alphas_, -mean_scores, 'b-', linewidth=2, label='MSE')
73+
ax1.fill_between(estimator.alphas_,
74+
-mean_scores - std_scores,
75+
-mean_scores + std_scores,
76+
alpha=0.2, label='±1 std. dev.')
77+
ax1.axvline(best_alpha, color='red', linestyle='--',
78+
label=f'Best alpha = {best_alpha:.2e}')
79+
ax1.set_ylabel('MSE')
80+
ax1.set_title('Cross-Validation Score vs. Regularization')
81+
ax1.legend(loc='best')
82+
ax1.grid(True, alpha=0.3)
83+
ax1.set_xlabel('alpha')
84+
85+
for j in range(coef_paths.shape[1]):
86+
ax2.semilogx(estimator.alphas_, coef_paths[:, j], lw=1, alpha=0.3)
87+
ax2.axvline(best_alpha, color='red', linestyle='--')
88+
ax2.set_xlabel('alpha')
89+
ax2.set_ylabel('Coefficient value')
90+
ax2.set_title('Regularization Path of Coefficients')
91+
ax2.grid(True, alpha=0.3)
92+
93+
plt.tight_layout()
94+
plt.show()
95+
96+
# %% [markdown]
97+
# Top panel: Mean CV MSE shows U-shape, minimized at chosen alpha for optimal
98+
# bias-variance tradeoff.
99+
#
100+
# Bottom panel: At this alpha, most coefficients are shrunk (many near zero),
101+
# highlighting a sparse subset of key predictors.
102+
103+
104+
# %%
105+
# Visualize distance to true coefficients
106+
# ----------------------------------------
107+
# Compute how well different regularization strengths recover the true coefficients
108+
109+
distances = []
110+
for alpha in estimator.alphas_:
111+
est_temp = GeneralizedLinearEstimator(
112+
datafit=Quadratic(),
113+
penalty=L1_plus_L2(alpha=alpha, l1_ratio=0.5),
114+
solver=AndersonCD(max_iter=100)
115+
)
116+
est_temp.fit(X, y)
117+
distances.append(np.linalg.norm(est_temp.coef_ - true_coef, ord=1))
118+
119+
plt.figure(figsize=(8, 5))
120+
plt.loglog(estimator.alphas_, distances, 'b-', linewidth=2)
121+
plt.axvline(estimator.alpha_, color='red', linestyle='--',
122+
label=f'CV-selected alpha = {estimator.alpha_:.3f}')
123+
plt.xlabel('Alpha (regularization strength)')
124+
plt.ylabel('L1 distance to true coefficients')
125+
plt.title('Recovery of True Coefficients')
126+
plt.legend()
127+
plt.grid(True, alpha=0.3)
128+
plt.show()
129+
130+
print(
131+
f"Distance at CV-selected alpha: "
132+
f"{np.linalg.norm(estimator.coef_ - true_coef, ord=1):.3f}")
133+
134+
# %% [markdown]
135+
# The U-shaped curve shows two failure modes: small alpha doesn't induce
136+
# enough sparsity (keeping noisy/irrelevant features), while large alpha
137+
# overshrinks all coefficients including the true signals. Cross-validation
138+
# finds a good balance without needing access to the ground truth.

skglm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,
55
SparseLogisticRegression, GeneralizedLinearEstimator, CoxEstimator, GroupLasso,
66
)
7+
from .cv import GeneralizedLinearEstimatorCV # noqa F401

skglm/cv.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import numpy as np
2+
from joblib import Parallel, delayed
3+
from skglm.datafits import Logistic, QuadraticSVC
4+
from skglm.estimators import GeneralizedLinearEstimator
5+
from sklearn.model_selection import KFold, StratifiedKFold
6+
from sklearn.metrics import accuracy_score, mean_squared_error
7+
8+
9+
class GeneralizedLinearEstimatorCV(GeneralizedLinearEstimator):
10+
"""Cross-validated wrapper for GeneralizedLinearEstimator.
11+
12+
This class performs cross-validated selection of the regularization parameter(s)
13+
for a generalized linear estimator, supporting both L1 and elastic-net penalties.
14+
15+
Parameters
16+
----------
17+
datafit : object
18+
Datafit (loss) function instance (e.g., Logistic, Quadratic).
19+
penalty : object
20+
Penalty instance with an 'alpha' parameter (and optionally 'l1_ratio').
21+
solver : object
22+
Solver instance to use for optimization.
23+
alphas : array-like of shape (n_alphas,), optional
24+
List of alpha values to try. If None, they are set automatically.
25+
l1_ratio : float or array-like, optional
26+
The ElasticNet mixing parameter(s), with 0 <= l1_ratio <= 1.
27+
Only used if the penalty supports 'l1_ratio'. If None, defaults to 1.0 (Lasso).
28+
cv : int, default=4
29+
Number of cross-validation folds.
30+
n_jobs : int, default=1
31+
Number of jobs to run in parallel for cross-validation.
32+
random_state : int or None, default=None
33+
Random seed for cross-validation splitting.
34+
eps : float, default=1e-3
35+
Ratio of minimum to maximum alpha if alphas are set automatically.
36+
n_alphas : int, default=100
37+
Number of alphas along the regularization path if alphas are set automatically.
38+
39+
Attributes
40+
----------
41+
alpha_ : float
42+
Best alpha found by cross-validation.
43+
l1_ratio_ : float or None
44+
Best l1_ratio found by cross-validation (if applicable).
45+
best_estimator_ : GeneralizedLinearEstimator
46+
Estimator fitted on the full data with the best parameters.
47+
coef_ : ndarray
48+
Coefficients of the fitted model.
49+
intercept_ : float or ndarray
50+
Intercept of the fitted model.
51+
alphas_ : ndarray
52+
Array of alphas used in the search.
53+
scores_path_ : ndarray
54+
Cross-validation scores for each parameter combination.
55+
n_iter_ : int or None
56+
Number of iterations run by the solver (if available).
57+
n_features_in_ : int or None
58+
Number of features seen during fit.
59+
feature_names_in_ : ndarray or None
60+
Names of features seen during fit.
61+
"""
62+
63+
def __init__(self, datafit, penalty, solver, alphas=None, l1_ratio=None,
64+
cv=4, n_jobs=1, random_state=None,
65+
eps=1e-3, n_alphas=100):
66+
super().__init__(datafit=datafit, penalty=penalty, solver=solver)
67+
self.alphas = alphas
68+
self.l1_ratio = l1_ratio
69+
self.cv = cv
70+
self.n_jobs = n_jobs
71+
self.random_state = random_state
72+
self.eps = eps
73+
self.n_alphas = n_alphas
74+
75+
def _score(self, y_true, y_pred):
76+
"""Compute the performance score (higher is better)."""
77+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
78+
return accuracy_score(y_true, y_pred)
79+
return -mean_squared_error(y_true, y_pred)
80+
81+
def fit(self, X, y):
82+
"""Fit the model using cross-validation."""
83+
if not hasattr(self.penalty, "alpha"):
84+
raise ValueError(
85+
"GeneralizedLinearEstimatorCV only supports penalties which "
86+
"expose an 'alpha' parameter."
87+
)
88+
n_samples, n_features = X.shape
89+
90+
if self.alphas is not None:
91+
alphas = np.sort(self.alphas)[::-1]
92+
else:
93+
alpha_max = np.max(np.abs(X.T @ y)) / n_samples
94+
alphas = np.geomspace(
95+
alpha_max,
96+
alpha_max * self.eps,
97+
self.n_alphas
98+
)
99+
has_l1_ratio = hasattr(self.penalty, "l1_ratio")
100+
l1_ratios = [1.] if not has_l1_ratio else np.atleast_1d(
101+
self.l1_ratio if self.l1_ratio is not None else [1.])
102+
103+
scores_path = np.empty((len(l1_ratios), len(alphas), self.cv))
104+
best_loss = -np.inf
105+
106+
def _solve_fold(k, train, test, alpha, l1, w_init):
107+
pen_kwargs = {k: v for k, v in self.penalty.__dict__.items()
108+
if k not in ("alpha", "l1_ratio")}
109+
if has_l1_ratio:
110+
pen_kwargs['l1_ratio'] = l1
111+
pen = type(self.penalty)(alpha=alpha, **pen_kwargs)
112+
113+
est = GeneralizedLinearEstimator(
114+
datafit=self.datafit, penalty=pen, solver=self.solver
115+
)
116+
if w_init is not None:
117+
est.coef_ = w_init[0]
118+
est.intercept_ = w_init[1]
119+
est.fit(X[train], y[train])
120+
y_pred = est.predict(X[test])
121+
return est.coef_, est.intercept_, self._score(y[test], y_pred)
122+
123+
for idx_ratio, l1_ratio in enumerate(l1_ratios):
124+
warm_start = [None] * self.cv
125+
126+
for idx_alpha, alpha in enumerate(alphas):
127+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
128+
kf = StratifiedKFold(n_splits=self.cv, shuffle=True,
129+
random_state=self.random_state)
130+
split_iter = kf.split(np.arange(n_samples), y)
131+
else:
132+
kf = KFold(n_splits=self.cv, shuffle=True,
133+
random_state=self.random_state)
134+
split_iter = kf.split(np.arange(n_samples))
135+
fold_result = Parallel(self.n_jobs)(
136+
delayed(_solve_fold)(k, tr, te, alpha, l1_ratio, warm_start[k])
137+
for k, (tr, te) in enumerate(split_iter)
138+
)
139+
140+
for k, (coef_fold, intercept_fold, loss_fold) in enumerate(fold_result):
141+
warm_start[k] = (coef_fold, intercept_fold)
142+
scores_path[idx_ratio, idx_alpha, k] = loss_fold
143+
144+
mean_loss = np.mean(scores_path[idx_ratio, idx_alpha])
145+
if mean_loss > best_loss:
146+
best_loss = mean_loss
147+
self.alpha_ = float(alpha)
148+
self.l1_ratio_ = float(l1_ratio) if has_l1_ratio else None
149+
150+
# Refit on full dataset
151+
pen_kwargs = {k: v for k, v in self.penalty.__dict__.items()
152+
if k not in ("alpha", "l1_ratio")}
153+
if has_l1_ratio:
154+
pen_kwargs["l1_ratio"] = self.l1_ratio_
155+
best_penalty = type(self.penalty)(
156+
alpha=self.alpha_, **pen_kwargs
157+
)
158+
best_estimator = GeneralizedLinearEstimator(
159+
datafit=self.datafit,
160+
penalty=best_penalty,
161+
solver=self.solver
162+
)
163+
best_estimator.fit(X, y)
164+
self.best_estimator_ = best_estimator
165+
self.coef_ = best_estimator.coef_
166+
self.intercept_ = best_estimator.intercept_
167+
self.n_iter_ = getattr(best_estimator, "n_iter_", None)
168+
self.n_features_in_ = getattr(best_estimator, "n_features_in_", None)
169+
self.feature_names_in_ = getattr(best_estimator, "feature_names_in_", None)
170+
self.alphas_ = alphas
171+
self.scores_path_ = np.squeeze(scores_path)
172+
return self
173+
174+
def predict(self, X):
175+
return self.best_estimator_.predict(X)
176+
177+
def predict_proba(self, X):
178+
return self.best_estimator_.predict_proba(X)
179+
180+
def score(self, X, y):
181+
return self.best_estimator_.score(X, y)

skglm/tests/test_cv.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
from sklearn.datasets import make_regression
3+
from sklearn.linear_model import ElasticNet
4+
from sklearn.model_selection import GridSearchCV, KFold
5+
from skglm.datafits import Quadratic
6+
from skglm.penalties import L1_plus_L2
7+
from skglm.solvers import AndersonCD
8+
from skglm.cv import GeneralizedLinearEstimatorCV
9+
import pytest
10+
11+
12+
@pytest.mark.parametrize("n_samples,n_features,noise",
13+
[(100, 10, 0.1), (100, 500, 0.2), (100, 500, 0.3)])
14+
def test_elasticnet_cv_matches_sklearn(n_samples, n_features, noise):
15+
"""Test GeneralizedLinearEstimatorCV matches sklearn GridSearchCV for ElasticNet."""
16+
seed = 42
17+
X, y = make_regression(n_samples=n_samples,
18+
n_features=n_features, noise=noise, random_state=seed)
19+
20+
n = X.shape[0]
21+
alpha_max = np.max(np.abs(X.T @ y)) / n
22+
alphas = alpha_max * np.array([1, 0.1, 0.01, 0.001])
23+
l1_ratios = np.array([0.2, 0.5, 0.8])
24+
cv = KFold(n_splits=5, shuffle=True, random_state=seed)
25+
26+
sklearn_model = GridSearchCV(
27+
ElasticNet(max_iter=10000, tol=1e-8),
28+
{'alpha': alphas, 'l1_ratio': l1_ratios},
29+
cv=cv, scoring='neg_mean_squared_error', n_jobs=1
30+
).fit(X, y)
31+
32+
skglm_model = GeneralizedLinearEstimatorCV(
33+
Quadratic(), L1_plus_L2(0.1, 0.5), AndersonCD(max_iter=10000, tol=1e-8),
34+
alphas=alphas, l1_ratio=l1_ratios, cv=5, random_state=seed, n_jobs=1
35+
).fit(X, y)
36+
37+
np.testing.assert_equal(sklearn_model.best_params_['alpha'],
38+
skglm_model.alpha_)
39+
np.testing.assert_equal(sklearn_model.best_params_['l1_ratio'],
40+
skglm_model.l1_ratio_)
41+
np.testing.assert_allclose(sklearn_model.best_estimator_.coef_,
42+
skglm_model.coef_.ravel(), rtol=1e-4, atol=1e-6)
43+
np.testing.assert_allclose(sklearn_model.best_estimator_.intercept_,
44+
skglm_model.intercept_, rtol=1e-4, atol=1e-6)

0 commit comments

Comments
 (0)