Skip to content

Commit 33d2b21

Browse files
author
Guillaume Lemaitre
committed
Merge branch 'deprecation_warning' of https://github.com/dvro/UnbalancedDataset into pr/196
Conflicts: imblearn/under_sampling/instance_hardness_threshold.py
2 parents c747362 + c8260ed commit 33d2b21

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

imblearn/under_sampling/instance_hardness_threshold.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,30 @@
33
from __future__ import division, print_function
44

55
import warnings
6-
76
from collections import Counter
87

98
import numpy as np
10-
9+
from six import string_types
10+
import sklearn
1111
from sklearn.base import ClassifierMixin
1212
from sklearn.ensemble import RandomForestClassifier
13-
from sklearn.cross_validation import StratifiedKFold
14-
15-
from six import string_types
1613

1714
from ..base import BaseBinarySampler
1815

1916

17+
def _get_cv_splits(X, y, cv, random_state):
18+
if hasattr(sklearn, 'model_selection'):
19+
from sklearn.model_selection import StratifiedKFold
20+
cv_iterator = StratifiedKFold(
21+
n_splits=cv, shuffle=False, random_state=random_state).split(X, y)
22+
else:
23+
from sklearn.cross_validation import StratifiedKFold
24+
cv_iterator = StratifiedKFold(
25+
y, n_folds=cv, shuffle=False, random_state=random_state)
26+
27+
return cv_iterator
28+
29+
2030
class InstanceHardnessThreshold(BaseBinarySampler):
2131
"""Class to perform under-sampling based on the instance hardness
2232
threshold.
@@ -225,8 +235,7 @@ def _sample(self, X, y):
225235
"""
226236

227237
# Create the different folds
228-
skf = StratifiedKFold(
229-
y, n_folds=self.cv, shuffle=False, random_state=self.random_state)
238+
skf = _get_cv_splits(X, y, self.cv, self.random_state)
230239

231240
probabilities = np.zeros(y.shape[0], dtype=float)
232241

0 commit comments

Comments
 (0)