Skip to content

Commit d544dc4

Browse files
committed
MAINT couple of fixes
1 parent bf260c4 commit d544dc4

File tree

16 files changed

+359
-291
lines changed

16 files changed

+359
-291
lines changed

doc/cross_validation.rst renamed to doc/model_selection.rst

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,25 @@
44
Cross validation
55
================
66

7-
.. currentmodule:: imblearn.cross_validation
7+
.. currentmodule:: imblearn.model_selection
88

99

10-
.. _instance_hardness_threshold:
10+
.. _instance_hardness_threshold_cv:
1111

12-
The term instance hardness is used in literature to express the difficulty to
13-
correctly classify an instance. An instance for which the predicted probability
14-
of the true class is low, has large instance hardness. The way these
15-
hard-to-classify instances are distributed over train and test sets in cross
16-
validation, has significant effect on the test set performance metrics. The
17-
`InstanceHardnessCV` splitter distributes samples with large instance hardness
18-
equally over the folds, resulting in more robust cross validation.
12+
The term instance hardness is used in literature to express the difficulty to correctly
13+
classify an instance. An instance for which the predicted probability of the true class
14+
is low, has large instance hardness. The way these hard-to-classify instances are
15+
distributed over train and test sets in cross validation, has significant effect on the
16+
test set performance metrics. The :class:`~imblearn.model_selection.InstanceHardnessCV`
17+
splitter distributes samples with large instance hardness equally over the folds,
18+
resulting in more robust cross validation.
1919

2020
We will discuss instance hardness in this document and explain how to use the
21-
`InstanceHardnessCV` splitter.
21+
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter.
2222

2323
Instance hardness and average precision
2424
=======================================
25+
2526
Instance hardness is defined as 1 minus the probability of the most probable class:
2627

2728
.. math::
@@ -32,7 +33,7 @@ In this equation :math:`H(x)` is the instance hardness for a sample with feature
3233
:math:`x` and :math:`P(\hat{y}|x)` the probability of predicted label :math:`\hat{y}`
3334
given the features. If the model predicts label 0 and gives a `predict_proba` output
3435
of [0.9, 0.1], the probability of the most probable class (0) is 0.9 and the
35-
instance hardness is 1-0.9=0.1.
36+
instance hardness is `1-0.9=0.1`.
3637

3738
Samples with large instance hardness have significant effect on the area under
3839
precision-recall curve, or average precision. Especially samples with label 0
@@ -42,7 +43,7 @@ where the area is largest; the precision is lowered in the range of low recall
4243
and high thresholds. When doing cross validation, e.g. in case of hyperparameter
4344
tuning or recursive feature elimination, random gathering of these points in
4445
some folds introduce variance in CV results that deteriorates robustness of the
45-
cross validation task. The `InstanceHardnessCV`
46+
cross validation task. The :class:`~imblearn.model_selection.InstanceHardnessCV`
4647
splitter aims to distribute the samples with large instance hardness over the
4748
folds in order to reduce undesired variance. Note that one should use this
4849
splitter to make model *selection* tasks robust like hyperparameter tuning and
@@ -53,8 +54,8 @@ want to know the variance of performance to be expected in production.
5354
Create imbalanced dataset with samples with large instance hardness
5455
===================================================================
5556

56-
Lets start by creating a dataset to work with. We create a dataset with 5% class
57-
imbalance using scikit-learn’s `make_blobs` function.
57+
Let's start by creating a dataset to work with. We create a dataset with 5% class
58+
imbalance using scikit-learn's :func:`~sklearn.datasets.make_blobs` function.
5859

5960
>>> import numpy as np
6061
>>> from matplotlib import pyplot as plt
@@ -66,8 +67,8 @@ imbalance using scikit-learn’s `make_blobs` function.
6667
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
6768
>>> plt.show()
6869

69-
.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_001.png
70-
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
70+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_001.png
71+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
7172
:align: center
7273

7374
Now we add some samples with large instance hardness
@@ -80,40 +81,48 @@ Now we add some samples with large instance hardness
8081
>>> plt.scatter(X[:, 0], X[:, 1], c=y)
8182
>>> plt.show()
8283

83-
.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_002.png
84-
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
84+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_002.png
85+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
8586
:align: center
8687

