Skip to content

Commit 25ee4fa

Browse files
chkoarglemaitre
authored andcommitted
[MRG] Benchmark over-sampling methods in a face regognition task (#198)
* Benchmark over-sampling methods using a 3NN classifier * fulfil the review
1 parent 62f6d2f commit 25ee4fa

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
=============================================================
3+
Benchmark over-sampling methods in a face regognition task
4+
=============================================================
5+
In this face recognition example two faces are used from the LFW
6+
(Faces in the Wild) dataset. Several implemented over-sampling
7+
methods are used in conjunction with a 3NN classifier in order
8+
to examine the improvement of the classifier's output quality
9+
by using an over-sampler.
10+
"""
11+
print(__doc__)
12+
13+
import matplotlib.pyplot as plt
14+
import numpy as np
15+
from scipy import interp
16+
from sklearn import datasets, neighbors
17+
from sklearn.metrics import auc, roc_curve
18+
from sklearn.model_selection import StratifiedKFold
19+
20+
from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
21+
from imblearn.pipeline import make_pipeline
22+
23+
LW = 2
24+
RANDOM_STATE = 42
25+
26+
27+
class DummySampler(object):
28+
29+
def sample(self, X, y):
30+
return X, y
31+
32+
def fit(self, X, y):
33+
return self
34+
35+
def fit_sample(self, X, y):
36+
return self.sample(X, y)
37+
38+
39+
cv = StratifiedKFold(n_splits=3)
40+
41+
# Load the dataset
42+
data = datasets.fetch_lfw_people()
43+
majority_person = 1871 # 530 photos of George W Bush
44+
minority_person = 531 # 29 photos of Bill Clinton
45+
majority_idxs = np.flatnonzero(data.target == majority_person)
46+
minority_idxs = np.flatnonzero(data.target == minority_person)
47+
idxs = np.hstack((majority_idxs, minority_idxs))
48+
49+
X = data.data[idxs]
50+
y = data.target[idxs]
51+
y[y == majority_person] = 0
52+
y[y == minority_person] = 1
53+
54+
55+
classifier = ['3NN', neighbors.KNeighborsClassifier(3)]
56+
57+
samplers = [
58+
['Standard', DummySampler()],
59+
['ADASYN', ADASYN(random_state=RANDOM_STATE)],
60+
['ROS', RandomOverSampler(random_state=RANDOM_STATE)],
61+
['SMOTE', SMOTE(random_state=RANDOM_STATE)],
62+
]
63+
64+
pipelines = [
65+
['{}-{}'.format(sampler[0], classifier[0]),
66+
make_pipeline(sampler[1], classifier[1])]
67+
for sampler in samplers
68+
]
69+
70+
71+
for name, pipeline in pipelines:
72+
mean_tpr = 0.0
73+
mean_fpr = np.linspace(0, 1, 100)
74+
for train, test in cv.split(X, y):
75+
probas_ = pipeline.fit(X[train], y[train]).predict_proba(X[test])
76+
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
77+
mean_tpr += interp(mean_fpr, fpr, tpr)
78+
mean_tpr[0] = 0.0
79+
roc_auc = auc(fpr, tpr)
80+
81+
mean_tpr /= cv.get_n_splits(X, y)
82+
mean_tpr[-1] = 1.0
83+
mean_auc = auc(mean_fpr, mean_tpr)
84+
plt.plot(mean_fpr, mean_tpr, linestyle='--',
85+
label='{} (area = %0.2f)'.format(name) % mean_auc, lw=LW)
86+
87+
plt.xlim([-0.05, 1.05])
88+
plt.ylim([-0.05, 1.05])
89+
plt.xlabel('False Positive Rate')
90+
plt.ylabel('True Positive Rate')
91+
plt.title('Receiver operating characteristic example')
92+
plt.legend(loc="lower right")
93+
94+
plt.plot([0, 1], [0, 1], linestyle='--', lw=LW, color='k',
95+
label='Luck')
96+
97+
plt.show()

0 commit comments

Comments
 (0)