Skip to content

Commit 0d6bb26

Browse files
Yuefeng Zhoucopybara-github
authored andcommitted
Update multi-worker Keras tutorial: add more details around sharding and steps_per_epoch; add a section for evaluation.
PiperOrigin-RevId: 314266301
1 parent 9af3e7c commit 0d6bb26

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

site/en/tutorials/distribute/multi_worker_with_keras.ipynb

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@
301301
"\n",
302302
"Note: In this Colab, the following code can run with expected result, but however this is effectively single-worker training since `TF_CONFIG` is not set. Once you set `TF_CONFIG` in your own example, you should expect speed-up with training on multiple machines.\n",
303303
"\n",
304-
"Note: Always pass in `steps_per_epoch` argument to `model.fit()` since `MultiWorkerMirroredStrategy` does not support last partial batch handling. When using `steps_per_epoch`, `model.fit()` does not create a new iterator from the input every epoch, but continues from wherever the last epoch ended. Hence, make sure to call `.repeat()` on the dataset so it has an adequate number of examples for N epochs."
304+
"Note: Always pass in `steps_per_epoch` argument to `model.fit()` since `MultiWorkerMirroredStrategy` does not support last partial batch handling. When using `steps_per_epoch`, `model.fit()` does not create a new iterator from the input every epoch, but continues from wherever the last epoch ended. Hence, make sure to call `.repeat()` on the dataset so it has an adequate number of examples for N epochs. If your dataset is not a repeated dataset, the `steps_per_epoch` should be set based on the amount of training data on each worker so that all workers would perform the same number of steps of training or evaluation, which is required by allreduce. In particular, if the sharding is not balanced, `steps_per_epoch` should be set to the size of the smallest sharded devided by the per-worker batch size."
305305
]
306306
},
307307
{
@@ -340,7 +340,7 @@
340340
},
341341
"source": [
342342
"### Dataset sharding and batch size\n",
343-
"In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to `model.fit()` without needing to shard; this is because `tf.distribute.Strategy` API takes care of the dataset sharding automatically in multi-worker trainings.\n",
343+
"In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to `model.fit()` without needing to shard; this is because `tf.distribute.Strategy` API takes care of the dataset sharding automatically in multi-worker trainings. It shards the dataset at the file level which may create skewed shards. In extreme cases where there is only one file, only the first shard (i.e. worker) will get training or evaluation data and as a result all workers will get errors.\n",
344344
"\n",
345345
"If you prefer manual sharding for your training, automatic sharding can be turned off via `tf.data.experimental.DistributeOptions` api. Concretely,"
346346
]
@@ -370,6 +370,31 @@
370370
"Another thing to notice is the batch size for the `datasets`. In the code snippet above, we use `global_batch_size = per_worker_batch_size * num_workers`, which is `num_workers` times as large as the case it was for single worker, because the effective per worker batch size is the global batch size (the parameter passed in `tf.data.Dataset.batch()`) divided by the number of workers, and with this change we are keeping the per worker batch size same as before."
371371
]
372372
},
373+
{
374+
"cell_type": "markdown",
375+
"metadata": {
376+
"colab_type": "text",
377+
"id": "gmqvlh5LhAoU"
378+
},
379+
"source": [
380+
"### Evaluation\n",
381+
"\n",
382+
"If you pass `validation_data` into `model.fit`, it will alternate between training and evaluation for each epoch. The evaluation taking `validation_data` is distributed across the same set of workers and the evaluation results are aggregated and available for all workers. Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set `validation_steps`. A repeated dataset is also recommended for evaluation.\n",
383+
"\n",
384+
"Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted."
385+
]
386+
},
387+
{
388+
"cell_type": "markdown",
389+
"metadata": {
390+
"colab_type": "text",
391+
"id": "CsQsRfBxKMcw"
392+
},
393+
"source": [
394+
"### Prediction\n",
395+
"Currently `model.predict` doesn't work with `MultiWorkerMirroredStrategy.`"
396+
]
397+
},
373398
{
374399
"cell_type": "markdown",
375400
"metadata": {

0 commit comments

Comments
 (0)