Skip to content

Commit 1130324

Browse files
glemaitrechkoar
andauthored
FIX remove smoothed_bootstrap and use only shrinkage param (#794)
Co-authored-by: Christos Aridas <[email protected]>
1 parent 3444430 commit 1130324

File tree

6 files changed

+83
-79
lines changed

6 files changed

+83
-79
lines changed

doc/over_sampling.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ It would also work with pandas dataframe::
8080
>>> df_resampled, y_resampled = ros.fit_resample(df_adult, y_adult)
8181
>>> df_resampled.head() # doctest: +SKIP
8282

83-
If repeating samples is an issue, the parameter `smoothed_bootstrap` can be
84-
turned to `True` to create a smoothed bootstrap. However, the original data
85-
needs to be numerical. The `shrinkage` parameter controls the dispersion of the
86-
new generated samples. We show an example illustrate that the new samples are
87-
not overlapping anymore once using a smoothed bootstrap. This ways of
88-
generating smoothed bootstrap is also known a Random Over-Sampler Examples
83+
If repeating samples is an issue, the parameter `shrinkage` allows to create a
84+
smoothed bootstrap. However, the original data needs to be numerical. The
85+
`shrinkage` parameter controls the dispersion of the new generated samples. We
86+
show an example illustrate that the new samples are not overlapping anymore
87+
once using a smoothed bootstrap. This ways of generating smoothed bootstrap is
88+
also known a Random Over-Sampling Examples
8989
(ROSE) :cite:`torelli2014rose`.
9090

9191
.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_003.png

doc/whats_new/v0.7.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ Enhancements
7474

7575
- Added an option to generate smoothed bootstrap in
7676
:class:`imblearn.over_sampling.RandomOverSampler`. It is controls by the
77-
parameters `smoothed_bootstrap` and `shrinkage`. This method is also known as
78-
Random Over-Sampling Examples (ROSE).
77+
parameter `shrinkage`. This method is also known as Random Over-Sampling
78+
Examples (ROSE).
7979
:pr:`754` by :user:`Andrea Lorenzon <andrealorenzon>` and
8080
:user:`Guillaume Lemaitre <glemaitre>`.
8181

examples/over-sampling/plot_comparison_over_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,15 @@ def plot_decision_function(X, y, clf, ax):
144144

145145
###############################################################################
146146
# By default, random over-sampling generates a bootstrap. The parameter
147-
# `smoothed_bootstrap` allows adding a small perturbation to the generated data
147+
# `shrinkage` allows adding a small perturbation to the generated data
148148
# to generate a smoothed bootstrap instead. The plot below shows the difference
149149
# between the two data generation strategies.
150150

151151
fig, axs = plt.subplots(1, 2, figsize=(15, 7))
152152
sampler = RandomOverSampler(random_state=0)
153153
plot_resampling(X, y, sampler, ax=axs[0])
154154
axs[0].set_title("RandomOverSampler with normal bootstrap")
155-
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0.2, random_state=0)
155+
sampler = RandomOverSampler(shrinkage=0.2, random_state=0)
156156
plot_resampling(X, y, sampler, ax=axs[1])
157157
axs[1].set_title("RandomOverSampler with smoothed bootstrap")
158158
fig.tight_layout()

examples/over-sampling/plot_shrinkage_effect.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@
6161
# from the majority class. Indeed, it is due to the fact that these samples
6262
# of the minority class are repeated during the bootstrap generation.
6363
#
64-
# We can set `smoothed_bootstrap=True` to add a small perturbation to the
64+
# We can set `shrinkage` to a floating value to add a small perturbation to the
6565
# samples created and therefore create a smoothed bootstrap.
66-
sampler = RandomOverSampler(smoothed_bootstrap=True, random_state=0)
66+
sampler = RandomOverSampler(shrinkage=1, random_state=0)
6767
X_res, y_res = sampler.fit_resample(X, y)
6868
Counter(y_res)
6969

