|
516 | 516 | "source": [
|
517 | 517 | "### Important note about BatchNormalization layers\n",
|
518 | 518 | "\n",
|
519 |
| - "Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial. \n", |
| 519 | + "Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial.\n", |
520 | 520 | "\n",
|
521 |
| - "When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics. \n", |
| 521 | + "When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics.\n", |
522 | 522 | "\n",
|
523 | 523 | "When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training = False` when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.\n",
|
524 | 524 | "\n",
|
|
617 | 617 | "model = tf.keras.Model(inputs, outputs)"
|
618 | 618 | ]
|
619 | 619 | },
|
| 620 | + { |
| 621 | + "cell_type": "code", |
| 622 | + "execution_count": null, |
| 623 | + "metadata": { |
| 624 | + "id": "I8ARiyMFsgbH" |
| 625 | + }, |
| 626 | + "outputs": [], |
| 627 | + "source": [ |
| 628 | + "model.summary()" |
| 629 | + ] |
| 630 | + }, |
620 | 631 | {
|
621 | 632 | "cell_type": "markdown",
|
622 | 633 | "metadata": {
|
623 |
| - "id": "g0ylJXE_kRLi" |
| 634 | + "id": "lxOcmVr0ydFZ" |
624 | 635 | },
|
625 | 636 | "source": [
|
626 |
| - "### Compile the model\n", |
627 |
| - "\n", |
628 |
| - "Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output." |
| 637 | + "The 8+ million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases." |
629 | 638 | ]
|
630 | 639 | },
|
631 | 640 | {
|
632 | 641 | "cell_type": "code",
|
633 | 642 | "execution_count": null,
|
634 | 643 | "metadata": {
|
635 |
| - "id": "RpR8HdyMhukJ" |
| 644 | + "id": "krvBumovycVA" |
636 | 645 | },
|
637 | 646 | "outputs": [],
|
638 | 647 | "source": [
|
639 |
| - "base_learning_rate = 0.0001\n", |
640 |
| - "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n", |
641 |
| - " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", |
642 |
| - " metrics=['accuracy'])" |
| 648 | + "len(model.trainable_variables)" |
643 | 649 | ]
|
644 | 650 | },
|
645 | 651 | {
|
646 | 652 | "cell_type": "code",
|
647 | 653 | "execution_count": null,
|
648 | 654 | "metadata": {
|
649 |
| - "id": "I8ARiyMFsgbH" |
| 655 | + "id": "jeGk93R2ahav" |
650 | 656 | },
|
651 | 657 | "outputs": [],
|
652 | 658 | "source": [
|
653 |
| - "model.summary()" |
| 659 | + "tf.keras.utils.plot_model(model, show_shapes=True)" |
654 | 660 | ]
|
655 | 661 | },
|
656 | 662 | {
|
657 | 663 | "cell_type": "markdown",
|
658 | 664 | "metadata": {
|
659 |
| - "id": "lxOcmVr0ydFZ" |
| 665 | + "id": "g0ylJXE_kRLi" |
660 | 666 | },
|
661 | 667 | "source": [
|
662 |
| - "The 2.5 million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases." |
| 668 | + "### Compile the model\n", |
| 669 | + "\n", |
| 670 | + "Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output." |
663 | 671 | ]
|
664 | 672 | },
|
665 | 673 | {
|
666 | 674 | "cell_type": "code",
|
667 | 675 | "execution_count": null,
|
668 | 676 | "metadata": {
|
669 |
| - "id": "krvBumovycVA" |
| 677 | + "id": "RpR8HdyMhukJ" |
670 | 678 | },
|
671 | 679 | "outputs": [],
|
672 | 680 | "source": [
|
673 |
| - "len(model.trainable_variables)" |
| 681 | + "base_learning_rate = 0.0001\n", |
| 682 | + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n", |
| 683 | + " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", |
| 684 | + " metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])" |
674 | 685 | ]
|
675 | 686 | },
|
676 | 687 | {
|
|
681 | 692 | "source": [
|
682 | 693 | "### Train the model\n",
|
683 | 694 | "\n",
|
684 |
| - "After training for 10 epochs, you should see ~94% accuracy on the validation set.\n" |
| 695 | + "After training for 10 epochs, you should see ~96% accuracy on the validation set.\n" |
685 | 696 | ]
|
686 | 697 | },
|
687 | 698 | {
|
|
863 | 874 | "source": [
|
864 | 875 | "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
|
865 | 876 | " optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),\n",
|
866 |
| - " metrics=['accuracy'])" |
| 877 | + " metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])" |
867 | 878 | ]
|
868 | 879 | },
|
869 | 880 | {
|
|
1070 | 1081 | "\n",
|
1071 | 1082 | "To learn more, visit the [Transfer learning guide](https://www.tensorflow.org/guide/keras/transfer_learning).\n"
|
1072 | 1083 | ]
|
| 1084 | + }, |
| 1085 | + { |
| 1086 | + "cell_type": "code", |
| 1087 | + "execution_count": null, |
| 1088 | + "metadata": { |
| 1089 | + "id": "uKIByL01da8c" |
| 1090 | + }, |
| 1091 | + "outputs": [], |
| 1092 | + "source": [] |
1073 | 1093 | }
|
1074 | 1094 | ],
|
1075 | 1095 | "metadata": {
|
1076 | 1096 | "accelerator": "GPU",
|
1077 | 1097 | "colab": {
|
1078 |
| - "collapsed_sections": [], |
1079 | 1098 | "name": "transfer_learning.ipynb",
|
| 1099 | + "private_outputs": true, |
| 1100 | + "provenance": [], |
1080 | 1101 | "toc_visible": true
|
1081 | 1102 | },
|
1082 | 1103 | "kernelspec": {
|
|
0 commit comments