|
2 | 2 | from joblib import Parallel, delayed |
3 | 3 | from skglm.datafits import Logistic, QuadraticSVC |
4 | 4 | from skglm.estimators import GeneralizedLinearEstimator |
5 | | - |
6 | | - |
7 | | -def _kfold_split(n_samples, k, rng): |
8 | | - indices = rng.permutation(n_samples) |
9 | | - fold_size = n_samples // k |
10 | | - extra = n_samples % k |
11 | | - |
12 | | - start = 0 |
13 | | - for i in range(k): |
14 | | - end = start + fold_size + (1 if i < extra else 0) |
15 | | - test = indices[start:end] |
16 | | - train = np.concatenate([indices[:start], indices[end:]]) |
17 | | - yield train, test |
18 | | - start = end |
| 5 | +from sklearn.model_selection import KFold, StratifiedKFold |
19 | 6 |
|
20 | 7 |
|
21 | 8 | class GeneralizedLinearEstimatorCV(GeneralizedLinearEstimator): |
@@ -48,7 +35,6 @@ def fit(self, X, y): |
48 | 35 | "expose an 'alpha' parameter." |
49 | 36 | ) |
50 | 37 | n_samples, n_features = X.shape |
51 | | - rng = np.random.RandomState(self.random_state) |
52 | 38 |
|
53 | 39 | if self.alphas is not None: |
54 | 40 | alphas = np.sort(self.alphas)[::-1] |
@@ -86,9 +72,17 @@ def _solve_fold(k, train, test, alpha, l1, w_init): |
86 | 72 | warm_start = [None] * self.cv |
87 | 73 |
|
88 | 74 | for idx_alpha, alpha in enumerate(alphas): |
| 75 | + if isinstance(self.datafit, (Logistic, QuadraticSVC)): |
| 76 | + kf = StratifiedKFold(n_splits=self.cv, shuffle=True, |
| 77 | + random_state=self.random_state) |
| 78 | + split_iter = kf.split(np.arange(n_samples), y) |
| 79 | + else: |
| 80 | + kf = KFold(n_splits=self.cv, shuffle=True, |
| 81 | + random_state=self.random_state) |
| 82 | + split_iter = kf.split(np.arange(n_samples)) |
89 | 83 | fold_results = Parallel(self.n_jobs)( |
90 | 84 | delayed(_solve_fold)(k, tr, te, alpha, l1_ratio, warm_start[k]) |
91 | | - for k, (tr, te) in enumerate(_kfold_split(n_samples, self.cv, rng)) |
| 85 | + for k, (tr, te) in enumerate(split_iter) |
92 | 86 | ) |
93 | 87 |
|
94 | 88 | for k, (coef_fold, intercept_fold, loss_fold) in \ |
|
0 commit comments