@@ -81,7 +81,7 @@
8181
#
8282
# The parameter `shrinkage` allows to add more or less perturbation. Let's
8383
# add more perturbation when generating the smoothed bootstrap.
84-
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=3, random_state=0)
84+
sampler = RandomOverSampler(shrinkage=3, random_state=0)
8585
X_res, y_res = sampler.fit_resample(X, y)
8686
Counter(y_res)
8787

@@ -96,7 +96,7 @@
9696
# %%
9797
# Increasing the value of `shrinkage` will disperse the new samples. Forcing
9898
# the shrinkage to 0 will be equivalent to generating a normal bootstrap.
99-
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0, random_state=0)
99+
sampler = RandomOverSampler(shrinkage=0, random_state=0)
100100
X_res, y_res = sampler.fit_resample(X, y)
101101
Counter(y_res)
102102

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Christos Aridas
55
# License: MIT
66

7+
from collections.abc import Mapping
78
from numbers import Real
89

910
import numpy as np
@@ -37,20 +38,20 @@ class RandomOverSampler(BaseOverSampler):
3738
3839
{random_state}
3940
40-
smoothed_bootstrap : bool, default=False
41-
Whether or not to generate smoothed bootstrap samples. When this option
42-
is triggered, be aware that the data to be resampled needs to be
43-
numerical data since a Gaussian perturbation will be generated and
44-
added to the bootstrap.
41+
shrinkage : float or dict, default=None
42+
Parameter controlling the shrinkage applied to the covariance matrix.
43+
when a smoothed bootstrap is generated. The options are:
4544
46-
.. versionadded:: 0.7
45+
- if `None`, a normal bootstrap will be generated without perturbation.
46+
It is equivalent to `shrinkage=0` as well;
47+
- if a `float` is given, the shrinkage factor will be used for all
48+
classes to generate the smoothed bootstrap;
49+
- if a `dict` is given, the shrinkage factor will specific for each
50+
class. The key correspond to the targeted class and the value is
51+
the shrinkage factor.
4752
48-
shrinkage : float or dict, default=1.0
49-
Factor to shrink the covariance matrix used to generate the
50-
smoothed bootstrap. A factor could be shared by all classes by
51-
providing a floating number or different for each class over-sampled
52-
by providing a dictionary where the key are the class targeted and the
53-
value is the shrinkage factor.
53+
The value needs of the shrinkage parameter needs to be higher or equal
54+
to 0.
5455
5556
.. versionadded:: 0.7
5657
@@ -63,7 +64,7 @@ class RandomOverSampler(BaseOverSampler):
6364
6465
shrinkage_ : dict or None
6566
The per-class shrinkage factor used to generate the smoothed bootstrap
66-
sample. `None` when `smoothed_bootstrap=False`.
67+
sample. When `shrinkage=None` a normal bootstrap will be generated.
6768
6869
.. versionadded:: 0.7
6970
@@ -125,12 +126,10 @@ def __init__(
125126
*,
126127
sampling_strategy="auto",
127128
random_state=None,
128-
smoothed_bootstrap=False,
129-
shrinkage=1.0,
129+
shrinkage=None,
130130
):
131131
super().__init__(sampling_strategy=sampling_strategy)
132132
self.random_state = random_state
133-
self.smoothed_bootstrap = smoothed_bootstrap
134133
self.shrinkage = shrinkage
135134

136135
def _check_X_y(self, X, y):
@@ -148,34 +147,47 @@ def _check_X_y(self, X, y):
148147
def _fit_resample(self, X, y):
149148
random_state = check_random_state(self.random_state)
150149

