|
1069 | 1069 | "id": "e1I0dPiqTV8H"
|
1070 | 1070 | },
|
1071 | 1071 | "source": [
|
1072 |
| - "If you would like to execute Python code during each invocation of a `Function`, `tf.py_function` is an exit hatch. The drawback of `tf.py_function` is that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors." |
| 1072 | + "If you would like to execute Python code during each invocation of a `Function`, `tf. py_function` is an exit hatch. The drawbacks of `tf.py_function` are that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors." |
| 1073 | + ] |
| 1074 | + }, |
| 1075 | + { |
| 1076 | + "cell_type": "code", |
| 1077 | + "execution_count": null, |
| 1078 | + "metadata": { |
| 1079 | + "id": "ZbI7XA_e6yA2" |
| 1080 | + }, |
| 1081 | + "outputs": [], |
| 1082 | + "source": [ |
| 1083 | + "@tf.py_function(Tout=tf.float32)\n", |
| 1084 | + "def py_plus(x, y):\n", |
| 1085 | + " print('Executing eagerly.')\n", |
| 1086 | + " return x + y\n", |
| 1087 | + "\n", |
| 1088 | + "@tf.function\n", |
| 1089 | + "def tf_wrapper(x, y):\n", |
| 1090 | + " print('Tracing.')\n", |
| 1091 | + " return py_plus(x, y)" |
| 1092 | + ] |
| 1093 | + }, |
| 1094 | + { |
| 1095 | + "cell_type": "markdown", |
| 1096 | + "metadata": { |
| 1097 | + "id": "h5ttN_sI7TdQ" |
| 1098 | + }, |
| 1099 | + "source": [ |
| 1100 | + "The `tf.function` will trace the first time:" |
| 1101 | + ] |
| 1102 | + }, |
| 1103 | + { |
| 1104 | + "cell_type": "code", |
| 1105 | + "execution_count": null, |
| 1106 | + "metadata": { |
| 1107 | + "id": "mAK4XINl7Ldy" |
| 1108 | + }, |
| 1109 | + "outputs": [], |
| 1110 | + "source": [ |
| 1111 | + "tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()" |
| 1112 | + ] |
| 1113 | + }, |
| 1114 | + { |
| 1115 | + "cell_type": "markdown", |
| 1116 | + "metadata": { |
| 1117 | + "id": "Atxvrd_o7dSy" |
| 1118 | + }, |
| 1119 | + "source": [ |
| 1120 | + "But the `tf.py_function` inside executes eagerly every time:" |
| 1121 | + ] |
| 1122 | + }, |
| 1123 | + { |
| 1124 | + "cell_type": "code", |
| 1125 | + "execution_count": null, |
| 1126 | + "metadata": { |
| 1127 | + "id": "vv7qTiTU7bjy" |
| 1128 | + }, |
| 1129 | + "outputs": [], |
| 1130 | + "source": [ |
| 1131 | + "tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()" |
1073 | 1132 | ]
|
1074 | 1133 | },
|
1075 | 1134 | {
|
|
1778 | 1837 | "id": "7Q8BRPCThTjB"
|
1779 | 1838 | },
|
1780 | 1839 | "source": [
|
1781 |
| - "If you need to change the optimizer during training, a workaround is to create a new `Function` for each optimizer, calling the [`ConcreteFunction`](#obtaining_concrete_functions) directly." |
| 1840 | + "If you need to change a stateful object between calls, it's simplest to define a `tf.Module` subclass, and create instances to hold those objects:" |
| 1841 | + ] |
| 1842 | + }, |
| 1843 | + { |
| 1844 | + "cell_type": "code", |
| 1845 | + "execution_count": null, |
| 1846 | + "metadata": { |
| 1847 | + "id": "3P59ocmIslHz" |
| 1848 | + }, |
| 1849 | + "outputs": [], |
| 1850 | + "source": [ |
| 1851 | + "class TrainStep(tf.Module):\n", |
| 1852 | + " def __init__(self, optimizer):\n", |
| 1853 | + " self.optimizer = optimizer\n", |
| 1854 | + "\n", |
| 1855 | + " @tf.function\n", |
| 1856 | + " def __call__(self, w, x, y):\n", |
| 1857 | + " with tf.GradientTape() as tape:\n", |
| 1858 | + " L = tf.reduce_sum(tf.square(w*x - y))\n", |
| 1859 | + " gradients = tape.gradient(L, [w])\n", |
| 1860 | + " self.optimizer.apply_gradients(zip(gradients, [w]))\n", |
| 1861 | + "\n", |
| 1862 | + "\n", |
| 1863 | + "opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n", |
| 1864 | + "opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n", |
| 1865 | + "\n", |
| 1866 | + "train_o1 = TrainStep(opt1)\n", |
| 1867 | + "train_o2 = TrainStep(opt2)\n", |
| 1868 | + "\n", |
| 1869 | + "train_o1(w, x, y)\n", |
| 1870 | + "train_o2(w, x, y)" |
| 1871 | + ] |
| 1872 | + }, |
| 1873 | + { |
| 1874 | + "cell_type": "markdown", |
| 1875 | + "metadata": { |
| 1876 | + "id": "dUHUi881smHF" |
| 1877 | + }, |
| 1878 | + "source": [ |
| 1879 | + "You could also do this manually by creating multiple instances of the `@tf.function` wrapper, one for each optimizer:" |
1782 | 1880 | ]
|
1783 | 1881 | },
|
1784 | 1882 | {
|
|
1841 | 1939 | "metadata": {
|
1842 | 1940 | "colab": {
|
1843 | 1941 | "name": "function.ipynb",
|
| 1942 | + "private_outputs": true, |
| 1943 | + "provenance": [], |
1844 | 1944 | "toc_visible": true
|
1845 | 1945 | },
|
1846 | 1946 | "kernelspec": {
|
|
0 commit comments