|
69 | 69 | "\n",
|
70 | 70 | "This guide first demonstrates how to add fault tolerance to training with `tf.estimator.Estimator` in TensorFlow 1 by specifying metric saving with `tf.estimator.RunConfig`. Then, you will learn how to implement fault tolerance for training in Tensorflow 2 in two ways:\n",
|
71 | 71 | "\n",
|
72 |
| - "- If you use the Keras `Model.fit` API, you can pass the `tf.keras.callbacks.experimental.BackupAndRestore` callback to it.\n", |
| 72 | + "- If you use the Keras `Model.fit` API, you can pass the `tf.keras.callbacks.BackupAndRestore` callback to it.\n", |
73 | 73 | "- If you use a custom training loop (with `tf.GradientTape`), you can arbitrarily save checkpoints using the `tf.train.Checkpoint` and `tf.train.CheckpointManager` APIs.\n",
|
74 | 74 | "\n",
|
75 | 75 | "Both of these methods will back up and restore the training states in [checkpoint](../../guide/checkpoint.ipynb) files.\n"
|
|
252 | 252 | "source": [
|
253 | 253 | "## TensorFlow 2: Back up and restore with a callback and Model.fit\n",
|
254 | 254 | "\n",
|
255 |
| - "In TensorFlow 2, if you use the Keras `Model.fit` API for training, you can provide the `tf.keras.callbacks.experimental.BackupAndRestore` callback to add the fault tolerance functionality.\n", |
| 255 | + "In TensorFlow 2, if you use the Keras `Model.fit` API for training, you can provide the `tf.keras.callbacks.BackupAndRestore` callback to add the fault tolerance functionality.\n", |
256 | 256 | "\n",
|
257 | 257 | "To help demonstrate this, let's first start by defining a callback class that artificially throws an error during the fifth checkpoint:\n"
|
258 | 258 | ]
|
|
278 | 278 | "id": "AhU3VTYZoDh-"
|
279 | 279 | },
|
280 | 280 | "source": [
|
281 |
| - "Then, define and instantiate a simple Keras model, define the loss function, call `Model.compile`, and set up a `tf.keras.callbacks.experimental.BackupAndRestore` callback that will save the checkpoints in a temporary directory:" |
| 281 | + "Then, define and instantiate a simple Keras model, define the loss function, call `Model.compile`, and set up a `tf.keras.callbacks.BackupAndRestore` callback that will save the checkpoints in a temporary directory:" |
282 | 282 | ]
|
283 | 283 | },
|
284 | 284 | {
|
|
307 | 307 | "\n",
|
308 | 308 | "log_dir = tempfile.mkdtemp()\n",
|
309 | 309 | "\n",
|
310 |
| - "backup_restore_callback = tf.keras.callbacks.experimental.BackupAndRestore(\n", |
| 310 | + "backup_restore_callback = tf.keras.callbacks.BackupAndRestore(\n", |
311 | 311 | " backup_dir = log_dir\n",
|
312 | 312 | ")"
|
313 | 313 | ]
|
|
452 | 452 | "\n",
|
453 | 453 | "To learn more about fault tolerance and checkpointing in TensorFlow 2, consider the following documentation:\n",
|
454 | 454 | "\n",
|
455 |
| - "- The `tf.keras.callbacks.experimental.BackupAndRestore` callback API docs.\n", |
| 455 | + "- The `tf.keras.callbacks.BackupAndRestore` callback API docs.\n", |
456 | 456 | "- The `tf.train.Checkpoint` and `tf.train.CheckpointManager` API docs.\n",
|
457 | 457 | "- The [Training checkpoints](../../guide/checkpoint.ipynb) guide, including the _Writing checkpoints_ section.\n",
|
458 | 458 | "\n",
|
|
0 commit comments