Skip to content

Commit 3756d94

Browse files
Automated rollback of commit 49eca1f
PiperOrigin-RevId: 405918182
1 parent 11c42a8 commit 3756d94

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

site/en/guide/function.ipynb

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@
927927
" state = rnn_step(input_data[i], state)\n",
928928
" states = states.write(i, state)\n",
929929
" return tf.transpose(states.stack(), [1, 0, 2])\n",
930-
" \n",
930+
"\n",
931931
"dynamic_rnn(rnn_step,\n",
932932
" tf.random.uniform([batch_size, seq_len, feature_size]),\n",
933933
" tf.zeros([batch_size, feature_size]))"
@@ -1017,13 +1017,86 @@
10171017
"assert len(external_list) == 1"
10181018
]
10191019
},
1020+
{
1021+
"cell_type": "markdown",
1022+
"metadata": {
1023+
"id": "5eZTFRv_k_nR"
1024+
},
1025+
"source": [
1026+
"Sometimes unexpected behaviors are very hard to notice. In the example below, the `counter` is intended to safeguard the increment of a variable. However because it is a python integer and not a TensorFlow object, it's value is captured during the first trace. When the `tf.function` is used, the `assign_add` will be recorded unconditionally in the underlying graph. Therefore `v` will increase by 1, every time the `tf.function` is called. This issue is common among users that try to migrate their Grpah-mode Tensorflow code to Tensorflow 2 using `tf.function` decorators, when python side-effects (the `counter` in the example) are used to determine what ops to run (`assign_add` in the example). Usually, users realize this only after seeing suspicious numerical results, or significantly lower performance than expected (e.g. if the guarded operation is very costly)."
1027+
]
1028+
},
1029+
{
1030+
"cell_type": "code",
1031+
"execution_count": null,
1032+
"metadata": {
1033+
"id": "5r6p7-9jk_3L"
1034+
},
1035+
"outputs": [],
1036+
"source": [
1037+
"class Model(tf.Module):\n",
1038+
" def __init__(self):\n",
1039+
" self.v = tf.Variable(0)\n",
1040+
" self.counter = 0\n",
1041+
"\n",
1042+
" @tf.function\n",
1043+
" def __call__(self):\n",
1044+
" if self.counter == 0:\n",
1045+
" # A python side-effect\n",
1046+
" self.counter += 1\n",
1047+
" self.v.assign_add(1)\n",
1048+
"\n",
1049+
" return self.v\n",
1050+
"\n",
1051+
"m = Model()\n",
1052+
"for n in range(3):\n",
1053+
" print(m().numpy()) # prints 1, 2, 3"
1054+
]
1055+
},
1056+
{
1057+
"cell_type": "markdown",
1058+
"metadata": {
1059+
"id": "tXCTcHoVcxhX"
1060+
},
1061+
"source": [
1062+
"A workaround to achieve the expected behavior is using [`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) to lift the operations outside of the function graph. This ensures that the variable increment is only done once during tracing time. It should be noted `init_scope` has other side effects including cleared control flow and gradient tape. Sometimes the usage of `init_scope` can become too complex to manage realistically."
1063+
]
1064+
},
1065+
{
1066+
"cell_type": "code",
1067+
"execution_count": null,
1068+
"metadata": {
1069+
"id": "An4MrIbrcvi8"
1070+
},
1071+
"outputs": [],
1072+
"source": [
1073+
"class Model(tf.Module):\n",
1074+
" def __init__(self):\n",
1075+
" self.v = tf.Variable(0)\n",
1076+
" self.counter = 0\n",
1077+
"\n",
1078+
" @tf.function\n",
1079+
" def __call__(self):\n",
1080+
" if self.counter == 0:\n",
1081+
" # Lifts ops out of function-building graphs\n",
1082+
" with tf.init_scope():\n",
1083+
" self.counter += 1\n",
1084+
" self.v.assign_add(1)\n",
1085+
"\n",
1086+
" return self.v\n",
1087+
"\n",
1088+
"m = Model()\n",
1089+
"for n in range(3):\n",
1090+
" print(m().numpy()) # prints 1, 1, 1"
1091+
]
1092+
},
10201093
{
10211094
"cell_type": "markdown",
10221095
"metadata": {
10231096
"id": "pbFG5CX4LwQA"
10241097
},
10251098
"source": [
1026-
"You should avoid mutating containers like lists, dicts, other objects that live outside the `Function`. Instead, use arguments and TF objects. For example, the section [\"Accumulating values in a loop\"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.\n",
1099+
"In summary, as a rule of thumb, you should avoid mutating python objects such as integers or containers like lists that live outside the `Function`. Instead, use arguments and TF objects. For example, the section [\"Accumulating values in a loop\"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.\n",
10271100
"\n",
10281101
"You can, in some cases, capture and manipulate state if it is a [`tf.Variable`](https://www.tensorflow.org/guide/variable). This is how the weights of Keras models are updated with repeated calls to the same `ConcreteFunction`."
10291102
]
@@ -1625,6 +1698,7 @@
16251698
"colab": {
16261699
"collapsed_sections": [],
16271700
"name": "function.ipynb",
1701+
"provenance": [],
16281702
"toc_visible": true
16291703
},
16301704
"kernelspec": {

0 commit comments

Comments
 (0)