Skip to content

Commit 51a06aa

Browse files
MarkDaoustcopybara-github
authored andcommitted
Fix accuracy metric for logit output.
Fixes: tensorflow/tensorflow#41413 PiperOrigin-RevId: 555773892
1 parent 1637c45 commit 51a06aa

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

site/en/tutorials/images/transfer_learning.ipynb

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,9 @@
516516
"source": [
517517
"### Important note about BatchNormalization layers\n",
518518
"\n",
519-
"Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial. \n",
519+
"Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial.\n",
520520
"\n",
521-
"When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics. \n",
521+
"When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics.\n",
522522
"\n",
523523
"When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training = False` when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.\n",
524524
"\n",
@@ -617,60 +617,71 @@
617617
"model = tf.keras.Model(inputs, outputs)"
618618
]
619619
},
620+
{
621+
"cell_type": "code",
622+
"execution_count": null,
623+
"metadata": {
624+
"id": "I8ARiyMFsgbH"
625+
},
626+
"outputs": [],
627+
"source": [
628+
"model.summary()"
629+
]
630+
},
620631
{
621632
"cell_type": "markdown",
622633
"metadata": {
623-
"id": "g0ylJXE_kRLi"
634+
"id": "lxOcmVr0ydFZ"
624635
},
625636
"source": [
626-
"### Compile the model\n",
627-
"\n",
628-
"Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output."
637+
"The 8+ million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases."
629638
]
630639
},
631640
{
632641
"cell_type": "code",
633642
"execution_count": null,
634643
"metadata": {
635-
"id": "RpR8HdyMhukJ"
644+
"id": "krvBumovycVA"
636645
},
637646
"outputs": [],
638647
"source": [
639-
"base_learning_rate = 0.0001\n",
640-
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n",
641-
" loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
642-
" metrics=['accuracy'])"
648+
"len(model.trainable_variables)"
643649
]
644650
},
645651
{
646652
"cell_type": "code",
647653
"execution_count": null,
648654
"metadata": {
649-
"id": "I8ARiyMFsgbH"
655+
"id": "jeGk93R2ahav"
650656
},
651657
"outputs": [],
652658
"source": [
653-
"model.summary()"
659+
"tf.keras.utils.plot_model(model, show_shapes=True)"
654660
]
655661
},
656662
{
657663
"cell_type": "markdown",
658664
"metadata": {
659-
"id": "lxOcmVr0ydFZ"
665+
"id": "g0ylJXE_kRLi"
660666
},
661667
"source": [
662-
"The 2.5 million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases."
668+
"### Compile the model\n",
669+
"\n",
670+
"Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output."
663671
]
664672
},
665673
{
666674
"cell_type": "code",
667675
"execution_count": null,
668676
"metadata": {
669-
"id": "krvBumovycVA"
677+
"id": "RpR8HdyMhukJ"
670678
},
671679
"outputs": [],
672680
"source": [
673-
"len(model.trainable_variables)"
681+
"base_learning_rate = 0.0001\n",
682+
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n",
683+
" loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
684+
" metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])"
674685
]
675686
},
676687
{
@@ -681,7 +692,7 @@
681692
"source": [
682693
"### Train the model\n",
683694
"\n",
684-
"After training for 10 epochs, you should see ~94% accuracy on the validation set.\n"
695+
"After training for 10 epochs, you should see ~96% accuracy on the validation set.\n"
685696
]
686697
},
687698
{
@@ -863,7 +874,7 @@
863874
"source": [
864875
"model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
865876
" optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),\n",
866-
" metrics=['accuracy'])"
877+
" metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])"
867878
]
868879
},
869880
{
@@ -1070,13 +1081,23 @@
10701081
"\n",
10711082
"To learn more, visit the [Transfer learning guide](https://www.tensorflow.org/guide/keras/transfer_learning).\n"
10721083
]
1084+
},
1085+
{
1086+
"cell_type": "code",
1087+
"execution_count": null,
1088+
"metadata": {
1089+
"id": "uKIByL01da8c"
1090+
},
1091+
"outputs": [],
1092+
"source": []
10731093
}
10741094
],
10751095
"metadata": {
10761096
"accelerator": "GPU",
10771097
"colab": {
1078-
"collapsed_sections": [],
10791098
"name": "transfer_learning.ipynb",
1099+
"private_outputs": true,
1100+
"provenance": [],
10801101
"toc_visible": true
10811102
},
10821103
"kernelspec": {

0 commit comments

Comments
 (0)