@@ -1168,34 +1168,45 @@ def get_metadata_routing(self):
11681168 router = MetadataRouter (owner = self .__class__ .__name__ )
11691169
11701170 # first we add all steps except the last one
1171- for _ , name , trans in self ._iter (with_final = False , filter_passthrough = True ):
1171+ for _ , name , trans in self ._iter (
1172+ with_final = False , filter_passthrough = True , filter_resample = False
1173+ ):
11721174 method_mapping = MethodMapping ()
11731175 # fit, fit_predict, and fit_transform call fit_transform if it
11741176 # exists, or else fit and transform
11751177 if hasattr (trans , "fit_transform" ):
1176- method_mapping .add (caller = "fit" , callee = "fit_transform" )
1177- method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1178- method_mapping .add (caller = "fit_predict" , callee = "fit_transform" )
1179- method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
1178+ (
1179+ method_mapping .add (caller = "fit" , callee = "fit_transform" )
1180+ .add (caller = "fit_transform" , callee = "fit_transform" )
1181+ .add (caller = "fit_predict" , callee = "fit_transform" )
1182+ )
11801183 else :
1181- method_mapping .add (caller = "fit" , callee = "fit" )
1182- method_mapping .add (caller = "fit" , callee = "transform" )
1183- method_mapping .add (caller = "fit_transform" , callee = "fit" )
1184- method_mapping .add (caller = "fit_transform" , callee = "transform" )
1185- method_mapping .add (caller = "fit_predict" , callee = "fit" )
1186- method_mapping .add (caller = "fit_predict" , callee = "transform" )
1187- method_mapping .add (caller = "fit_resample" , callee = "fit" )
1188- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1189-
1190- method_mapping .add (caller = "predict" , callee = "transform" )
1191- method_mapping .add (caller = "predict" , callee = "transform" )
1192- method_mapping .add (caller = "predict_proba" , callee = "transform" )
1193- method_mapping .add (caller = "decision_function" , callee = "transform" )
1194- method_mapping .add (caller = "predict_log_proba" , callee = "transform" )
1195- method_mapping .add (caller = "transform" , callee = "transform" )
1196- method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1197- method_mapping .add (caller = "score" , callee = "transform" )
1198- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1184+ (
1185+ method_mapping .add (caller = "fit" , callee = "fit" )
1186+ .add (caller = "fit" , callee = "transform" )
1187+ .add (caller = "fit_transform" , callee = "fit" )
1188+ .add (caller = "fit_transform" , callee = "transform" )
1189+ .add (caller = "fit_predict" , callee = "fit" )
1190+ .add (caller = "fit_predict" , callee = "transform" )
1191+ )
1192+
1193+ (
1194+ # handling sampler if the fit_* stage
1195+ method_mapping .add (caller = "fit" , callee = "fit_resample" )
1196+ .add (caller = "fit_transform" , callee = "fit_resample" )
1197+ .add (caller = "fit_predict" , callee = "fit_resample" )
1198+ )
1199+ (
1200+ method_mapping .add (caller = "predict" , callee = "transform" )
1201+ .add (caller = "predict" , callee = "transform" )
1202+ .add (caller = "predict_proba" , callee = "transform" )
1203+ .add (caller = "decision_function" , callee = "transform" )
1204+ .add (caller = "predict_log_proba" , callee = "transform" )
1205+ .add (caller = "transform" , callee = "transform" )
1206+ .add (caller = "inverse_transform" , callee = "inverse_transform" )
1207+ .add (caller = "score" , callee = "transform" )
1208+ .add (caller = "fit_resample" , callee = "transform" )
1209+ )
11991210
12001211 router .add (method_mapping = method_mapping , ** {name : trans })
12011212
@@ -1207,23 +1218,24 @@ def get_metadata_routing(self):
12071218 method_mapping = MethodMapping ()
12081219 if hasattr (final_est , "fit_transform" ):
12091220 method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1210- method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
12111221 else :
1222+ (
1223+ method_mapping .add (caller = "fit" , callee = "fit" ).add (
1224+ caller = "fit" , callee = "transform"
1225+ )
1226+ )
1227+ (
12121228 method_mapping .add (caller = "fit" , callee = "fit" )
1213- method_mapping .add (caller = "fit" , callee = "transform" )
1214- method_mapping .add (caller = "fit_resample" , callee = "fit" )
1215- method_mapping .add (caller = "fit_resample" , callee = "transform" )
1216-
1217- method_mapping .add (caller = "fit" , callee = "fit" )
1218- method_mapping .add (caller = "predict" , callee = "predict" )
1219- method_mapping .add (caller = "fit_predict" , callee = "fit_predict" )
1220- method_mapping .add (caller = "predict_proba" , callee = "predict_proba" )
1221- method_mapping .add (caller = "decision_function" , callee = "decision_function" )
1222- method_mapping .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1223- method_mapping .add (caller = "transform" , callee = "transform" )
1224- method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1225- method_mapping .add (caller = "score" , callee = "score" )
1226- method_mapping .add (caller = "fit_resample" , callee = "fit_resample" )
1229+ .add (caller = "predict" , callee = "predict" )
1230+ .add (caller = "fit_predict" , callee = "fit_predict" )
1231+ .add (caller = "predict_proba" , callee = "predict_proba" )
1232+ .add (caller = "decision_function" , callee = "decision_function" )
1233+ .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1234+ .add (caller = "transform" , callee = "transform" )
1235+ .add (caller = "inverse_transform" , callee = "inverse_transform" )
1236+ .add (caller = "score" , callee = "score" )
1237+ .add (caller = "fit_resample" , callee = "fit_resample" )
1238+ )
12271239
12281240 router .add (method_mapping = method_mapping , ** {final_name : final_est })
12291241 return router
0 commit comments