Skip to content

Commit ed6f5f9

Browse files
adressing comments until l1_ratio is None
1 parent 6622cf0 commit ed6f5f9

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

examples/plot_generalized_linear_estimator_cv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import numpy as np
77
from sklearn.datasets import make_regression
8-
from skglm.penalties.generalized_linear_cv import GeneralizedLinearEstimatorCV
8+
from skglm.cv import GeneralizedLinearEstimatorCV
99
from skglm.datafits import Quadratic
1010
from skglm.penalties import L1_plus_L2
1111
from skglm.solvers import AndersonCD

skglm/cv.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import numpy as np
2+
from joblib import Parallel, delayed
3+
from sklearn.utils.extmath import softmax
4+
from skglm.datafits import Logistic, QuadraticSVC
5+
from skglm.estimators import GeneralizedLinearEstimator
6+
7+
8+
def _kfold_split(n_samples, k, rng):
9+
indices = rng.permutation(n_samples)
10+
fold_size = n_samples // k
11+
extra = n_samples % k
12+
13+
start = 0
14+
for i in range(k):
15+
end = start + fold_size + (1 if i < extra else 0)
16+
test = indices[start:end]
17+
train = np.concatenate([indices[:start], indices[end:]])
18+
yield train, test
19+
start = end
20+
21+
22+
class GeneralizedLinearEstimatorCV(GeneralizedLinearEstimator):
23+
"""CV wrapper for GeneralizedLinearEstimator."""
24+
25+
def __init__(self, datafit, penalty, solver, alphas=None, l1_ratio=None,
26+
cv=4, n_jobs=1, random_state=None, scoring=None,
27+
eps=1e-3, n_alphas=100):
28+
super().__init__(datafit=datafit, penalty=penalty, solver=solver)
29+
self.alphas = alphas
30+
self.l1_ratio = l1_ratio
31+
self.cv = cv
32+
self.n_jobs = n_jobs
33+
self.random_state = random_state
34+
self.scoring = scoring
35+
self.eps = eps
36+
self.n_alphas = n_alphas
37+
38+
def _score(self, y_true, y_pred):
39+
"""Compute the performance score (higher is better)."""
40+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
41+
return float(np.mean(y_true == y_pred))
42+
return -float(np.mean((y_true - y_pred) ** 2))
43+
44+
def fit(self, X, y):
45+
"""Fit the model using cross-validation."""
46+
if not hasattr(self.penalty, "alpha"):
47+
raise ValueError(
48+
"GeneralizedLinearEstimatorCV only supports penalties with 'alpha'."
49+
)
50+
y = np.asarray(y)
51+
n_samples, n_features = X.shape
52+
rng = np.random.RandomState(self.random_state)
53+
54+
if self.alphas is not None:
55+
alphas = np.sort(self.alphas)[::-1]
56+
else:
57+
alpha_max = np.max(np.abs(X.T @ y)) / n_samples
58+
alphas = np.geomspace(
59+
alpha_max,
60+
alpha_max * self.eps,
61+
self.n_alphas
62+
)[::-1]
63+
has_l1_ratio = hasattr(self.penalty, "l1_ratio")
64+
l1_ratios = [1.] if not has_l1_ratio else np.atleast_1d(
65+
self.l1_ratio if self.l1_ratio is not None else [1.])
66+
67+
mse_path = np.empty((len(l1_ratios), len(alphas), self.cv))
68+
best_loss = np.inf
69+
70+
def _solve_fold(k, train, test, alpha, l1, w_start):
71+
pen_kwargs = {k: v for k, v in self.penalty.__dict__.items()
72+
if k not in ("alpha", "l1_ratio")}
73+
if has_l1_ratio:
74+
pen_kwargs['l1_ratio'] = l1
75+
pen = type(self.penalty)(alpha=alpha, **pen_kwargs)
76+
77+
kw = dict(X=X[train], y=y[train], datafit=self.datafit, penalty=pen)
78+
if 'w' in self.solver.solve.__code__.co_varnames:
79+
kw['w'] = w_start
80+
w = self.solver.solve(**kw)
81+
w = w[0] if isinstance(w, tuple) else w
82+
83+
coef, intercept = (w[:n_features], w[n_features]
84+
) if w.size == n_features + 1 else (w, 0.0)
85+
86+
y_pred = X[test] @ coef + intercept
87+
return w, self._score(y[test], y_pred)
88+
89+
for idx_ratio, l1_ratio in enumerate(l1_ratios):
90+
warm_start = [None] * self.cv
91+
92+
for idx_alpha, alpha in enumerate(alphas):
93+
fold_results = Parallel(self.n_jobs)(
94+
delayed(_solve_fold)(k, tr, te, alpha, l1_ratio, warm_start[k])
95+
for k, (tr, te) in enumerate(_kfold_split(n_samples, self.cv, rng))
96+
)
97+
98+
for k, (w_fold, loss_fold) in enumerate(fold_results):
99+
warm_start[k] = w_fold
100+
mse_path[idx_ratio, idx_alpha, k] = loss_fold
101+
102+
mean_loss = np.mean(mse_path[idx_ratio, idx_alpha])
103+
if mean_loss < best_loss:
104+
best_loss = mean_loss
105+
self.alpha_ = float(alpha)
106+
self.l1_ratio_ = float(l1_ratio) if l1_ratio is not None else None
107+
108+
# Refit on full dataset
109+
self.penalty.alpha = self.alpha_
110+
if hasattr(self.penalty, "l1_ratio"):
111+
self.penalty.l1_ratio = self.l1_ratio_
112+
super().fit(X, y)
113+
self.alphas_ = alphas
114+
self.mse_path_ = mse_path
115+
return self
116+
117+
def predict(self, X):
118+
"""Predict using the linear model."""
119+
X = np.asarray(X)
120+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
121+
return (X @ self.coef_ + self.intercept_ > 0).astype(int)
122+
return X @ self.coef_ + self.intercept_
123+
124+
def predict_proba(self, X):
125+
"""Probability estimates for classification tasks."""
126+
if not isinstance(self.datafit, (Logistic, QuadraticSVC)):
127+
raise AttributeError(
128+
"predict_proba is only available for classification tasks"
129+
)
130+
X = np.asarray(X)
131+
decision = X @ self.coef_ + self.intercept_
132+
decision_2d = np.c_[-decision, decision]
133+
return softmax(decision_2d, copy=False)
134+
135+
def score(self, X, y):
136+
"""Return a 'higher = better' performance metric."""
137+
return -self._score(np.asarray(y), self.predict(X))

0 commit comments

Comments
 (0)