151-
if self.smoothed_bootstrap:
152-
if isinstance(self.shrinkage, Real):
153-
self.shrinkage_ = {
154-
klass: self.shrinkage for klass in self.sampling_strategy_
155-
}
156-
else:
157-
missing_shrinkage_keys = (
158-
self.sampling_strategy_.keys() - self.shrinkage.keys()
150+
if isinstance(self.shrinkage, Real):
151+
self.shrinkage_ = {
152+
klass: self.shrinkage for klass in self.sampling_strategy_
153+
}
154+
elif self.shrinkage is None or isinstance(self.shrinkage, Mapping):
155+
self.shrinkage_ = self.shrinkage
156+
else:
157+
raise ValueError(
158+
f"`shrinkage` should either be a positive floating number or "
159+
f"a dictionary mapping a class to a positive floating number. "
160+
f"Got {repr(self.shrinkage)} instead."
161+
)
162+
163+
if self.shrinkage_ is not None:
164+
missing_shrinkage_keys = (
165+
self.sampling_strategy_.keys() - self.shrinkage_.keys()
166+
)
167+
if missing_shrinkage_keys:
168+
raise ValueError(
169+
f"`shrinkage` should contain a shrinkage factor for "
170+
f"each class that will be resampled. The missing "
171+
f"classes are: {repr(missing_shrinkage_keys)}"
159172
)
160-
if missing_shrinkage_keys:
173+
174+
for klass, shrink_factor in self.shrinkage_.items():
175+
if shrink_factor < 0:
161176
raise ValueError(
162-
f"`shrinkage` should contain a shrinkage factor for "
163-
f"each class that will be resampled. The missing "
164-
f"classes are: {repr(missing_shrinkage_keys)}"
177+
f"The shrinkage factor needs to be >= 0. "
178+
f"Got {shrink_factor} for class {klass}."
165179
)
166-
self.shrinkage_ = self.shrinkage
180+
167181
# smoothed bootstrap imposes to make numerical operation; we need
168182
# to be sure to have only numerical data in X
169183
try:
170184
X = check_array(X, accept_sparse=["csr", "csc"], dtype="numeric")
171185
except ValueError as exc:
172186
raise ValueError(
173-
"When smoothed_bootstrap=True, X needs to contain only "
187+
"When shrinkage is not None, X needs to contain only "
174188
"numerical data to later generate a smoothed bootstrap "
175189
"sample."
176190
) from exc
177-
else:
178-
self.shrinkage_ = None
179191

180192
X_resampled = [X.copy()]
181193
y_resampled = [y.copy()]
@@ -189,7 +201,7 @@ def _fit_resample(self, X, y):
189201
replace=True,
190202
)
191203
sample_indices = np.append(sample_indices, bootstrap_indices)
192-
if self.smoothed_bootstrap:
204+
if self.shrinkage_ is not None:
193205
# generate a smoothed bootstrap with a perturbation
194206
n_samples, n_features = X.shape
195207
smoothing_constant = (4 / ((n_features + 2) * n_samples)) ** (

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ def test_ros_init():
4343
assert ros.random_state == RND_SEED
4444

4545

46-
@pytest.mark.parametrize(
47-
"params",
48-
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
49-
)
46+
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
5047
@pytest.mark.parametrize("X_type", ["array", "dataframe"])
5148
def test_ros_fit_resample(X_type, data, params):
5249
X, Y = data
@@ -80,16 +77,13 @@ def test_ros_fit_resample(X_type, data, params):
8077
assert_allclose(X_resampled, X_gt)
8178
assert_array_equal(y_resampled, y_gt)
8279

83-
if not params["smoothed_bootstrap"]:
80+
if params["shrinkage"] is None:
8481
assert ros.shrinkage_ is None
8582
else:
8683
assert ros.shrinkage_ == {0: 0}
8784

8885

89-
@pytest.mark.parametrize(
90-
"params",
91-
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
92-
)
86+
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
9387
def test_ros_fit_resample_half(data, params):
9488
X, Y = data
9589
sampling_strategy = {0: 3, 1: 7}
@@ -115,16 +109,13 @@ def test_ros_fit_resample_half(data, params):
115109
assert_allclose(X_resampled, X_gt)
116110
assert_array_equal(y_resampled, y_gt)
117111

118-
if not params["smoothed_bootstrap"]:
112+
if params["shrinkage"] is None:
119113
assert ros.shrinkage_ is None
120114
else:
121115
assert ros.shrinkage_ == {0: 0, 1: 0}
122116

123117

124-
@pytest.mark.parametrize(
125-
"params",
126-
[{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}]
127-
)
118+
@pytest.mark.parametrize("params", [{"shrinkage": None}, {"shrinkage": 0}])
128119
def test_multiclass_fit_resample(data, params):
129120
# check the random over-sampling with a multiclass problem
130121
X, Y = data
@@ -138,7 +129,7 @@ def test_multiclass_fit_resample(data, params):
138129
assert count_y_res[1] == 5
139130
assert count_y_res[2] == 5
140131

141-
if not params["smoothed_bootstrap"]:
132+
if params["shrinkage"] is None:
142133
assert ros.shrinkage_ is None
143134
else:
144135
assert ros.shrinkage_ == {0: 0, 2: 0}
@@ -188,11 +179,8 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
188179
[["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object
189180
)
190181
y = np.array([0, 0, 1])
191-
ros = RandomOverSampler(
192-
smoothed_bootstrap=True,
193-
random_state=RND_SEED,
194-
)
195-
err_msg = "When smoothed_bootstrap=True, X needs to contain only numerical"
182+
ros = RandomOverSampler(shrinkage=1, random_state=RND_SEED)
183+
err_msg = "When shrinkage is not None, X needs to contain only numerical"
196184
with pytest.raises(ValueError, match=err_msg):
197185
ros.fit_resample(X_hetero, y)
198186

@@ -201,7 +189,7 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
201189
def test_random_over_sampler_smoothed_bootstrap(X_type, data):
202190
# check that smoothed bootstrap is working for numerical array
203191
X, y = data
204-
sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1)
192+
sampler = RandomOverSampler(shrinkage=1)
205193
X = _convert_container(X, X_type)
206194
X_res, y_res = sampler.fit_resample(X, y)
207195

@@ -217,10 +205,8 @@ def test_random_over_sampler_equivalence_shrinkage(data):
217205
# bootstrap
218206
X, y = data
219207

220-
ros_not_shrink = RandomOverSampler(
221-
smoothed_bootstrap=True, shrinkage=0, random_state=0
222-
)
223-
ros_hard_bootstrap = RandomOverSampler(smoothed_bootstrap=False, random_state=0)
208+
ros_not_shrink = RandomOverSampler(shrinkage=0, random_state=0)
209+
ros_hard_bootstrap = RandomOverSampler(shrinkage=None, random_state=0)
224210

225211
X_res_not_shrink, y_res_not_shrink = ros_not_shrink.fit_resample(X, y)
226212
X_res, y_res = ros_hard_bootstrap.fit_resample(X, y)
@@ -240,7 +226,7 @@ def test_random_over_sampler_shrinkage_behaviour(data):
240226
# should also be larger.
241227
X, y = data
242228

243-
ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1, random_state=0)
229+
ros = RandomOverSampler(shrinkage=1, random_state=0)
244230
X_res_shink_1, y_res_shrink_1 = ros.fit_resample(X, y)
245231

