22The :mod:`imblearn.pipeline` module implements utilities to build a
33composite estimator, as a chain of transforms, samples and estimators.
44"""
5+
56# Adapted from scikit-learn
67
78# Author: Edouard Duchesnay
1213# Christos Aridas
1314# Guillaume Lemaitre <[email protected] > 1415# License: BSD
16+ import warnings
17+ from contextlib import contextmanager
18+ from copy import deepcopy
19+
1520import sklearn
1621from sklearn import pipeline
1722from sklearn .base import clone
23+ from sklearn .exceptions import NotFittedError
1824from sklearn .utils import Bunch
1925from sklearn .utils .fixes import parse_version
2026from sklearn .utils .metaestimators import available_if
21- from sklearn .utils .validation import check_memory
27+ from sklearn .utils .validation import check_memory , check_is_fitted
2228
2329from .base import _ParamsValidationMixin
2430from .utils ._metadata_requests import (
3036 process_routing ,
3137)
3238from .utils ._param_validation import HasMethods , validate_params
33- from .utils .fixes import _fit_context
39+ from .utils .fixes import _fit_context , get_tags
3440
3541METHODS .append ("fit_resample" )
3642
4349 from sklearn .utils ._user_interface import _print_elapsed_time
4450
4551
52+ @contextmanager
53+ def _raise_or_warn_if_not_fitted (estimator ):
54+ """A context manager to make sure a NotFittedError is raised, if a sub-estimator
55+ raises the error.
56+ Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation.
57+ TODO(1.8): remove this context manager and replace with check_is_fitted.
58+ """
59+ try :
60+ yield
61+ except NotFittedError as exc :
62+ raise NotFittedError ("Pipeline is not fitted yet." ) from exc
63+
64+ # we only get here if the above didn't raise
65+ try :
66+ check_is_fitted (estimator )
67+ except NotFittedError :
68+ warnings .warn (
69+ "This Pipeline instance is not fitted yet. Call 'fit' with "
70+ "appropriate arguments before using other methods such as transform, "
71+ "predict, etc. This will raise an error in 1.8 instead of the current "
72+ "warning." ,
73+ FutureWarning ,
74+ )
75+
76+
4677class Pipeline (_ParamsValidationMixin , pipeline .Pipeline ):
4778 """Pipeline of transforms and resamples with a final estimator.
4879
@@ -456,18 +487,22 @@ def predict(self, X, **params):
456487 y_pred : ndarray
457488 Result of calling `predict` on the final estimator.
458489 """
459- Xt = X
490+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
491+ with _raise_or_warn_if_not_fitted (self ):
492+ Xt = X
460493
461- if not _routing_enabled ():
462- for _ , name , transform in self ._iter (with_final = False ):
463- Xt = transform .transform (Xt )
464- return self .steps [- 1 ][1 ].predict (Xt , ** params )
494+ if not _routing_enabled ():
495+ for _ , name , transform in self ._iter (with_final = False ):
496+ Xt = transform .transform (Xt )
497+ return self .steps [- 1 ][1 ].predict (Xt , ** params )
465498
466- # metadata routing enabled
467- routed_params = process_routing (self , "predict" , ** params )
468- for _ , name , transform in self ._iter (with_final = False ):
469- Xt = transform .transform (Xt , ** routed_params [name ].transform )
470- return self .steps [- 1 ][1 ].predict (Xt , ** routed_params [self .steps [- 1 ][0 ]].predict )
499+ # metadata routing enabled
500+ routed_params = process_routing (self , "predict" , ** params )
501+ for _ , name , transform in self ._iter (with_final = False ):
502+ Xt = transform .transform (Xt , ** routed_params [name ].transform )
503+ return self .steps [- 1 ][1 ].predict (
504+ Xt , ** routed_params [self .steps [- 1 ][0 ]].predict
505+ )
471506
472507 def _can_fit_resample (self ):
473508 return self ._final_estimator == "passthrough" or hasattr (
@@ -646,20 +681,22 @@ def predict_proba(self, X, **params):
646681 y_proba : ndarray of shape (n_samples, n_classes)
647682 Result of calling `predict_proba` on the final estimator.
648683 """
649- Xt = X
684+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
685+ with _raise_or_warn_if_not_fitted (self ):
686+ Xt = X
687+
688+ if not _routing_enabled ():
689+ for _ , name , transform in self ._iter (with_final = False ):
690+ Xt = transform .transform (Xt )
691+ return self .steps [- 1 ][1 ].predict_proba (Xt , ** params )
650692
651- if not _routing_enabled ():
693+ # metadata routing enabled
694+ routed_params = process_routing (self , "predict_proba" , ** params )
652695 for _ , name , transform in self ._iter (with_final = False ):
653- Xt = transform .transform (Xt )
654- return self .steps [- 1 ][1 ].predict_proba (Xt , ** params )
655-
656- # metadata routing enabled
657- routed_params = process_routing (self , "predict_proba" , ** params )
658- for _ , name , transform in self ._iter (with_final = False ):
659- Xt = transform .transform (Xt , ** routed_params [name ].transform )
660- return self .steps [- 1 ][1 ].predict_proba (
661- Xt , ** routed_params [self .steps [- 1 ][0 ]].predict_proba
662- )
696+ Xt = transform .transform (Xt , ** routed_params [name ].transform )
697+ return self .steps [- 1 ][1 ].predict_proba (
698+ Xt , ** routed_params [self .steps [- 1 ][0 ]].predict_proba
699+ )
663700
664701 @available_if (pipeline ._final_estimator_has ("decision_function" ))
665702 def decision_function (self , X , ** params ):
@@ -691,20 +728,23 @@ def decision_function(self, X, **params):
691728 y_score : ndarray of shape (n_samples, n_classes)
692729 Result of calling `decision_function` on the final estimator.
693730 """
694- _raise_for_params (params , self , "decision_function" )
731+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
732+ with _raise_or_warn_if_not_fitted (self ):
733+ _raise_for_params (params , self , "decision_function" )
695734
696- # not branching here since params is only available if
697- # enable_metadata_routing=True
698- routed_params = process_routing (self , "decision_function" , ** params )
735+ # not branching here since params is only available if
736+ # enable_metadata_routing=True
737+ routed_params = process_routing (self , "decision_function" , ** params )
699738
700- Xt = X
701- for _ , name , transform in self ._iter (with_final = False ):
702- Xt = transform .transform (
703- Xt , ** routed_params .get (name , {}).get ("transform" , {})
739+ Xt = X
740+ for _ , name , transform in self ._iter (with_final = False ):
741+ Xt = transform .transform (
742+ Xt , ** routed_params .get (name , {}).get ("transform" , {})
743+ )
744+ return self .steps [- 1 ][1 ].decision_function (
745+ Xt ,
746+ ** routed_params .get (self .steps [- 1 ][0 ], {}).get ("decision_function" , {}),
704747 )
705- return self .steps [- 1 ][1 ].decision_function (
706- Xt , ** routed_params .get (self .steps [- 1 ][0 ], {}).get ("decision_function" , {})
707- )
708748
709749 @available_if (pipeline ._final_estimator_has ("score_samples" ))
710750 def score_samples (self , X ):
@@ -726,10 +766,12 @@ def score_samples(self, X):
726766 y_score : ndarray of shape (n_samples,)
727767 Result of calling `score_samples` on the final estimator.
728768 """
729- Xt = X
730- for _ , _ , transformer in self ._iter (with_final = False ):
731- Xt = transformer .transform (Xt )
732- return self .steps [- 1 ][1 ].score_samples (Xt )
769+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
770+ with _raise_or_warn_if_not_fitted (self ):
771+ Xt = X
772+ for _ , _ , transformer in self ._iter (with_final = False ):
773+ Xt = transformer .transform (Xt )
774+ return self .steps [- 1 ][1 ].score_samples (Xt )
733775
734776 @available_if (pipeline ._final_estimator_has ("predict_log_proba" ))
735777 def predict_log_proba (self , X , ** params ):
@@ -773,20 +815,22 @@ def predict_log_proba(self, X, **params):
773815 y_log_proba : ndarray of shape (n_samples, n_classes)
774816 Result of calling `predict_log_proba` on the final estimator.
775817 """
776- Xt = X
818+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
819+ with _raise_or_warn_if_not_fitted (self ):
820+ Xt = X
777821
778- if not _routing_enabled ():
822+ if not _routing_enabled ():
823+ for _ , name , transform in self ._iter (with_final = False ):
824+ Xt = transform .transform (Xt )
825+ return self .steps [- 1 ][1 ].predict_log_proba (Xt , ** params )
826+
827+ # metadata routing enabled
828+ routed_params = process_routing (self , "predict_log_proba" , ** params )
779829 for _ , name , transform in self ._iter (with_final = False ):
780- Xt = transform .transform (Xt )
781- return self .steps [- 1 ][1 ].predict_log_proba (Xt , ** params )
782-
783- # metadata routing enabled
784- routed_params = process_routing (self , "predict_log_proba" , ** params )
785- for _ , name , transform in self ._iter (with_final = False ):
786- Xt = transform .transform (Xt , ** routed_params [name ].transform )
787- return self .steps [- 1 ][1 ].predict_log_proba (
788- Xt , ** routed_params [self .steps [- 1 ][0 ]].predict_log_proba
789- )
830+ Xt = transform .transform (Xt , ** routed_params [name ].transform )
831+ return self .steps [- 1 ][1 ].predict_log_proba (
832+ Xt , ** routed_params [self .steps [- 1 ][0 ]].predict_log_proba
833+ )
790834
791835 def _can_transform (self ):
792836 return self ._final_estimator == "passthrough" or hasattr (
@@ -826,15 +870,17 @@ def transform(self, X, **params):
826870 Xt : ndarray of shape (n_samples, n_transformed_features)
827871 Transformed data.
828872 """
829- _raise_for_params (params , self , "transform" )
830-
831- # not branching here since params is only available if
832- # enable_metadata_routing=True
833- routed_params = process_routing (self , "transform" , ** params )
834- Xt = X
835- for _ , name , transform in self ._iter ():
836- Xt = transform .transform (Xt , ** routed_params [name ].transform )
837- return Xt
873+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
874+ with _raise_or_warn_if_not_fitted (self ):
875+ _raise_for_params (params , self , "transform" )
876+
877+ # not branching here since params is only available if
878+ # enable_metadata_routing=True
879+ routed_params = process_routing (self , "transform" , ** params )
880+ Xt = X
881+ for _ , name , transform in self ._iter ():
882+ Xt = transform .transform (Xt , ** routed_params [name ].transform )
883+ return Xt
838884
839885 def _can_inverse_transform (self ):
840886 return all (hasattr (t , "inverse_transform" ) for _ , _ , t in self ._iter ())
@@ -869,17 +915,19 @@ def inverse_transform(self, Xt, **params):
869915 Inverse transformed data, that is, data in the original feature
870916 space.
871917 """
872- _raise_for_params (params , self , "inverse_transform" )
873-
874- # we don't have to branch here, since params is only non-empty if
875- # enable_metadata_routing=True.
876- routed_params = process_routing (self , "inverse_transform" , ** params )
877- reverse_iter = reversed (list (self ._iter ()))
878- for _ , name , transform in reverse_iter :
879- Xt = transform .inverse_transform (
880- Xt , ** routed_params [name ].inverse_transform
881- )
882- return Xt
918+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
919+ with _raise_or_warn_if_not_fitted (self ):
920+ _raise_for_params (params , self , "inverse_transform" )
921+
922+ # we don't have to branch here, since params is only non-empty if
923+ # enable_metadata_routing=True.
924+ routed_params = process_routing (self , "inverse_transform" , ** params )
925+ reverse_iter = reversed (list (self ._iter ()))
926+ for _ , name , transform in reverse_iter :
927+ Xt = transform .inverse_transform (
928+ Xt , ** routed_params [name ].inverse_transform
929+ )
930+ return Xt
883931
884932 @available_if (pipeline ._final_estimator_has ("score" ))
885933 def score (self , X , y = None , sample_weight = None , ** params ):
@@ -918,24 +966,28 @@ def score(self, X, y=None, sample_weight=None, **params):
918966 score : float
919967 Result of calling `score` on the final estimator.
920968 """
921- Xt = X
922- if not _routing_enabled ():
923- for _ , name , transform in self ._iter (with_final = False ):
924- Xt = transform .transform (Xt )
925- score_params = {}
926- if sample_weight is not None :
927- score_params ["sample_weight" ] = sample_weight
928- return self .steps [- 1 ][1 ].score (Xt , y , ** score_params )
929-
930- # metadata routing is enabled.
931- routed_params = process_routing (
932- self , "score" , sample_weight = sample_weight , ** params
933- )
969+ # TODO(1.8): Remove the context manager and use check_is_fitted(self)
970+ with _raise_or_warn_if_not_fitted (self ):
971+ Xt = X
972+ if not _routing_enabled ():
973+ for _ , name , transform in self ._iter (with_final = False ):
974+ Xt = transform .transform (Xt )
975+ score_params = {}
976+ if sample_weight is not None :
977+ score_params ["sample_weight" ] = sample_weight
978+ return self .steps [- 1 ][1 ].score (Xt , y , ** score_params )
979+
980+ # metadata routing is enabled.
981+ routed_params = process_routing (
982+ self , "score" , sample_weight = sample_weight , ** params
983+ )
934984
935- Xt = X
936- for _ , name , transform in self ._iter (with_final = False ):
937- Xt = transform .transform (Xt , ** routed_params [name ].transform )
938- return self .steps [- 1 ][1 ].score (Xt , y , ** routed_params [self .steps [- 1 ][0 ]].score )
985+ Xt = X
986+ for _ , name , transform in self ._iter (with_final = False ):
987+ Xt = transform .transform (Xt , ** routed_params [name ].transform )
988+ return self .steps [- 1 ][1 ].score (
989+ Xt , y , ** routed_params [self .steps [- 1 ][0 ]].score
990+ )
939991
940992 # TODO: once scikit-learn >= 1.4, the following function should be simplified by
941993 # calling `super().get_metadata_routing()`
0 commit comments