87-
Assess cross validation performance variance using InstanceHardnessCV splitter
88-
==============================================================================
88+
Assess cross validation performance variance using `InstanceHardnessCV` splitter
89+
================================================================================
8990

90-
Then we take a `LogisticRegressionClassifier` and assess the cross validation
91-
performance using a `StratifiedKFold` cv splitter and the `cross_validate`
92-
function.
91+
Then we take a :class:`~sklearn.linear_model.LogisticRegression` and assess the
92+
cross validation performance using a :class:`~sklearn.model_selection.StratifiedKFold`
93+
cv splitter and the :func:`~sklearn.model_selection.cross_validate` function.
9394

9495
>>> from sklearn.ensemble import LogisticRegressionClassifier
9596
>>> clf = LogisticRegressionClassifier(random_state=random_state)
9697
>>> skf_cv = StratifiedKFold(n_splits=5, shuffle=True,
9798
... random_state=random_state)
9899
>>> skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")
99100

100-
Now, we do the same using an `InstanceHardnessCV` splitter. We use provide our
101-
classifier to the splitter to calculate instance hardness and distribute samples
102-
with large instance hardness equally over the folds.
101+
Now, we do the same using an :class:`~imblearn.model_selection.InstanceHardnessCV`
102+
splitter. We use provide our classifier to the splitter to calculate instance hardness
103+
and distribute samples with large instance hardness equally over the folds.
103104

104105
>>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5,
105106
... random_state=random_state)
106107
>>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")
107108

108-
When we plot the test scores for both cv splitters, we see that the variance using
109-
the `InstanceHardnessCV` splitter is lower than for the `StratifiedKFold` splitter.
109+
When we plot the test scores for both cv splitters, we see that the variance using the
110+
:class:`~imblearn.model_selection.InstanceHardnessCV` splitter is lower than for the
111+
:class:`~sklearn.model_selection.StratifiedKFold` splitter.
110112

111113
>>> plt.boxplot([skf_result['test_score'], ih_result['test_score']],
112114
... tick_labels=["StratifiedKFold", "InstanceHardnessCV"],
113115
... vert=False)
114116
>>> plt.xlabel('Average precision')
115117
>>> plt.tight_layout()
116118

117-
.. image:: ./auto_examples/cross_validation/images/sphx_glr_plot_instance_hardness_cv_003.png
118-
:target: ./auto_examples/cross_validation/plot_instance_hardness_cv.html
119-
:align: center
119+
.. image:: ./auto_examples/model_selection/images/sphx_glr_plot_instance_hardness_cv_003.png
120+
:target: ./auto_examples/model_selection/plot_instance_hardness_cv.html
121+
:align: center
122+
123+
Be aware that the most important part of cross-validation splitters is to simulate the
124+
conditions that one will encounter in production. Therefore, if it is likely to get
125+
difficult samples in production, one should use a cross-validation splitter that
126+
emulates this situation. In our case, the
127+
:class:`~sklearn.model_selection.StratifiedKFold` splitter did not allow to distribute
128+
the difficult samples over the folds and thus it was likely a problem for our use case.

doc/references/cross_validation.rst

Lines changed: 0 additions & 23 deletions
This file was deleted.

doc/references/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ This is the full API documentation of the `imbalanced-learn` toolbox.
1818
miscellaneous
1919
pipeline
2020
metrics
21-
cross_validation
21+
model_selection
2222
datasets
2323
utils

doc/references/model_selection.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
.. _model_selection_ref:
2+
3+
Model selection methods
4+
=======================
5+
6+
.. automodule:: imblearn.model_selection
7+
:no-members:
8+
:no-inherited-members:
9+
10+
Cross-validation splitters
11+
--------------------------
12+
13+
.. automodule:: imblearn.model_selection._split
14+
:no-members:
15+
:no-inherited-members:
16+
17+
.. currentmodule:: imblearn.model_selection
18+
19+
.. autosummary::
20+
:toctree: generated/
21+
:template: class.rst
22+
23+
InstanceHardnessCV

doc/user_guide.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ User Guide
1919
ensemble.rst
2020
miscellaneous.rst
2121
metrics.rst
22-
cross_validation.rst
22+
model_selection.rst
2323
common_pitfalls.rst
2424
Dataset loading utilities <datasets/index.rst>
2525
developers_utils.rst

examples/cross_validation/README.txt

Lines changed: 0 additions & 6 deletions
This file was deleted.

