Skip to content

Commit 79cc5e1

Browse files
committed
clean up
1 parent 880f7c7 commit 79cc5e1

File tree

6 files changed

+20
-146
lines changed

6 files changed

+20
-146
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ ignore_errors = True
1010
omit =
1111
*/tests/*
1212
**/setup.py
13+
**/_sklearn_compat.py

imblearn/over_sampling/_smote/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import sklearn
1515
from scipy import sparse
16+
from scipy.stats import mode
1617
from sklearn.base import clone
1718
from sklearn.exceptions import DataConversionWarning
1819
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
@@ -31,9 +32,8 @@
3132
from ...metrics.pairwise import ValueDifferenceMetric
3233
from ...utils import Substitution, check_neighbors_object, check_target_type
3334
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
34-
from ...utils._sklearn_compat import validate_data
35+
from ...utils._sklearn_compat import _is_pandas_df, validate_data
3536
from ...utils._validation import _check_X
36-
from ...utils.fixes import _is_pandas_df, _mode
3737
from ..base import BaseOverSampler
3838

3939
sklearn_version = parse_version(sklearn.__version__).base_version
@@ -997,7 +997,8 @@ def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
997997
# where for each feature individually, each category generated is the
998998
# most common category
999999
X_new = np.squeeze(
1000-
_mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
1000+
mode(X_class[nn_indices[samples_indices]], axis=1, keepdims=True).mode,
1001+
axis=1,
10011002
)
10021003
y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
10031004
return X_new, y_new

imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from collections import Counter
1111

1212
import numpy as np
13+
from scipy.stats import mode
1314
from sklearn.utils import _safe_indexing
1415
from sklearn.utils._param_validation import HasMethods, Interval, StrOptions
1516

1617
from ...utils import Substitution, check_neighbors_object
1718
from ...utils._docstring import _n_jobs_docstring
18-
from ...utils.fixes import _mode
1919
from ..base import BaseCleaningSampler
2020

2121
SEL_KIND = ("all", "mode")
@@ -168,7 +168,7 @@ def _fit_resample(self, X, y):
168168
nnhood_idx = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
169169
nnhood_label = y[nnhood_idx]
170170
if self.kind_sel == "mode":
171-
nnhood_label, _ = _mode(nnhood_label, axis=1)
171+
nnhood_label, _ = mode(nnhood_label, axis=1, keepdims=False)
172172
nnhood_bool = np.ravel(nnhood_label) == y_class
173173
elif self.kind_sel == "all":
174174
nnhood_label = nnhood_label == target_class

imblearn/utils/_sklearn_compat.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,23 @@ def _raise_for_params(params, owner, method):
240240
f" details. Extra parameters passed are: {set(params)}"
241241
)
242242

243+
def _is_pandas_df(X):
244+
"""Return True if the X is a pandas dataframe."""
245+
try:
246+
pd = sys.modules["pandas"]
247+
except KeyError:
248+
return False
249+
return isinstance(X, pd.DataFrame)
250+
243251
else:
244252
from sklearn.utils._metadata_requests import (
245253
_raise_for_params, # noqa: F401
246254
process_routing, # noqa: F401
247255
)
248-
from sklearn.utils.validation import _is_fitted # noqa: F401
256+
from sklearn.utils.validation import (
257+
_is_fitted, # noqa: F401
258+
_is_pandas_df, # noqa: F401
259+
)
249260

250261

251262
########################################################################################

imblearn/utils/_validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from sklearn.utils.multiclass import type_of_target
1818
from sklearn.utils.validation import _num_samples
1919

20-
from ..utils._sklearn_compat import check_array
21-
from .fixes import _is_pandas_df
20+
from ..utils._sklearn_compat import _is_pandas_df, check_array
2221

2322
SAMPLING_KIND = (
2423
"over-sampling",

imblearn/utils/fixes.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

0 commit comments

Comments
 (0)