246232
ros.set_params(shrinkage=5)
@@ -252,12 +238,18 @@ def test_random_over_sampler_shrinkage_behaviour(data):
252238
assert disperstion_shrink_1 < disperstion_shrink_5
253239

254240

255-
def test_random_over_sampler_shrinkage_error(data):
256-
# check that we raise proper error when shrinkage do not contain the
257-
# necessary information
241+
@pytest.mark.parametrize(
242+
"shrinkage, err_msg",
243+
[
244+
({}, "`shrinkage` should contain a shrinkage factor for each class"),
245+
(-1, "The shrinkage factor needs to be >= 0"),
246+
({0: -1}, "The shrinkage factor needs to be >= 0"),
247+
([1, ], "`shrinkage` should either be a positive floating number or")
248+
]
249+
)
250+
def test_random_over_sampler_shrinkage_error(data, shrinkage, err_msg):
251+
# check the validation of the shrinkage parameter
258252
X, y = data
259-
shrinkage = {}
260-
ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=shrinkage)
261-
err_msg = "`shrinkage` should contain a shrinkage factor for each class"
253+
ros = RandomOverSampler(shrinkage=shrinkage)
262254
with pytest.raises(ValueError, match=err_msg):
263255
ros.fit_resample(X, y)

0 commit comments

Comments
 (0)