|
106 | 106 | },
|
107 | 107 | {
|
108 | 108 | "cell_type": "code",
|
109 |
| - "execution_count": null, |
| 109 | + "execution_count": 2, |
110 | 110 | "metadata": {
|
111 | 111 | "id": "PzkV-2cna823"
|
112 | 112 | },
|
|
594 | 594 | "source": [
|
595 | 595 | "## Nesting `tf.Variable`s, `tf.Module`s, `tf.keras.layers` & `tf.keras.models` in decorated calls\n",
|
596 | 596 | "\n",
|
597 |
| - "Decorating your layer call in `tf.compat.v1.keras.utils.track_tf1_style_variables` will only add automatic implicit tracking of variables created (and reused) via `tf.compat.v1.get_variable`. It will not capture weights directly created by `tf.Variable` calls, such as those used by typical Keras layers and most `tf.Module`s. You still need to explicitly track these in the same way you would for any other Keras layer or `tf.Module`.\n", |
598 |
| - "\n", |
599 |
| - "If you need to embed `tf.Variable` calls, Keras layers/models, or `tf.Module`s in your decorators (either because you are following the incremental migration to Native TF2 described later in this guide, or because your TF1.x code partially consisted of Keras modules):\n", |
600 |
| - "* Explicitly make sure that the variable/module/layer is only created once\n", |
601 |
| - "* Explicitly attach them as instance attributes just as you would when defining a [typical module/layer](https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow)\n", |
602 |
| - "* Explicitly reuse the already-created object in follow-on calls\n", |
| 597 | + "Decorating your layer call in `tf.compat.v1.keras.utils.track_tf1_style_variables` will only add automatic implicit tracking of variables created (and reused) via `tf.compat.v1.get_variable`. It will not capture weights directly created by `tf.Variable` calls, such as those used by typical Keras layers and most `tf.Module`s. This section describes how to handle these nested cases.\n" |
| 598 | + ] |
| 599 | + }, |
| 600 | + { |
| 601 | + "cell_type": "markdown", |
| 602 | + "metadata": { |
| 603 | + "id": "Azxza3bVOZlv" |
| 604 | + }, |
| 605 | + "source": [ |
| 606 | + "### (Pre-existing usages) `tf.keras.layers` and `tf.keras.models`\n", |
603 | 607 | "\n",
|
604 |
| - "This ensures that weights are not created new and are correctly resued. Additionally, this also ensures that existing weights and regularization losses get tracked.\n", |
| 608 | + "For pre-existing usages of nested Keras layers and models, use `tf.compat.v1.keras.utils.get_or_create_layer`. This is only recommended for easing migration of existing TF1.x nested Keras usages; new code should use explicit attribute setting as described below for tf.Variables and tf.Modules.\n", |
605 | 609 | "\n",
|
606 |
| - "Here is an example of how this could look:" |
| 610 | + "To use `tf.compat.v1.keras.utils.get_or_create_layer`, wrap the code that constructs your nested model into a method, and pass it in to the method. Example:" |
607 | 611 | ]
|
608 | 612 | },
|
609 | 613 | {
|
610 | 614 | "cell_type": "code",
|
611 | 615 | "execution_count": null,
|
612 | 616 | "metadata": {
|
613 |
| - "id": "mrRPPoJ5ap5U" |
| 617 | + "id": "LN15TcRgHKsq" |
614 | 618 | },
|
615 | 619 | "outputs": [],
|
616 | 620 | "source": [
|
617 |
| - "class WrappedDenseLayer(tf.keras.layers.Layer):\n", |
| 621 | + "class NestedModel(tf.keras.Model):\n", |
618 | 622 | "\n",
|
619 |
| - " def __init__(self, units, **kwargs):\n", |
620 |
| - " super().__init__(**kwargs)\n", |
| 623 | + " def __init__(self, units, *args, **kwargs):\n", |
| 624 | + " super().__init__(*args, **kwargs)\n", |
621 | 625 | " self.units = units\n",
|
622 |
| - " self._dense_model = None\n", |
| 626 | + "\n", |
| 627 | + " def build_model(self):\n", |
| 628 | + " inp = tf.keras.Input(shape=(5, 5))\n", |
| 629 | + " dense_layer = tf.keras.layers.Dense(\n", |
| 630 | + " 10, name=\"dense\", kernel_regularizer=\"l2\",\n", |
| 631 | + " kernel_initializer=tf.compat.v1.ones_initializer())\n", |
| 632 | + " model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))\n", |
| 633 | + " return model\n", |
623 | 634 | "\n",
|
624 | 635 | " @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
|
625 | 636 | " def call(self, inputs):\n",
|
626 |
| - " # Create the nested tf.variable/module/layer/model\n", |
627 |
| - " # only if it has not been created already\n", |
628 |
| - " if not self._dense_model:\n", |
629 |
| - " inp = tf.keras.Input(shape=inputs.shape)\n", |
630 |
| - " dense_layer = tf.keras.layers.Dense(\n", |
631 |
| - " self.units, name=\"dense\",\n", |
632 |
| - " kernel_regularizer=\"l2\")\n", |
633 |
| - " self._dense_model = tf.keras.Model(\n", |
634 |
| - " inputs=inp, outputs=dense_layer(inp))\n", |
635 |
| - " return self._dense_model(inputs)\n", |
| 637 | + " # Get or create a nested model without assigning it as an explicit property\n", |
| 638 | + " model = tf.compat.v1.keras.utils.get_or_create_layer(\n", |
| 639 | + " \"dense_model\", self.build_model)\n", |
| 640 | + " return model(inputs)\n", |
636 | 641 | "\n",
|
637 |
| - "layer = WrappedDenseLayer(10)\n", |
638 |
| - "\n", |
639 |
| - "layer(tf.ones(shape=(5, 5)))" |
| 642 | + "layer = NestedModel(10)\n", |
| 643 | + "layer(tf.ones(shape=(5,5)))" |
640 | 644 | ]
|
641 | 645 | },
|
642 | 646 | {
|
643 | 647 | "cell_type": "markdown",
|
644 | 648 | "metadata": {
|
645 |
| - "id": "Lo9h6wc6bmEF" |
| 649 | + "id": "DgsKlltPHI8z" |
646 | 650 | },
|
647 | 651 | "source": [
|
648 |
| - "The weights are correctly tracked:" |
| 652 | + "This method ensures that these nested layers are correctly reused and tracked by tensorflow. Note that the `@track_tf1_style_variables` decorator is still required on the appropriate method. The model builder method passed into `get_or_create_layer` (in this case, `self.build_model`), should take no arguments.\n", |
| 653 | + "\n", |
| 654 | + "Weights are tracked:" |
649 | 655 | ]
|
650 | 656 | },
|
651 | 657 | {
|
652 | 658 | "cell_type": "code",
|
653 | 659 | "execution_count": null,
|
654 | 660 | "metadata": {
|
655 |
| - "id": "Qt6USaTVbauM" |
| 661 | + "id": "3zO5A78MJsqO" |
656 | 662 | },
|
657 | 663 | "outputs": [],
|
658 | 664 | "source": [
|
|
667 | 673 | {
|
668 | 674 | "cell_type": "markdown",
|
669 | 675 | "metadata": {
|
670 |
| - "id": "oyH4lIcPb45r" |
| 676 | + "id": "o3Xsi-JbKTuj" |
671 | 677 | },
|
672 | 678 | "source": [
|
673 |
| - "As is the regularization loss (if present):" |
| 679 | + "And regularization loss as well:" |
674 | 680 | ]
|
675 | 681 | },
|
676 | 682 | {
|
677 | 683 | "cell_type": "code",
|
678 | 684 | "execution_count": null,
|
679 | 685 | "metadata": {
|
680 |
| - "id": "N7cmuhRGbfFt" |
| 686 | + "id": "mdK5RGm5KW5C" |
681 | 687 | },
|
682 | 688 | "outputs": [],
|
683 | 689 | "source": [
|
684 |
| - "regularization_loss = tf.add_n(layer.losses)\n", |
685 |
| - "regularization_loss" |
| 690 | + "tf.add_n(layer.losses)" |
686 | 691 | ]
|
687 | 692 | },
|
688 | 693 | {
|
689 | 694 | "cell_type": "markdown",
|
690 | 695 | "metadata": {
|
691 |
| - "id": "FsTgnydkdezQ" |
| 696 | + "id": "J_VRycQYJrXu" |
692 | 697 | },
|
693 | 698 | "source": [
|
694 |
| - "### Guidance on variable names\n", |
| 699 | + "### Incremental migration: `tf.Variables` and `tf.Modules`\n", |
695 | 700 | "\n",
|
696 |
| - "Explicit `tf.Variable` calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of `get_variable` and `variable_scopes`. Although the shim will make your variable names match for variables created by `get_variable` even when going from TF1.x graphs to TF2 eager execution & `tf.function`, it cannot guarantee the same for the variable names generated for `tf.Variable` calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and `tf.function`.\n", |
697 |
| - "\n", |
698 |
| - "You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide." |
699 |
| - ] |
700 |
| - }, |
701 |
| - { |
702 |
| - "cell_type": "markdown", |
703 |
| - "metadata": { |
704 |
| - "id": "mSFaHTCvhUso" |
705 |
| - }, |
706 |
| - "source": [ |
707 |
| - "### Nesting layers/modules that use `@track_tf1_style_variables`\n", |
| 701 | + "If you need to embed `tf.Variable` calls or `tf.Module`s in your decorated methods (for example, if you are following the incremental migration to non-legacy TF2 APIs described later in this guide), you still need to explicitly track these, with the following requirements:\n", |
| 702 | + "* Explicitly make sure that the variable/module/layer is only created once\n", |
| 703 | + "* Explicitly attach them as instance attributes just as you would when defining a [typical module or layer](https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow)\n", |
| 704 | + "* Explicitly reuse the already-created object in follow-on calls\n", |
708 | 705 | "\n",
|
709 |
| - "If you are nesting one layer that uses the `@track_tf1_style_variables` decorator inside of another, you should treat it the same way you would treat any Keras layer or `tf.Module` that did not use `get_variable` to create its variables.\n", |
| 706 | + "This ensures that weights are not created new each call and are correctly reused. Additionally, this also ensures that existing weights and regularization losses get tracked.\n", |
710 | 707 | "\n",
|
711 |
| - "For example," |
| 708 | + "Here is an example of how this could look:" |
712 | 709 | ]
|
713 | 710 | },
|
714 | 711 | {
|
715 | 712 | "cell_type": "code",
|
716 | 713 | "execution_count": null,
|
717 | 714 | "metadata": {
|
718 |
| - "id": "SI5V-1JLhTfW" |
| 715 | + "id": "mrRPPoJ5ap5U" |
719 | 716 | },
|
720 | 717 | "outputs": [],
|
721 | 718 | "source": [
|
|
726 | 723 | " self.units = units\n",
|
727 | 724 | "\n",
|
728 | 725 | " @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
|
729 |
| - " def call(self, inputs):\n", |
| 726 | + " def __call__(self, inputs):\n", |
730 | 727 | " out = inputs\n",
|
731 |
| - " with tf.compat.v1.variable_scope(\"dense\"):\n", |
| 728 | + " with tf.compat.v1.variable_scope(\"inner_dense\"):\n", |
732 | 729 | " # The weights are created with a `regularizer`,\n",
|
733 | 730 | " # so the layer should track their regularization losses\n",
|
734 | 731 | " kernel = tf.compat.v1.get_variable(\n",
|
|
762 | 759 | "\n",
|
763 | 760 | "layer = WrappedDenseLayer(10)\n",
|
764 | 761 | "\n",
|
765 |
| - "layer(tf.ones(shape=(5, 5)))\n", |
| 762 | + "layer(tf.ones(shape=(5, 5)))" |
| 763 | + ] |
| 764 | + }, |
| 765 | + { |
| 766 | + "cell_type": "markdown", |
| 767 | + "metadata": { |
| 768 | + "id": "Lo9h6wc6bmEF" |
| 769 | + }, |
| 770 | + "source": [ |
| 771 | + "Note that explicit tracking of the nested module is needed even though it is decorated with the `track_tf1_style_variables` decorator. This is because each module/layer with decorated methods has its own variable store associated with it. \n", |
| 772 | + "\n", |
| 773 | + "The weights are correctly tracked:" |
| 774 | + ] |
| 775 | + }, |
| 776 | + { |
| 777 | + "cell_type": "code", |
| 778 | + "execution_count": null, |
| 779 | + "metadata": { |
| 780 | + "id": "Qt6USaTVbauM" |
| 781 | + }, |
| 782 | + "outputs": [], |
| 783 | + "source": [ |
| 784 | + "assert len(layer.weights) == 6\n", |
| 785 | + "weights = {x.name: x for x in layer.variables}\n", |
| 786 | + "\n", |
| 787 | + "assert set(weights.keys()) == {\"outer/inner_dense/bias:0\",\n", |
| 788 | + " \"outer/inner_dense/kernel:0\",\n", |
| 789 | + " \"outer/dense/bias:0\",\n", |
| 790 | + " \"outer/dense/kernel:0\",\n", |
| 791 | + " \"outer/dense_1/bias:0\",\n", |
| 792 | + " \"outer/dense_1/kernel:0\"}\n", |
766 | 793 | "\n",
|
767 |
| - "# Recursively track weights and regularization losses\n", |
768 |
| - "layer.trainable_weights\n", |
| 794 | + "layer.trainable_weights" |
| 795 | + ] |
| 796 | + }, |
| 797 | + { |
| 798 | + "cell_type": "markdown", |
| 799 | + "metadata": { |
| 800 | + "id": "dHn-bJoNJw7l" |
| 801 | + }, |
| 802 | + "source": [ |
| 803 | + "As well as regularization loss:" |
| 804 | + ] |
| 805 | + }, |
| 806 | + { |
| 807 | + "cell_type": "code", |
| 808 | + "execution_count": null, |
| 809 | + "metadata": { |
| 810 | + "id": "pq5GFtXjJyut" |
| 811 | + }, |
| 812 | + "outputs": [], |
| 813 | + "source": [ |
769 | 814 | "layer.losses"
|
770 | 815 | ]
|
771 | 816 | },
|
772 | 817 | {
|
773 | 818 | "cell_type": "markdown",
|
774 | 819 | "metadata": {
|
775 |
| - "id": "DkEkLnGbipSS" |
| 820 | + "id": "p7VKJj3JOCEk" |
776 | 821 | },
|
777 | 822 | "source": [
|
778 |
| - "Notice that `variable_scope`s set in the outer layer may affect the naming of variables set in the nested layer, *but* `get_variable` will not share variables by name across the outer shim-based layer and the nested shim-based layer even if they have the same name, because the nested and outer layer utilize different internal variable stores." |
| 823 | + "Note that if the `NestedLayer` were a non-Keras `tf.Module` instead, variables would still be tracked but regularization losses would not be automatically tracked, so you would have to explicitly track them separately." |
779 | 824 | ]
|
780 | 825 | },
|
781 | 826 | {
|
782 | 827 | "cell_type": "markdown",
|
783 | 828 | "metadata": {
|
784 |
| - "id": "PfbiY08UizLz" |
| 829 | + "id": "FsTgnydkdezQ" |
785 | 830 | },
|
786 | 831 | "source": [
|
787 |
| - "As mentioned previously, if you are using a shim-decorated `tf.Module` there is no `losses` property to recursively and automatically track the regularization loss of your nested layer, and you will have to track it separately." |
| 832 | + "### Guidance on variable names\n", |
| 833 | + "\n", |
| 834 | + "Explicit `tf.Variable` calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of `get_variable` and `variable_scopes`. Although the shim will make your variable names match for variables created by `get_variable` even when going from TF1.x graphs to TF2 eager execution & `tf.function`, it cannot guarantee the same for the variable names generated for `tf.Variable` calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and `tf.function`.\n", |
| 835 | + "\n", |
| 836 | + "You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide." |
788 | 837 | ]
|
789 | 838 | },
|
790 | 839 | {
|
|
0 commit comments