Skip to content

Commit fefc1a8

Browse files
adress further comments from mathurin
1 parent ed6f5f9 commit fefc1a8

File tree

1 file changed

+39
-37
lines changed

1 file changed

+39
-37
lines changed

skglm/cv.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
from joblib import Parallel, delayed
3-
from sklearn.utils.extmath import softmax
43
from skglm.datafits import Logistic, QuadraticSVC
54
from skglm.estimators import GeneralizedLinearEstimator
65

@@ -45,9 +44,9 @@ def fit(self, X, y):
4544
"""Fit the model using cross-validation."""
4645
if not hasattr(self.penalty, "alpha"):
4746
raise ValueError(
48-
"GeneralizedLinearEstimatorCV only supports penalties with 'alpha'."
47+
"GeneralizedLinearEstimatorCV only supports penalties which "
48+
"expose an 'alpha' parameter."
4949
)
50-
y = np.asarray(y)
5150
n_samples, n_features = X.shape
5251
rng = np.random.RandomState(self.random_state)
5352

@@ -67,24 +66,21 @@ def fit(self, X, y):
6766
mse_path = np.empty((len(l1_ratios), len(alphas), self.cv))
6867
best_loss = np.inf
6968

70-
def _solve_fold(k, train, test, alpha, l1, w_start):
69+
def _solve_fold(k, train, test, alpha, l1, w_init):
7170
pen_kwargs = {k: v for k, v in self.penalty.__dict__.items()
7271
if k not in ("alpha", "l1_ratio")}
7372
if has_l1_ratio:
7473
pen_kwargs['l1_ratio'] = l1
7574
pen = type(self.penalty)(alpha=alpha, **pen_kwargs)
7675

77-
kw = dict(X=X[train], y=y[train], datafit=self.datafit, penalty=pen)
78-
if 'w' in self.solver.solve.__code__.co_varnames:
79-
kw['w'] = w_start
80-
w = self.solver.solve(**kw)
81-
w = w[0] if isinstance(w, tuple) else w
82-
83-
coef, intercept = (w[:n_features], w[n_features]
84-
) if w.size == n_features + 1 else (w, 0.0)
85-
86-
y_pred = X[test] @ coef + intercept
87-
return w, self._score(y[test], y_pred)
76+
est = GeneralizedLinearEstimator(
77+
datafit=self.datafit, penalty=pen, solver=self.solver
78+
)
79+
est.penalty.alpha = alpha
80+
est.solver.warm_start = True
81+
est.fit(X[train], y[train])
82+
y_pred = est.predict(X[test])
83+
return est.coef_, est.intercept_, self._score(y[test], y_pred)
8884

8985
for idx_ratio, l1_ratio in enumerate(l1_ratios):
9086
warm_start = [None] * self.cv
@@ -95,8 +91,9 @@ def _solve_fold(k, train, test, alpha, l1, w_start):
9591
for k, (tr, te) in enumerate(_kfold_split(n_samples, self.cv, rng))
9692
)
9793

98-
for k, (w_fold, loss_fold) in enumerate(fold_results):
99-
warm_start[k] = w_fold
94+
for k, (coef_fold, intercept_fold, loss_fold) in \
95+
enumerate(fold_results):
96+
warm_start[k] = (coef_fold, intercept_fold)
10097
mse_path[idx_ratio, idx_alpha, k] = loss_fold
10198

10299
mean_loss = np.mean(mse_path[idx_ratio, idx_alpha])
@@ -106,32 +103,37 @@ def _solve_fold(k, train, test, alpha, l1, w_start):
106103
self.l1_ratio_ = float(l1_ratio) if l1_ratio is not None else None
107104

108105
# Refit on full dataset
109-
self.penalty.alpha = self.alpha_
106+
pen_kwargs = {k: v for k, v in self.penalty.__dict__.items()
107+
if k not in ("alpha", "l1_ratio")}
110108
if hasattr(self.penalty, "l1_ratio"):
111-
self.penalty.l1_ratio = self.l1_ratio_
112-
super().fit(X, y)
109+
best_penalty = type(self.penalty)(
110+
alpha=self.alpha_, l1_ratio=self.l1_ratio_, **pen_kwargs
111+
)
112+
else:
113+
best_penalty = type(self.penalty)(
114+
alpha=self.alpha_, **pen_kwargs
115+
)
116+
best_estimator = GeneralizedLinearEstimator(
117+
datafit=self.datafit,
118+
penalty=best_penalty,
119+
solver=self.solver
120+
)
121+
best_estimator.fit(X, y)
122+
self.best_estimator_ = best_estimator
123+
self.coef_ = best_estimator.coef_
124+
self.intercept_ = best_estimator.intercept_
125+
self.n_iter_ = getattr(best_estimator, "n_iter_", None)
126+
self.n_features_in_ = getattr(best_estimator, "n_features_in_", None)
127+
self.feature_names_in_ = getattr(best_estimator, "feature_names_in_", None)
113128
self.alphas_ = alphas
114-
self.mse_path_ = mse_path
129+
self.mse_path_ = np.squeeze(mse_path)
115130
return self
116131

117132
def predict(self, X):
118-
"""Predict using the linear model."""
119-
X = np.asarray(X)
120-
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
121-
return (X @ self.coef_ + self.intercept_ > 0).astype(int)
122-
return X @ self.coef_ + self.intercept_
133+
return self.best_estimator_.predict(X)
123134

124135
def predict_proba(self, X):
125-
"""Probability estimates for classification tasks."""
126-
if not isinstance(self.datafit, (Logistic, QuadraticSVC)):
127-
raise AttributeError(
128-
"predict_proba is only available for classification tasks"
129-
)
130-
X = np.asarray(X)
131-
decision = X @ self.coef_ + self.intercept_
132-
decision_2d = np.c_[-decision, decision]
133-
return softmax(decision_2d, copy=False)
136+
return self.best_estimator_.predict_proba(X)
134137

135138
def score(self, X, y):
136-
"""Return a 'higher = better' performance metric."""
137-
return -self._score(np.asarray(y), self.predict(X))
139+
return self.best_estimator_.score(X, y)

0 commit comments

Comments
 (0)