Skip to content

Commit c457b4a

Browse files
committed
iter
1 parent 468f925 commit c457b4a

File tree

2 files changed

+148
-89
lines changed

2 files changed

+148
-89
lines changed

imblearn/pipeline.py

Lines changed: 140 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
The :mod:`imblearn.pipeline` module implements utilities to build a
33
composite estimator, as a chain of transforms, samples and estimators.
44
"""
5+
56
# Adapted from scikit-learn
67

78
# Author: Edouard Duchesnay
@@ -12,13 +13,18 @@
1213
# Christos Aridas
1314
# Guillaume Lemaitre <[email protected]>
1415
# License: BSD
16+
import warnings
17+
from contextlib import contextmanager
18+
from copy import deepcopy
19+
1520
import sklearn
1621
from sklearn import pipeline
1722
from sklearn.base import clone
23+
from sklearn.exceptions import NotFittedError
1824
from sklearn.utils import Bunch
1925
from sklearn.utils.fixes import parse_version
2026
from sklearn.utils.metaestimators import available_if
21-
from sklearn.utils.validation import check_memory
27+
from sklearn.utils.validation import check_memory, check_is_fitted
2228

2329
from .base import _ParamsValidationMixin
2430
from .utils._metadata_requests import (
@@ -30,7 +36,7 @@
3036
process_routing,
3137
)
3238
from .utils._param_validation import HasMethods, validate_params
33-
from .utils.fixes import _fit_context
39+
from .utils.fixes import _fit_context, get_tags
3440

3541
METHODS.append("fit_resample")
3642

@@ -43,6 +49,31 @@
4349
from sklearn.utils._user_interface import _print_elapsed_time
4450

4551

52+
@contextmanager
53+
def _raise_or_warn_if_not_fitted(estimator):
54+
"""A context manager to make sure a NotFittedError is raised, if a sub-estimator
55+
raises the error.
56+
Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation.
57+
TODO(1.8): remove this context manager and replace with check_is_fitted.
58+
"""
59+
try:
60+
yield
61+
except NotFittedError as exc:
62+
raise NotFittedError("Pipeline is not fitted yet.") from exc
63+
64+
# we only get here if the above didn't raise
65+
try:
66+
check_is_fitted(estimator)
67+
except NotFittedError:
68+
warnings.warn(
69+
"This Pipeline instance is not fitted yet. Call 'fit' with "
70+
"appropriate arguments before using other methods such as transform, "
71+
"predict, etc. This will raise an error in 1.8 instead of the current "
72+
"warning.",
73+
FutureWarning,
74+
)
75+
76+
4677
class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
4778
"""Pipeline of transforms and resamples with a final estimator.
4879
@@ -456,18 +487,22 @@ def predict(self, X, **params):
456487
y_pred : ndarray
457488
Result of calling `predict` on the final estimator.
458489
"""
459-
Xt = X
490+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
491+
with _raise_or_warn_if_not_fitted(self):
492+
Xt = X
460493

461-
if not _routing_enabled():
462-
for _, name, transform in self._iter(with_final=False):
463-
Xt = transform.transform(Xt)
464-
return self.steps[-1][1].predict(Xt, **params)
494+
if not _routing_enabled():
495+
for _, name, transform in self._iter(with_final=False):
496+
Xt = transform.transform(Xt)
497+
return self.steps[-1][1].predict(Xt, **params)
465498

466-
# metadata routing enabled
467-
routed_params = process_routing(self, "predict", **params)
468-
for _, name, transform in self._iter(with_final=False):
469-
Xt = transform.transform(Xt, **routed_params[name].transform)
470-
return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
499+
# metadata routing enabled
500+
routed_params = process_routing(self, "predict", **params)
501+
for _, name, transform in self._iter(with_final=False):
502+
Xt = transform.transform(Xt, **routed_params[name].transform)
503+
return self.steps[-1][1].predict(
504+
Xt, **routed_params[self.steps[-1][0]].predict
505+
)
471506

472507
def _can_fit_resample(self):
473508
return self._final_estimator == "passthrough" or hasattr(
@@ -646,20 +681,22 @@ def predict_proba(self, X, **params):
646681
y_proba : ndarray of shape (n_samples, n_classes)
647682
Result of calling `predict_proba` on the final estimator.
648683
"""
649-
Xt = X
684+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
685+
with _raise_or_warn_if_not_fitted(self):
686+
Xt = X
687+
688+
if not _routing_enabled():
689+
for _, name, transform in self._iter(with_final=False):
690+
Xt = transform.transform(Xt)
691+
return self.steps[-1][1].predict_proba(Xt, **params)
650692

