Skip to content

Commit a5cb58b

Browse files
committed
iter
1 parent 6d1805a commit a5cb58b

File tree

7 files changed

+53
-42
lines changed

7 files changed

+53
-42
lines changed

imblearn/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.utils.multiclass import check_classification_targets
1313

1414
from .utils import check_sampling_strategy, check_target_type
15-
from .utils._sklearn_compat import _fit_context, validate_data
15+
from .utils._sklearn_compat import _fit_context, get_tags, validate_data
1616
from .utils._validation import ArraysTransformer
1717

1818

@@ -217,7 +217,11 @@ def is_sampler(estimator):
217217
is_sampler : bool
218218
True if estimator is a sampler, otherwise False.
219219
"""
220-
if estimator._estimator_type == "sampler":
220+
221+
if hasattr(estimator, "_estimator_type") and estimator._estimator_type == "sampler":
222+
return True
223+
tags = get_tags(estimator)
224+
if hasattr(tags, "sampler_tags") and tags.sampler_tags is not None:
221225
return True
222226
return False
223227

imblearn/metrics/pairwise.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ValueDifferenceMetric(BaseEstimator):
118118
[0.04, 0. , 1.44],
119119
[1.96, 1.44, 0. ]])
120120
"""
121+
121122
_parameter_constraints: dict = {
122123
"n_categories": [StrOptions({"auto"}), "array-like"],
123124
"k": [numbers.Integral],
@@ -150,6 +151,7 @@ def fit(self, X, y):
150151
self._validate_params()
151152
check_consistent_length(X, y)
152153
X, y = validate_data(self, X=X, y=y, reset=True, dtype=np.int32)
154+
X = check_array(X, ensure_non_negative=True)
153155

154156
if isinstance(self.n_categories, str) and self.n_categories == "auto":
155157
# categories are expected to be encoded from 0 to n_categories - 1
@@ -208,11 +210,11 @@ def pairwise(self, X, Y=None):
208210
The VDM pairwise distance.
209211
"""
210212
check_is_fitted(self)
211-
X = check_array(X, dtype=np.int32)
213+
X = check_array(X, ensure_non_negative=True, dtype=np.int32)
212214
n_samples_X = X.shape[0]
213215

214216
if Y is not None:
215-
Y = check_array(Y, dtype=np.int32)
217+
Y = check_array(Y, ensure_non_negative=True, dtype=np.int32)
216218
n_samples_Y = Y.shape[0]
217219
else:
218220
n_samples_Y = n_samples_X

imblearn/tests/test_common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import sklearn
1313
from sklearn.exceptions import ConvergenceWarning
1414
from sklearn.utils._testing import ignore_warnings
15-
from sklearn.utils.estimator_checks import (
16-
parametrize_with_checks as parametrize_with_checks_sklearn,
17-
)
1815
from sklearn.utils.fixes import parse_version
1916

2017
from imblearn.over_sampling import RandomOverSampler
2118
from imblearn.under_sampling import RandomUnderSampler
19+
from imblearn.utils._sklearn_compat import (
20+
parametrize_with_checks as parametrize_with_checks_sklearn,
21+
)
2222
from imblearn.utils._test_common.instance_generator import (
2323
_get_check_estimator_ids,
2424
_get_expected_failed_checks,

imblearn/tests/test_docstring_parameters.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import pytest
1212
from sklearn.datasets import make_classification
13-
from sklearn.linear_model import LogisticRegression
1413
from sklearn.utils._testing import (
1514
_get_func_name,
1615
check_docstring_parameters,
@@ -24,9 +23,9 @@
2423

2524
import imblearn
2625
from imblearn.base import is_sampler
27-
from imblearn.utils._sklearn_compat import _construct_instances
26+
from imblearn.under_sampling import NearMiss
27+
from imblearn.utils._test_common.instance_generator import _tested_estimators
2828
from imblearn.utils.estimator_checks import _set_checking_parameters
29-
from imblearn.utils.testing import all_estimators
3029

3130
# walk_packages() ignores DeprecationWarnings, now we need to ignore
3231
# FutureWarnings
@@ -43,10 +42,10 @@
4342
)
4443

4544
# functions to ignore args / docstring of
46-
_DOCSTRING_IGNORES = [
47-
"RUSBoostClassifier", # TODO remove after releasing scikit-learn 1.0.1
48-
"ValueDifferenceMetric",
49-
]
45+
_DOCSTRING_IGNORES = ["ValueDifferenceMetric"]
46+
_IGNORE_ATTRIBUTES = {
47+
NearMiss: ["nn_ver3_"],
48+
}
5049

5150
# Methods where y param should be ignored if y=None by default
5251
_METHODS_IGNORE_NONE_Y = [
@@ -159,28 +158,19 @@ def test_tabs():
159158
)
160159

161160

162-
def _construct_compose_pipeline_instance(Estimator):
163-
# Minimal / degenerate instances: only useful to test the docstrings.
164-
if Estimator.__name__ == "Pipeline":
165-
return Estimator(steps=[("clf", LogisticRegression())])
166-
167-
168-
@pytest.mark.parametrize("name, Estimator", all_estimators())
169-
def test_fit_docstring_attributes(name, Estimator):
161+
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
162+
def test_fit_docstring_attributes(estimator):
170163
pytest.importorskip("numpydoc")
171164
from numpydoc import docscrape
172165

166+
Estimator = estimator.__class__
173167
if Estimator.__name__ in _DOCSTRING_IGNORES:
174168
return
175169

176170
doc = docscrape.ClassDoc(Estimator)
177171
attributes = doc["Attributes"]
178172

179-
if Estimator.__name__ == "Pipeline":
180-
est = _construct_compose_pipeline_instance(Estimator)
181-
else:
182-
est = next(_construct_instances(Estimator))
183-
_set_checking_parameters(est)
173+
_set_checking_parameters(estimator)
184174

185175
X, y = make_classification(
186176
n_samples=20,
@@ -190,16 +180,16 @@ def test_fit_docstring_attributes(name, Estimator):
190180
random_state=2,
191181
)
192182

193-
y = _enforce_estimator_tags_y(est, y)
194-
X = _enforce_estimator_tags_X(est, X)
183+
y = _enforce_estimator_tags_y(estimator, y)
184+
X = _enforce_estimator_tags_X(estimator, X)
195185

196-
if "oob_score" in est.get_params():
197-
est.set_params(bootstrap=True, oob_score=True)
186+
if "oob_score" in estimator.get_params():
187+
estimator.set_params(bootstrap=True, oob_score=True)
198188

199-
if is_sampler(est):
200-
est.fit_resample(X, y)
189+
if is_sampler(estimator):
190+
estimator.fit_resample(X, y)
201191
else:
202-
est.fit(X, y)
192+
estimator.fit(X, y)
203193

204194
skipped_attributes = set(
205195
[
@@ -218,9 +208,11 @@ def test_fit_docstring_attributes(name, Estimator):
218208
continue
219209
# ignore deprecation warnings
220210
with ignore_warnings(category=FutureWarning):
221-
assert hasattr(est, attr.name)
211+
if attr.name in _IGNORE_ATTRIBUTES.get(Estimator, []):
212+
continue
213+
assert hasattr(estimator, attr.name)
222214

223-
fit_attr = _get_all_fitted_attributes(est)
215+
fit_attr = _get_all_fitted_attributes(estimator)
224216
fit_attr_names = [attr.name for attr in attributes]
225217
undocumented_attrs = set(fit_attr).difference(fit_attr_names)
226218
undocumented_attrs = set(undocumented_attrs).difference(skipped_attributes)

imblearn/under_sampling/_prototype_selection/_nearmiss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ class NearMiss(BaseUnderSampler):
6262
nn_ : estimator object
6363
Validated K-nearest Neighbours object created from `n_neighbors` parameter.
6464
65+
nn_ver3_ : estimator object
66+
Validated K-nearest Neighbours object created from `n_neighbors_ver3` parameter.
67+
6568
sample_indices_ : ndarray of shape (n_new_samples,)
6669
Indices of the samples selected.
6770

imblearn/utils/_test_common/instance_generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
RandomOverSampler: dict(random_state=42),
6868
SMOTE: dict(random_state=42),
6969
SMOTEN: dict(random_state=42),
70+
SMOTENC: dict(categorical_features=[0], random_state=42),
7071
SVMSMOTE: dict(random_state=42),
7172
# under-sampling
7273
ClusterCentroids: dict(random_state=42),
@@ -199,6 +200,8 @@ def _yield_instances_for_check(check, estimator_orig):
199200
PER_ESTIMATOR_XFAIL_CHECKS = {
200201
BalancedRandomForestClassifier: {
201202
"check_sample_weight_equivalence": "FIXME",
203+
"check_sample_weight_equivalence_on_sparse_data": "FIXME",
204+
"check_sample_weight_equivalence_on_dense_data": "FIXME",
202205
},
203206
NearMiss: {
204207
"check_samplers_fit_resample": "FIXME",
@@ -212,9 +215,14 @@ def _yield_instances_for_check(check, estimator_orig):
212215
"Pipeline changes the `steps` parameter, which it shouldn't."
213216
"Therefore this test is x-fail until we fix this."
214217
),
218+
"check_classifiers_train": "FIXME",
219+
"check_supervised_y_2d": "FIXME",
215220
},
216221
RUSBoostClassifier: {
217222
"check_sample_weight_equivalence": "FIXME",
223+
"check_sample_weight_equivalence_on_sparse_data": "FIXME",
224+
"check_sample_weight_equivalence_on_dense_data": "FIXME",
225+
"check_estimator_sparse_matrix": "FIXME",
218226
},
219227
}
220228

imblearn/utils/tests/test_estimator_checks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from imblearn.base import BaseSampler
77
from imblearn.over_sampling.base import BaseOverSampler
88
from imblearn.utils import check_target_type as target_check
9+
from imblearn.utils._sklearn_compat import validate_data
910
from imblearn.utils.estimator_checks import (
1011
check_samplers_fit,
1112
check_samplers_nan,
@@ -47,15 +48,15 @@ class NotFittedSampler(BaseBadSampler):
4748
"""Sampler without target checking."""
4849

4950
def fit(self, X, y):
50-
X, y = self._validate_data(X, y)
51+
X, y = validate_data(self, X=X, y=y)
5152
return self
5253

5354

5455
class NoAcceptingSparseSampler(BaseBadSampler):
5556
"""Sampler which does not accept sparse matrix."""
5657

5758
def fit(self, X, y):
58-
X, y = self._validate_data(X, y)
59+
X, y = validate_data(self, X=X, y=y)
5960
self.sampling_strategy_ = "sampling_strategy_"
6061
return self
6162

@@ -72,12 +73,13 @@ def _fit_resample(self, X, y):
7273
class IndicesSampler(BaseOverSampler):
7374
def _check_X_y(self, X, y):
7475
y, binarize_y = target_check(y, indicate_one_vs_all=True)
75-
X, y = self._validate_data(
76-
X,
77-
y,
76+
X, y = validate_data(
77+
self,
78+
X=X,
79+
y=y,
7880
reset=True,
7981
dtype=None,
80-
force_all_finite=False,
82+
ensure_all_finite=False,
8183
)
8284
return X, y, binarize_y
8385

0 commit comments

Comments
 (0)