Skip to content

Commit 9f3872d

Browse files
authored
FIX make sure that FunctionSampler will bypass validation in fit (#790)
1 parent 0b48def commit 9f3872d

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

doc/whats_new/v0.7.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Bug fixes
5454
the targeted class.
5555
:pr:`769` by :user:`Guillaume Lemaitre <glemaitre>`.
5656

57+
- Fix a bug in :class:`imblearn.FunctionSampler` where validation was performed
58+
even with `validate=False` when calling `fit`.
59+
:pr:`790` by :user:`Guillaume Lemaitre <glemaitre>`.
60+
5761
Enhancements
5862
............
5963

imblearn/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,38 @@ def __init__(self, *, func=None, accept_sparse=True, kw_args=None,
220220
self.kw_args = kw_args
221221
self.validate = validate
222222

223+
def fit(self, X, y):
224+
"""Check inputs and statistics of the sampler.
225+
226+
You should use ``fit_resample`` in all cases.
227+
228+
Parameters
229+
----------
230+
X : {array-like, dataframe, sparse matrix} of shape \
231+
(n_samples, n_features)
232+
Data array.
233+
234+
y : array-like of shape (n_samples,)
235+
Target array.
236+
237+
Returns
238+
-------
239+
self : object
240+
Return the instance itself.
241+
"""
242+
# we need to overwrite SamplerMixin.fit to bypass the validation
243+
if self.validate:
244+
check_classification_targets(y)
245+
X, y, _ = self._check_X_y(
246+
X, y, accept_sparse=self.accept_sparse
247+
)
248+
249+
self.sampling_strategy_ = check_sampling_strategy(
250+
self.sampling_strategy, y, self._sampling_type
251+
)
252+
253+
return self
254+
223255
def fit_resample(self, X, y):
224256
"""Resample the dataset.
225257

imblearn/tests/test_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,18 @@ def dummy_sampler(X, y):
9494
y_pred = pipeline.fit(X, y).predict(X)
9595

9696
assert type_of_target(y_pred) == 'continuous'
97+
98+
99+
def test_function_resampler_fit():
100+
# Check that the validation is bypass when calling `fit`
101+
# Non-regression test for:
102+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/782
103+
X = np.array([[1, np.nan], [2, 3], [np.inf, 4]])
104+
y = np.array([0, 1, 1])
105+
106+
def func(X, y):
107+
return X[:1], y[:1]
108+
109+
sampler = FunctionSampler(func=func, validate=False)
110+
sampler.fit(X, y)
111+
sampler.fit_resample(X, y)

0 commit comments

Comments
 (0)