|
76 | 76 | "* Create train, validation, and test sets.\n",
|
77 | 77 | "* Define and train a model using Keras (including setting class weights).\n",
|
78 | 78 | "* Evaluate the model using various metrics (including precision and recall).\n",
|
79 |
| - "* Try common techniques for dealing with imbalanced data like:\n", |
80 |
| - " * Class weighting \n", |
81 |
| - " * Oversampling\n" |
| 79 | + "* Select a threshold for a probabilistic classifier to get a deterministic classifier.\n", |
| 80 | + "* Try and compare with class weighted modelling and oversampling." |
82 | 81 | ]
|
83 | 82 | },
|
84 | 83 | {
|
|
275 | 274 | "id": "8a_Z_kBmr7Oh"
|
276 | 275 | },
|
277 | 276 | "source": [
|
| 277 | + "We check whether the distribution of the classes in the three sets is about the same or not." |
| 278 | + ] |
| 279 | + }, |
| 280 | + { |
| 281 | + "cell_type": "code", |
| 282 | + "execution_count": null, |
| 283 | + "metadata": { |
| 284 | + "id": "96520cffee66" |
| 285 | + }, |
| 286 | + "outputs": [], |
| 287 | + "source": [ |
| 288 | + "print(f'Average class probability in training set: {train_labels.mean():.4f}')\n", |
| 289 | + "print(f'Average class probability in validation set: {val_labels.mean():.4f}')\n", |
| 290 | + "print(f'Average class probability in test set: {test_labels.mean():.4f}')" |
| 291 | + ] |
| 292 | + }, |
| 293 | + { |
| 294 | + "attachments": {}, |
| 295 | + "cell_type": "markdown", |
| 296 | + "metadata": { |
| 297 | + "id": "8a_Z_kBmr7Oh" |
| 298 | + }, |
| 299 | + "source": [ |
| 300 | + "Given the small number of positive labels, this seems about right.\n", |
| 301 | + "\n", |
278 | 302 | "Normalize the input features using the sklearn StandardScaler.\n",
|
279 | 303 | "This will set the mean to 0 and standard deviation to 1.\n",
|
280 | 304 | "\n",
|
|
374 | 398 | "outputs": [],
|
375 | 399 | "source": [
|
376 | 400 | "METRICS = [\n",
|
| 401 | + " keras.metrics.BinaryCrossentropy(name='cross entropy'), # same as model's loss\n", |
| 402 | + " keras.metrics.MeanSquaredError(name='Brier score'),\n", |
377 | 403 | " keras.metrics.TruePositives(name='tp'),\n",
|
378 | 404 | " keras.metrics.FalsePositives(name='fp'),\n",
|
379 | 405 | " keras.metrics.TrueNegatives(name='tn'),\n",
|
|
406 | 432 | ]
|
407 | 433 | },
|
408 | 434 | {
|
| 435 | + "attachments": {}, |
409 | 436 | "cell_type": "markdown",
|
410 | 437 | "metadata": {
|
411 | 438 | "id": "SU0GX6E6mieP"
|
|
414 | 441 | "### Understanding useful metrics\n",
|
415 | 442 | "\n",
|
416 | 443 | "Notice that there are a few metrics defined above that can be computed by the model that will be helpful when evaluating the performance.\n",
|
| 444 | + "These can be divided into three groups.\n", |
| 445 | + "\n", |
| 446 | + "#### Metrics for probability predictions\n", |
| 447 | + "\n", |
| 448 | + "As we train our network with the cross entropy as a loss function, it is fully capable of predicting class probabilities, i.e. it is a probabilistic classifier.\n", |
| 449 | + "Good metrics to assess probabilistic predictions are, in fact, **proper scoring rules**. Their key property is that predicting the true probability is optimal. We give two well-known examples:\n", |
417 | 450 | "\n",
|
| 451 | + "* **cross entropy** also known as log loss\n", |
| 452 | + "* **Mean squared error** also known as the Brier score\n", |
418 | 453 | "\n",
|
| 454 | + "#### Metrics for deterministic 0/1 predictions\n", |
| 455 | + "\n", |
| 456 | + "In the end, one often wants to predict a class label, 0 or 1, *no fraud* or *fraud*.\n", |
| 457 | + "This is called a deterministic classifier.\n", |
| 458 | + "To get a label prediction from our probabilistic classifier, one needs to choose a probability threshold $t$.\n", |
| 459 | + "The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default. \n", |
419 | 460 | "\n",
|
420 | 461 | "* **False** negatives and **false** positives are samples that were **incorrectly** classified\n",
|
421 | 462 | "* **True** negatives and **true** positives are samples that were **correctly** classified\n",
|
|
425 | 466 | "> $\\frac{\\text{true positives}}{\\text{true positives + false positives}}$\n",
|
426 | 467 | "* **Recall** is the percentage of **actual** positives that were correctly classified\n",
|
427 | 468 | "> $\\frac{\\text{true positives}}{\\text{true positives + false negatives}}$\n",
|
| 469 | + "\n", |
| 470 | + "**Note:** Accuracy is not a helpful metric for this task. You can have 99.8%+ accuracy on this task by predicting False all the time. \n", |
| 471 | + "\n", |
| 472 | + "#### Other metrices\n", |
| 473 | + "\n", |
| 474 | + "The following metrics take into account all possible choices of thresholds $t$.\n", |
| 475 | + "\n", |
428 | 476 | "* **AUC** refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample.\n",
|
429 | 477 | "* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds. \n",
|
430 | 478 | "\n",
|
431 |
| - "Note: Accuracy is not a helpful metric for this task. You can have 99.8%+ accuracy on this task by predicting False all the time. \n", |
432 | 479 | "\n",
|
433 |
| - "Read more:\n", |
434 |
| - "* [True vs. False and Positive vs. Negative](https://developers.google.com/machine-learning/crash-course/classification/true-false-positive-negative)\n", |
435 |
| - "* [Accuracy](https://developers.google.com/machine-learning/crash-course/classification/accuracy)\n", |
| 480 | + "#### Read more:\n", |
| 481 | + "* [Strictly Proper Scoring Rules, Prediction, and Estimation](https://www.stat.washington.edu/people/raftery/Research/PDF/Gneiting2007jasa.pdf)\n", |
| 482 | + "* [True vs. False and Positive vs. Negative](https://developers.google.com/machine-learning/crash-course/classification/true-false-positive-negative)\n", |
| 483 | + "* [Accuracy](https://developers.google.com/machine-learning/crash-course/classification/accuracy)\n", |
436 | 484 | "* [Precision and Recall](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall)\n",
|
437 | 485 | "* [ROC-AUC](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc)\n",
|
438 | 486 | "* [Relationship between Precision-Recall and ROC Curves](https://www.biostat.wisc.edu/~page/rocpr.pdf)"
|
|
458 | 506 | "Now create and train your model using the function that was defined earlier. Notice that the model is fit using a larger than default batch size of 2048, this is important to ensure that each batch has a decent chance of containing a few positive samples. If the batch size was too small, they would likely have no fraudulent transactions to learn from.\n",
|
459 | 507 | "\n",
|
460 | 508 | "\n",
|
461 |
| - "Note: this model will not handle the class imbalance well. You will improve it later in this tutorial." |
| 509 | + "Note: Fitting this model will not handle the class imbalance efficiently. You will improve it later in this tutorial." |
462 | 510 | ]
|
463 | 511 | },
|
464 | 512 | {
|
|
527 | 575 | "id": "qk_3Ry6EoYDq"
|
528 | 576 | },
|
529 | 577 | "source": [
|
530 |
| - "These initial guesses are not great. You know the dataset is imbalanced. Set the output layer's bias to reflect that (See: [A Recipe for Training Neural Networks: \"init well\"](http://karpathy.github.io/2019/04/25/recipe/#2-set-up-the-end-to-end-trainingevaluation-skeleton--get-dumb-baselines)). This can help with initial convergence." |
| 578 | + "These initial guesses are not great. You know the dataset is imbalanced. Set the output layer's bias to reflect that, see [A Recipe for Training Neural Networks: \"init well\"](http://karpathy.github.io/2019/04/25/recipe/#2-set-up-the-end-to-end-trainingevaluation-skeleton--get-dumb-baselines). This can help with initial convergence." |
531 | 579 | ]
|
532 | 580 | },
|
533 | 581 | {
|
|
628 | 676 | "id": "FrDC8hvNr9yw"
|
629 | 677 | },
|
630 | 678 | "source": [
|
631 |
| - "This initial loss is about 50 times less than if would have been with naive initialization.\n", |
| 679 | + "This initial loss is about 50 times less than it would have been with naive initialization.\n", |
632 | 680 | "\n",
|
633 |
| - "This way the model doesn't need to spend the first few epochs just learning that positive examples are unlikely. This also makes it easier to read plots of the loss during training." |
| 681 | + "This way the model doesn't need to spend the first few epochs just learning that positive examples are unlikely. It also makes it easier to read plots of the loss during training." |
634 | 682 | ]
|
635 | 683 | },
|
636 | 684 | {
|
|
724 | 772 | " color=colors[n], label='Val ' + label,\n",
|
725 | 773 | " linestyle=\"--\")\n",
|
726 | 774 | " plt.xlabel('Epoch')\n",
|
727 |
| - " plt.ylabel('Loss')" |
| 775 | + " plt.ylabel('Loss')\n", |
| 776 | + " plt.legend()" |
728 | 777 | ]
|
729 | 778 | },
|
730 | 779 | {
|
|
868 | 917 | },
|
869 | 918 | "outputs": [],
|
870 | 919 | "source": [
|
871 |
| - "def plot_cm(labels, predictions, p=0.5):\n", |
872 |
| - " cm = confusion_matrix(labels, predictions > p)\n", |
| 920 | + "def plot_cm(labels, predictions, threshold=0.5):\n", |
| 921 | + " cm = confusion_matrix(labels, predictions > threshold)\n", |
873 | 922 | " plt.figure(figsize=(5,5))\n",
|
874 | 923 | " sns.heatmap(cm, annot=True, fmt=\"d\")\n",
|
875 |
| - " plt.title('Confusion matrix @{:.2f}'.format(p))\n", |
| 924 | + " plt.title('Confusion matrix @{:.2f}'.format(threshold))\n", |
876 | 925 | " plt.ylabel('Actual label')\n",
|
877 | 926 | " plt.xlabel('Predicted label')\n",
|
878 | 927 | "\n",
|
|
915 | 964 | "id": "PyZtSr1v6L4t"
|
916 | 965 | },
|
917 | 966 | "source": [
|
918 |
| - "If the model had predicted everything perfectly, this would be a [diagonal matrix](https://en.wikipedia.org/wiki/Diagonal_matrix) where values off the main diagonal, indicating incorrect predictions, would be zero. In this case the matrix shows that you have relatively few false positives, meaning that there were relatively few legitimate transactions that were incorrectly flagged. However, you would likely want to have even fewer false negatives despite the cost of increasing the number of false positives. This trade off may be preferable because false negatives would allow fraudulent transactions to go through, whereas false positives may cause an email to be sent to a customer to ask them to verify their card activity." |
| 967 | + "If the model had predicted everything perfectly (impossible with true randomness), this would be a [diagonal matrix](https://en.wikipedia.org/wiki/Diagonal_matrix) where values off the main diagonal, indicating incorrect predictions, would be zero. In this case, the matrix shows that you have relatively few false positives, meaning that there were relatively few legitimate transactions that were incorrectly flagged." |
| 968 | + ] |
| 969 | + }, |
| 970 | + { |
| 971 | + "cell_type": "markdown", |
| 972 | + "metadata": { |
| 973 | + "id": "P-QpQsip_F2Q" |
| 974 | + }, |
| 975 | + "source": [ |
| 976 | + "### Changing the threshold\n", |
| 977 | + "\n", |
| 978 | + "The default threshold of $t=50\\%$ corresponds to equal costs of false negatives and false positives.\n", |
| 979 | + "In the case of fraud detection, however, you would likely associate higher costs to false negatives than to false positives.\n", |
| 980 | + "This trade off may be preferable because false negatives would allow fraudulent transactions to go through, whereas false positives may cause an email to be sent to a customer to ask them to verify their card activity.\n", |
| 981 | + "\n", |
| 982 | + "By decreasing the threshold, we attribute higher cost to false negatives, thereby increasing missed transactions at the price of more false positives.\n", |
| 983 | + "We test thresholds at 10% and at 1%." |
| 984 | + ] |
| 985 | + }, |
| 986 | + { |
| 987 | + "cell_type": "code", |
| 988 | + "execution_count": null, |
| 989 | + "metadata": { |
| 990 | + "id": "52bd793e04bb" |
| 991 | + }, |
| 992 | + "outputs": [], |
| 993 | + "source": [ |
| 994 | + "plot_cm(test_labels, test_predictions_baseline, threshold=0.1)\n", |
| 995 | + "plot_cm(test_labels, test_predictions_baseline, threshold=0.01)" |
919 | 996 | ]
|
920 | 997 | },
|
921 | 998 | {
|
| 999 | + "attachments": {}, |
922 | 1000 | "cell_type": "markdown",
|
923 | 1001 | "metadata": {
|
924 | 1002 | "id": "P-QpQsip_F2Q"
|
925 | 1003 | },
|
926 | 1004 | "source": [
|
927 | 1005 | "### Plot the ROC\n",
|
928 | 1006 | "\n",
|
929 |
| - "Now plot the [ROC](https://developers.google.com/machine-learning/glossary#ROC). This plot is useful because it shows, at a glance, the range of performance the model can reach just by tuning the output threshold." |
| 1007 | + "Now plot the [ROC](https://developers.google.com/machine-learning/glossary#ROC). This plot is useful because it shows, at a glance, the range of performance the model can reach by tuning the output threshold over its full range (0 to 1). So each point corresponds to a single value of the threshold." |
930 | 1008 | ]
|
931 | 1009 | },
|
932 | 1010 | {
|
|
969 | 1047 | "id": "Y5twGRLfNwmO"
|
970 | 1048 | },
|
971 | 1049 | "source": [
|
972 |
| - "### Plot the AUPRC\r\n", |
| 1050 | + "### Plot the PRC\n", |
973 | 1051 | "\n",
|
974 |
| - "Now plot the [AUPRC](https://developers.google.com/machine-learning/glossary?hl=en#PR_AUC). Area under the interpolated precision-recall curve, obtained by plotting (recall, precision) points for different values of the classification threshold. Depending on how it's calculated, PR AUC may be equivalent to the average precision of the model.\r\n" |
| 1052 | + "Now plot the [AUPRC](https://developers.google.com/machine-learning/glossary?hl=en#PR_AUC). Area under the interpolated precision-recall curve, obtained by plotting (recall, precision) points for different values of the classification threshold. Depending on how it's calculated, PR AUC may be equivalent to the average precision of the model.\n" |
975 | 1053 | ]
|
976 | 1054 | },
|
977 | 1055 | {
|
|
982 | 1060 | },
|
983 | 1061 | "outputs": [],
|
984 | 1062 | "source": [
|
985 |
| - "def plot_prc(name, labels, predictions, **kwargs):\r\n", |
986 |
| - " precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)\r\n", |
987 |
| - "\r\n", |
988 |
| - " plt.plot(precision, recall, label=name, linewidth=2, **kwargs)\r\n", |
989 |
| - " plt.xlabel('Precision')\r\n", |
990 |
| - " plt.ylabel('Recall')\r\n", |
991 |
| - " plt.grid(True)\r\n", |
992 |
| - " ax = plt.gca()\r\n", |
| 1063 | + "def plot_prc(name, labels, predictions, **kwargs):\n", |
| 1064 | + " precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)\n", |
| 1065 | + "\n", |
| 1066 | + " plt.plot(precision, recall, label=name, linewidth=2, **kwargs)\n", |
| 1067 | + " plt.xlabel('Precision')\n", |
| 1068 | + " plt.ylabel('Recall')\n", |
| 1069 | + " plt.grid(True)\n", |
| 1070 | + " ax = plt.gca()\n", |
993 | 1071 | " ax.set_aspect('equal')"
|
994 | 1072 | ]
|
995 | 1073 | },
|
|
1001 | 1079 | },
|
1002 | 1080 | "outputs": [],
|
1003 | 1081 | "source": [
|
1004 |
| - "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\r\n", |
1005 |
| - "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\r\n", |
| 1082 | + "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n", |
| 1083 | + "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n", |
1006 | 1084 | "plt.legend(loc='lower right');"
|
1007 | 1085 | ]
|
1008 | 1086 | },
|
|
1032 | 1110 | "source": [
|
1033 | 1111 | "### Calculate class weights\n",
|
1034 | 1112 | "\n",
|
1035 |
| - "The goal is to identify fraudulent transactions, but you don't have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to \"pay more attention\" to examples from an under-represented class." |
| 1113 | + "The goal is to identify fraudulent transactions, but you don't have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to \"pay more attention\" to examples from an under-represented class. Note, however, that this does not increase in any way the amount of information of your dataset. In the end, using class weights is more or less equivalent to changing the output bias or to changing the threshold. Let's see how it works out." |
1036 | 1114 | ]
|
1037 | 1115 | },
|
1038 | 1116 | {
|
|
1153 | 1231 | "id": "PTh1rtDn8r4-"
|
1154 | 1232 | },
|
1155 | 1233 | "source": [
|
1156 |
| - "Here you can see that with class weights the accuracy and precision are lower because there are more false positives, but conversely the recall and AUC are higher because the model also found more true positives. Despite having lower accuracy, this model has higher recall (and identifies more fraudulent transactions). Of course, there is a cost to both types of error (you wouldn't want to bug users by flagging too many legitimate transactions as fraudulent, either). Carefully consider the trade-offs between these different types of errors for your application." |
| 1234 | + "Here you can see that with class weights the accuracy and precision are lower because there are more false positives, but conversely the recall and AUC are higher because the model also found more true positives. Despite having lower accuracy, this model has higher recall (and identifies more fraudulent transactions than the baseline model at threshold 50%). Of course, there is a cost to both types of error (you wouldn't want to bug users by flagging too many legitimate transactions as fraudulent, either). Carefully consider the trade-offs between these different types of errors for your application.\n", |
| 1235 | + "\n", |
| 1236 | + "Compared to the baseline model with changed threshold, the class weighted model is clearly inferior. The superiority of the baseline model is further confirmed by the lower test loss value (cross entropy and mean squared error) and additionally can be seen by plotting the ROC curves of both models together." |
1157 | 1237 | ]
|
1158 | 1238 | },
|
1159 | 1239 | {
|
|
1189 | 1269 | "id": "_0krS8g1OTbD"
|
1190 | 1270 | },
|
1191 | 1271 | "source": [
|
1192 |
| - "### Plot the AUPRC" |
| 1272 | + "### Plot the PRC" |
1193 | 1273 | ]
|
1194 | 1274 | },
|
1195 | 1275 | {
|
|
1200 | 1280 | },
|
1201 | 1281 | "outputs": [],
|
1202 | 1282 | "source": [
|
1203 |
| - "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\r\n", |
1204 |
| - "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\r\n", |
1205 |
| - "\r\n", |
1206 |
| - "plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\r\n", |
1207 |
| - "plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\r\n", |
1208 |
| - "\r\n", |
1209 |
| - "\r\n", |
| 1283 | + "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n", |
| 1284 | + "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n", |
| 1285 | + "\n", |
| 1286 | + "plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\n", |
| 1287 | + "plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\n", |
| 1288 | + "\n", |
| 1289 | + "\n", |
1210 | 1290 | "plt.legend(loc='lower right');"
|
1211 | 1291 | ]
|
1212 | 1292 | },
|
|
1581 | 1661 | "for name, value in zip(resampled_model.metrics_names, resampled_results):\n",
|
1582 | 1662 | " print(name, ': ', value)\n",
|
1583 | 1663 | "print()\n",
|
1584 |
| - "\n", |
1585 | 1664 | "plot_cm(test_labels, test_predictions_resampled)"
|
1586 | 1665 | ]
|
1587 | 1666 | },
|
|
1604 | 1683 | "source": [
|
1605 | 1684 | "plot_roc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n",
|
1606 | 1685 | "plot_roc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n",
|
1607 |
| - "\n", |
1608 | 1686 | "plot_roc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\n",
|
1609 | 1687 | "plot_roc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\n",
|
1610 |
| - "\n", |
1611 | 1688 | "plot_roc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\n",
|
1612 | 1689 | "plot_roc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\n",
|
1613 | 1690 | "plt.legend(loc='lower right');"
|
|
0 commit comments