examples/cross_validation/plot_instance_hardness_cv.py

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
====================================================
3+
Distribute hard-to-classify datapoints over CV folds
4+
====================================================
5+
6+
'Instance hardness' refers to the difficulty to classify an instance. The way
7+
hard-to-classify instances are distributed over train and test sets has
8+
significant effect on the test set performance metrics. In this example we
9+
show how to deal with this problem. We are making the comparison with normal
10+
:class:`~sklearn.model_selection.StratifiedKFold` cross-validation splitter.
11+
"""
12+
13+
# Authors: Frits Hermans, https://fritshermans.github.io
14+
# License: MIT
15+
16+
# %%
17+
print(__doc__)
18+
19+
# %%
20+
# Create an imbalanced dataset with instance hardness
21+
# ---------------------------------------------------
22+
#
23+
# We create an imbalanced dataset with using scikit-learn's
24+
# :func:`~sklearn.datasets.make_blobs` function and set the class imbalance ratio to
25+
# 5%.
26+
import numpy as np
27+
from matplotlib import pyplot as plt
28+
from sklearn.datasets import make_blobs
29+
30+
X, y = make_blobs(n_samples=[950, 50], centers=((-3, 0), (3, 0)), random_state=10)
31+
plt.scatter(X[:, 0], X[:, 1], c=y)
32+
33+
# %%
34+
# To introduce instance hardness in our dataset, we add some hard to classify samples:
35+
X_hard, y_hard = make_blobs(
36+
n_samples=10, centers=((3, 0), (-3, 0)), cluster_std=1, random_state=10
37+
)
38+
X, y = np.vstack((X, X_hard)), np.hstack((y, y_hard))
39+
plt.scatter(X[:, 0], X[:, 1], c=y)
40+
41+
# %%
42+
# Compare cross validation scores using `StratifiedKFold` and `InstanceHardnessCV`
43+
# --------------------------------------------------------------------------------
44+
#
45+
# Now, we want to assess a linear predictive model. Therefore, we should use
46+
# cross-validation. The most important concept with cross-validation is to create
47+
# training and test splits that are representative of the the data in production to have
48+
# statistical results that one can expect in production.
49+
#
50+
# By applying a standard :class:`~sklearn.model_selection.StratifiedKFold`
51+
# cross-validation splitter, we do not control in which fold the hard-to-classify
52+
# samples will be.
53+
#
54+
# The :class:`~imblearn.model_selection.InstanceHardnessCV` splitter allows to
55+
# control the distribution of the hard-to-classify samples over the folds.
56+
#
57+
# Let's make an experiment to compare the results that we get with both splitters.
58+
# We use a :class:`~sklearn.linear_model.LogisticRegression` classifier and
59+
# :func:`~sklearn.model_selection.cross_validate` to calculate the cross validation
60+
# scores. We use average precision for scoring.
61+
import pandas as pd
62+
from sklearn.linear_model import LogisticRegression
63+
from sklearn.model_selection import StratifiedKFold, cross_validate
64+
65+
from imblearn.model_selection import InstanceHardnessCV
66+
67+
logistic_regression = LogisticRegression()
68+
69+
results = {}
70+
for cv in (
71+
StratifiedKFold(n_splits=5, shuffle=True, random_state=10),
72+
InstanceHardnessCV(estimator=LogisticRegression(), n_splits=5, random_state=10),
73+
):
74+
result = cross_validate(
75+
logistic_regression,
76+
X,
77+
y,
78+
cv=cv,
79+
scoring="average_precision",
80+
)
81+
results[cv.__class__.__name__] = result["test_score"]
82+
results = pd.DataFrame(results)
83+
84+
# %%
85+
ax = results.plot.box(vert=False, whis=[0, 100])
86+
ax.set(
87+
xlabel="Average precision",
88+
title="Cross validation scores with different splitters",
89+
xlim=(0, 1),
90+
)
91+
92+
# %%
93+
# The boxplot shows that the :class:`~imblearn.model_selection.InstanceHardnessCV`
94+
# splitter results in less variation of average precision than
95+
# :class:`~sklearn.model_selection.StratifiedKFold` splitter. When doing
96+
# hyperparameter tuning or feature selection using a wrapper method (like
97+
# :class:`~sklearn.feature_selection.RFECV`) this will give more stable results.

0 commit comments

Comments
 (0)