|
| 1 | +"""Miscellaneous samplers objects.""" |
| 2 | + |
| 3 | +# Authors: Guillaume Lemaitre <[email protected]> |
| 4 | +# License: MIT |
| 5 | + |
| 6 | +import logging |
| 7 | + |
| 8 | +from sklearn.utils import check_X_y |
| 9 | +from sklearn.utils.validation import check_is_fitted |
| 10 | + |
| 11 | +from .base import SamplerMixin |
| 12 | +from .utils import check_target_type, hash_X_y |
| 13 | + |
| 14 | + |
| 15 | +def _identity(X, y): |
| 16 | + return X, y |
| 17 | + |
| 18 | + |
| 19 | +class FunctionSampler(SamplerMixin): |
| 20 | + """Construct a sampler from calling an arbitrary callable. |
| 21 | +
|
| 22 | + Read more in the :ref:`User Guide <function_sampler>`. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + func : callable or None, |
| 27 | + The callable to use for the transformation. This will be passed the |
| 28 | + same arguments as transform, with args and kwargs forwarded. If func is |
| 29 | + None, then func will be the identity function. |
| 30 | +
|
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, func=None, accept_sparse=True, kw_args=None, |
| 34 | + random_state=None): |
| 35 | + self.func = func |
| 36 | + self.accept_sparse = accept_sparse |
| 37 | + self.kw_args = kw_args |
| 38 | + self.random_state = random_state |
| 39 | + self.logger = logging.getLogger(__name__) |
| 40 | + |
| 41 | + def _check_X_y(self, X, y): |
| 42 | + if self.accept_sparse: |
| 43 | + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) |
| 44 | + else: |
| 45 | + X, y = check_X_y(X, y, accept_sparse=False) |
| 46 | + y = check_target_type(y) |
| 47 | + |
| 48 | + return X, y |
| 49 | + |
| 50 | + def fit(self, X, y): |
| 51 | + print(self.accept_sparse) |
| 52 | + X, y = self._check_X_y(X, y) |
| 53 | + self.X_hash_, self.y_hash_ = hash_X_y(X, y) |
| 54 | + # when using a sampler, ratio_ is supposed to exist after fit |
| 55 | + self.ratio_ = 'is_fitted' |
| 56 | + |
| 57 | + return self |
| 58 | + |
| 59 | + def _sample(self, X, y, func, kw_args): |
| 60 | + X, y = self._check_X_y(X, y) |
| 61 | + check_is_fitted(self, 'ratio_') |
| 62 | + X_hash, y_hash = hash_X_y(X, y) |
| 63 | + if self.X_hash_ != X_hash or self.y_hash_ != y_hash: |
| 64 | + raise RuntimeError("X and y need to be same array earlier fitted.") |
| 65 | + |
| 66 | + if func is None: |
| 67 | + func = _identity |
| 68 | + |
| 69 | + return func(X, y, **(kw_args if kw_args else {})) |
| 70 | + |
| 71 | + def sample(self, X, y): |
| 72 | + return self._sample(X, y, func=self.func, kw_args=self.kw_args) |
0 commit comments