Skip to content

Commit e310fa7

Browse files
committed
Allow kwargs in FunctionSampler
1 parent 4151711 commit e310fa7

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

imblearn/misc.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,10 @@ class FunctionSampler(SamplerMixin):
3030
3131
"""
3232

33-
def __init__(self, func=None, accept_sparse=True, kw_args=None,
34-
random_state=None):
33+
def __init__(self, func=None, accept_sparse=True, kw_args=None):
3534
self.func = func
3635
self.accept_sparse = accept_sparse
3736
self.kw_args = kw_args
38-
self.random_state = random_state
3937
self.logger = logging.getLogger(__name__)
4038

4139
def _check_X_y(self, X, y):
@@ -48,15 +46,14 @@ def _check_X_y(self, X, y):
4846
return X, y
4947

5048
def fit(self, X, y):
51-
print(self.accept_sparse)
5249
X, y = self._check_X_y(X, y)
5350
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
5451
# when using a sampler, ratio_ is supposed to exist after fit
5552
self.ratio_ = 'is_fitted'
5653

5754
return self
5855

59-
def _sample(self, X, y, func, kw_args):
56+
def _sample(self, X, y, func=None, kw_args=None):
6057
X, y = self._check_X_y(X, y)
6158
check_is_fitted(self, 'ratio_')
6259
X_hash, y_hash = hash_X_y(X, y)
@@ -66,7 +63,7 @@ def _sample(self, X, y, func, kw_args):
6663
if func is None:
6764
func = _identity
6865

69-
return func(X, y, **(kw_args if kw_args else {}))
66+
return func(X, y, **(kw_args if self.kw_args else {}))
7067

7168
def sample(self, X, y):
7269
return self._sample(X, y, func=self.func, kw_args=self.kw_args)

imblearn/tests/test_misc.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,45 @@
33
# Authors: Guillaume Lemaitre <[email protected]>
44
# License: MIT
55

6+
from sklearn.datasets import load_iris
7+
from sklearn.utils.testing import assert_array_equal
8+
9+
from imblearn.datasets import make_imbalance
610
from imblearn.misc import FunctionSampler
11+
from imblearn.under_sampling import RandomUnderSampler
12+
13+
iris = load_iris()
14+
X, y = make_imbalance(iris.data, iris.target, ratio={0: 10, 1: 25},
15+
random_state=0)
16+
17+
18+
def test_function_sampler_identity():
19+
sampler = FunctionSampler()
20+
X_res, y_res = sampler.fit_sample(X, y)
21+
assert_array_equal(X_res, X)
22+
assert_array_equal(y_res, y)
23+
24+
25+
def test_function_sampler_func():
26+
27+
def func(X, y):
28+
return X[:10], y[:10]
29+
30+
sampler = FunctionSampler(func=func)
31+
X_res, y_res = sampler.fit_sample(X, y)
32+
assert_array_equal(X_res, X[:10])
33+
assert_array_equal(y_res, y[:10])
34+
35+
36+
def test_function_sampler_func_kwargs():
737

38+
def func(X, y, ratio, random_state):
39+
rus = RandomUnderSampler(ratio=ratio, random_state=random_state)
40+
return rus.fit_sample(X, y)
841

9-
def function_sampler_identity():
10-
sampler = FunctionSampler(1)
42+
sampler = FunctionSampler(func=func, kw_args={'ratio': 'auto',
43+
'random_state': 0})
44+
X_res, y_res = sampler.fit_sample(X, y)
45+
X_res_2, y_res_2 = RandomUnderSampler(random_state=0).fit_sample(X, y)
46+
assert_array_equal(X_res, X_res_2)
47+
assert_array_equal(y_res, y_res_2)

0 commit comments

Comments
 (0)