Skip to content

Commit 1086fb6

Browse files
authored
DOC add common pitfalls regarding data leakage in sampling (#776)
1 parent 9f3872d commit 1086fb6

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

doc/common_pitfalls.rst

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
.. _common_pitfalls:
2+
3+
=========================================
4+
Common pitfalls and recommended practices
5+
=========================================
6+
7+
This section is a complement to the documentation given
8+
`[here] <https://scikit-learn.org/dev/common_pitfalls.html>`_ in scikit-learn.
9+
Indeed, we will highlight the issue of misusing resampling, leading to a
10+
**data leakage**. Due to this leakage, the performance of a model reported
11+
will be over-optimistic.
12+
13+
Data leakage
14+
============
15+
16+
As mentioned in the scikit-learn documentation, data leakage occurs when
17+
information that would not be available at prediction time is used when
18+
building the model.
19+
20+
In the resampling setting, there is a common pitfall that corresponds to
21+
resample the **entire** dataset before splitting it into a train and a test
22+
partitions. Note that it would be equivalent to resample the train and test
23+
partitions as well.
24+
25+
Such of a processing leads to two issues:
26+
27+
* the model will not be tested on a dataset with class distribution similar
28+
to the real use-case. Indeed, by resampling the entire dataset, both the
29+
training and testing set will be potentially balanced while the model should
30+
be tested on the natural imbalanced dataset to evaluate the potential bias
31+
of the model;
32+
* the resampling procedure might use information about samples in the dataset
33+
to either generate or select some of the samples. Therefore, we might use
34+
information of samples which will be later used as testing samples which
35+
is the typical data leakage issue.
36+
37+
We will demonstrate the wrong and right ways to do some sampling and emphasize
38+
the tools that one should use, avoiding to fall in the trap.
39+
40+
We will use the adult census dataset. For the sake of simplicity, we will only
41+
use the numerical features. Also, we will make the dataset more imbalanced to
42+
increase the effect of the wrongdoings::
43+
44+
>>> from sklearn.datasets import fetch_openml
45+
>>> from imblearn.datasets import make_imbalance
46+
>>> X, y = fetch_openml(
47+
... data_id=1119, as_frame=True, return_X_y=True
48+
... )
49+
>>> X = X.select_dtypes(include="number")
50+
>>> X, y = make_imbalance(
51+
... X, y, sampling_strategy={">50K": 300}, random_state=1
52+
... )
53+
54+
Let's first check the balancing ratio on this dataset::
55+
56+
>>> y.value_counts(normalize=True)
57+
<=50K 0.98801
58+
>50K 0.01199
59+
Name: class, dtype: float64
60+
61+
To later highlight some of the issue, we will keep aside a left-out set that we
62+
will not use for the evaluation of the model::
63+
64+
>>> from sklearn.model_selection import train_test_split
65+
>>> X, X_left_out, y, y_left_out = train_test_split(
66+
... X, y, stratify=y, random_state=0
67+
... )
68+
69+
We will use a :class:`sklearn.ensemble.HistGradientBoostingClassifier` as a
70+
baseline classifier. First, we will train and check the performance of this
71+
classifier, without any preprocessing to alleviate the bias toward the majority
72+
class. We evaluate the generalization performance of the classifier via
73+
cross-validation::
74+
75+
>>> from sklearn.experimental import enable_hist_gradient_boosting
76+
>>> from sklearn.ensemble import HistGradientBoostingClassifier
77+
>>> from sklearn.model_selection import cross_validate
78+
>>> model = HistGradientBoostingClassifier(random_state=0)
79+
>>> cv_results = cross_validate(
80+
... model, X, y, scoring="balanced_accuracy",
81+
... return_train_score=True, return_estimator=True,
82+
... n_jobs=-1
83+
... )
84+
>>> print(
85+
... f"Balanced accuracy mean +/- std. dev.: "
86+
... f"{cv_results['test_score'].mean():.3f} +/- "
87+
... f"{cv_results['test_score'].std():.3f}"
88+
... )
89+
Balanced accuracy mean +/- std. dev.: 0.609 +/- 0.024
90+
91+
We see that the classifier does not give good performance in terms of balanced
92+
accuracy mainly due to the class imbalance issue.
93+
94+
In the cross-validation, we stored the different classifiers of all folds. We
95+
will show that evaluating these classifiers on the left-out data will give
96+
close statistical performance::
97+
98+
>>> import numpy as np
99+
>>> from sklearn.metrics import balanced_accuracy_score
100+
>>> scores = []
101+
>>> for fold_id, cv_model in enumerate(cv_results["estimator"]):
102+
... scores.append(
103+
... balanced_accuracy_score(
104+
... y_left_out, cv_model.predict(X_left_out)
105+
... )
106+
... )
107+
>>> print(
108+
... f"Balanced accuracy mean +/- std. dev.: "
109+
... f"{np.mean(scores):.3f} +/- {np.std(scores):.3f}"
110+
... )
111+
Balanced accuracy mean +/- std. dev.: 0.628 +/- 0.009
112+
113+
Let's now show the **wrong** pattern to apply when it comes to resampling to
114+
alleviate the class imbalance issue. We will use a sampler to balance the
115+
**entire** dataset and check the statistical performance of our classifier via
116+
cross-validation::
117+
118+
>>> from imblearn.under_sampling import RandomUnderSampler
119+
>>> sampler = RandomUnderSampler(random_state=0)
120+
>>> X_resampled, y_resampled = sampler.fit_resample(X, y)
121+
>>> model = HistGradientBoostingClassifier(random_state=0)
122+
>>> cv_results = cross_validate(
123+
... model, X_resampled, y_resampled, scoring="balanced_accuracy",
124+
... return_train_score=True, return_estimator=True,
125+
... n_jobs=-1
126+
... )
127+
>>> print(
128+
... f"Balanced accuracy mean +/- std. dev.: "
129+
... f"{cv_results['test_score'].mean():.3f} +/- "
130+
... f"{cv_results['test_score'].std():.3f}"
131+
... )
132+
Balanced accuracy mean +/- std. dev.: 0.724 +/- 0.042
133+
134+
We see that the statistical performance are worse than in the previous case.
135+
Indeed, the data leakage gave us too optimistic results due to the reason
136+
stated earlier in this section.
137+
138+
We will now illustrate the correct pattern to use. Indeed, as in scikit-learn,
139+
using a :class:`~imblearn.pipeline.Pipeline` avoids to make any data leakage
140+
because the resampling will be delegated to imbalanced-learn and does not
141+
require any manual steps::
142+
143+
>>> from imblearn.pipeline import make_pipeline
144+
>>> model = make_pipeline(
145+
... RandomUnderSampler(random_state=0),
146+
... HistGradientBoostingClassifier(random_state=0)
147+
... )
148+
>>> cv_results = cross_validate(
149+
... model, X, y, scoring="balanced_accuracy",
150+
... return_train_score=True, return_estimator=True,
151+
... n_jobs=-1
152+
... )
153+
>>> print(
154+
... f"Balanced accuracy mean +/- std. dev.: "
155+
... f"{cv_results['test_score'].mean():.3f} +/- "
156+
... f"{cv_results['test_score'].std():.3f}"
157+
... )
158+
Balanced accuracy mean +/- std. dev.: 0.732 +/- 0.019
159+
160+
We observe that we get good statistical performance as well. However, now we
161+
can check the performance of the model from each cross-validation fold to
162+
ensure that we have similar performance::
163+
164+
>>> scores = []
165+
>>> for fold_id, cv_model in enumerate(cv_results["estimator"]):
166+
... scores.append(
167+
... balanced_accuracy_score(
168+
... y_left_out, cv_model.predict(X_left_out)
169+
... )
170+
... )
171+
>>> print(
172+
... f"Balanced accuracy mean +/- std. dev.: "
173+
... f"{np.mean(scores):.3f} +/- {np.std(scores):.3f}"
174+
... )
175+
Balanced accuracy mean +/- std. dev.: 0.727 +/- 0.008
176+
177+
We see that the statistical performance are very close to the cross-validation
178+
study that we perform, without any sign of over-optimistic results.

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ User Guide
1919
ensemble.rst
2020
miscellaneous.rst
2121
metrics.rst
22+
common_pitfalls.rst
2223
Dataset loading utilities <datasets/index.rst>
2324
developers_utils.rst
2425
zzz_references.rst

0 commit comments

Comments
 (0)