|
1166 | 1166 | "correct_a = leaky_function(tf.constant(1))\n",
|
1167 | 1167 | "\n",
|
1168 | 1168 | "print(correct_a.numpy()) # Good - value obtained from function's returns\n",
|
1169 |
| - "with assert_raises(AttributeError):\n", |
| 1169 | + "try:\n", |
1170 | 1170 | " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
|
1171 |
| - "print(x)" |
| 1171 | + "except AttributeError as expected:\n", |
| 1172 | + " print(expected)" |
1172 | 1173 | ]
|
1173 | 1174 | },
|
1174 | 1175 | {
|
|
1197 | 1198 | "correct_a = leaky_function(tf.constant(1))\n",
|
1198 | 1199 | "\n",
|
1199 | 1200 | "print(correct_a.numpy()) # Good - value obtained from function's returns\n",
|
1200 |
| - "with assert_raises(AttributeError):\n", |
| 1201 | + "try:\n", |
1201 | 1202 | " x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
|
1202 |
| - "print(x)\n", |
| 1203 | + "except AttributeError as expected:\n", |
| 1204 | + " print(expected)\n", |
1203 | 1205 | "\n",
|
1204 | 1206 | "@tf.function\n",
|
1205 | 1207 | "def captures_leaked_tensor(b):\n",
|
|
1244 | 1246 | " external_object.field = a # Bad - leaks tensor"
|
1245 | 1247 | ]
|
1246 | 1248 | },
|
| 1249 | + { |
| 1250 | + "cell_type": "markdown", |
| 1251 | + "metadata": { |
| 1252 | + "id": "g-XVQcD-wf5K" |
| 1253 | + }, |
| 1254 | + "source": [ |
| 1255 | + "### Recursive tf.functions are not supported\n", |
| 1256 | + "\n", |
| 1257 | + "Recursive `Function`s are not supported and could cause infinite loops. For example," |
| 1258 | + ] |
| 1259 | + }, |
| 1260 | + { |
| 1261 | + "cell_type": "code", |
| 1262 | + "execution_count": null, |
| 1263 | + "metadata": { |
| 1264 | + "id": "QSN-T1m5EFcR" |
| 1265 | + }, |
| 1266 | + "outputs": [], |
| 1267 | + "source": [ |
| 1268 | + "@tf.function\n", |
| 1269 | + "def recursive_fn(n):\n", |
| 1270 | + " if n > 0:\n", |
| 1271 | + " return recursive_fn(n - 1)\n", |
| 1272 | + " else:\n", |
| 1273 | + " return 1\n", |
| 1274 | + "\n", |
| 1275 | + "with assert_raises(Exception):\n", |
| 1276 | + " recursive_fn(tf.constant(5)) # Bad - maximum recursion error." |
| 1277 | + ] |
| 1278 | + }, |
| 1279 | + { |
| 1280 | + "cell_type": "markdown", |
| 1281 | + "metadata": { |
| 1282 | + "id": "LyRyooKGUxNV" |
| 1283 | + }, |
| 1284 | + "source": [ |
| 1285 | + "Even if a recursive `Function` seems to work, the python function will be traced multiple times and could have performance implication. For example," |
| 1286 | + ] |
| 1287 | + }, |
| 1288 | + { |
| 1289 | + "cell_type": "code", |
| 1290 | + "execution_count": null, |
| 1291 | + "metadata": { |
| 1292 | + "id": "7FlmTqfMUwmT" |
| 1293 | + }, |
| 1294 | + "outputs": [], |
| 1295 | + "source": [ |
| 1296 | + "@tf.function\n", |
| 1297 | + "def recursive_fn(n):\n", |
| 1298 | + " if n > 0:\n", |
| 1299 | + " print('tracing')\n", |
| 1300 | + " return recursive_fn(n - 1)\n", |
| 1301 | + " else:\n", |
| 1302 | + " return 1\n", |
| 1303 | + "\n", |
| 1304 | + "recursive_fn(5) # Warning - multiple tracings" |
| 1305 | + ] |
| 1306 | + }, |
1247 | 1307 | {
|
1248 | 1308 | "cell_type": "markdown",
|
1249 | 1309 | "metadata": {
|
|
0 commit comments