651-
if not _routing_enabled():
693+
# metadata routing enabled
694+
routed_params = process_routing(self, "predict_proba", **params)
652695
for _, name, transform in self._iter(with_final=False):
653-
Xt = transform.transform(Xt)
654-
return self.steps[-1][1].predict_proba(Xt, **params)
655-
656-
# metadata routing enabled
657-
routed_params = process_routing(self, "predict_proba", **params)
658-
for _, name, transform in self._iter(with_final=False):
659-
Xt = transform.transform(Xt, **routed_params[name].transform)
660-
return self.steps[-1][1].predict_proba(
661-
Xt, **routed_params[self.steps[-1][0]].predict_proba
662-
)
696+
Xt = transform.transform(Xt, **routed_params[name].transform)
697+
return self.steps[-1][1].predict_proba(
698+
Xt, **routed_params[self.steps[-1][0]].predict_proba
699+
)
663700

664701
@available_if(pipeline._final_estimator_has("decision_function"))
665702
def decision_function(self, X, **params):
@@ -691,20 +728,23 @@ def decision_function(self, X, **params):
691728
y_score : ndarray of shape (n_samples, n_classes)
692729
Result of calling `decision_function` on the final estimator.
693730
"""
694-
_raise_for_params(params, self, "decision_function")
731+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
732+
with _raise_or_warn_if_not_fitted(self):
733+
_raise_for_params(params, self, "decision_function")
695734

696-
# not branching here since params is only available if
697-
# enable_metadata_routing=True
698-
routed_params = process_routing(self, "decision_function", **params)
735+
# not branching here since params is only available if
736+
# enable_metadata_routing=True
737+
routed_params = process_routing(self, "decision_function", **params)
699738

700-
Xt = X
701-
for _, name, transform in self._iter(with_final=False):
702-
Xt = transform.transform(
703-
Xt, **routed_params.get(name, {}).get("transform", {})
739+
Xt = X
740+
for _, name, transform in self._iter(with_final=False):
741+
Xt = transform.transform(
742+
Xt, **routed_params.get(name, {}).get("transform", {})
743+
)
744+
return self.steps[-1][1].decision_function(
745+
Xt,
746+
**routed_params.get(self.steps[-1][0], {}).get("decision_function", {}),
704747
)
705-
return self.steps[-1][1].decision_function(
706-
Xt, **routed_params.get(self.steps[-1][0], {}).get("decision_function", {})
707-
)
708748

709749
@available_if(pipeline._final_estimator_has("score_samples"))
710750
def score_samples(self, X):
@@ -726,10 +766,12 @@ def score_samples(self, X):
726766
y_score : ndarray of shape (n_samples,)
727767
Result of calling `score_samples` on the final estimator.
728768
"""
729-
Xt = X
730-
for _, _, transformer in self._iter(with_final=False):
731-
Xt = transformer.transform(Xt)
732-
return self.steps[-1][1].score_samples(Xt)
769+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
770+
with _raise_or_warn_if_not_fitted(self):
771+
Xt = X
772+
for _, _, transformer in self._iter(with_final=False):
773+
Xt = transformer.transform(Xt)
774+
return self.steps[-1][1].score_samples(Xt)
733775

734776
@available_if(pipeline._final_estimator_has("predict_log_proba"))
735777
def predict_log_proba(self, X, **params):
@@ -773,20 +815,22 @@ def predict_log_proba(self, X, **params):
773815
y_log_proba : ndarray of shape (n_samples, n_classes)
774816
Result of calling `predict_log_proba` on the final estimator.
775817
"""
776-
Xt = X
818+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
819+
with _raise_or_warn_if_not_fitted(self):
820+
Xt = X
777821

778-
if not _routing_enabled():
822+
if not _routing_enabled():
823+
for _, name, transform in self._iter(with_final=False):
824+
Xt = transform.transform(Xt)
825+
return self.steps[-1][1].predict_log_proba(Xt, **params)
826+
827+
# metadata routing enabled
828+
routed_params = process_routing(self, "predict_log_proba", **params)
779829
for _, name, transform in self._iter(with_final=False):
780-
Xt = transform.transform(Xt)
781-
return self.steps[-1][1].predict_log_proba(Xt, **params)
782-
783-
# metadata routing enabled
784-
routed_params = process_routing(self, "predict_log_proba", **params)
785-
for _, name, transform in self._iter(with_final=False):
786-
Xt = transform.transform(Xt, **routed_params[name].transform)
787-
return self.steps[-1][1].predict_log_proba(
788-
Xt, **routed_params[self.steps[-1][0]].predict_log_proba
789-
)
830+
Xt = transform.transform(Xt, **routed_params[name].transform)
831+
return self.steps[-1][1].predict_log_proba(
832+
Xt, **routed_params[self.steps[-1][0]].predict_log_proba
833+
)
790834

791835
def _can_transform(self):
792836
return self._final_estimator == "passthrough" or hasattr(
@@ -826,15 +870,17 @@ def transform(self, X, **params):
826870
Xt : ndarray of shape (n_samples, n_transformed_features)
827871
Transformed data.
828872
"""
829-
_raise_for_params(params, self, "transform")
830-
831-
# not branching here since params is only available if
832-
# enable_metadata_routing=True
833-
routed_params = process_routing(self, "transform", **params)
834-
Xt = X
835-
for _, name, transform in self._iter():
836-
Xt = transform.transform(Xt, **routed_params[name].transform)
837-
return Xt
873+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
874+
with _raise_or_warn_if_not_fitted(self):
875+
_raise_for_params(params, self, "transform")
876+
877+
# not branching here since params is only available if
878+
# enable_metadata_routing=True
879+
routed_params = process_routing(self, "transform", **params)
880+
Xt = X
881+
for _, name, transform in self._iter():
882+
Xt = transform.transform(Xt, **routed_params[name].transform)
883+
return Xt
838884

