|
927 | 927 | " state = rnn_step(input_data[i], state)\n",
|
928 | 928 | " states = states.write(i, state)\n",
|
929 | 929 | " return tf.transpose(states.stack(), [1, 0, 2])\n",
|
930 |
| - " \n", |
| 930 | + "\n", |
931 | 931 | "dynamic_rnn(rnn_step,\n",
|
932 | 932 | " tf.random.uniform([batch_size, seq_len, feature_size]),\n",
|
933 | 933 | " tf.zeros([batch_size, feature_size]))"
|
|
1017 | 1017 | "assert len(external_list) == 1"
|
1018 | 1018 | ]
|
1019 | 1019 | },
|
| 1020 | + { |
| 1021 | + "cell_type": "markdown", |
| 1022 | + "metadata": { |
| 1023 | + "id": "5eZTFRv_k_nR" |
| 1024 | + }, |
| 1025 | + "source": [ |
| 1026 | + "Sometimes unexpected behaviors are very hard to notice. In the example below, the `counter` is intended to safeguard the increment of a variable. However because it is a python integer and not a TensorFlow object, it's value is captured during the first trace. When the `tf.function` is used, the `assign_add` will be recorded unconditionally in the underlying graph. Therefore `v` will increase by 1, every time the `tf.function` is called. This issue is common among users that try to migrate their Grpah-mode Tensorflow code to Tensorflow 2 using `tf.function` decorators, when python side-effects (the `counter` in the example) are used to determine what ops to run (`assign_add` in the example). Usually, users realize this only after seeing suspicious numerical results, or significantly lower performance than expected (e.g. if the guarded operation is very costly)." |
| 1027 | + ] |
| 1028 | + }, |
| 1029 | + { |
| 1030 | + "cell_type": "code", |
| 1031 | + "execution_count": null, |
| 1032 | + "metadata": { |
| 1033 | + "id": "5r6p7-9jk_3L" |
| 1034 | + }, |
| 1035 | + "outputs": [], |
| 1036 | + "source": [ |
| 1037 | + "class Model(tf.Module):\n", |
| 1038 | + " def __init__(self):\n", |
| 1039 | + " self.v = tf.Variable(0)\n", |
| 1040 | + " self.counter = 0\n", |
| 1041 | + "\n", |
| 1042 | + " @tf.function\n", |
| 1043 | + " def __call__(self):\n", |
| 1044 | + " if self.counter == 0:\n", |
| 1045 | + " # A python side-effect\n", |
| 1046 | + " self.counter += 1\n", |
| 1047 | + " self.v.assign_add(1)\n", |
| 1048 | + "\n", |
| 1049 | + " return self.v\n", |
| 1050 | + "\n", |
| 1051 | + "m = Model()\n", |
| 1052 | + "for n in range(3):\n", |
| 1053 | + " print(m().numpy()) # prints 1, 2, 3" |
| 1054 | + ] |
| 1055 | + }, |
| 1056 | + { |
| 1057 | + "cell_type": "markdown", |
| 1058 | + "metadata": { |
| 1059 | + "id": "tXCTcHoVcxhX" |
| 1060 | + }, |
| 1061 | + "source": [ |
| 1062 | + "A workaround to achieve the expected behavior is using [`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) to lift the operations outside of the function graph. This ensures that the variable increment is only done once during tracing time. It should be noted `init_scope` has other side effects including cleared control flow and gradient tape. Sometimes the usage of `init_scope` can become too complex to manage realistically." |
| 1063 | + ] |
| 1064 | + }, |
| 1065 | + { |
| 1066 | + "cell_type": "code", |
| 1067 | + "execution_count": null, |
| 1068 | + "metadata": { |
| 1069 | + "id": "An4MrIbrcvi8" |
| 1070 | + }, |
| 1071 | + "outputs": [], |
| 1072 | + "source": [ |
| 1073 | + "class Model(tf.Module):\n", |
| 1074 | + " def __init__(self):\n", |
| 1075 | + " self.v = tf.Variable(0)\n", |
| 1076 | + " self.counter = 0\n", |
| 1077 | + "\n", |
| 1078 | + " @tf.function\n", |
| 1079 | + " def __call__(self):\n", |
| 1080 | + " if self.counter == 0:\n", |
| 1081 | + " # Lifts ops out of function-building graphs\n", |
| 1082 | + " with tf.init_scope():\n", |
| 1083 | + " self.counter += 1\n", |
| 1084 | + " self.v.assign_add(1)\n", |
| 1085 | + "\n", |
| 1086 | + " return self.v\n", |
| 1087 | + "\n", |
| 1088 | + "m = Model()\n", |
| 1089 | + "for n in range(3):\n", |
| 1090 | + " print(m().numpy()) # prints 1, 1, 1" |
| 1091 | + ] |
| 1092 | + }, |
1020 | 1093 | {
|
1021 | 1094 | "cell_type": "markdown",
|
1022 | 1095 | "metadata": {
|
1023 | 1096 | "id": "pbFG5CX4LwQA"
|
1024 | 1097 | },
|
1025 | 1098 | "source": [
|
1026 |
| - "You should avoid mutating containers like lists, dicts, other objects that live outside the `Function`. Instead, use arguments and TF objects. For example, the section [\"Accumulating values in a loop\"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.\n", |
| 1099 | + "In summary, as a rule of thumb, you should avoid mutating python objects such as integers or containers like lists that live outside the `Function`. Instead, use arguments and TF objects. For example, the section [\"Accumulating values in a loop\"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.\n", |
1027 | 1100 | "\n",
|
1028 | 1101 | "You can, in some cases, capture and manipulate state if it is a [`tf.Variable`](https://www.tensorflow.org/guide/variable). This is how the weights of Keras models are updated with repeated calls to the same `ConcreteFunction`."
|
1029 | 1102 | ]
|
|
1625 | 1698 | "colab": {
|
1626 | 1699 | "collapsed_sections": [],
|
1627 | 1700 | "name": "function.ipynb",
|
| 1701 | + "provenance": [], |
1628 | 1702 | "toc_visible": true
|
1629 | 1703 | },
|
1630 | 1704 | "kernelspec": {
|
|
0 commit comments