|
258 | 258 | "train_df, val_df = train_test_split(train_df, test_size=0.2)\n",
|
259 | 259 | "\n",
|
260 | 260 | "# Form np arrays of labels and features.\n",
|
261 |
| - "train_labels = np.array(train_df.pop('Class'))\n", |
262 |
| - "bool_train_labels = train_labels != 0\n", |
263 |
| - "val_labels = np.array(val_df.pop('Class'))\n", |
264 |
| - "test_labels = np.array(test_df.pop('Class'))\n", |
| 261 | + "train_labels = np.array(train_df.pop('Class')).reshape(-1, 1)\n", |
| 262 | + "bool_train_labels = train_labels[:, 0] != 0\n", |
| 263 | + "val_labels = np.array(val_df.pop('Class')).reshape(-1, 1)\n", |
| 264 | + "test_labels = np.array(test_df.pop('Class')).reshape(-1, 1)\n", |
265 | 265 | "\n",
|
266 | 266 | "train_features = np.array(train_df)\n",
|
267 | 267 | "val_features = np.array(val_df)\n",
|
|
291 | 291 | ]
|
292 | 292 | },
|
293 | 293 | {
|
294 |
| - "attachments": {}, |
295 | 294 | "cell_type": "markdown",
|
296 | 295 | "metadata": {
|
297 |
| - "id": "8a_Z_kBmr7Oh" |
| 296 | + "id": "ueKV4cmcoRnf" |
298 | 297 | },
|
299 | 298 | "source": [
|
300 | 299 | "Given the small number of positive labels, this seems about right.\n",
|
301 | 300 | "\n",
|
302 | 301 | "Normalize the input features using the sklearn StandardScaler.\n",
|
303 | 302 | "This will set the mean to 0 and standard deviation to 1.\n",
|
304 | 303 | "\n",
|
305 |
| - "Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets. " |
| 304 | + "Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets." |
306 | 305 | ]
|
307 | 306 | },
|
308 | 307 | {
|
|
352 | 351 | "\n",
|
353 | 352 | "Next compare the distributions of the positive and negative examples over a few features. Good questions to ask yourself at this point are:\n",
|
354 | 353 | "\n",
|
355 |
| - "* Do these distributions make sense? \n", |
| 354 | + "* Do these distributions make sense?\n", |
356 | 355 | " * Yes. You've normalized the input and these are mostly concentrated in the `+/- 2` range.\n",
|
357 | 356 | "* Can you see the difference between the distributions?\n",
|
358 | 357 | " * Yes the positive examples contain a much higher rate of extreme values."
|
|
386 | 385 | "source": [
|
387 | 386 | "## Define the model and metrics\n",
|
388 | 387 | "\n",
|
389 |
| - "Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent: " |
| 388 | + "Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent:" |
390 | 389 | ]
|
391 | 390 | },
|
392 | 391 | {
|
|
403 | 402 | " keras.metrics.TruePositives(name='tp'),\n",
|
404 | 403 | " keras.metrics.FalsePositives(name='fp'),\n",
|
405 | 404 | " keras.metrics.TrueNegatives(name='tn'),\n",
|
406 |
| - " keras.metrics.FalseNegatives(name='fn'), \n", |
| 405 | + " keras.metrics.FalseNegatives(name='fn'),\n", |
407 | 406 | " keras.metrics.BinaryAccuracy(name='accuracy'),\n",
|
408 | 407 | " keras.metrics.Precision(name='precision'),\n",
|
409 | 408 | " keras.metrics.Recall(name='recall'),\n",
|
|
432 | 431 | ]
|
433 | 432 | },
|
434 | 433 | {
|
435 |
| - "attachments": {}, |
436 | 434 | "cell_type": "markdown",
|
437 | 435 | "metadata": {
|
438 | 436 | "id": "SU0GX6E6mieP"
|
|
456 | 454 | "In the end, one often wants to predict a class label, 0 or 1, *no fraud* or *fraud*.\n",
|
457 | 455 | "This is called a deterministic classifier.\n",
|
458 | 456 | "To get a label prediction from our probabilistic classifier, one needs to choose a probability threshold $t$.\n",
|
459 |
| - "The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default. \n", |
| 457 | + "The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default.\n", |
460 | 458 | "\n",
|
461 | 459 | "* **False** negatives and **false** positives are samples that were **incorrectly** classified\n",
|
462 | 460 | "* **True** negatives and **true** positives are samples that were **correctly** classified\n",
|
|
474 | 472 | "The following metrics take into account all possible choices of thresholds $t$.\n",
|
475 | 473 | "\n",
|
476 | 474 | "* **AUC** refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample.\n",
|
477 |
| - "* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds. \n", |
| 475 | + "* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds.\n", |
478 | 476 | "\n",
|
479 | 477 | "\n",
|
480 | 478 | "#### Read more:\n",
|
|
520 | 518 | "EPOCHS = 100\n",
|
521 | 519 | "BATCH_SIZE = 2048\n",
|
522 | 520 | "\n",
|
523 |
| - "early_stopping = tf.keras.callbacks.EarlyStopping(\n", |
524 |
| - " monitor='val_prc', \n", |
| 521 | + "def early_stopping():\n", |
| 522 | + " return tf.keras.callbacks.EarlyStopping(\n", |
| 523 | + " monitor='val_prc',\n", |
525 | 524 | " verbose=1,\n",
|
526 | 525 | " patience=10,\n",
|
527 | 526 | " mode='max',\n",
|
|
584 | 583 | "id": "PdbfWDuVpo6k"
|
585 | 584 | },
|
586 | 585 | "source": [
|
587 |
| - "With the default bias initialization the loss should be about `math.log(2) = 0.69314` " |
| 586 | + "With the default bias initialization the loss should be about `math.log(2) = 0.69314`" |
588 | 587 | ]
|
589 | 588 | },
|
590 | 589 | {
|
|
630 | 629 | "id": "d1juXI9yY1KD"
|
631 | 630 | },
|
632 | 631 | "source": [
|
633 |
| - "Set that as the initial bias, and the model will give much more reasonable initial guesses. \n", |
| 632 | + "Set that as the initial bias, and the model will give much more reasonable initial guesses.\n", |
634 | 633 | "\n",
|
635 | 634 | "It should be near: `pos/total = 0.0018`"
|
636 | 635 | ]
|
|
700 | 699 | },
|
701 | 700 | "outputs": [],
|
702 | 701 | "source": [
|
703 |
| - "initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')\n", |
| 702 | + "initial_weights = os.path.join(tempfile.mkdtemp(), 'initial.weights.h5')\n", |
704 | 703 | "model.save_weights(initial_weights)"
|
705 | 704 | ]
|
706 | 705 | },
|
|
714 | 713 | "\n",
|
715 | 714 | "Before moving on, confirm quick that the careful bias initialization actually helped.\n",
|
716 | 715 | "\n",
|
717 |
| - "Train the model for 20 epochs, with and without this careful initialization, and compare the losses: " |
| 716 | + "Train the model for 20 epochs, with and without this careful initialization, and compare the losses:" |
718 | 717 | ]
|
719 | 718 | },
|
720 | 719 | {
|
|
733 | 732 | " train_labels,\n",
|
734 | 733 | " batch_size=BATCH_SIZE,\n",
|
735 | 734 | " epochs=20,\n",
|
736 |
| - " validation_data=(val_features, val_labels), \n", |
| 735 | + " validation_data=(val_features, val_labels),\n", |
737 | 736 | " verbose=0)"
|
738 | 737 | ]
|
739 | 738 | },
|
|
752 | 751 | " train_labels,\n",
|
753 | 752 | " batch_size=BATCH_SIZE,\n",
|
754 | 753 | " epochs=20,\n",
|
755 |
| - " validation_data=(val_features, val_labels), \n", |
| 754 | + " validation_data=(val_features, val_labels),\n", |
756 | 755 | " verbose=0)"
|
757 | 756 | ]
|
758 | 757 | },
|
|
794 | 793 | "id": "fKMioV0ddG3R"
|
795 | 794 | },
|
796 | 795 | "source": [
|
797 |
| - "The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage. " |
| 796 | + "The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage." |
798 | 797 | ]
|
799 | 798 | },
|
800 | 799 | {
|
|
821 | 820 | " train_labels,\n",
|
822 | 821 | " batch_size=BATCH_SIZE,\n",
|
823 | 822 | " epochs=EPOCHS,\n",
|
824 |
| - " callbacks=[early_stopping],\n", |
| 823 | + " callbacks=[early_stopping()],\n", |
825 | 824 | " validation_data=(val_features, val_labels))"
|
826 | 825 | ]
|
827 | 826 | },
|
|
996 | 995 | ]
|
997 | 996 | },
|
998 | 997 | {
|
999 |
| - "attachments": {}, |
1000 | 998 | "cell_type": "markdown",
|
1001 | 999 | "metadata": {
|
1002 |
| - "id": "P-QpQsip_F2Q" |
| 1000 | + "id": "kF8k-g9goRni" |
1003 | 1001 | },
|
1004 | 1002 | "source": [
|
1005 | 1003 | "### Plot the ROC\n",
|
|
1161 | 1159 | " train_labels,\n",
|
1162 | 1160 | " batch_size=BATCH_SIZE,\n",
|
1163 | 1161 | " epochs=EPOCHS,\n",
|
1164 |
| - " callbacks=[early_stopping],\n", |
| 1162 | + " callbacks=[early_stopping()],\n", |
1165 | 1163 | " validation_data=(val_features, val_labels),\n",
|
1166 | 1164 | " # The class weights go here\n",
|
1167 |
| - " class_weight=class_weight) " |
| 1165 | + " class_weight=class_weight)" |
1168 | 1166 | ]
|
1169 | 1167 | },
|
1170 | 1168 | {
|
|
1333 | 1331 | "source": [
|
1334 | 1332 | "#### Using NumPy\n",
|
1335 | 1333 | "\n",
|
1336 |
| - "You can balance the dataset manually by choosing the right number of random \n", |
| 1334 | + "You can balance the dataset manually by choosing the right number of random\n", |
1337 | 1335 | "indices from the positive examples:"
|
1338 | 1336 | ]
|
1339 | 1337 | },
|
|
1485 | 1483 | },
|
1486 | 1484 | "outputs": [],
|
1487 | 1485 | "source": [
|
1488 |
| - "resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)\n", |
| 1486 | + "resampled_steps_per_epoch = int(np.ceil(2.0*neg/BATCH_SIZE))\n", |
1489 | 1487 | "resampled_steps_per_epoch"
|
1490 | 1488 | ]
|
1491 | 1489 | },
|
|
1499 | 1497 | "\n",
|
1500 | 1498 | "Now try training the model with the resampled data set instead of using class weights to see how these methods compare.\n",
|
1501 | 1499 | "\n",
|
1502 |
| - "Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps. " |
| 1500 | + "Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps." |
1503 | 1501 | ]
|
1504 | 1502 | },
|
1505 | 1503 | {
|
|
1514 | 1512 | "resampled_model.load_weights(initial_weights)\n",
|
1515 | 1513 | "\n",
|
1516 | 1514 | "# Reset the bias to zero, since this dataset is balanced.\n",
|
1517 |
| - "output_layer = resampled_model.layers[-1] \n", |
| 1515 | + "output_layer = resampled_model.layers[-1]\n", |
1518 | 1516 | "output_layer.bias.assign([0])\n",
|
1519 | 1517 | "\n",
|
1520 | 1518 | "val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()\n",
|
1521 |
| - "val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) \n", |
| 1519 | + "val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)\n", |
1522 | 1520 | "\n",
|
1523 | 1521 | "resampled_history = resampled_model.fit(\n",
|
1524 | 1522 | " resampled_ds,\n",
|
1525 | 1523 | " epochs=EPOCHS,\n",
|
1526 | 1524 | " steps_per_epoch=resampled_steps_per_epoch,\n",
|
1527 |
| - " callbacks=[early_stopping],\n", |
| 1525 | + " callbacks=[early_stopping()],\n", |
1528 | 1526 | " validation_data=val_ds)"
|
1529 | 1527 | ]
|
1530 | 1528 | },
|
|
1536 | 1534 | "source": [
|
1537 | 1535 | "If the training process were considering the whole dataset on each gradient update, this oversampling would be basically identical to the class weighting.\n",
|
1538 | 1536 | "\n",
|
1539 |
| - "But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight. \n", |
| 1537 | + "But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight.\n", |
1540 | 1538 | "\n",
|
1541 | 1539 | "This smoother gradient signal makes it easier to train the model."
|
1542 | 1540 | ]
|
|
1549 | 1547 | "source": [
|
1550 | 1548 | "### Check training history\n",
|
1551 | 1549 | "\n",
|
1552 |
| - "Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data. " |
| 1550 | + "Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data." |
1553 | 1551 | ]
|
1554 | 1552 | },
|
1555 | 1553 | {
|
|
1578 | 1576 | "id": "KFLxRL8eoDE5"
|
1579 | 1577 | },
|
1580 | 1578 | "source": [
|
1581 |
| - "Because training is easier on the balanced data, the above training procedure may overfit quickly. \n", |
| 1579 | + "Because training is easier on the balanced data, the above training procedure may overfit quickly.\n", |
1582 | 1580 | "\n",
|
1583 | 1581 | "So break up the epochs to give the `tf.keras.callbacks.EarlyStopping` finer control over when to stop training."
|
1584 | 1582 | ]
|
|
1595 | 1593 | "resampled_model.load_weights(initial_weights)\n",
|
1596 | 1594 | "\n",
|
1597 | 1595 | "# Reset the bias to zero, since this dataset is balanced.\n",
|
1598 |
| - "output_layer = resampled_model.layers[-1] \n", |
| 1596 | + "output_layer = resampled_model.layers[-1]\n", |
1599 | 1597 | "output_layer.bias.assign([0])\n",
|
1600 | 1598 | "\n",
|
1601 | 1599 | "resampled_history = resampled_model.fit(\n",
|
1602 | 1600 | " resampled_ds,\n",
|
1603 | 1601 | " # These are not real epochs\n",
|
1604 | 1602 | " steps_per_epoch=20,\n",
|
1605 | 1603 | " epochs=10*EPOCHS,\n",
|
1606 |
| - " callbacks=[early_stopping],\n", |
| 1604 | + " callbacks=[early_stopping()],\n", |
1607 | 1605 | " validation_data=(val_ds))"
|
1608 | 1606 | ]
|
1609 | 1607 | },
|
|
1696 | 1694 | "id": "vayGnv0VOe_v"
|
1697 | 1695 | },
|
1698 | 1696 | "source": [
|
1699 |
| - "### Plot the AUPRC\r\n" |
| 1697 | + "### Plot the AUPRC\n" |
1700 | 1698 | ]
|
1701 | 1699 | },
|
1702 | 1700 | {
|
|
1707 | 1705 | },
|
1708 | 1706 | "outputs": [],
|
1709 | 1707 | "source": [
|
1710 |
| - "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\r\n", |
1711 |
| - "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\r\n", |
1712 |
| - "\r\n", |
1713 |
| - "plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\r\n", |
1714 |
| - "plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\r\n", |
1715 |
| - "\r\n", |
1716 |
| - "plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\r\n", |
1717 |
| - "plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\r\n", |
| 1708 | + "plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n", |
| 1709 | + "plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n", |
| 1710 | + "\n", |
| 1711 | + "plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\n", |
| 1712 | + "plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\n", |
| 1713 | + "\n", |
| 1714 | + "plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\n", |
| 1715 | + "plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\n", |
1718 | 1716 | "plt.legend(loc='lower right');"
|
1719 | 1717 | ]
|
1720 | 1718 | },
|
|
1732 | 1730 | ],
|
1733 | 1731 | "metadata": {
|
1734 | 1732 | "colab": {
|
1735 |
| - "collapsed_sections": [], |
1736 | 1733 | "name": "imbalanced_data.ipynb",
|
1737 | 1734 | "toc_visible": true
|
1738 | 1735 | },
|
|
0 commit comments