Skip to content

Commit 265f653

Browse files
committed
do not use mixin but implement functionality
1 parent d544dc4 commit 265f653

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

imblearn/model_selection/_split.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import warnings
2+
13
import numpy as np
24
from sklearn.base import clone
35
from sklearn.model_selection import LeaveOneGroupOut, cross_val_predict
4-
from sklearn.model_selection._split import BaseCrossValidator, _UnsupportedGroupCVMixin
6+
from sklearn.model_selection._split import BaseCrossValidator
57
from sklearn.utils.multiclass import type_of_target
68
from sklearn.utils.validation import _num_samples
79

810

9-
class InstanceHardnessCV(_UnsupportedGroupCVMixin, BaseCrossValidator):
11+
class InstanceHardnessCV(BaseCrossValidator):
1012
"""Instance-hardness cross-validation splitter.
1113
1214
Cross-validation splitter that distributes samples with large instance hardness
@@ -72,6 +74,12 @@ def split(self, X, y, groups=None):
7274
test : ndarray
7375
The testing set indices for that split.
7476
"""
77+
if groups is not None:
78+
warnings.warn(
79+
f"The groups parameter is ignored by {self.__class__.__name__}",
80+
UserWarning,
81+
)
82+
7583
classes = np.unique(y)
7684
y_type = type_of_target(y)
7785
if y_type != "binary":

imblearn/model_selection/tests/test_split.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ def data():
2121
)
2222

2323

24+
def test_groups_parameter_warning(data):
25+
"""Test that a warning is raised when groups parameter is provided."""
26+
X, y = data
27+
ih_cv = InstanceHardnessCV(estimator=LogisticRegression())
28+
29+
warning_msg = "The groups parameter is ignored by InstanceHardnessCV"
30+
with pytest.warns(UserWarning, match=warning_msg):
31+
list(ih_cv.split(X, y, groups=np.ones_like(y)))
32+
33+
2434
def test_error_on_multiclass():
2535
"""Test that an error is raised when the target is not binary."""
2636
X, y = make_classification(n_classes=3, n_clusters_per_class=1)

0 commit comments

Comments
 (0)