Skip to content

Commit 10313cb

Browse files
hubingallincopybara-github
authored andcommitted
Review and update public doc "Distributed training with Keras"
PiperOrigin-RevId: 438662329
1 parent ee52c35 commit 10313cb

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

site/en/tutorials/distribute/keras.ipynb

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"\n",
7777
"You will use the `tf.keras` APIs to build the model and `Model.fit` for training it. (To learn about distributed training with a custom training loop and the `MirroredStrategy`, check out [this tutorial](custom_training.ipynb).)\n",
7878
"\n",
79-
"`MirroredStrategy` trains your model on multiple GPUs on a single machine. For _synchronous training on many GPUs on multiple workers_, use the `tf.distribute.MultiWorkerMirroredStrategy` [with the Keras Model.fit](multi_worker_with_keras.ipynb) or [a custom training loop](multi_worker_with_ctl.ipynb). For other options, refer to the [Distributed training guide](../../guide/distributed_training.ipynb).\n",
79+
"`MirroredStrategy` trains your model on multiple GPUs on a single machine. For _synchronous training on many GPUs on multiple workers_, use the `tf.distribute.MultiWorkerMirroredStrategy` with the [Keras Model.fit](multi_worker_with_keras.ipynb) or [a custom training loop](multi_worker_with_ctl.ipynb). For other options, refer to the [Distributed training guide](../../guide/distributed_training.ipynb).\n",
8080
"\n",
8181
"To learn about various other strategies, there is the [Distributed training with TensorFlow](../../guide/distributed_training.ipynb) guide."
8282
]
@@ -289,7 +289,7 @@
289289
"id": "1BnQYQTpB3YA"
290290
},
291291
"source": [
292-
"Create and compile the Keras model in the context of `Strategy.scope`:"
292+
"Within the context of `Strategy.scope`, create and compile the model using the Keras API:"
293293
]
294294
},
295295
{
@@ -329,13 +329,16 @@
329329
"id": "YOXO5nvvK3US"
330330
},
331331
"source": [
332-
"Define the following `tf.keras.callbacks`:\n",
332+
"Define the following [Keras Callbacks](https://www.tensorflow.org/guide/keras/train_and_evaluate):\n",
333333
"\n",
334334
"- `tf.keras.callbacks.TensorBoard`: writes a log for TensorBoard, which allows you to visualize the graphs.\n",
335335
"- `tf.keras.callbacks.ModelCheckpoint`: saves the model at a certain frequency, such as after every epoch.\n",
336+
"- `tf.keras.callbacks.BackupAndRestore`: provides the fault tolerance functionality by backing up the model and current epoch number. Learn more in the _Fault tolerance_ section of the [Multi-worker training with Keras](multi_worker_with_keras.ipynb) tutorial.\n",
336337
"- `tf.keras.callbacks.LearningRateScheduler`: schedules the learning rate to change after, for example, every epoch/batch.\n",
337338
"\n",
338-
"For illustrative purposes, add a custom callback called `PrintLR` to display the *learning rate* in the notebook."
339+
"For illustrative purposes, add a [custom callback](https://www.tensorflow.org/guide/keras/custom_callback) called `PrintLR` to display the *learning rate* in the notebook.\n",
340+
"\n",
341+
"**Note:** Use the `BackupAndRestore` callback instead of `ModelCheckpoint` as the main mechanism to restore the training state upon a restart from a job failure. Since `BackupAndRestore` only supports eager mode, in graph mode consider using `ModelCheckpoint`."
339342
]
340343
},
341344
{
@@ -382,8 +385,8 @@
382385
"# Define a callback for printing the learning rate at the end of each epoch.\n",
383386
"class PrintLR(tf.keras.callbacks.Callback):\n",
384387
" def on_epoch_end(self, epoch, logs=None):\n",
385-
" print('\\nLearning rate for epoch {} is {}'.format(epoch + 1,\n",
386-
" model.optimizer.lr.numpy()))"
388+
" print('\\nLearning rate for epoch {} is {}'.format(",
389+
" epoch + 1, model.optimizer.lr.numpy()))"
387390
]
388391
},
389392
{
@@ -419,7 +422,7 @@
419422
"id": "6EophnOAB3YD"
420423
},
421424
"source": [
422-
"Now, train the model in the usual way by calling `Model.fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not."
425+
"Now, train the model in the usual way by calling Keras `Model.fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not."
423426
]
424427
},
425428
{
@@ -535,7 +538,7 @@
535538
"id": "Xa87y_A0vRma"
536539
},
537540
"source": [
538-
"Export the graph and the variables to the platform-agnostic SavedModel format using `Model.save`. After your model is saved, you can load it with or without the `Strategy.scope`."
541+
"Export the graph and the variables to the platform-agnostic SavedModel format using Keras `Model.save`. After your model is saved, you can load it with or without the `Strategy.scope`."
539542
]
540543
},
541544
{
@@ -626,7 +629,7 @@
626629
"\n",
627630
"More examples that use different distribution strategies with the Keras `Model.fit` API:\n",
628631
"\n",
629-
"1. The [Solve GLUE tasks using BERT on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial uses `tf.distribute.MirroredStrategy` for training on GPUs and `tf.distribute.TPUStrategy`on TPUs.\n",
632+
"1. The [Solve GLUE tasks using BERT on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial uses `tf.distribute.MirroredStrategy` for training on GPUs and `tf.distribute.TPUStrategy` on TPUs.\n",
630633
"1. The [Save and load a model using a distribution strategy](save_and_load.ipynb) tutorial demonstates how to use the SavedModel APIs with `tf.distribute.Strategy`.\n",
631634
"1. The [official TensorFlow models](https://github.com/tensorflow/models/tree/master/official) can be configured to run multiple distribution strategies.\n",
632635
"\n",

0 commit comments

Comments
 (0)