Skip to content

Commit db85c84

Browse files
chkoarglemaitre
authored andcommitted
Pipeline checks (#166)
* All intermediate estimators should not implement both sample and transform methods * All intermediate estimators should not be or inherit from the Pipeline class
1 parent 23c0b4c commit db85c84

File tree

2 files changed

+72
-6
lines changed

2 files changed

+72
-6
lines changed

imblearn/pipeline.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,32 @@
2424
__all__ = ['Pipeline']
2525

2626

27+
def _validate_step_methods(step):
28+
29+
if (not (hasattr(step, "fit") or hasattr(step, "fit_transform") or hasattr(
30+
step, "fit_sample")) or
31+
not (hasattr(step, "transform") or hasattr(step, "sample"))):
32+
raise TypeError(
33+
"All intermediate steps of the chain should "
34+
"be estimators that implement fit and transform or sample (but not both)"
35+
" '%s' (type %s) doesn't)" % (step, type(t)))
36+
37+
38+
def _validate_step_behaviour(step):
39+
if (hasattr(step, "fit_sample") and hasattr(step, "fit_transform")) or (
40+
hasattr(step, "sample") and hasattr(step, "transform")):
41+
raise TypeError(
42+
"All intermediate steps of the chain should "
43+
"be estimators that implement fit and transform or sample."
44+
" '%s' implements both)" % (step))
45+
46+
47+
def _validate_step_class(step):
48+
if isinstance(step, pipeline.Pipeline):
49+
raise TypeError(
50+
"All intermediate steps of the chain should not be Pipelines")
51+
52+
2753
class Pipeline(pipeline.Pipeline):
2854

2955
"""Pipeline of transforms and resamples with a final estimator.
@@ -104,12 +130,9 @@ def __init__(self, steps):
104130
estimator = estimators[-1]
105131

106132
for t in transforms:
107-
if (not (hasattr(t, "fit") or hasattr(t, "fit_transform") or
108-
hasattr(t, "fit_sample")) or
109-
not (hasattr(t, "transform") or hasattr(t, "sample"))):
110-
raise TypeError("All intermediate steps of the chain should "
111-
"be transforms and implement fit and transform"
112-
" '%s' (type %s) doesn't)" % (t, type(t)))
133+
_validate_step_methods(t)
134+
_validate_step_behaviour(t)
135+
_validate_step_class(t)
113136

114137
if not hasattr(estimator, "fit"):
115138
raise TypeError("Last step of chain should implement fit "

imblearn/tests/test_pipeline.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ def predict(self, X):
7474
return self.successful
7575

7676

77+
78+
class FitTransformSample(T):
79+
"""Mock classifier
80+
"""
81+
82+
def fit(self, X, y, should_succeed=False):
83+
pass
84+
85+
def sample(X, y=None):
86+
return X, y
87+
88+
def transform(X, y=None):
89+
return X
90+
7791
def test_pipeline_init():
7892
# Test the various init parameters of the pipeline.
7993
assert_raises(TypeError, Pipeline)
@@ -431,3 +445,32 @@ def test_pipeline_methods_anova_rus():
431445
pipe.predict_proba(X)
432446
pipe.predict_log_proba(X)
433447
pipe.score(X, y)
448+
449+
450+
451+
452+
def test_pipeline_with_step_that_implements_both_sample_and_transform():
453+
# Test the various methods of the pipeline (anova).
454+
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
455+
n_informative=3, n_redundant=1, flip_y=0,
456+
n_features=20, n_clusters_per_class=1,
457+
n_samples=5000, random_state=0)
458+
459+
clf = LogisticRegression()
460+
assert_raises(TypeError, Pipeline, [('step', FitTransformSample()), ('logistic', clf)])
461+
#assert_raises(TypeError, lambda x: [][0])
462+
463+
464+
def test_pipeline_with_step_that_it_is_pipeline():
465+
# Test the various methods of the pipeline (anova).
466+
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
467+
n_informative=3, n_redundant=1, flip_y=0,
468+
n_features=20, n_clusters_per_class=1,
469+
n_samples=5000, random_state=0)
470+
# Test with RandomUnderSampling + Anova + LogisticRegression
471+
clf = LogisticRegression()
472+
rus = RandomUnderSampler(random_state=0)
473+
filter1 = SelectKBest(f_classif, k=2)
474+
pipe1 = Pipeline([('rus', rus), ('anova', filter1)])
475+
assert_raises(TypeError, Pipeline, [('pipe1', pipe1), ('logistic', clf)])
476+

0 commit comments

Comments
 (0)