Skip to content

Commit 99d4c40

Browse files
authored
Merge pull request #20 from agramfort/cosmits
cosmit
2 parents 55793f4 + 952b464 commit 99d4c40

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

examples/plot_credit_default.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
###############################################################################
1919
# Data import and preparation
20-
# ..................
20+
# ...........................
2121
#
2222
# There are 3 categorical variables (SEX, EDUCATION and MARRIAGE) and 20
2323
# numerical variables.
@@ -69,9 +69,7 @@
6969
data['PAY_AMT_old_std'] = data[old_PAY_AMT].apply(
7070
lambda x: np.std(x), axis=1)
7171

72-
data = data.drop(old_PAY_AMT, axis=1)
73-
data = data.drop(old_BILL_AMT, axis=1)
74-
data = data.drop(old_PAY, axis=1)
72+
data.drop(old_PAY_AMT + old_BILL_AMT + old_PAY, axis=1, inplace=True)
7573

7674
# Creating the train/test split
7775
feature_names = list(data.columns)
@@ -85,44 +83,42 @@
8583
X_test = data[n_samples_train:]
8684

8785
###############################################################################
88-
# Benchmark with a Random Forest classifier.
89-
# ..................
86+
# Benchmark with a Random Forest classifier
87+
# .........................................
9088
#
9189
# This part shows the training and performance evaluation of a random forest
9290
# model. The objective remains to extract rules which targets credit defaults.
9391

94-
RF = GridSearchCV(
92+
rf = GridSearchCV(
9593
RandomForestClassifier(
9694
random_state=rng,
9795
n_estimators=30,
9896
class_weight='balanced'),
99-
param_grid={
100-
'max_depth': range(3, 8, 1),
101-
'max_features': np.linspace(0.1, 1., 5)
102-
},
97+
param_grid={'max_depth': range(3, 8, 1),
98+
'max_features': np.linspace(0.1, 1., 5)},
10399
scoring={'AUC': 'roc_auc'}, cv=5,
104100
refit='AUC', n_jobs=-1)
105101

106-
RF.fit(X_train, y_train)
107-
scoring_RF = RF.predict_proba(X_test)[:, 1]
102+
rf.fit(X_train, y_train)
103+
scoring_rf = rf.predict_proba(X_test)[:, 1]
108104

109-
print("Random Forest selected parameters : " + str(RF.best_params_))
105+
print("Random Forest selected parameters : %s" % rf.best_params_)
110106

111107
# Plot ROC and PR curves
112108

113109
fig, axes = plt.subplots(1, 2, figsize=(12, 5),
114110
sharex=True, sharey=True)
115111

116112
ax = axes[0]
117-
fpr_RF, tpr_RF, _ = roc_curve(y_test, scoring_RF)
113+
fpr_RF, tpr_RF, _ = roc_curve(y_test, scoring_rf)
118114
ax.step(fpr_RF, tpr_RF, linestyle='-.', c='g', lw=1, where='post')
119115
ax.set_title("ROC", fontsize=20)
120116
ax.legend(loc='upper center', fontsize=8)
121117
ax.set_xlabel('False Positive Rate', fontsize=18)
122118
ax.set_ylabel('True Positive Rate (Recall)', fontsize=18)
123119

124120
ax = axes[1]
125-
precision_RF, recall_RF, _ = precision_recall_curve(y_test, scoring_RF)
121+
precision_RF, recall_RF, _ = precision_recall_curve(y_test, scoring_rf)
126122
ax.step(recall_RF, precision_RF, linestyle='-.', c='g', lw=1, where='post')
127123
ax.set_title("Precision-Recall", fontsize=20)
128124
ax.set_xlabel('Recall (True Positive Rate)', fontsize=18)
@@ -145,7 +141,7 @@
145141

146142
###############################################################################
147143
# Getting rules with skrules
148-
# ..................
144+
# ..........................
149145
#
150146
# This part shows how SkopeRules can be fitted to detect credit defaults.
151147
# Performances are compared with the random forest model previously trained.
@@ -155,8 +151,7 @@
155151
clf = SkopeRules(
156152
similarity_thres=.9, max_depth=3, max_features=0.5,
157153
max_samples_features=0.5, random_state=rng, n_estimators=30,
158-
feature_names=feature_names, recall_min=0.02, precision_min=0.6
159-
)
154+
feature_names=feature_names, recall_min=0.02, precision_min=0.6)
160155
clf.fit(X_train, y_train)
161156

162157
# in the separate_rules_score method, a score of k means that rule number k
@@ -178,7 +173,7 @@
178173

179174
ax = axes[0]
180175
fpr, tpr, _ = roc_curve(y_test, scoring)
181-
fpr_RF, tpr_RF, _ = roc_curve(y_test, scoring_RF)
176+
fpr_rf, tpr_rf, _ = roc_curve(y_test, scoring_rf)
182177
ax.scatter(fpr[:-1], tpr[:-1], c='b', s=10)
183178
ax.step(fpr_RF, tpr_RF, linestyle='-.', c='g', lw=1, where='post')
184179
ax.set_title("ROC", fontsize=20)
@@ -188,7 +183,7 @@
188183

189184
ax = axes[1]
190185
precision, recall, _ = precision_recall_curve(y_test, scoring)
191-
precision_RF, recall_RF, _ = precision_recall_curve(y_test, scoring_RF)
186+
precision_rf, recall_rf, _ = precision_recall_curve(y_test, scoring_rf)
192187
ax.scatter(recall[1:-1], precision[1:-1], c='b', s=10)
193188
ax.step(recall_RF, precision_RF, linestyle='-.', c='g', lw=1, where='post')
194189
ax.set_title("Precision-Recall", fontsize=20)
@@ -198,7 +193,7 @@
198193

199194
###############################################################################
200195
# The ROC and Precision-Recall curves show the performance of the rules
201-
# generated by SkopeRulesthe (the blue points) and the performance of the
196+
# generated by SkopeRules the (the blue points) and the performance of the
202197
# Random Forest classifier fitted above.
203198
# Each blue point represents the performance of a set of rules: The kth point
204199
# represents the score associated to the concatenation (union) of the k first

0 commit comments

Comments
 (0)