Skip to content

Commit ee949b9

Browse files
chkoarglemaitre
authored andcommitted
[MRG] Add an example using validation curves (#203)
* Add an example using validation cruves * pep8
1 parent b45a3e4 commit ee949b9

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
==========================
3+
Plotting Validation Curves
4+
==========================
5+
6+
In this example the impact of the SMOTE's k_neighbors parameter is examined.
7+
In the plot you can see the validation scores of a SMOTE-CART classifier for
8+
different values of the SMOTE's k_neighbors parameter.
9+
"""
10+
11+
12+
import matplotlib.pyplot as plt
13+
import numpy as np
14+
from sklearn import model_selection as ms
15+
from sklearn import datasets, metrics, tree
16+
17+
from imblearn import over_sampling as os
18+
from imblearn import pipeline as pl
19+
20+
print(__doc__)
21+
22+
LW = 2
23+
RANDOM_STATE = 42
24+
25+
scorer = metrics.make_scorer(metrics.cohen_kappa_score)
26+
27+
# Generate the dataset
28+
X, y = datasets.make_classification(n_classes=2, class_sep=2,
29+
weights=[0.1, 0.9], n_informative=10,
30+
n_redundant=1, flip_y=0, n_features=20,
31+
n_clusters_per_class=4, n_samples=5000,
32+
random_state=10)
33+
smote = os.SMOTE(random_state=RANDOM_STATE)
34+
cart = tree.DecisionTreeClassifier(random_state=RANDOM_STATE)
35+
pipeline = pl.make_pipeline(smote, cart)
36+
37+
param_range = range(1, 11)
38+
train_scores, test_scores = ms.validation_curve(
39+
pipeline, X, y, param_name="smote__k_neighbors", param_range=param_range,
40+
cv=3, scoring=scorer, n_jobs=1)
41+
train_scores_mean = np.mean(train_scores, axis=1)
42+
train_scores_std = np.std(train_scores, axis=1)
43+
test_scores_mean = np.mean(test_scores, axis=1)
44+
test_scores_std = np.std(test_scores, axis=1)
45+
46+
plt.title("Validation Curve with SMOTE-CART")
47+
plt.xlabel("k_neighbors")
48+
plt.ylabel("Cohen's kappa")
49+
plt.plot(param_range, test_scores_mean, color="navy", lw=LW)
50+
plt.legend(loc="best")
51+
plt.show()

0 commit comments

Comments
 (0)