Skip to content

Commit 16f268f

Browse files
MarkDaoustcopybara-github
authored andcommitted
Use py_function decorators.
PiperOrigin-RevId: 568972897
1 parent 2176ed2 commit 16f268f

File tree

4 files changed

+210
-297
lines changed

4 files changed

+210
-297
lines changed

site/en/guide/data.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"cell_type": "code",
1616
"execution_count": null,
1717
"metadata": {
18-
"cellView": "form",
1918
"id": "llMNufAK7nfK"
2019
},
2120
"outputs": [],
@@ -589,7 +588,7 @@
589588
"source": [
590589
"The first output is an `int32` the second is a `float32`.\n",
591590
"\n",
592-
"The first item is a scalar, shape `()`, and the second is a vector of unknown length, shape `(None,)` "
591+
"The first item is a scalar, shape `()`, and the second is a vector of unknown length, shape `(None,)`"
593592
]
594593
},
595594
{
@@ -601,8 +600,8 @@
601600
"outputs": [],
602601
"source": [
603602
"ds_series = tf.data.Dataset.from_generator(\n",
604-
" gen_series, \n",
605-
" output_types=(tf.int32, tf.float32), \n",
603+
" gen_series,\n",
604+
" output_types=(tf.int32, tf.float32),\n",
606605
" output_shapes=((), (None,)))\n",
607606
"\n",
608607
"ds_series"
@@ -710,8 +709,8 @@
710709
"outputs": [],
711710
"source": [
712711
"ds = tf.data.Dataset.from_generator(\n",
713-
" lambda: img_gen.flow_from_directory(flowers), \n",
714-
" output_types=(tf.float32, tf.float32), \n",
712+
" lambda: img_gen.flow_from_directory(flowers),\n",
713+
" output_types=(tf.float32, tf.float32),\n",
715714
" output_shapes=([32,256,256,3], [32,5])\n",
716715
")\n",
717716
"\n",
@@ -1932,6 +1931,7 @@
19321931
"source": [
19331932
"import scipy.ndimage as ndimage\n",
19341933
"\n",
1934+
"@tf.py_function(Tout=tf.float32)\n",
19351935
"def random_rotate_image(image):\n",
19361936
" image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)\n",
19371937
" return image"
@@ -1969,7 +1969,7 @@
19691969
"source": [
19701970
"def tf_random_rotate_image(image, label):\n",
19711971
" im_shape = image.shape\n",
1972-
" [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])\n",
1972+
" image = random_rotate_image(image)\n",
19731973
" image.set_shape(im_shape)\n",
19741974
" return image, label"
19751975
]
@@ -2819,7 +2819,7 @@
28192819
"])\n",
28202820
"\n",
28212821
"model.compile(optimizer='adam',\n",
2822-
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), \n",
2822+
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
28232823
" metrics=['accuracy'])"
28242824
]
28252825
},
@@ -2953,8 +2953,8 @@
29532953
],
29542954
"metadata": {
29552955
"colab": {
2956-
"collapsed_sections": [],
29572956
"name": "data.ipynb",
2957+
"provenance": [],
29582958
"toc_visible": true
29592959
},
29602960
"kernelspec": {

site/en/guide/function.ipynb

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,66 @@
10691069
"id": "e1I0dPiqTV8H"
10701070
},
10711071
"source": [
1072-
"If you would like to execute Python code during each invocation of a `Function`, `tf.py_function` is an exit hatch. The drawback of `tf.py_function` is that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors."
1072+
"If you would like to execute Python code during each invocation of a `Function`, `tf. py_function` is an exit hatch. The drawbacks of `tf.py_function` are that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors."
1073+
]
1074+
},
1075+
{
1076+
"cell_type": "code",
1077+
"execution_count": null,
1078+
"metadata": {
1079+
"id": "ZbI7XA_e6yA2"
1080+
},
1081+
"outputs": [],
1082+
"source": [
1083+
"@tf.py_function(Tout=tf.float32)\n",
1084+
"def py_plus(x, y):\n",
1085+
" print('Executing eagerly.')\n",
1086+
" return x + y\n",
1087+
"\n",
1088+
"@tf.function\n",
1089+
"def tf_wrapper(x, y):\n",
1090+
" print('Tracing.')\n",
1091+
" return py_plus(x, y)"
1092+
]
1093+
},
1094+
{
1095+
"cell_type": "markdown",
1096+
"metadata": {
1097+
"id": "h5ttN_sI7TdQ"
1098+
},
1099+
"source": [
1100+
"The `tf.function` will trace the first time:"
1101+
]
1102+
},
1103+
{
1104+
"cell_type": "code",
1105+
"execution_count": null,
1106+
"metadata": {
1107+
"id": "mAK4XINl7Ldy"
1108+
},
1109+
"outputs": [],
1110+
"source": [
1111+
"tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()"
1112+
]
1113+
},
1114+
{
1115+
"cell_type": "markdown",
1116+
"metadata": {
1117+
"id": "Atxvrd_o7dSy"
1118+
},
1119+
"source": [
1120+
"But the `tf.py_function` inside executes eagerly every time:"
1121+
]
1122+
},
1123+
{
1124+
"cell_type": "code",
1125+
"execution_count": null,
1126+
"metadata": {
1127+
"id": "vv7qTiTU7bjy"
1128+
},
1129+
"outputs": [],
1130+
"source": [
1131+
"tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()"
10731132
]
10741133
},
10751134
{
@@ -1778,7 +1837,46 @@
17781837
"id": "7Q8BRPCThTjB"
17791838
},
17801839
"source": [
1781-
"If you need to change the optimizer during training, a workaround is to create a new `Function` for each optimizer, calling the [`ConcreteFunction`](#obtaining_concrete_functions) directly."
1840+
"If you need to change a stateful object between calls, it's simplest to define a `tf.Module` subclass, and create instances to hold those objects:"
1841+
]
1842+
},
1843+
{
1844+
"cell_type": "code",
1845+
"execution_count": null,
1846+
"metadata": {
1847+
"id": "3P59ocmIslHz"
1848+
},
1849+
"outputs": [],
1850+
"source": [
1851+
"class TrainStep(tf.Module):\n",
1852+
" def __init__(self, optimizer):\n",
1853+
" self.optimizer = optimizer\n",
1854+
"\n",
1855+
" @tf.function\n",
1856+
" def __call__(self, w, x, y):\n",
1857+
" with tf.GradientTape() as tape:\n",
1858+
" L = tf.reduce_sum(tf.square(w*x - y))\n",
1859+
" gradients = tape.gradient(L, [w])\n",
1860+
" self.optimizer.apply_gradients(zip(gradients, [w]))\n",
1861+
"\n",
1862+
"\n",
1863+
"opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
1864+
"opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
1865+
"\n",
1866+
"train_o1 = TrainStep(opt1)\n",
1867+
"train_o2 = TrainStep(opt2)\n",
1868+
"\n",
1869+
"train_o1(w, x, y)\n",
1870+
"train_o2(w, x, y)"
1871+
]
1872+
},
1873+
{
1874+
"cell_type": "markdown",
1875+
"metadata": {
1876+
"id": "dUHUi881smHF"
1877+
},
1878+
"source": [
1879+
"You could also do this manually by creating multiple instances of the `@tf.function` wrapper, one for each optimizer:"
17821880
]
17831881
},
17841882
{
@@ -1841,6 +1939,8 @@
18411939
"metadata": {
18421940
"colab": {
18431941
"name": "function.ipynb",
1942+
"private_outputs": true,
1943+
"provenance": [],
18441944
"toc_visible": true
18451945
},
18461946
"kernelspec": {

0 commit comments

Comments
 (0)