Skip to content

Commit 44a3c7b

Browse files
hubingallincopybara-github
authored andcommitted
Review and update public doc "Custom training with tf.distribute.Strategy"
PiperOrigin-RevId: 440216715
1 parent 4ab7c5e commit 44a3c7b

File tree

1 file changed

+29
-37
lines changed

1 file changed

+29
-37
lines changed

site/en/tutorials/distribute/custom_training.ipynb

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
"id": "FbVhjPpzn6BM"
6969
},
7070
"source": [
71-
"This tutorial demonstrates how to use [`tf.distribute.Strategy`](https://www.tensorflow.org/guide/distributed_training) with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.\n",
71+
"This tutorial demonstrates how to use `tf.distribute.Strategy` — a TensorFlow API that provides an abstraction for [distributing your training](../../guide/distributed_training.ipynb) across multiple processing units (GPUs, multiple machines, or TPUs) — with custom training loops. In this example, you will train a simple convolutional neural network on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) containing 70,000 images of size 28 x 28.\n",
7272
"\n",
73-
"We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop."
73+
"[Custom training loops](../customization/custom_training_walkthrough.ipynb) provide flexibility and a greater control on training. They also make it is easier to debug the model and the training loop."
7474
]
7575
},
7676
{
@@ -97,7 +97,7 @@
9797
"id": "MM6W__qraV55"
9898
},
9999
"source": [
100-
"## Download the fashion MNIST dataset"
100+
"## Download the Fashion MNIST dataset"
101101
]
102102
},
103103
{
@@ -112,14 +112,14 @@
112112
"\n",
113113
"(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
114114
"\n",
115-
"# Adding a dimension to the array -> new shape == (28, 28, 1)\n",
116-
"# We are doing this because the first layer in our model is a convolutional\n",
115+
"# Add a dimension to the array -> new shape == (28, 28, 1)\n",
116+
"# This is done because the first layer in our model is a convolutional\n",
117117
"# layer and it requires a 4D input (batch_size, height, width, channels).\n",
118118
"# batch_size dimension will be added later on.\n",
119119
"train_images = train_images[..., None]\n",
120120
"test_images = test_images[..., None]\n",
121121
"\n",
122-
"# Getting the images in [0, 1] range.\n",
122+
"# Scale the images to the [0, 1] range.\n",
123123
"train_images = train_images / np.float32(255)\n",
124124
"test_images = test_images / np.float32(255)"
125125
]
@@ -141,13 +141,13 @@
141141
"source": [
142142
"How does `tf.distribute.MirroredStrategy` strategy work?\n",
143143
"\n",
144-
"* All the variables and the model graph is replicated on the replicas.\n",
144+
"* All the variables and the model graph are replicated across the replicas.\n",
145145
"* Input is evenly distributed across the replicas.\n",
146146
"* Each replica calculates the loss and gradients for the input it received.\n",
147147
"* The gradients are synced across all the replicas by summing them.\n",
148148
"* After the sync, the same update is made to the copies of the variables on each replica.\n",
149149
"\n",
150-
"Note: You can put all the code below inside a single scope. We are dividing it into several code cells for illustration purposes.\n"
150+
"Note: You can put all the code below inside a single scope. This example divides it into several code cells for illustration purposes.\n"
151151
]
152152
},
153153
{
@@ -158,8 +158,8 @@
158158
},
159159
"outputs": [],
160160
"source": [
161-
"# If the list of devices is not specified in the\n",
162-
"# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.\n",
161+
"# If the list of devices is not specified in\n",
162+
"# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.\n",
163163
"strategy = tf.distribute.MirroredStrategy()"
164164
]
165165
},
@@ -171,7 +171,7 @@
171171
},
172172
"outputs": [],
173173
"source": [
174-
"print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
174+
"print('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
175175
]
176176
},
177177
{
@@ -183,15 +183,6 @@
183183
"## Setup input pipeline"
184184
]
185185
},
186-
{
187-
"cell_type": "markdown",
188-
"metadata": {
189-
"id": "0Qb6nDgxiN_n"
190-
},
191-
"source": [
192-
"Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope."
193-
]
194-
},
195186
{
196187
"cell_type": "code",
197188
"execution_count": null,
@@ -240,7 +231,7 @@
240231
"source": [
241232
"## Create the model\n",
242233
"\n",
243-
"Create a model using `tf.keras.Sequential`. You can also use the Model Subclassing API to do this."
234+
"Create a model using `tf.keras.Sequential`. You can also use the [Model Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) or the [functional API](https://www.tensorflow.org/guide/keras/functional) to do this."
244235
]
245236
},
246237
{
@@ -286,14 +277,14 @@
286277
"source": [
287278
"## Define the loss function\n",
288279
"\n",
289-
"Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.\n",
280+
"Normally, on a single machine with single GPU/CPU, loss is divided by the number of examples in the batch of input.\n",
290281
"\n",
291282
"*So, how should the loss be calculated when using a `tf.distribute.Strategy`?*\n",
292283
"\n",
293284
"* For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed\n",
294285
"across the replicas (4 GPUs), each replica getting an input of size 16.\n",
295286
"\n",
296-
"* The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64)."
287+
"* The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (`BATCH_SIZE_PER_REPLICA` = 16), the loss should be divided by the `GLOBAL_BATCH_SIZE` (64)."
297288
]
298289
},
299290
{
@@ -315,10 +306,10 @@
315306
"source": [
316307
"*How to do this in TensorFlow?*\n",
317308
"\n",
318-
"* If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: \n",
309+
"* If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the `GLOBAL_BATCH_SIZE`: \n",
319310
"`scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)`\n",
320311
"or you can use `tf.nn.compute_average_loss` which takes the per example loss,\n",
321-
"optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.\n",
312+
"optional sample weights, and `GLOBAL_BATCH_SIZE` as arguments and returns the scaled loss.\n",
322313
"\n",
323314
"* If you are using regularization losses in your model then you need to scale\n",
324315
"the loss value by number of replicas. You can do this by using the `tf.nn.scale_regularization_loss` function.\n",
@@ -351,7 +342,7 @@
351342
"outputs": [],
352343
"source": [
353344
"with strategy.scope():\n",
354-
" # Set reduction to `none` so we can do the reduction afterwards and divide by\n",
345+
" # Set reduction to `NONE` so you can do the reduction afterwards and divide by\n",
355346
" # global batch size.\n",
356347
" loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
357348
" from_logits=True,\n",
@@ -484,9 +475,9 @@
484475
"\n",
485476
" template = (\"Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, \"\n",
486477
" \"Test Accuracy: {}\")\n",
487-
" print (template.format(epoch+1, train_loss,\n",
488-
" train_accuracy.result()*100, test_loss.result(),\n",
489-
" test_accuracy.result()*100))\n",
478+
" print(template.format(epoch + 1, train_loss,\n",
479+
" train_accuracy.result() * 100, test_loss.result(),\n",
480+
" test_accuracy.result() * 100))\n",
490481
"\n",
491482
" test_loss.reset_states()\n",
492483
" train_accuracy.reset_states()\n",
@@ -501,7 +492,7 @@
501492
"source": [
502493
"Things to note in the example above:\n",
503494
"\n",
504-
"* We are iterating over the `train_dist_dataset` and `test_dist_dataset` using a `for x in ...` construct.\n",
495+
"* Iterate over the `train_dist_dataset` and `test_dist_dataset` using a `for x in ...` construct.\n",
505496
"* The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.\n",
506497
"* `tf.keras.Metrics` should be updated inside `train_step` and `test_step` that gets executed by `tf.distribute.Strategy.run`.\n",
507498
"*`tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.\n"
@@ -570,8 +561,8 @@
570561
"for images, labels in test_dataset:\n",
571562
" eval_step(images, labels)\n",
572563
"\n",
573-
"print ('Accuracy after restoring the saved model without strategy: {}'.format(\n",
574-
" eval_accuracy.result()*100))"
564+
"print('Accuracy after restoring the saved model without strategy: {}'.format(\n",
565+
" eval_accuracy.result() * 100))"
575566
]
576567
},
577568
{
@@ -606,7 +597,7 @@
606597
" average_train_loss = total_loss / num_batches\n",
607598
"\n",
608599
" template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
609-
" print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))\n",
600+
" print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))\n",
610601
" train_accuracy.reset_states()"
611602
]
612603
},
@@ -617,7 +608,7 @@
617608
},
618609
"source": [
619610
"### Iterating inside a tf.function\n",
620-
"You can also iterate over the entire input `train_dist_dataset` inside a tf.function using the `for x in ...` construct or by creating iterators like we did above. The example below demonstrates wrapping one epoch of training in a tf.function and iterating over `train_dist_dataset` inside the function."
611+
"You can also iterate over the entire input `train_dist_dataset` inside a `tf.function` using the `for x in ...` construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a `@tf.function` decorator and iterating over `train_dist_dataset` inside the function."
621612
]
622613
},
623614
{
@@ -643,7 +634,7 @@
643634
" train_loss = distributed_train_epoch(train_dist_dataset)\n",
644635
"\n",
645636
" template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
646-
" print (template.format(epoch+1, train_loss, train_accuracy.result()*100))\n",
637+
" print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))\n",
647638
"\n",
648639
" train_accuracy.reset_states()"
649640
]
@@ -658,7 +649,7 @@
658649
"\n",
659650
"Note: As a general rule, you should use `tf.keras.Metrics` to track per-sample values and avoid values that have been aggregated within a replica.\n",
660651
"\n",
661-
"We do *not* recommend using `tf.metrics.Mean` to track the training loss across different replicas, because of the loss scaling computation that is carried out.\n",
652+
"Because of the loss scaling computation that is carried out, it's not recommended to use `tf.metrics.Mean` to track the training loss across different replicas.\n",
662653
"\n",
663654
"For example, if you run a training job with the following characteristics:\n",
664655
"* Two replicas\n",
@@ -699,7 +690,8 @@
699690
"## Next steps\n",
700691
"\n",
701692
"* Try out the new `tf.distribute.Strategy` API on your models.\n",
702-
"* Visit the [Performance section](../../guide/function.ipynb) in the guide to learn more about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models."
693+
"* Visit the [Better performance with tf.function](../../guide/function.ipynb) and [TensorFlow Profiler](../../guide/profiler.md) guide to learn more about tools to optimize the performance of your TensorFlow models.\n",
694+
"* The [Distributed training in TensorFlow](../../guide/distributed_training.ipynb) guide provides an overview of the available distribution strategies."
703695
]
704696
}
705697
],

0 commit comments

Comments
 (0)