diff --git a/mlxtend/feature_selection/sequential_feature_selector.py b/mlxtend/feature_selection/sequential_feature_selector.py index e907a4ab1..6879c127e 100644 --- a/mlxtend/feature_selection/sequential_feature_selector.py +++ b/mlxtend/feature_selection/sequential_feature_selector.py @@ -18,15 +18,19 @@ from sklearn.base import BaseEstimator from sklearn.base import MetaEstimatorMixin from ..externals.name_estimators import _name_estimators -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_val_score, RepeatedKFold, RepeatedStratifiedKFold from sklearn.externals.joblib import Parallel, delayed - def _calc_score(selector, X, y, indices): if selector.cv: + if selector.n_cv_repeats > 0: + if selector._estimator_type == 'classifier': + cv_folds = RepeatedStratifiedKFold(n_splits=selector.cv, n_repeats=selector.n_cv_repeats) + else: + cv_folds = RepeatedKFold(n_splits=selector.cv, n_repeats=selector.n_cv_repeats) scores = cross_val_score(selector.est_, X[:, indices], y, - cv=selector.cv, + cv=selector.cv if selector.n_cv_repeats == 0 else cv_folds, scoring=selector.scorer, n_jobs=1, pre_dispatch=selector.pre_dispatch) @@ -103,6 +107,11 @@ class SequentialFeatureSelector(BaseEstimator, MetaEstimatorMixin): if False. Set to False if the estimator doesn't implement scikit-learn's set_params and get_params methods. In addition, it is required to set cv=0, and n_jobs=1. + n_cv_repeats : int (default = 0) + The number of times cross-validation will be repeated. If 0 then it's + not repeated. Negative numbers raise an exception. Uses Scikit-learn + RepeatedStratifiedKFold for a classifier or RepeatedKFold otherwise. + Attributes ---------- @@ -125,7 +134,8 @@ def __init__(self, estimator, k_features=1, verbose=0, scoring=None, cv=5, n_jobs=1, pre_dispatch='2*n_jobs', - clone_estimator=True): + clone_estimator=True, + n_cv_repeats = 0): self.estimator = estimator self.k_features = k_features @@ -149,6 +159,13 @@ def __init__(self, estimator, k_features=1, self.est_ = self.estimator self.scoring = scoring + self.n_cv_repeats = n_cv_repeats + if self.n_cv_repeats < 0: + raise AttributeError('Number of cross-validation repeats should be >= 0.') + if not self.cv and self.n_cv_repeats > 0: + raise AttributeError('Cannot repeat cross-validation when it\'s set to 0.') + + if scoring is None: if self.est_._estimator_type == 'classifier': scoring = 'accuracy'