839885
def _can_inverse_transform(self):
840886
return all(hasattr(t, "inverse_transform") for _, _, t in self._iter())
@@ -869,17 +915,19 @@ def inverse_transform(self, Xt, **params):
869915
Inverse transformed data, that is, data in the original feature
870916
space.
871917
"""
872-
_raise_for_params(params, self, "inverse_transform")
873-
874-
# we don't have to branch here, since params is only non-empty if
875-
# enable_metadata_routing=True.
876-
routed_params = process_routing(self, "inverse_transform", **params)
877-
reverse_iter = reversed(list(self._iter()))
878-
for _, name, transform in reverse_iter:
879-
Xt = transform.inverse_transform(
880-
Xt, **routed_params[name].inverse_transform
881-
)
882-
return Xt
918+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
919+
with _raise_or_warn_if_not_fitted(self):
920+
_raise_for_params(params, self, "inverse_transform")
921+
922+
# we don't have to branch here, since params is only non-empty if
923+
# enable_metadata_routing=True.
924+
routed_params = process_routing(self, "inverse_transform", **params)
925+
reverse_iter = reversed(list(self._iter()))
926+
for _, name, transform in reverse_iter:
927+
Xt = transform.inverse_transform(
928+
Xt, **routed_params[name].inverse_transform
929+
)
930+
return Xt
883931

884932
@available_if(pipeline._final_estimator_has("score"))
885933
def score(self, X, y=None, sample_weight=None, **params):
@@ -918,24 +966,28 @@ def score(self, X, y=None, sample_weight=None, **params):
918966
score : float
919967
Result of calling `score` on the final estimator.
920968
"""
921-
Xt = X
922-
if not _routing_enabled():
923-
for _, name, transform in self._iter(with_final=False):
924-
Xt = transform.transform(Xt)
925-
score_params = {}
926-
if sample_weight is not None:
927-
score_params["sample_weight"] = sample_weight
928-
return self.steps[-1][1].score(Xt, y, **score_params)
929-
930-
# metadata routing is enabled.
931-
routed_params = process_routing(
932-
self, "score", sample_weight=sample_weight, **params
933-
)
969+
# TODO(1.8): Remove the context manager and use check_is_fitted(self)
970+
with _raise_or_warn_if_not_fitted(self):
971+
Xt = X
972+
if not _routing_enabled():
973+
for _, name, transform in self._iter(with_final=False):
974+
Xt = transform.transform(Xt)
975+
score_params = {}
976+
if sample_weight is not None:
977+
score_params["sample_weight"] = sample_weight
978+
return self.steps[-1][1].score(Xt, y, **score_params)
979+
980+
# metadata routing is enabled.
981+
routed_params = process_routing(
982+
self, "score", sample_weight=sample_weight, **params
983+
)
934984

935-
Xt = X
936-
for _, name, transform in self._iter(with_final=False):
937-
Xt = transform.transform(Xt, **routed_params[name].transform)
938-
return self.steps[-1][1].score(Xt, y, **routed_params[self.steps[-1][0]].score)
985+
Xt = X
986+
for _, name, transform in self._iter(with_final=False):
987+
Xt = transform.transform(Xt, **routed_params[name].transform)
988+
return self.steps[-1][1].score(
989+
Xt, y, **routed_params[self.steps[-1][0]].score
990+
)
939991

940992
# TODO: once scikit-learn >= 1.4, the following function should be simplified by
941993
# calling `super().get_metadata_routing()`

imblearn/tests/test_pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
R_TOL = 1e-4
5050

5151

52-
class NoFit:
52+
class NoFit(BaseEstimator):
5353
"""Small class to test parameter dispatching."""
5454

5555
def __init__(self, a=None, b=None):
@@ -109,6 +109,9 @@ def predict(self, X):
109109
def score(self, X, y=None):
110110
return np.sum(X)
111111

112+
def __sklearn_is_fitted__(self):
113+
return True
114+
112115

113116
class FitParamT(BaseEstimator):
114117
"""Mock classifier"""
@@ -118,6 +121,7 @@ def __init__(self):
118121

119122
def fit(self, X, y, should_succeed=False):
120123
self.successful = should_succeed
124+
self.fitted_ = True
121125

122126
def predict(self, X):
123127
return self.successful
@@ -146,6 +150,9 @@ def fit(self, X, y):
146150
class DummyEstimatorParams(BaseEstimator):
147151
"""Mock classifier that takes params on predict"""
148152

153+
def __sklearn_is_fitted__(self):
154+
return True
155+
149156
def fit(self, X, y):
150157
return self
151158

0 commit comments

Comments
 (0)