Skip to content

Commit f8be785

Browse files
change to sklearn Kfold()
1 parent 6288205 commit f8be785

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

skglm/cv.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,7 @@
22
from joblib import Parallel, delayed
33
from skglm.datafits import Logistic, QuadraticSVC
44
from skglm.estimators import GeneralizedLinearEstimator
5-
6-
7-
def _kfold_split(n_samples, k, rng):
8-
indices = rng.permutation(n_samples)
9-
fold_size = n_samples // k
10-
extra = n_samples % k
11-
12-
start = 0
13-
for i in range(k):
14-
end = start + fold_size + (1 if i < extra else 0)
15-
test = indices[start:end]
16-
train = np.concatenate([indices[:start], indices[end:]])
17-
yield train, test
18-
start = end
5+
from sklearn.model_selection import KFold, StratifiedKFold
196

207

218
class GeneralizedLinearEstimatorCV(GeneralizedLinearEstimator):
@@ -48,7 +35,6 @@ def fit(self, X, y):
4835
"expose an 'alpha' parameter."
4936
)
5037
n_samples, n_features = X.shape
51-
rng = np.random.RandomState(self.random_state)
5238

5339
if self.alphas is not None:
5440
alphas = np.sort(self.alphas)[::-1]
@@ -86,9 +72,17 @@ def _solve_fold(k, train, test, alpha, l1, w_init):
8672
warm_start = [None] * self.cv
8773

8874
for idx_alpha, alpha in enumerate(alphas):
75+
if isinstance(self.datafit, (Logistic, QuadraticSVC)):
76+
kf = StratifiedKFold(n_splits=self.cv, shuffle=True,
77+
random_state=self.random_state)
78+
split_iter = kf.split(np.arange(n_samples), y)
79+
else:
80+
kf = KFold(n_splits=self.cv, shuffle=True,
81+
random_state=self.random_state)
82+
split_iter = kf.split(np.arange(n_samples))
8983
fold_results = Parallel(self.n_jobs)(
9084
delayed(_solve_fold)(k, tr, te, alpha, l1_ratio, warm_start[k])
91-
for k, (tr, te) in enumerate(_kfold_split(n_samples, self.cv, rng))
85+
for k, (tr, te) in enumerate(split_iter)
9286
)
9387

9488
for k, (coef_fold, intercept_fold, loss_fold) in \

0 commit comments

Comments
 (0)