Skip to content

Commit 3050034

Browse files
Add a recursive tf.function section to the tf.function limitation.
The api.StagingError or its base AutoGraphError is not exported. So use the Python Exception in the doc. Also tweaked the print format of leaky tensors a little bit. PiperOrigin-RevId: 402405540
1 parent 334203c commit 3050034

File tree

1 file changed

+64
-4
lines changed

1 file changed

+64
-4
lines changed

site/en/guide/function.ipynb

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,9 +1166,10 @@
11661166
"correct_a = leaky_function(tf.constant(1))\n",
11671167
"\n",
11681168
"print(correct_a.numpy()) # Good - value obtained from function's returns\n",
1169-
"with assert_raises(AttributeError):\n",
1169+
"try:\n",
11701170
" x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
1171-
"print(x)"
1171+
"except AttributeError as expected:\n",
1172+
" print(expected)"
11721173
]
11731174
},
11741175
{
@@ -1197,9 +1198,10 @@
11971198
"correct_a = leaky_function(tf.constant(1))\n",
11981199
"\n",
11991200
"print(correct_a.numpy()) # Good - value obtained from function's returns\n",
1200-
"with assert_raises(AttributeError):\n",
1201+
"try:\n",
12011202
" x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
1202-
"print(x)\n",
1203+
"except AttributeError as expected:\n",
1204+
" print(expected)\n",
12031205
"\n",
12041206
"@tf.function\n",
12051207
"def captures_leaked_tensor(b):\n",
@@ -1244,6 +1246,64 @@
12441246
" external_object.field = a # Bad - leaks tensor"
12451247
]
12461248
},
1249+
{
1250+
"cell_type": "markdown",
1251+
"metadata": {
1252+
"id": "g-XVQcD-wf5K"
1253+
},
1254+
"source": [
1255+
"### Recursive tf.functions are not supported\n",
1256+
"\n",
1257+
"Recursive `Function`s are not supported and could cause infinite loops. For example,"
1258+
]
1259+
},
1260+
{
1261+
"cell_type": "code",
1262+
"execution_count": null,
1263+
"metadata": {
1264+
"id": "QSN-T1m5EFcR"
1265+
},
1266+
"outputs": [],
1267+
"source": [
1268+
"@tf.function\n",
1269+
"def recursive_fn(n):\n",
1270+
" if n > 0:\n",
1271+
" return recursive_fn(n - 1)\n",
1272+
" else:\n",
1273+
" return 1\n",
1274+
"\n",
1275+
"with assert_raises(Exception):\n",
1276+
" recursive_fn(tf.constant(5)) # Bad - maximum recursion error."
1277+
]
1278+
},
1279+
{
1280+
"cell_type": "markdown",
1281+
"metadata": {
1282+
"id": "LyRyooKGUxNV"
1283+
},
1284+
"source": [
1285+
"Even if a recursive `Function` seems to work, the python function will be traced multiple times and could have performance implication. For example,"
1286+
]
1287+
},
1288+
{
1289+
"cell_type": "code",
1290+
"execution_count": null,
1291+
"metadata": {
1292+
"id": "7FlmTqfMUwmT"
1293+
},
1294+
"outputs": [],
1295+
"source": [
1296+
"@tf.function\n",
1297+
"def recursive_fn(n):\n",
1298+
" if n > 0:\n",
1299+
" print('tracing')\n",
1300+
" return recursive_fn(n - 1)\n",
1301+
" else:\n",
1302+
" return 1\n",
1303+
"\n",
1304+
"recursive_fn(5) # Warning - multiple tracings"
1305+
]
1306+
},
12471307
{
12481308
"cell_type": "markdown",
12491309
"metadata": {

0 commit comments

Comments
 (0)