Skip to content

Commit d1929a8

Browse files
committed
FIX forward properly the metadata with the pipeline
1 parent 2d65471 commit d1929a8

File tree

3 files changed

+92
-49
lines changed

3 files changed

+92
-49
lines changed

imblearn/base.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from sklearn.base import BaseEstimator, OneToOneFeatureMixin
1111
from sklearn.preprocessing import label_binarize
12-
from sklearn.utils._metadata_requests import METHODS
12+
from sklearn.utils._metadata_requests import METHODS, SIMPLE_METHODS
1313
from sklearn.utils.multiclass import check_classification_targets
1414

1515
from .utils import check_sampling_strategy, check_target_type
@@ -21,6 +21,7 @@
2121
if "fit_transform" not in METHODS:
2222
METHODS.append("fit_transform")
2323
METHODS.append("fit_resample")
24+
SIMPLE_METHODS.append("fit_resample")
2425

2526

2627
class SamplerMixin(metaclass=ABCMeta):
@@ -33,7 +34,7 @@ class SamplerMixin(metaclass=ABCMeta):
3334
_estimator_type = "sampler"
3435

3536
@_fit_context(prefer_skip_nested_validation=True)
36-
def fit(self, X, y):
37+
def fit(self, X, y, **params):
3738
"""Check inputs and statistics of the sampler.
3839
3940
You should use ``fit_resample`` in all cases.
@@ -47,6 +48,9 @@ def fit(self, X, y):
4748
y : array-like of shape (n_samples,)
4849
Target array.
4950
51+
**params : dict
52+
Extra parameters to use by the sampler.
53+
5054
Returns
5155
-------
5256
self : object
@@ -58,7 +62,8 @@ def fit(self, X, y):
5862
)
5963
return self
6064

61-
def fit_resample(self, X, y):
65+
@_fit_context(prefer_skip_nested_validation=True)
66+
def fit_resample(self, X, y, **params):
6267
"""Resample the dataset.
6368
6469
Parameters
@@ -70,6 +75,9 @@ def fit_resample(self, X, y):
7075
y : array-like of shape (n_samples,)
7176
Corresponding label for each sample in X.
7277
78+
**params : dict
79+
Extra parameters to use by the sampler.
80+
7381
Returns
7482
-------
7583
X_resampled : {array-like, dataframe, sparse matrix} of shape \
@@ -87,7 +95,7 @@ def fit_resample(self, X, y):
8795
self.sampling_strategy, y, self._sampling_type
8896
)
8997

90-
output = self._fit_resample(X, y)
98+
output = self._fit_resample(X, y, **params)
9199

