Skip to content

Commit 4a3e670

Browse files
Merge pull request #2305 from priankakariatyml:imbalanced_classification_tf_2.16_fixes
PiperOrigin-RevId: 664610546
2 parents 7d4187e + 4b60ead commit 4a3e670

File tree

1 file changed

+45
-48
lines changed

1 file changed

+45
-48
lines changed

site/en/tutorials/structured_data/imbalanced_data.ipynb

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,10 @@
258258
"train_df, val_df = train_test_split(train_df, test_size=0.2)\n",
259259
"\n",
260260
"# Form np arrays of labels and features.\n",
261-
"train_labels = np.array(train_df.pop('Class'))\n",
262-
"bool_train_labels = train_labels != 0\n",
263-
"val_labels = np.array(val_df.pop('Class'))\n",
264-
"test_labels = np.array(test_df.pop('Class'))\n",
261+
"train_labels = np.array(train_df.pop('Class')).reshape(-1, 1)\n",
262+
"bool_train_labels = train_labels[:, 0] != 0\n",
263+
"val_labels = np.array(val_df.pop('Class')).reshape(-1, 1)\n",
264+
"test_labels = np.array(test_df.pop('Class')).reshape(-1, 1)\n",
265265
"\n",
266266
"train_features = np.array(train_df)\n",
267267
"val_features = np.array(val_df)\n",
@@ -291,18 +291,17 @@
291291
]
292292
},
293293
{
294-
"attachments": {},
295294
"cell_type": "markdown",
296295
"metadata": {
297-
"id": "8a_Z_kBmr7Oh"
296+
"id": "ueKV4cmcoRnf"
298297
},
299298
"source": [
300299
"Given the small number of positive labels, this seems about right.\n",
301300
"\n",
302301
"Normalize the input features using the sklearn StandardScaler.\n",
303302
"This will set the mean to 0 and standard deviation to 1.\n",
304303
"\n",
305-
"Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets. "
304+
"Note: The `StandardScaler` is only fit using the `train_features` to be sure the model is not peeking at the validation or test sets."
306305
]
307306
},
308307
{
@@ -352,7 +351,7 @@
352351
"\n",
353352
"Next compare the distributions of the positive and negative examples over a few features. Good questions to ask yourself at this point are:\n",
354353
"\n",
355-
"* Do these distributions make sense? \n",
354+
"* Do these distributions make sense?\n",
356355
" * Yes. You've normalized the input and these are mostly concentrated in the `+/- 2` range.\n",
357356
"* Can you see the difference between the distributions?\n",
358357
" * Yes the positive examples contain a much higher rate of extreme values."
@@ -386,7 +385,7 @@
386385
"source": [
387386
"## Define the model and metrics\n",
388387
"\n",
389-
"Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent: "
388+
"Define a function that creates a simple neural network with a densly connected hidden layer, a [dropout](https://developers.google.com/machine-learning/glossary/#dropout_regularization) layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent:"
390389
]
391390
},
392391
{
@@ -403,7 +402,7 @@
403402
" keras.metrics.TruePositives(name='tp'),\n",
404403
" keras.metrics.FalsePositives(name='fp'),\n",
405404
" keras.metrics.TrueNegatives(name='tn'),\n",
406-
" keras.metrics.FalseNegatives(name='fn'), \n",
405+
" keras.metrics.FalseNegatives(name='fn'),\n",
407406
" keras.metrics.BinaryAccuracy(name='accuracy'),\n",
408407
" keras.metrics.Precision(name='precision'),\n",
409408
" keras.metrics.Recall(name='recall'),\n",
@@ -432,7 +431,6 @@
432431
]
433432
},
434433
{
435-
"attachments": {},
436434
"cell_type": "markdown",
437435
"metadata": {
438436
"id": "SU0GX6E6mieP"
@@ -456,7 +454,7 @@
456454
"In the end, one often wants to predict a class label, 0 or 1, *no fraud* or *fraud*.\n",
457455
"This is called a deterministic classifier.\n",
458456
"To get a label prediction from our probabilistic classifier, one needs to choose a probability threshold $t$.\n",
459-
"The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default. \n",
457+
"The default is to predict label 1 (fraud) if the predicted probability is larger than $t=50\\%$ and all the following metrics implicitly use this default.\n",
460458
"\n",
461459
"* **False** negatives and **false** positives are samples that were **incorrectly** classified\n",
462460
"* **True** negatives and **true** positives are samples that were **correctly** classified\n",
@@ -474,7 +472,7 @@
474472
"The following metrics take into account all possible choices of thresholds $t$.\n",
475473
"\n",
476474
"* **AUC** refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample.\n",
477-
"* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds. \n",
475+
"* **AUPRC** refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds.\n",
478476
"\n",
479477
"\n",
480478
"#### Read more:\n",
@@ -520,8 +518,9 @@
520518
"EPOCHS = 100\n",
521519
"BATCH_SIZE = 2048\n",
522520
"\n",
523-
"early_stopping = tf.keras.callbacks.EarlyStopping(\n",
524-
" monitor='val_prc', \n",
521+
"def early_stopping():\n",
522+
" return tf.keras.callbacks.EarlyStopping(\n",
523+
" monitor='val_prc',\n",
525524
" verbose=1,\n",
526525
" patience=10,\n",
527526
" mode='max',\n",
@@ -584,7 +583,7 @@
584583
"id": "PdbfWDuVpo6k"
585584
},
586585
"source": [
587-
"With the default bias initialization the loss should be about `math.log(2) = 0.69314` "
586+
"With the default bias initialization the loss should be about `math.log(2) = 0.69314`"
588587
]
589588
},
590589
{
@@ -630,7 +629,7 @@
630629
"id": "d1juXI9yY1KD"
631630
},
632631
"source": [
633-
"Set that as the initial bias, and the model will give much more reasonable initial guesses. \n",
632+
"Set that as the initial bias, and the model will give much more reasonable initial guesses.\n",
634633
"\n",
635634
"It should be near: `pos/total = 0.0018`"
636635
]
@@ -700,7 +699,7 @@
700699
},
701700
"outputs": [],
702701
"source": [
703-
"initial_weights = os.path.join(tempfile.mkdtemp(), 'initial_weights')\n",
702+
"initial_weights = os.path.join(tempfile.mkdtemp(), 'initial.weights.h5')\n",
704703
"model.save_weights(initial_weights)"
705704
]
706705
},
@@ -714,7 +713,7 @@
714713
"\n",
715714
"Before moving on, confirm quick that the careful bias initialization actually helped.\n",
716715
"\n",
717-
"Train the model for 20 epochs, with and without this careful initialization, and compare the losses: "
716+
"Train the model for 20 epochs, with and without this careful initialization, and compare the losses:"
718717
]
719718
},
720719
{
@@ -733,7 +732,7 @@
733732
" train_labels,\n",
734733
" batch_size=BATCH_SIZE,\n",
735734
" epochs=20,\n",
736-
" validation_data=(val_features, val_labels), \n",
735+
" validation_data=(val_features, val_labels),\n",
737736
" verbose=0)"
738737
]
739738
},
@@ -752,7 +751,7 @@
752751
" train_labels,\n",
753752
" batch_size=BATCH_SIZE,\n",
754753
" epochs=20,\n",
755-
" validation_data=(val_features, val_labels), \n",
754+
" validation_data=(val_features, val_labels),\n",
756755
" verbose=0)"
757756
]
758757
},
@@ -794,7 +793,7 @@
794793
"id": "fKMioV0ddG3R"
795794
},
796795
"source": [
797-
"The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage. "
796+
"The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage."
798797
]
799798
},
800799
{
@@ -821,7 +820,7 @@
821820
" train_labels,\n",
822821
" batch_size=BATCH_SIZE,\n",
823822
" epochs=EPOCHS,\n",
824-
" callbacks=[early_stopping],\n",
823+
" callbacks=[early_stopping()],\n",
825824
" validation_data=(val_features, val_labels))"
826825
]
827826
},
@@ -996,10 +995,9 @@
996995
]
997996
},
998997
{
999-
"attachments": {},
1000998
"cell_type": "markdown",
1001999
"metadata": {
1002-
"id": "P-QpQsip_F2Q"
1000+
"id": "kF8k-g9goRni"
10031001
},
10041002
"source": [
10051003
"### Plot the ROC\n",
@@ -1161,10 +1159,10 @@
11611159
" train_labels,\n",
11621160
" batch_size=BATCH_SIZE,\n",
11631161
" epochs=EPOCHS,\n",
1164-
" callbacks=[early_stopping],\n",
1162+
" callbacks=[early_stopping()],\n",
11651163
" validation_data=(val_features, val_labels),\n",
11661164
" # The class weights go here\n",
1167-
" class_weight=class_weight) "
1165+
" class_weight=class_weight)"
11681166
]
11691167
},
11701168
{
@@ -1333,7 +1331,7 @@
13331331
"source": [
13341332
"#### Using NumPy\n",
13351333
"\n",
1336-
"You can balance the dataset manually by choosing the right number of random \n",
1334+
"You can balance the dataset manually by choosing the right number of random\n",
13371335
"indices from the positive examples:"
13381336
]
13391337
},
@@ -1485,7 +1483,7 @@
14851483
},
14861484
"outputs": [],
14871485
"source": [
1488-
"resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)\n",
1486+
"resampled_steps_per_epoch = int(np.ceil(2.0*neg/BATCH_SIZE))\n",
14891487
"resampled_steps_per_epoch"
14901488
]
14911489
},
@@ -1499,7 +1497,7 @@
14991497
"\n",
15001498
"Now try training the model with the resampled data set instead of using class weights to see how these methods compare.\n",
15011499
"\n",
1502-
"Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps. "
1500+
"Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps."
15031501
]
15041502
},
15051503
{
@@ -1514,17 +1512,17 @@
15141512
"resampled_model.load_weights(initial_weights)\n",
15151513
"\n",
15161514
"# Reset the bias to zero, since this dataset is balanced.\n",
1517-
"output_layer = resampled_model.layers[-1] \n",
1515+
"output_layer = resampled_model.layers[-1]\n",
15181516
"output_layer.bias.assign([0])\n",
15191517
"\n",
15201518
"val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()\n",
1521-
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) \n",
1519+
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(2)\n",
15221520
"\n",
15231521
"resampled_history = resampled_model.fit(\n",
15241522
" resampled_ds,\n",
15251523
" epochs=EPOCHS,\n",
15261524
" steps_per_epoch=resampled_steps_per_epoch,\n",
1527-
" callbacks=[early_stopping],\n",
1525+
" callbacks=[early_stopping()],\n",
15281526
" validation_data=val_ds)"
15291527
]
15301528
},
@@ -1536,7 +1534,7 @@
15361534
"source": [
15371535
"If the training process were considering the whole dataset on each gradient update, this oversampling would be basically identical to the class weighting.\n",
15381536
"\n",
1539-
"But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight. \n",
1537+
"But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight.\n",
15401538
"\n",
15411539
"This smoother gradient signal makes it easier to train the model."
15421540
]
@@ -1549,7 +1547,7 @@
15491547
"source": [
15501548
"### Check training history\n",
15511549
"\n",
1552-
"Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data. "
1550+
"Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data."
15531551
]
15541552
},
15551553
{
@@ -1578,7 +1576,7 @@
15781576
"id": "KFLxRL8eoDE5"
15791577
},
15801578
"source": [
1581-
"Because training is easier on the balanced data, the above training procedure may overfit quickly. \n",
1579+
"Because training is easier on the balanced data, the above training procedure may overfit quickly.\n",
15821580
"\n",
15831581
"So break up the epochs to give the `tf.keras.callbacks.EarlyStopping` finer control over when to stop training."
15841582
]
@@ -1595,15 +1593,15 @@
15951593
"resampled_model.load_weights(initial_weights)\n",
15961594
"\n",
15971595
"# Reset the bias to zero, since this dataset is balanced.\n",
1598-
"output_layer = resampled_model.layers[-1] \n",
1596+
"output_layer = resampled_model.layers[-1]\n",
15991597
"output_layer.bias.assign([0])\n",
16001598
"\n",
16011599
"resampled_history = resampled_model.fit(\n",
16021600
" resampled_ds,\n",
16031601
" # These are not real epochs\n",
16041602
" steps_per_epoch=20,\n",
16051603
" epochs=10*EPOCHS,\n",
1606-
" callbacks=[early_stopping],\n",
1604+
" callbacks=[early_stopping()],\n",
16071605
" validation_data=(val_ds))"
16081606
]
16091607
},
@@ -1696,7 +1694,7 @@
16961694
"id": "vayGnv0VOe_v"
16971695
},
16981696
"source": [
1699-
"### Plot the AUPRC\r\n"
1697+
"### Plot the AUPRC\n"
17001698
]
17011699
},
17021700
{
@@ -1707,14 +1705,14 @@
17071705
},
17081706
"outputs": [],
17091707
"source": [
1710-
"plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\r\n",
1711-
"plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\r\n",
1712-
"\r\n",
1713-
"plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\r\n",
1714-
"plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\r\n",
1715-
"\r\n",
1716-
"plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\r\n",
1717-
"plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\r\n",
1708+
"plot_prc(\"Train Baseline\", train_labels, train_predictions_baseline, color=colors[0])\n",
1709+
"plot_prc(\"Test Baseline\", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')\n",
1710+
"\n",
1711+
"plot_prc(\"Train Weighted\", train_labels, train_predictions_weighted, color=colors[1])\n",
1712+
"plot_prc(\"Test Weighted\", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')\n",
1713+
"\n",
1714+
"plot_prc(\"Train Resampled\", train_labels, train_predictions_resampled, color=colors[2])\n",
1715+
"plot_prc(\"Test Resampled\", test_labels, test_predictions_resampled, color=colors[2], linestyle='--')\n",
17181716
"plt.legend(loc='lower right');"
17191717
]
17201718
},
@@ -1732,7 +1730,6 @@
17321730
],
17331731
"metadata": {
17341732
"colab": {
1735-
"collapsed_sections": [],
17361733
"name": "imbalanced_data.ipynb",
17371734
"toc_visible": true
17381735
},

0 commit comments

Comments
 (0)