Skip to content

Commit 4225dff

Browse files
author
chkoar
committed
using a function
2 parents 62f6d2f + 2b653d0 commit 4225dff

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

imblearn/under_sampling/instance_hardness_threshold.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,28 @@
33
from __future__ import division, print_function
44

55
import warnings
6-
76
from collections import Counter
87

98
import numpy as np
10-
9+
from six import string_types
1110
from sklearn.base import ClassifierMixin
1211
from sklearn.ensemble import RandomForestClassifier
13-
from sklearn.cross_validation import StratifiedKFold
14-
15-
from six import string_types
1612

1713
from ..base import BaseBinarySampler
1814

1915

16+
def _get_cv_splits(X, y, cv, random_state):
17+
try:
18+
from sklearn.model_selection import StratifiedKFold
19+
cv_iterator = StratifiedKFold(
20+
n_splits=cv, shuffle=False, random_state=random_state).split(X, y)
21+
except:
22+
from sklearn.cross_validation import StratifiedKFold
23+
cv_iterator = StratifiedKFold(
24+
y, n_folds=cv, shuffle=False, random_state=random_state)
25+
return cv_iterator
26+
27+
2028
class InstanceHardnessThreshold(BaseBinarySampler):
2129
"""Class to perform under-sampling based on the instance hardness
2230
threshold.
@@ -225,8 +233,7 @@ def _sample(self, X, y):
225233
"""
226234

227235
# Create the different folds
228-
skf = StratifiedKFold(y, n_folds=self.cv, shuffle=False,
229-
random_state=self.random_state)
236+
skf = _get_cv_splits(X, y, self.cv, self.random_state)
230237

231238
probabilities = np.zeros(y.shape[0], dtype=float)
232239

0 commit comments

Comments
 (0)