92100
y_ = (
93101
label_binarize(output[1], classes=np.unique(y)) if binarize_y else output[1]
@@ -97,7 +105,7 @@ def fit_resample(self, X, y):
97105
return (X_, y_) if len(output) == 2 else (X_, y_, output[2])
98106

99107
@abstractmethod
100-
def _fit_resample(self, X, y):
108+
def _fit_resample(self, X, y, **params):
101109
"""Base method defined in each sampler to defined the sampling
102110
strategy.
103111
@@ -109,6 +117,9 @@ def _fit_resample(self, X, y):
109117
y : array-like of shape (n_samples,)
110118
Corresponding label for each sample in X.
111119
120+
**params : dict
121+
Extra parameters to use by the sampler.
122+
112123
Returns
113124
-------
114125
X_resampled : {ndarray, sparse matrix} of shape \
@@ -139,7 +150,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
139150
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
140151
return X, y, binarize_y
141152

142-
def fit(self, X, y):
153+
def fit(self, X, y, **params):
143154
"""Check inputs and statistics of the sampler.
144155
145156
You should use ``fit_resample`` in all cases.
@@ -158,10 +169,9 @@ def fit(self, X, y):
158169
self : object
159170
Return the instance itself.
160171
"""
161-
self._validate_params()
162-
return super().fit(X, y)
172+
return super().fit(X, y, **params)
163173

164-
def fit_resample(self, X, y):
174+
def fit_resample(self, X, y, **params):
165175
"""Resample the dataset.
166176
167177
Parameters
@@ -182,8 +192,7 @@ def fit_resample(self, X, y):
182192
y_resampled : array-like of shape (n_samples_new,)
183193
The corresponding label of `X_resampled`.
184194
"""
185-
self._validate_params()
186-
return super().fit_resample(X, y)
195+
return super().fit_resample(X, y, **params)
187196

188197
def _more_tags(self):
189198
return {"X_types": ["2darray", "sparse", "dataframe"]}

imblearn/pipeline.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,34 +1168,45 @@ def get_metadata_routing(self):
11681168
router = MetadataRouter(owner=self.__class__.__name__)
11691169

11701170
# first we add all steps except the last one
1171-
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
1171+
for _, name, trans in self._iter(
1172+
with_final=False, filter_passthrough=True, filter_resample=False
1173+
):
11721174
method_mapping = MethodMapping()
11731175
# fit, fit_predict, and fit_transform call fit_transform if it
11741176
# exists, or else fit and transform
11751177
if hasattr(trans, "fit_transform"):
1176-
method_mapping.add(caller="fit", callee="fit_transform")
1177-
method_mapping.add(caller="fit_transform", callee="fit_transform")
1178-
method_mapping.add(caller="fit_predict", callee="fit_transform")
1179-
method_mapping.add(caller="fit_resample", callee="fit_transform")
1178+
(
1179+
method_mapping.add(caller="fit", callee="fit_transform")
1180+
.add(caller="fit_transform", callee="fit_transform")
1181+
.add(caller="fit_predict", callee="fit_transform")
1182+
)
11801183
else:
1181-
method_mapping.add(caller="fit", callee="fit")
1182-
method_mapping.add(caller="fit", callee="transform")
1183-
method_mapping.add(caller="fit_transform", callee="fit")
1184-
method_mapping.add(caller="fit_transform", callee="transform")
1185-
method_mapping.add(caller="fit_predict", callee="fit")
1186-
method_mapping.add(caller="fit_predict", callee="transform")
1187-
method_mapping.add(caller="fit_resample", callee="fit")
1188-
method_mapping.add(caller="fit_resample", callee="transform")
1189-
1190-
method_mapping.add(caller="predict", callee="transform")
1191-
method_mapping.add(caller="predict", callee="transform")
1192-
method_mapping.add(caller="predict_proba", callee="transform")
1193-
method_mapping.add(caller="decision_function", callee="transform")
1194-
method_mapping.add(caller="predict_log_proba", callee="transform")
1195-
method_mapping.add(caller="transform", callee="transform")
1196-
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1197-
method_mapping.add(caller="score", callee="transform")
1198-
method_mapping.add(caller="fit_resample", callee="transform")
1184+
(
1185+
method_mapping.add(caller="fit", callee="fit")
1186+
.add(caller="fit", callee="transform")
1187+
.add(caller="fit_transform", callee="fit")
1188+
.add(caller="fit_transform", callee="transform")
1189+
.add(caller="fit_predict", callee="fit")
1190+
.add(caller="fit_predict", callee="transform")
1191+
)
1192+
1193+
(
1194+
# handling sampler if the fit_* stage
1195+
method_mapping.add(caller="fit", callee="fit_resample")
1196+
.add(caller="fit_transform", callee="fit_resample")
1197+
.add(caller="fit_predict", callee="fit_resample")
1198+
)
1199+
(
1200+
method_mapping.add(caller="predict", callee="transform")
1201+
.add(caller="predict", callee="transform")
1202+
.add(caller="predict_proba", callee="transform")
1203+
.add(caller="decision_function", callee="transform")
1204+
.add(caller="predict_log_proba", callee="transform")
1205+
.add(caller="transform", callee="transform")
1206+
.add(caller="inverse_transform", callee="inverse_transform")
1207+
.add(caller="score", callee="transform")
1208+
.add(caller="fit_resample", callee="transform")
1209+
)
11991210

12001211
router.add(method_mapping=method_mapping, **{name: trans})
12011212

@@ -1207,23 +1218,24 @@ def get_metadata_routing(self):
12071218
method_mapping = MethodMapping()
12081219
if hasattr(final_est, "fit_transform"):
12091220
method_mapping.add(caller="fit_transform", callee="fit_transform")
1210-
method_mapping.add(caller="fit_resample", callee="fit_transform")
12111221
else:
1222+
(
1223+
method_mapping.add(caller="fit", callee="fit").add(
1224+
caller="fit", callee="transform"
1225+
)
1226+
)
1227+
(
12121228
method_mapping.add(caller="fit", callee="fit")
1213-
method_mapping.add(caller="fit", callee="transform")
1214-
method_mapping.add(caller="fit_resample", callee="fit")
1215-
method_mapping.add(caller="fit_resample", callee="transform")
1216-
1217-
method_mapping.add(caller="fit", callee="fit")
1218-
method_mapping.add(caller="predict", callee="predict")
1219-
method_mapping.add(caller="fit_predict", callee="fit_predict")
1220-
method_mapping.add(caller="predict_proba", callee="predict_proba")
1221-
method_mapping.add(caller="decision_function", callee="decision_function")
1222-
method_mapping.add(caller="predict_log_proba", callee="predict_log_proba")
1223-
method_mapping.add(caller="transform", callee="transform")
1224-
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1225-
method_mapping.add(caller="score", callee="score")
1226-
method_mapping.add(caller="fit_resample", callee="fit_resample")
1229+
.add(caller="predict", callee="predict")
1230+
.add(caller="fit_predict", callee="fit_predict")
1231+
.add(caller="predict_proba", callee="predict_proba")
1232+
.add(caller="decision_function", callee="decision_function")
1233+
.add(caller="predict_log_proba", callee="predict_log_proba")
1234+
.add(caller="transform", callee="transform")
1235+
.add(caller="inverse_transform", callee="inverse_transform")
1236+
.add(caller="score", callee="score")
1237+
.add(caller="fit_resample", callee="fit_resample")
1238+
)
12271239

12281240
router.add(method_mapping=method_mapping, **{final_name: final_est})
12291241
return router

imblearn/tests/test_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from sklearn.utils.fixes import parse_version
3636

37+
from imblearn.base import BaseSampler
3738
from imblearn.datasets import make_imbalance
3839
from imblearn.pipeline import Pipeline, make_pipeline
3940
from imblearn.under_sampling import EditedNearestNeighbours as ENN
@@ -1495,3 +1496,24 @@ def test_transform_input_sklearn_version():
14951496

14961497
# end of transform_input tests
14971498
# =============================
1499+
1500+
1501+
def test_metadata_routing_with_sampler():
1502+
"""Check that we can use a sampler with metadata routing."""
1503+
X, y = make_classification()
1504+
cost_matrix = np.random.rand(X.shape[0], 2, 2)
1505+
1506+
class CostSensitiveSampler(BaseSampler):
1507+
def fit_resample(self, X, y, cost_matrix=None):
1508+
return self._fit_resample(X, y, cost_matrix=cost_matrix)
1509+
1510+
def _fit_resample(self, X, y, cost_matrix=None):
1511+
self.cost_matrix_ = cost_matrix
1512+
return X, y
1513+
1514+
with config_context(enable_metadata_routing=True):
1515+
sampler = CostSensitiveSampler().set_fit_resample_request(cost_matrix=True)
1516+
pipeline = Pipeline([("sampler", sampler), ("model", LogisticRegression())])
1517+
pipeline.fit(X, y, cost_matrix=cost_matrix)
1518+
1519+
assert_allclose(pipeline[0].cost_matrix_, cost_matrix)

0 commit comments

Comments
 (0)