Skip to content

Commit a9ee326

Browse files
committed
iter
1 parent 7f58be7 commit a9ee326

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

imblearn/misc.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Miscellaneous samplers objects."""
2+
3+
# Authors: Guillaume Lemaitre <[email protected]>
4+
# License: MIT
5+
6+
import logging
7+
8+
from sklearn.utils import check_X_y
9+
from sklearn.utils.validation import check_is_fitted
10+
11+
from .base import SamplerMixin
12+
from .utils import check_target_type, hash_X_y
13+
14+
15+
def _identity(X, y):
16+
return X, y
17+
18+
19+
class FunctionSampler(SamplerMixin):
20+
"""Construct a sampler from calling an arbitrary callable.
21+
22+
Read more in the :ref:`User Guide <function_sampler>`.
23+
24+
Parameters
25+
----------
26+
func : callable or None,
27+
The callable to use for the transformation. This will be passed the
28+
same arguments as transform, with args and kwargs forwarded. If func is
29+
None, then func will be the identity function.
30+
31+
"""
32+
33+
def __init__(self, func=None, accept_sparse=True, kw_args=None,
34+
random_state=None):
35+
self.func = func
36+
self.accept_sparse = accept_sparse
37+
self.kw_args = kw_args
38+
self.random_state = random_state
39+
self.logger = logging.getLogger(__name__)
40+
41+
def _check_X_y(self, X, y):
42+
if self.accept_sparse:
43+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
44+
else:
45+
X, y = check_X_y(X, y, accept_sparse=False)
46+
y = check_target_type(y)
47+
48+
return X, y
49+
50+
def fit(self, X, y):
51+
print(self.accept_sparse)
52+
X, y = self._check_X_y(X, y)
53+
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
54+
# when using a sampler, ratio_ is supposed to exist after fit
55+
self.ratio_ = 'is_fitted'
56+
57+
return self
58+
59+
def _sample(self, X, y, func, kw_args):
60+
X, y = self._check_X_y(X, y)
61+
check_is_fitted(self, 'ratio_')
62+
X_hash, y_hash = hash_X_y(X, y)
63+
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
64+
raise RuntimeError("X and y need to be same array earlier fitted.")
65+
66+
if func is None:
67+
func = _identity
68+
69+
return func(X, y, **(kw_args if kw_args else {}))
70+
71+
def sample(self, X, y):
72+
return self._sample(X, y, func=self.func, kw_args=self.kw_args)

imblearn/tests/test_misc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Test for miscellaneous samplers objects."""
2+
3+
# Authors: Guillaume Lemaitre <[email protected]>
4+
# License: MIT

0 commit comments

Comments
 (0)