|
| 1 | +""" |
| 2 | +========================== |
| 3 | +Plotting Validation Curves |
| 4 | +========================== |
| 5 | +
|
| 6 | +In this example the impact of the SMOTE's k_neighbors parameter is examined. |
| 7 | +In the plot you can see the validation scores of a SMOTE-CART classifier for |
| 8 | +different values of the SMOTE's k_neighbors parameter. |
| 9 | +""" |
| 10 | + |
| 11 | + |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import numpy as np |
| 14 | +from sklearn import model_selection as ms |
| 15 | +from sklearn import datasets, metrics, tree |
| 16 | + |
| 17 | +from imblearn import over_sampling as os |
| 18 | +from imblearn import pipeline as pl |
| 19 | + |
| 20 | +print(__doc__) |
| 21 | + |
| 22 | +LW = 2 |
| 23 | +RANDOM_STATE = 42 |
| 24 | + |
| 25 | +scorer = metrics.make_scorer(metrics.cohen_kappa_score) |
| 26 | + |
| 27 | +# Generate the dataset |
| 28 | +X, y = datasets.make_classification(n_classes=2, class_sep=2, |
| 29 | + weights=[0.1, 0.9], n_informative=10, |
| 30 | + n_redundant=1, flip_y=0, n_features=20, |
| 31 | + n_clusters_per_class=4, n_samples=5000, |
| 32 | + random_state=10) |
| 33 | +smote = os.SMOTE(random_state=RANDOM_STATE) |
| 34 | +cart = tree.DecisionTreeClassifier(random_state=RANDOM_STATE) |
| 35 | +pipeline = pl.make_pipeline(smote, cart) |
| 36 | + |
| 37 | +param_range = range(1, 11) |
| 38 | +train_scores, test_scores = ms.validation_curve( |
| 39 | + pipeline, X, y, param_name="smote__k_neighbors", param_range=param_range, |
| 40 | + cv=3, scoring=scorer, n_jobs=1) |
| 41 | +train_scores_mean = np.mean(train_scores, axis=1) |
| 42 | +train_scores_std = np.std(train_scores, axis=1) |
| 43 | +test_scores_mean = np.mean(test_scores, axis=1) |
| 44 | +test_scores_std = np.std(test_scores, axis=1) |
| 45 | + |
| 46 | +plt.title("Validation Curve with SMOTE-CART") |
| 47 | +plt.xlabel("k_neighbors") |
| 48 | +plt.ylabel("Cohen's kappa") |
| 49 | +plt.plot(param_range, test_scores_mean, color="navy", lw=LW) |
| 50 | +plt.legend(loc="best") |
| 51 | +plt.show() |
0 commit comments