Skip to content

Commit d9ba4af

Browse files
authored
TST create sparse and dataframe tags (#803)
1 parent 6155658 commit d9ba4af

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

imblearn/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def _check_X_y(self, X, y, accept_sparse=None):
130130
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
131131
return X, y, binarize_y
132132

133+
def _more_tags(self):
134+
return {"X_types": ["2darray", "sparse", "dataframe"]}
135+
133136

134137
def _identity(X, y):
135138
return X, y

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _fit_resample(self, X, y):
241241

242242
def _more_tags(self):
243243
return {
244-
"X_types": ["2darray", "string"],
244+
"X_types": ["2darray", "string", "sparse", "dataframe"],
245245
"sample_indices": True,
246246
"allow_nan": True,
247247
}

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def _fit_resample(self, X, y):
108108
index_target_class = slice(None)
109109

110110
idx_under = np.concatenate(
111-
(idx_under, np.flatnonzero(y == target_class)[index_target_class],),
111+
(
112+
idx_under,
113+
np.flatnonzero(y == target_class)[index_target_class],
114+
),
112115
axis=0,
113116
)
114117

@@ -118,7 +121,7 @@ def _fit_resample(self, X, y):
118121

119122
def _more_tags(self):
120123
return {
121-
"X_types": ["2darray", "string"],
124+
"X_types": ["2darray", "string", "sparse", "dataframe"],
122125
"sample_indices": True,
123126
"allow_nan": True,
124127
}

imblearn/utils/estimator_checks.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,22 @@ def _set_checking_parameters(estimator):
5151

5252

5353
def _yield_sampler_checks(sampler):
54+
tags = sampler._get_tags()
5455
yield check_target_type
5556
yield check_samplers_one_label
5657
yield check_samplers_fit
5758
yield check_samplers_fit_resample
5859
yield check_samplers_sampling_strategy_fit_resample
59-
yield check_samplers_sparse
60-
yield check_samplers_pandas
60+
if "sparse" in tags["X_types"]:
61+
yield check_samplers_sparse
62+
if "dataframe" in tags["X_types"]:
63+
yield check_samplers_pandas
6164
yield check_samplers_list
6265
yield check_samplers_multiclass_ova
6366
yield check_samplers_preserve_dtype
67+
# we don't filter samplers based on their tag here because we want to make
68+
# sure that the fitted attribute does not exist if the tag is not
69+
# stipulated
6470
yield check_samplers_sample_indices
6571
yield check_samplers_2d_target
6672

@@ -75,7 +81,8 @@ def _yield_all_checks(estimator):
7581
tags = estimator._get_tags()
7682
if tags["_skip_test"]:
7783
warnings.warn(
78-
f"Explicit SKIP via _skip_test tag for estimator {name}.", SkipTestWarning,
84+
f"Explicit SKIP via _skip_test tag for estimator {name}.",
85+
SkipTestWarning,
7986
)
8087
return
8188
# trigger our checks if this is a SamplerMixin
@@ -116,6 +123,7 @@ def parametrize_with_checks(estimators):
116123
... def test_sklearn_compatible_estimator(estimator, check):
117124
... check(estimator)
118125
"""
126+
119127
def checks_generator():
120128
for estimator in estimators:
121129
name = type(estimator).__name__
@@ -124,9 +132,7 @@ def checks_generator():
124132
yield _maybe_mark_xfail(estimator, check, pytest)
125133

126134
return pytest.mark.parametrize(
127-
"estimator, check",
128-
checks_generator(),
129-
ids=_get_check_estimator_ids
135+
"estimator, check", checks_generator(), ids=_get_check_estimator_ids
130136
)
131137

132138

@@ -137,14 +143,22 @@ def check_target_type(name, estimator_orig):
137143
y = np.linspace(0, 1, 20)
138144
msg = "Unknown label type: 'continuous'"
139145
assert_raises_regex(
140-
ValueError, msg, estimator.fit_resample, X, y,
146+
ValueError,
147+
msg,
148+
estimator.fit_resample,
149+
X,
150+
y,
141151
)
142152
# if the target is multilabel then we should raise an error
143153
rng = np.random.RandomState(42)
144154
y = rng.randint(2, size=(20, 3))
145155
msg = "Multilabel and multioutput targets are not supported."
146156
assert_raises_regex(
147-
ValueError, msg, estimator.fit_resample, X, y,
157+
ValueError,
158+
msg,
159+
estimator.fit_resample,
160+
X,
161+
y,
148162
)
149163

150164

@@ -385,9 +399,7 @@ def check_samplers_sample_indices(name, sampler_orig):
385399
assert not hasattr(sampler, "sample_indices_")
386400

387401

388-
def check_classifier_on_multilabel_or_multioutput_targets(
389-
name, estimator_orig
390-
):
402+
def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
391403
estimator = clone(estimator_orig)
392404
X, y = make_multilabel_classification(n_samples=30)
393405
msg = "Multilabel and multioutput targets are not supported."

0 commit comments

Comments
 (0)