Skip to content

Commit 3ec9a34

Browse files
Update section on nested models in migration shim guide.
PiperOrigin-RevId: 411117072
1 parent 69ddbac commit 3ec9a34

File tree

1 file changed

+110
-61
lines changed

1 file changed

+110
-61
lines changed

site/en/guide/migrate/model_mapping.ipynb

Lines changed: 110 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
},
107107
{
108108
"cell_type": "code",
109-
"execution_count": null,
109+
"execution_count": 2,
110110
"metadata": {
111111
"id": "PzkV-2cna823"
112112
},
@@ -594,65 +594,71 @@
594594
"source": [
595595
"## Nesting `tf.Variable`s, `tf.Module`s, `tf.keras.layers` & `tf.keras.models` in decorated calls\n",
596596
"\n",
597-
"Decorating your layer call in `tf.compat.v1.keras.utils.track_tf1_style_variables` will only add automatic implicit tracking of variables created (and reused) via `tf.compat.v1.get_variable`. It will not capture weights directly created by `tf.Variable` calls, such as those used by typical Keras layers and most `tf.Module`s. You still need to explicitly track these in the same way you would for any other Keras layer or `tf.Module`.\n",
598-
"\n",
599-
"If you need to embed `tf.Variable` calls, Keras layers/models, or `tf.Module`s in your decorators (either because you are following the incremental migration to Native TF2 described later in this guide, or because your TF1.x code partially consisted of Keras modules):\n",
600-
"* Explicitly make sure that the variable/module/layer is only created once\n",
601-
"* Explicitly attach them as instance attributes just as you would when defining a [typical module/layer](https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow)\n",
602-
"* Explicitly reuse the already-created object in follow-on calls\n",
597+
"Decorating your layer call in `tf.compat.v1.keras.utils.track_tf1_style_variables` will only add automatic implicit tracking of variables created (and reused) via `tf.compat.v1.get_variable`. It will not capture weights directly created by `tf.Variable` calls, such as those used by typical Keras layers and most `tf.Module`s. This section describes how to handle these nested cases.\n"
598+
]
599+
},
600+
{
601+
"cell_type": "markdown",
602+
"metadata": {
603+
"id": "Azxza3bVOZlv"
604+
},
605+
"source": [
606+
"### (Pre-existing usages) `tf.keras.layers` and `tf.keras.models`\n",
603607
"\n",
604-
"This ensures that weights are not created new and are correctly resued. Additionally, this also ensures that existing weights and regularization losses get tracked.\n",
608+
"For pre-existing usages of nested Keras layers and models, use `tf.compat.v1.keras.utils.get_or_create_layer`. This is only recommended for easing migration of existing TF1.x nested Keras usages; new code should use explicit attribute setting as described below for tf.Variables and tf.Modules.\n",
605609
"\n",
606-
"Here is an example of how this could look:"
610+
"To use `tf.compat.v1.keras.utils.get_or_create_layer`, wrap the code that constructs your nested model into a method, and pass it in to the method. Example:"
607611
]
608612
},
609613
{
610614
"cell_type": "code",
611615
"execution_count": null,
612616
"metadata": {
613-
"id": "mrRPPoJ5ap5U"
617+
"id": "LN15TcRgHKsq"
614618
},
615619
"outputs": [],
616620
"source": [
617-
"class WrappedDenseLayer(tf.keras.layers.Layer):\n",
621+
"class NestedModel(tf.keras.Model):\n",
618622
"\n",
619-
" def __init__(self, units, **kwargs):\n",
620-
" super().__init__(**kwargs)\n",
623+
" def __init__(self, units, *args, **kwargs):\n",
624+
" super().__init__(*args, **kwargs)\n",
621625
" self.units = units\n",
622-
" self._dense_model = None\n",
626+
"\n",
627+
" def build_model(self):\n",
628+
" inp = tf.keras.Input(shape=(5, 5))\n",
629+
" dense_layer = tf.keras.layers.Dense(\n",
630+
" 10, name=\"dense\", kernel_regularizer=\"l2\",\n",
631+
" kernel_initializer=tf.compat.v1.ones_initializer())\n",
632+
" model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))\n",
633+
" return model\n",
623634
"\n",
624635
" @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
625636
" def call(self, inputs):\n",
626-
" # Create the nested tf.variable/module/layer/model\n",
627-
" # only if it has not been created already\n",
628-
" if not self._dense_model:\n",
629-
" inp = tf.keras.Input(shape=inputs.shape)\n",
630-
" dense_layer = tf.keras.layers.Dense(\n",
631-
" self.units, name=\"dense\",\n",
632-
" kernel_regularizer=\"l2\")\n",
633-
" self._dense_model = tf.keras.Model(\n",
634-
" inputs=inp, outputs=dense_layer(inp))\n",
635-
" return self._dense_model(inputs)\n",
637+
" # Get or create a nested model without assigning it as an explicit property\n",
638+
" model = tf.compat.v1.keras.utils.get_or_create_layer(\n",
639+
" \"dense_model\", self.build_model)\n",
640+
" return model(inputs)\n",
636641
"\n",
637-
"layer = WrappedDenseLayer(10)\n",
638-
"\n",
639-
"layer(tf.ones(shape=(5, 5)))"
642+
"layer = NestedModel(10)\n",
643+
"layer(tf.ones(shape=(5,5)))"
640644
]
641645
},
642646
{
643647
"cell_type": "markdown",
644648
"metadata": {
645-
"id": "Lo9h6wc6bmEF"
649+
"id": "DgsKlltPHI8z"
646650
},
647651
"source": [
648-
"The weights are correctly tracked:"
652+
"This method ensures that these nested layers are correctly reused and tracked by tensorflow. Note that the `@track_tf1_style_variables` decorator is still required on the appropriate method. The model builder method passed into `get_or_create_layer` (in this case, `self.build_model`), should take no arguments.\n",
653+
"\n",
654+
"Weights are tracked:"
649655
]
650656
},
651657
{
652658
"cell_type": "code",
653659
"execution_count": null,
654660
"metadata": {
655-
"id": "Qt6USaTVbauM"
661+
"id": "3zO5A78MJsqO"
656662
},
657663
"outputs": [],
658664
"source": [
@@ -667,55 +673,46 @@
667673
{
668674
"cell_type": "markdown",
669675
"metadata": {
670-
"id": "oyH4lIcPb45r"
676+
"id": "o3Xsi-JbKTuj"
671677
},
672678
"source": [
673-
"As is the regularization loss (if present):"
679+
"And regularization loss as well:"
674680
]
675681
},
676682
{
677683
"cell_type": "code",
678684
"execution_count": null,
679685
"metadata": {
680-
"id": "N7cmuhRGbfFt"
686+
"id": "mdK5RGm5KW5C"
681687
},
682688
"outputs": [],
683689
"source": [
684-
"regularization_loss = tf.add_n(layer.losses)\n",
685-
"regularization_loss"
690+
"tf.add_n(layer.losses)"
686691
]
687692
},
688693
{
689694
"cell_type": "markdown",
690695
"metadata": {
691-
"id": "FsTgnydkdezQ"
696+
"id": "J_VRycQYJrXu"
692697
},
693698
"source": [
694-
"### Guidance on variable names\n",
699+
"### Incremental migration: `tf.Variables` and `tf.Modules`\n",
695700
"\n",
696-
"Explicit `tf.Variable` calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of `get_variable` and `variable_scopes`. Although the shim will make your variable names match for variables created by `get_variable` even when going from TF1.x graphs to TF2 eager execution & `tf.function`, it cannot guarantee the same for the variable names generated for `tf.Variable` calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and `tf.function`.\n",
697-
"\n",
698-
"You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide."
699-
]
700-
},
701-
{
702-
"cell_type": "markdown",
703-
"metadata": {
704-
"id": "mSFaHTCvhUso"
705-
},
706-
"source": [
707-
"### Nesting layers/modules that use `@track_tf1_style_variables`\n",
701+
"If you need to embed `tf.Variable` calls or `tf.Module`s in your decorated methods (for example, if you are following the incremental migration to non-legacy TF2 APIs described later in this guide), you still need to explicitly track these, with the following requirements:\n",
702+
"* Explicitly make sure that the variable/module/layer is only created once\n",
703+
"* Explicitly attach them as instance attributes just as you would when defining a [typical module or layer](https://www.tensorflow.org/guide/intro_to_modules#defining_models_and_layers_in_tensorflow)\n",
704+
"* Explicitly reuse the already-created object in follow-on calls\n",
708705
"\n",
709-
"If you are nesting one layer that uses the `@track_tf1_style_variables` decorator inside of another, you should treat it the same way you would treat any Keras layer or `tf.Module` that did not use `get_variable` to create its variables.\n",
706+
"This ensures that weights are not created new each call and are correctly reused. Additionally, this also ensures that existing weights and regularization losses get tracked.\n",
710707
"\n",
711-
"For example,"
708+
"Here is an example of how this could look:"
712709
]
713710
},
714711
{
715712
"cell_type": "code",
716713
"execution_count": null,
717714
"metadata": {
718-
"id": "SI5V-1JLhTfW"
715+
"id": "mrRPPoJ5ap5U"
719716
},
720717
"outputs": [],
721718
"source": [
@@ -726,9 +723,9 @@
726723
" self.units = units\n",
727724
"\n",
728725
" @tf.compat.v1.keras.utils.track_tf1_style_variables\n",
729-
" def call(self, inputs):\n",
726+
" def __call__(self, inputs):\n",
730727
" out = inputs\n",
731-
" with tf.compat.v1.variable_scope(\"dense\"):\n",
728+
" with tf.compat.v1.variable_scope(\"inner_dense\"):\n",
732729
" # The weights are created with a `regularizer`,\n",
733730
" # so the layer should track their regularization losses\n",
734731
" kernel = tf.compat.v1.get_variable(\n",
@@ -762,29 +759,81 @@
762759
"\n",
763760
"layer = WrappedDenseLayer(10)\n",
764761
"\n",
765-
"layer(tf.ones(shape=(5, 5)))\n",
762+
"layer(tf.ones(shape=(5, 5)))"
763+
]
764+
},
765+
{
766+
"cell_type": "markdown",
767+
"metadata": {
768+
"id": "Lo9h6wc6bmEF"
769+
},
770+
"source": [
771+
"Note that explicit tracking of the nested module is needed even though it is decorated with the `track_tf1_style_variables` decorator. This is because each module/layer with decorated methods has its own variable store associated with it. \n",
772+
"\n",
773+
"The weights are correctly tracked:"
774+
]
775+
},
776+
{
777+
"cell_type": "code",
778+
"execution_count": null,
779+
"metadata": {
780+
"id": "Qt6USaTVbauM"
781+
},
782+
"outputs": [],
783+
"source": [
784+
"assert len(layer.weights) == 6\n",
785+
"weights = {x.name: x for x in layer.variables}\n",
786+
"\n",
787+
"assert set(weights.keys()) == {\"outer/inner_dense/bias:0\",\n",
788+
" \"outer/inner_dense/kernel:0\",\n",
789+
" \"outer/dense/bias:0\",\n",
790+
" \"outer/dense/kernel:0\",\n",
791+
" \"outer/dense_1/bias:0\",\n",
792+
" \"outer/dense_1/kernel:0\"}\n",
766793
"\n",
767-
"# Recursively track weights and regularization losses\n",
768-
"layer.trainable_weights\n",
794+
"layer.trainable_weights"
795+
]
796+
},
797+
{
798+
"cell_type": "markdown",
799+
"metadata": {
800+
"id": "dHn-bJoNJw7l"
801+
},
802+
"source": [
803+
"As well as regularization loss:"
804+
]
805+
},
806+
{
807+
"cell_type": "code",
808+
"execution_count": null,
809+
"metadata": {
810+
"id": "pq5GFtXjJyut"
811+
},
812+
"outputs": [],
813+
"source": [
769814
"layer.losses"
770815
]
771816
},
772817
{
773818
"cell_type": "markdown",
774819
"metadata": {
775-
"id": "DkEkLnGbipSS"
820+
"id": "p7VKJj3JOCEk"
776821
},
777822
"source": [
778-
"Notice that `variable_scope`s set in the outer layer may affect the naming of variables set in the nested layer, *but* `get_variable` will not share variables by name across the outer shim-based layer and the nested shim-based layer even if they have the same name, because the nested and outer layer utilize different internal variable stores."
823+
"Note that if the `NestedLayer` were a non-Keras `tf.Module` instead, variables would still be tracked but regularization losses would not be automatically tracked, so you would have to explicitly track them separately."
779824
]
780825
},
781826
{
782827
"cell_type": "markdown",
783828
"metadata": {
784-
"id": "PfbiY08UizLz"
829+
"id": "FsTgnydkdezQ"
785830
},
786831
"source": [
787-
"As mentioned previously, if you are using a shim-decorated `tf.Module` there is no `losses` property to recursively and automatically track the regularization loss of your nested layer, and you will have to track it separately."
832+
"### Guidance on variable names\n",
833+
"\n",
834+
"Explicit `tf.Variable` calls and Keras layers use a different layer name / variable name autogeneration mechanism than you may be used to from the combination of `get_variable` and `variable_scopes`. Although the shim will make your variable names match for variables created by `get_variable` even when going from TF1.x graphs to TF2 eager execution & `tf.function`, it cannot guarantee the same for the variable names generated for `tf.Variable` calls and Keras layers that you embed within your method decorators. It is even possible for multiple variables to share the same name in TF2 eager execution and `tf.function`.\n",
835+
"\n",
836+
"You should take special care with this when following the sections on validating correctness and mapping TF1.x checkpoints later on in this guide."
788837
]
789838
},
790839
{

0 commit comments

Comments
 (0)