Skip to content

Commit 95e21e1

Browse files
authored
API deprecate estimator_ in favor of estimators_ in CNN and OSS (#1011)
1 parent d8cf8d6 commit 95e21e1

File tree

5 files changed

+137
-17
lines changed

5 files changed

+137
-17
lines changed

doc/whats_new/v0.12.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
Version 0.12.0 (Under development)
44
==================================
55

6-
76
Changelog
87
---------
8+
9+
Deprecations
10+
............
11+
12+
- Deprecate `estimator_` argument in favor of `estimators_` for the classes
13+
:class:`~imblearn.under_sampling.CondensedNearestNeighbour` and
14+
:class:`~imblearn.under_sampling.OneSidedSelection`. `estimator_` will be removed
15+
in 0.14.
16+
:pr:`xxx` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# License: MIT
77

88
import numbers
9+
import warnings
910
from collections import Counter
1011

1112
import numpy as np
@@ -59,6 +60,16 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
5960
estimator_ : estimator object
6061
The validated K-nearest neighbor estimator created from `n_neighbors` parameter.
6162
63+
.. deprecated:: 0.12
64+
`estimator_` is deprecated in 0.12 and will be removed in 0.14. Use
65+
`estimators_` instead that contains the list of all K-nearest
66+
neighbors estimator used for each pair of class.
67+
68+
estimators_ : list of estimator objects of shape (n_resampled_classes - 1,)
69+
Contains the K-nearest neighbor estimator used for per of classes.
70+
71+
.. versionadded:: 0.12
72+
6273
sample_indices_ : ndarray of shape (n_new_samples,)
6374
Indices of the samples selected.
6475
@@ -87,8 +98,8 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
8798
-----
8899
The method is based on [1]_.
89100
90-
Supports multi-class resampling. A one-vs.-rest scheme is used when
91-
sampling a class as proposed in [1]_.
101+
Supports multi-class resampling: a strategy one (minority) vs. each other
102+
classes is applied.
92103
93104
References
94105
----------
@@ -142,22 +153,25 @@ def __init__(
142153
def _validate_estimator(self):
143154
"""Private function to create the NN estimator"""
144155
if self.n_neighbors is None:
145-
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
156+
estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
146157
elif isinstance(self.n_neighbors, numbers.Integral):
147-
self.estimator_ = KNeighborsClassifier(
158+
estimator = KNeighborsClassifier(
148159
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
149160
)
150161
elif isinstance(self.n_neighbors, KNeighborsClassifier):
151-
self.estimator_ = clone(self.n_neighbors)
162+
estimator = clone(self.n_neighbors)
163+
164+
return estimator
152165

153166
def _fit_resample(self, X, y):
154-
self._validate_estimator()
167+
estimator = self._validate_estimator()
155168

156169
random_state = check_random_state(self.random_state)
157170
target_stats = Counter(y)
158171
class_minority = min(target_stats, key=target_stats.get)
159172
idx_under = np.empty((0,), dtype=int)
160173

174+
self.estimators_ = []
161175
for target_class in np.unique(y):
162176
if target_class in self.sampling_strategy_.keys():
163177
# Randomly get one sample from the majority class
@@ -184,7 +198,7 @@ def _fit_resample(self, X, y):
184198
S_y = _safe_indexing(y, S_indices)
185199

186200
# fit knn on C
187-
self.estimator_.fit(C_x, C_y)
201+
self.estimators_.append(clone(estimator).fit(C_x, C_y))
188202

189203
good_classif_label = idx_maj_sample.copy()
190204
# Check each sample in S if we keep it or drop it
@@ -196,7 +210,7 @@ def _fit_resample(self, X, y):
196210
# Classify on S
197211
if not issparse(x_sam):
198212
x_sam = x_sam.reshape(1, -1)
199-
pred_y = self.estimator_.predict(x_sam)
213+
pred_y = self.estimators_[-1].predict(x_sam)
200214

201215
# If the prediction do not agree with the true label
202216
# append it in C_x
@@ -210,12 +224,12 @@ def _fit_resample(self, X, y):
210224
C_y = _safe_indexing(y, C_indices)
211225

212226
# fit a knn on C
213-
self.estimator_.fit(C_x, C_y)
227+
self.estimators_[-1].fit(C_x, C_y)
214228

215229
# This experimental to speed up the search
216230
# Classify all the element in S and avoid to test the
217231
# well classified elements
218-
pred_S_y = self.estimator_.predict(S_x)
232+
pred_S_y = self.estimators_[-1].predict(S_x)
219233
good_classif_label = np.unique(
220234
np.append(idx_maj_sample, np.flatnonzero(pred_S_y == S_y))
221235
)
@@ -230,5 +244,15 @@ def _fit_resample(self, X, y):
230244

231245
return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under)
232246

247+
@property
248+
def estimator_(self):
249+
"""Last fitted k-NN estimator."""
250+
warnings.warn(
251+
"`estimator_` attribute has been deprecated in 0.12 and will be "
252+
"removed in 0.14. Use `estimators_` instead.",
253+
FutureWarning,
254+
)
255+
return self.estimators_[-1]
256+
233257
def _more_tags(self):
234258
return {"sample_indices": True}

imblearn/under_sampling/_prototype_selection/_one_sided_selection.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: MIT
66

77
import numbers
8+
import warnings
89
from collections import Counter
910

1011
import numpy as np
@@ -58,6 +59,16 @@ class OneSidedSelection(BaseCleaningSampler):
5859
estimator_ : estimator object
5960
Validated K-nearest neighbors estimator created from parameter `n_neighbors`.
6061
62+
.. deprecated:: 0.12
63+
`estimator_` is deprecated in 0.12 and will be removed in 0.14. Use
64+
`estimators_` instead that contains the list of all K-nearest
65+
neighbors estimator used for each pair of class.
66+
67+
estimators_ : list of estimator objects of shape (n_resampled_classes - 1,)
68+
Contains the K-nearest neighbor estimator used for per of classes.
69+
70+
.. versionadded:: 0.12
71+
6172
sample_indices_ : ndarray of shape (n_new_samples,)
6273
Indices of the samples selected.
6374
@@ -138,23 +149,26 @@ def __init__(
138149
def _validate_estimator(self):
139150
"""Private function to create the NN estimator"""
140151
if self.n_neighbors is None:
141-
self.estimator_ = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
152+
estimator = KNeighborsClassifier(n_neighbors=1, n_jobs=self.n_jobs)
142153
elif isinstance(self.n_neighbors, int):
143-
self.estimator_ = KNeighborsClassifier(
154+
estimator = KNeighborsClassifier(
144155
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs
145156
)
146157
elif isinstance(self.n_neighbors, KNeighborsClassifier):
147-
self.estimator_ = clone(self.n_neighbors)
158+
estimator = clone(self.n_neighbors)
159+
160+
return estimator
148161

149162
def _fit_resample(self, X, y):
150-
self._validate_estimator()
163+
estimator = self._validate_estimator()
151164

152165
random_state = check_random_state(self.random_state)
153166
target_stats = Counter(y)
154167
class_minority = min(target_stats, key=target_stats.get)
155168

156169
idx_under = np.empty((0,), dtype=int)
157170

171+
self.estimators_ = []
158172
for target_class in np.unique(y):
159173
if target_class in self.sampling_strategy_.keys():
160174
# select a sample from the current class
@@ -177,8 +191,8 @@ def _fit_resample(self, X, y):
177191
idx_maj_extracted = np.delete(idx_maj, sel_idx_maj, axis=0)
178192
S_x = _safe_indexing(X, idx_maj_extracted)
179193
S_y = _safe_indexing(y, idx_maj_extracted)
180-
self.estimator_.fit(C_x, C_y)
181-
pred_S_y = self.estimator_.predict(S_x)
194+
self.estimators_.append(clone(estimator).fit(C_x, C_y))
195+
pred_S_y = self.estimators_[-1].predict(S_x)
182196

183197
S_misclassified_indices = np.flatnonzero(pred_S_y != S_y)
184198
idx_tmp = idx_maj_extracted[S_misclassified_indices]
@@ -199,5 +213,15 @@ def _fit_resample(self, X, y):
199213

200214
return X_cleaned, y_cleaned
201215

216+
@property
217+
def estimator_(self):
218+
"""Last fitted k-NN estimator."""
219+
warnings.warn(
220+
"`estimator_` attribute has been deprecated in 0.12 and will be "
221+
"removed in 0.14. Use `estimators_` instead.",
222+
FutureWarning,
223+
)
224+
return self.estimators_[-1]
225+
202226
def _more_tags(self):
203227
return {"sample_indices": True}

imblearn/under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pytest
8+
from sklearn.datasets import make_classification
89
from sklearn.neighbors import KNeighborsClassifier
910
from sklearn.utils._testing import assert_array_equal
1011

@@ -95,3 +96,34 @@ def test_cnn_fit_resample_with_object(n_neighbors):
9596
X_resampled, y_resampled = cnn.fit_resample(X, Y)
9697
assert_array_equal(X_resampled, X_gt)
9798
assert_array_equal(y_resampled, y_gt)
99+
100+
101+
def test_condensed_nearest_neighbour_multiclass():
102+
"""Check the validity of the fitted attributes `estimators_`."""
103+
X, y = make_classification(
104+
n_samples=1_000,
105+
n_classes=4,
106+
weights=[0.1, 0.2, 0.2, 0.5],
107+
n_clusters_per_class=1,
108+
random_state=0,
109+
)
110+
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
111+
cnn.fit_resample(X, y)
112+
113+
assert len(cnn.estimators_) == len(cnn.sampling_strategy_)
114+
other_classes = []
115+
for est in cnn.estimators_:
116+
assert est.classes_[0] == 0 # minority class
117+
assert est.classes_[1] in {1, 2, 3} # other classes
118+
other_classes.append(est.classes_[1])
119+
assert len(set(other_classes)) == len(other_classes)
120+
121+
122+
# TODO: remove in 0.14
123+
def test_condensed_nearest_neighbors_deprecation():
124+
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
125+
cnn = CondensedNearestNeighbour(random_state=RND_SEED)
126+
cnn.fit_resample(X, Y)
127+
warn_msg = "`estimator_` attribute has been deprecated"
128+
with pytest.warns(FutureWarning, match=warn_msg):
129+
cnn.estimator_

imblearn/under_sampling/_prototype_selection/tests/test_one_sided_selection.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pytest
8+
from sklearn.datasets import make_classification
89
from sklearn.neighbors import KNeighborsClassifier
910
from sklearn.utils._testing import assert_array_equal
1011

@@ -95,3 +96,34 @@ def test_oss_with_object(n_neighbors):
9596
X_resampled, y_resampled = oss.fit_resample(X, Y)
9697
assert_array_equal(X_resampled, X_gt)
9798
assert_array_equal(y_resampled, y_gt)
99+
100+
101+
def test_one_sided_selection_multiclass():
102+
"""Check the validity of the fitted attributes `estimators_`."""
103+
X, y = make_classification(
104+
n_samples=1_000,
105+
n_classes=4,
106+
weights=[0.1, 0.2, 0.2, 0.5],
107+
n_clusters_per_class=1,
108+
random_state=0,
109+
)
110+
oss = OneSidedSelection(random_state=RND_SEED)
111+
oss.fit_resample(X, y)
112+
113+
assert len(oss.estimators_) == len(oss.sampling_strategy_)
114+
other_classes = []
115+
for est in oss.estimators_:
116+
assert est.classes_[0] == 0 # minority class
117+
assert est.classes_[1] in {1, 2, 3} # other classes
118+
other_classes.append(est.classes_[1])
119+
assert len(set(other_classes)) == len(other_classes)
120+
121+
122+
# TODO: remove in 0.14
123+
def test_one_sided_selection_deprecation():
124+
"""Check that we raise a FutureWarning when accessing the parameter `estimator_`."""
125+
oss = OneSidedSelection(random_state=RND_SEED)
126+
oss.fit_resample(X, Y)
127+
warn_msg = "`estimator_` attribute has been deprecated"
128+
with pytest.warns(FutureWarning, match=warn_msg):
129+
oss.estimator_

0 commit comments

Comments
 (0)