1515# License: BSD
1616import warnings
1717from contextlib import contextmanager
18+ from copy import deepcopy
1819
1920import sklearn
2021from sklearn import pipeline
2526 METHODS ,
2627 MetadataRouter ,
2728 MethodMapping ,
28- _raise_for_params ,
2929 _routing_enabled ,
3030 get_routing_for_object ,
31- process_routing ,
3231)
3332from sklearn .utils ._param_validation import HasMethods
3433from sklearn .utils .fixes import parse_version
3837from .utils ._sklearn_compat import (
3938 _fit_context ,
4039 _print_elapsed_time ,
40+ _raise_for_params ,
41+ get_tags ,
42+ process_routing ,
4143 validate_params ,
4244)
4345
46+ if "fit_predict" not in METHODS :
47+ METHODS .append ("fit_predict" )
4448METHODS .append ("fit_resample" )
4549
4650__all__ = ["Pipeline" , "make_pipeline" ]
@@ -245,6 +249,12 @@ class Pipeline(pipeline.Pipeline):
245249 "verbose" : ["boolean" ],
246250 }
247251
252+ def __init__ (self , steps , * , transform_input = None , memory = None , verbose = False ):
253+ self .steps = steps
254+ self .transform_input = transform_input
255+ self .memory = memory
256+ self .verbose = verbose
257+
248258 # BaseEstimator interface
249259
250260 def _validate_steps (self ):
@@ -1162,35 +1172,29 @@ def get_metadata_routing(self):
11621172 # fit, fit_predict, and fit_transform call fit_transform if it
11631173 # exists, or else fit and transform
11641174 if hasattr (trans , "fit_transform" ):
1165- (
1166- method_mapping .add (caller = "fit" , callee = "fit_transform" )
1167- .add (caller = "fit_transform" , callee = "fit_transform" )
1168- .add (caller = "fit_predict" , callee = "fit_transform" )
1169- .add (caller = "fit_resample" , callee = "fit_transform" )
1170- )
1175+ method_mapping .add (caller = "fit" , callee = "fit_transform" )
1176+ method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1177+ method_mapping .add (caller = "fit_predict" , callee = "fit_transform" )
1178+ method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
11711179 else :
1172- (
1173- method_mapping .add (caller = "fit" , callee = "fit" )
1174- .add (caller = "fit" , callee = "transform" )
1175- .add (caller = "fit_transform" , callee = "fit" )
1176- .add (caller = "fit_transform" , callee = "transform" )
1177- .add (caller = "fit_predict" , callee = "fit" )
1178- .add (caller = "fit_predict" , callee = "transform" )
1179- .add (caller = "fit_resample" , callee = "fit" )
1180- .add (caller = "fit_resample" , callee = "transform" )
1181- )
1182-
1183- (
1184- method_mapping .add (caller = "predict" , callee = "transform" )
1185- .add (caller = "predict" , callee = "transform" )
1186- .add (caller = "predict_proba" , callee = "transform" )
1187- .add (caller = "decision_function" , callee = "transform" )
1188- .add (caller = "predict_log_proba" , callee = "transform" )
1189- .add (caller = "transform" , callee = "transform" )
1190- .add (caller = "inverse_transform" , callee = "inverse_transform" )
1191- .add (caller = "score" , callee = "transform" )
1192- .add (caller = "fit_resample" , callee = "transform" )
1193- )
1180+ method_mapping .add (caller = "fit" , callee = "fit" )
1181+ method_mapping .add (caller = "fit" , callee = "transform" )
1182+ method_mapping .add (caller = "fit_transform" , callee = "fit" )
1183+ method_mapping .add (caller = "fit_transform" , callee = "transform" )
1184+ method_mapping .add (caller = "fit_predict" , callee = "fit" )
1185+ method_mapping .add (caller = "fit_predict" , callee = "transform" )
1186+ method_mapping .add (caller = "fit_resample" , callee = "fit" )
1187+ method_mapping .add (caller = "fit_resample" , callee = "transform" )
1188+
1189+ method_mapping .add (caller = "predict" , callee = "transform" )
1190+ method_mapping .add (caller = "predict" , callee = "transform" )
1191+ method_mapping .add (caller = "predict_proba" , callee = "transform" )
1192+ method_mapping .add (caller = "decision_function" , callee = "transform" )
1193+ method_mapping .add (caller = "predict_log_proba" , callee = "transform" )
1194+ method_mapping .add (caller = "transform" , callee = "transform" )
1195+ method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1196+ method_mapping .add (caller = "score" , callee = "transform" )
1197+ method_mapping .add (caller = "fit_resample" , callee = "transform" )
11941198
11951199 router .add (method_mapping = method_mapping , ** {name : trans })
11961200
@@ -1201,30 +1205,24 @@ def get_metadata_routing(self):
12011205 # then we add the last step
12021206 method_mapping = MethodMapping ()
12031207 if hasattr (final_est , "fit_transform" ):
1204- (
1205- method_mapping .add (caller = "fit_transform" , callee = "fit_transform" ).add (
1206- caller = "fit_resample" , callee = "fit_transform"
1207- )
1208- )
1208+ method_mapping .add (caller = "fit_transform" , callee = "fit_transform" )
1209+ method_mapping .add (caller = "fit_resample" , callee = "fit_transform" )
12091210 else :
1210- (
1211- method_mapping .add (caller = "fit" , callee = "fit" )
1212- .add (caller = "fit" , callee = "transform" )
1213- .add (caller = "fit_resample" , callee = "fit" )
1214- .add (caller = "fit_resample" , callee = "transform" )
1215- )
1216- (
12171211 method_mapping .add (caller = "fit" , callee = "fit" )
1218- .add (caller = "predict" , callee = "predict" )
1219- .add (caller = "fit_predict" , callee = "fit_predict" )
1220- .add (caller = "predict_proba" , callee = "predict_proba" )
1221- .add (caller = "decision_function" , callee = "decision_function" )
1222- .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1223- .add (caller = "transform" , callee = "transform" )
1224- .add (caller = "inverse_transform" , callee = "inverse_transform" )
1225- .add (caller = "score" , callee = "score" )
1226- .add (caller = "fit_resample" , callee = "fit_resample" )
1227- )
1212+ method_mapping .add (caller = "fit" , callee = "transform" )
1213+ method_mapping .add (caller = "fit_resample" , callee = "fit" )
1214+ method_mapping .add (caller = "fit_resample" , callee = "transform" )
1215+
1216+ method_mapping .add (caller = "fit" , callee = "fit" )
1217+ method_mapping .add (caller = "predict" , callee = "predict" )
1218+ method_mapping .add (caller = "fit_predict" , callee = "fit_predict" )
1219+ method_mapping .add (caller = "predict_proba" , callee = "predict_proba" )
1220+ method_mapping .add (caller = "decision_function" , callee = "decision_function" )
1221+ method_mapping .add (caller = "predict_log_proba" , callee = "predict_log_proba" )
1222+ method_mapping .add (caller = "transform" , callee = "transform" )
1223+ method_mapping .add (caller = "inverse_transform" , callee = "inverse_transform" )
1224+ method_mapping .add (caller = "score" , callee = "score" )
1225+ method_mapping .add (caller = "fit_resample" , callee = "fit_resample" )
12281226
12291227 router .add (method_mapping = method_mapping , ** {final_name : final_est })
12301228 return router
@@ -1258,6 +1256,67 @@ def _check_method_params(self, method, props, **kwargs):
12581256 fit_params_steps [step ]["fit_predict" ][param ] = pval
12591257 return fit_params_steps
12601258
1259+ def __sklearn_is_fitted__ (self ):
1260+ """Indicate whether pipeline has been fit.
1261+
1262+ This is done by checking whether the last non-`passthrough` step of the
1263+ pipeline is fitted.
1264+
1265+ An empty pipeline is considered fitted.
1266+ """
1267+
1268+ # First find the last step that is not 'passthrough'
1269+ last_step = None
1270+ for _ , estimator in reversed (self .steps ):
1271+ if estimator != "passthrough" :
1272+ last_step = estimator
1273+ break
1274+
1275+ if last_step is None :
1276+ # All steps are 'passthrough', so the pipeline is considered fitted
1277+ return True
1278+
1279+ try :
1280+ # check if the last step of the pipeline is fitted
1281+ # we only check the last step since if the last step is fit, it
1282+ # means the previous steps should also be fit. This is faster than
1283+ # checking if every step of the pipeline is fit.
1284+ check_is_fitted (last_step )
1285+ return True
1286+ except NotFittedError :
1287+ return False
1288+
1289+ def __sklearn_tags__ (self ):
1290+ tags = super ().__sklearn_tags__ ()
1291+
1292+ if not self .steps :
1293+ return tags
1294+
1295+ try :
1296+ if self .steps [0 ][1 ] is not None and self .steps [0 ][1 ] != "passthrough" :
1297+ tags .input_tags .pairwise = get_tags (
1298+ self .steps [0 ][1 ]
1299+ ).input_tags .pairwise
1300+ except (ValueError , AttributeError , TypeError ):
1301+ # This happens when the `steps` is not a list of (name, estimator)
1302+ # tuples and `fit` is not called yet to validate the steps.
1303+ pass
1304+
1305+ try :
1306+ if self .steps [- 1 ][1 ] is not None and self .steps [- 1 ][1 ] != "passthrough" :
1307+ last_step_tags = get_tags (self .steps [- 1 ][1 ])
1308+ tags .estimator_type = last_step_tags .estimator_type
1309+ tags .target_tags .multi_output = last_step_tags .target_tags .multi_output
1310+ tags .classifier_tags = deepcopy (last_step_tags .classifier_tags )
1311+ tags .regressor_tags = deepcopy (last_step_tags .regressor_tags )
1312+ tags .transformer_tags = deepcopy (last_step_tags .transformer_tags )
1313+ except (ValueError , AttributeError , TypeError ):
1314+ # This happens when the `steps` is not a list of (name, estimator)
1315+ # tuples and `fit` is not called yet to validate the steps.
1316+ pass
1317+
1318+ return tags
1319+
12611320
12621321def _fit_resample_one (sampler , X , y , message_clsname = "" , message = None , params = None ):
12631322 with _print_elapsed_time (message_clsname , message ):
0 commit comments