Skip to content

Commit e9684c1

Browse files
committed
fix tests
1 parent 265f653 commit e9684c1

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

imblearn/model_selection/_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def split(self, X, y, groups=None):
9494
)
9595
# sorting first on y and then by the instance hardness
9696
sorted_indices = np.lexsort((y_proba[:, pos_label], y))
97-
groups = np.zeros(len(X), dtype=int)
97+
groups = np.empty(_num_samples(X), dtype=int)
9898
groups[sorted_indices] = np.arange(_num_samples(X)) % self.n_splits
9999
cv = LeaveOneGroupOut()
100100
for train_index, test_index in cv.split(X, y, groups):

imblearn/model_selection/tests/test_split.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from sklearn.datasets import make_classification
44
from sklearn.linear_model import LogisticRegression
5+
from sklearn.metrics import make_scorer, precision_score
56
from sklearn.model_selection import cross_validate
67
from sklearn.utils._testing import assert_allclose
78

@@ -11,20 +12,20 @@
1112
@pytest.fixture
1213
def data():
1314
return make_classification(
14-
weights=[0.9, 0.1],
15-
class_sep=2,
15+
weights=[0.5, 0.5],
16+
class_sep=0.5,
1617
n_informative=3,
1718
n_redundant=1,
1819
flip_y=0.05,
19-
n_samples=1000,
20+
n_samples=50,
2021
random_state=10,
2122
)
2223

2324

2425
def test_groups_parameter_warning(data):
2526
"""Test that a warning is raised when groups parameter is provided."""
2627
X, y = data
27-
ih_cv = InstanceHardnessCV(estimator=LogisticRegression())
28+
ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3)
2829

2930
warning_msg = "The groups parameter is ignored by InstanceHardnessCV"
3031
with pytest.warns(UserWarning, match=warning_msg):
@@ -42,9 +43,11 @@ def test_error_on_multiclass():
4243
def test_default_params(data):
4344
"""Test that the default parameters are used."""
4445
X, y = data
45-
ih_cv = InstanceHardnessCV(estimator=LogisticRegression())
46-
cv_result = cross_validate(LogisticRegression(), X, y, cv=ih_cv)
47-
assert_allclose(cv_result["test_score"], [0.975, 0.965, 0.96, 0.955, 0.965])
46+
ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3)
47+
cv_result = cross_validate(
48+
LogisticRegression(), X, y, cv=ih_cv, scoring="precision"
49+
)
50+
assert_allclose(cv_result["test_score"], [0.625, 0.6, 0.625], atol=1e-6, rtol=1e-6)
4851

4952

5053
@pytest.mark.parametrize("dtype_target", [None, object])
@@ -53,9 +56,15 @@ def test_target_string_labels(data, dtype_target):
5356
X, y = data
5457
labels = np.array(["a", "b"], dtype=dtype_target)
5558
y = labels[y]
56-
ih_cv = InstanceHardnessCV(estimator=LogisticRegression())
57-
cv_result = cross_validate(LogisticRegression(), X, y, cv=ih_cv)
58-
assert_allclose(cv_result["test_score"], [0.975, 0.965, 0.96, 0.955, 0.965])
59+
ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), n_splits=3)
60+
cv_result = cross_validate(
61+
LogisticRegression(),
62+
X,
63+
y,
64+
cv=ih_cv,
65+
scoring=make_scorer(precision_score, pos_label="b"),
66+
)
67+
assert_allclose(cv_result["test_score"], [0.625, 0.6, 0.625], atol=1e-6, rtol=1e-6)
5968

6069

6170
@pytest.mark.parametrize("dtype_target", [None, object])
@@ -68,9 +77,19 @@ def test_target_string_pos_label(data, dtype_target):
6877
X, y = data
6978
labels = np.array(["a", "b"], dtype=dtype_target)
7079
y = labels[y]
71-
ih_cv = InstanceHardnessCV(estimator=LogisticRegression(), pos_label="a")
72-
cv_result = cross_validate(LogisticRegression(), X, y, cv=ih_cv)
73-
assert_allclose(cv_result["test_score"], [0.965, 0.975, 0.965, 0.955, 0.96])
80+
ih_cv = InstanceHardnessCV(
81+
estimator=LogisticRegression(), pos_label="a", n_splits=3
82+
)
83+
cv_result = cross_validate(
84+
LogisticRegression(),
85+
X,
86+
y,
87+
cv=ih_cv,
88+
scoring=make_scorer(precision_score, pos_label="a"),
89+
)
90+
assert_allclose(
91+
cv_result["test_score"], [0.666667, 0.666667, 0.4], atol=1e-6, rtol=1e-6
92+
)
7493

7594

7695
@pytest.mark.parametrize("n_splits", [2, 3, 4])

0 commit comments

Comments
 (0)