Skip to content

Commit dfea136

Browse files
arnoegwcopybara-github
authored andcommitted
Update CTL tutorial and guides to make loss reduction consistent
with Keras Model.fit: * Divide the prediction loss by actual per-replica batch size and num_replicas_in_sync. * Always include regularization losses. This change is for the tensorflow repo. The material on keras.io can stay as-is. PiperOrigin-RevId: 558161949
1 parent 6bdd89b commit dfea136

File tree

6 files changed

+191
-95
lines changed

6 files changed

+191
-95
lines changed

site/en/guide/distributed_training.ipynb

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,9 @@
526526
"mirrored_strategy = tf.distribute.MirroredStrategy()\n",
527527
"\n",
528528
"with mirrored_strategy.scope():\n",
529-
" model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
529+
" model = tf.keras.Sequential([\n",
530+
" tf.keras.layers.Dense(1, input_shape=(1,),\n",
531+
" kernel_regularizer=tf.keras.regularizers.L2(1e-4))])\n",
530532
" model.compile(loss='mse', optimizer='sgd')"
531533
]
532534
},
@@ -673,7 +675,9 @@
673675
"outputs": [],
674676
"source": [
675677
"with mirrored_strategy.scope():\n",
676-
" model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])\n",
678+
" model = tf.keras.Sequential([\n",
679+
" tf.keras.layers.Dense(1, input_shape=(1,),\n",
680+
" kernel_regularizer=tf.keras.regularizers.L2(1e-4))])\n",
677681
" optimizer = tf.keras.optimizers.SGD()"
678682
]
679683
},
@@ -716,20 +720,21 @@
716720
},
717721
"outputs": [],
718722
"source": [
723+
"# Sets `reduction=NONE` to leave it to tf.nn.compute_average_loss() below.\n",
719724
"loss_object = tf.keras.losses.BinaryCrossentropy(\n",
720725
" from_logits=True,\n",
721726
" reduction=tf.keras.losses.Reduction.NONE)\n",
722727
"\n",
723-
"def compute_loss(labels, predictions):\n",
724-
" per_example_loss = loss_object(labels, predictions)\n",
725-
" return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)\n",
726-
"\n",
727728
"def train_step(inputs):\n",
728729
" features, labels = inputs\n",
729730
"\n",
730731
" with tf.GradientTape() as tape:\n",
731732
" predictions = model(features, training=True)\n",
732-
" loss = compute_loss(labels, predictions)\n",
733+
" per_example_loss = loss_object(labels, predictions)\n",
734+
" loss = tf.nn.compute_average_loss(per_example_loss)\n",
735+
" model_losses = model.losses\n",
736+
" if model_losses:\n",
737+
" loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
733738
"\n",
734739
" gradients = tape.gradient(loss, model.trainable_variables)\n",
735740
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
@@ -750,9 +755,16 @@
750755
"source": [
751756
"A few other things to note in the code above:\n",
752757
"\n",
753-
"1. You used `tf.nn.compute_average_loss` to compute the loss. `tf.nn.compute_average_loss` sums the per example loss and divides the sum by the `global_batch_size`. This is important because later after the gradients are calculated on each replica, they are aggregated across the replicas by **summing** them.\n",
754-
"2. You also used the `tf.distribute.Strategy.reduce` API to aggregate the results returned by `tf.distribute.Strategy.run`. `tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can `reduce` them to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.\n",
755-
"3. When you call `apply_gradients` within a distribution strategy scope, its behavior is modified. Specifically, before applying gradients on each parallel instance during synchronous training, it performs a sum-over-all-replicas of the gradients.\n"
758+
" 1. You used `tf.nn.compute_average_loss` to reduce the per-example prediction losses to a scalar. `tf.nn.compute_average_loss` sums the per example loss and divides the sum by the global batch size. This is important because later after the gradients are calculated on each replica, they are aggregated across the replicas by **summing** them.\n",
759+
"\n",
760+
" By default, the global batch size is taken to be `tf.get_strategy().num_replicas_in_sync * tf.shape(per_example_loss)[0]`. It can also be specified explicitly as a keyword argument `global_batch_size=`. Without short batches, the default is equivalent to `tf.nn.compute_average_loss(..., global_batch_size=global_batch_size)` with the `global_batch_size` defined above. (For more on short batches and how to avoid or handle them, see the [Custom Training tutorial](../tutorials/distribute/custom_training.ipynb).)\n",
761+
"\n",
762+
" 2. You used `tf.nn.scale_regularization_loss` to scale regularization losses registered with the `Model` object, if any, by `1/num_replicas_in_sync` as well. For those regularization losses that are input-dependent, it falls on the modeling code, not the custom training loop, to perform the averaging over the per-replica(!) batch size; that way the modeling code can remain agnostic of replication while the training loop remains agnostic of how regularization losses are computed.\n",
763+
"\n",
764+
" 3. When you call `apply_gradients` within a distribution strategy scope, its behavior is modified. Specifically, before applying gradients on each parallel instance during synchronous training, it performs a sum-over-all-replicas of the gradients.\n",
765+
"\n",
766+
" 4. You also used the `tf.distribute.Strategy.reduce` API to aggregate the results returned by `tf.distribute.Strategy.run` for reporting. `tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can `reduce` them to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.\n",
767+
"\n"
756768
]
757769
},
758770
{

site/en/guide/tpu.ipynb

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,32 @@
254254
"outputs": [],
255255
"source": [
256256
"def create_model():\n",
257+
" regularizer = tf.keras.regularizers.L2(1e-5)\n",
257258
" return tf.keras.Sequential(\n",
258-
" [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),\n",
259-
" tf.keras.layers.Conv2D(256, 3, activation='relu'),\n",
259+
" [tf.keras.layers.Conv2D(256, 3, input_shape=(28, 28, 1),\n",
260+
" activation='relu',\n",
261+
" kernel_regularizer=regularizer),\n",
262+
" tf.keras.layers.Conv2D(256, 3,\n",
263+
" activation='relu',\n",
264+
" kernel_regularizer=regularizer),\n",
260265
" tf.keras.layers.Flatten(),\n",
261-
" tf.keras.layers.Dense(256, activation='relu'),\n",
262-
" tf.keras.layers.Dense(128, activation='relu'),\n",
263-
" tf.keras.layers.Dense(10)])"
266+
" tf.keras.layers.Dense(256,\n",
267+
" activation='relu',\n",
268+
" kernel_regularizer=regularizer),\n",
269+
" tf.keras.layers.Dense(128,\n",
270+
" activation='relu',\n",
271+
" kernel_regularizer=regularizer),\n",
272+
" tf.keras.layers.Dense(10,\n",
273+
" kernel_regularizer=regularizer)])"
274+
]
275+
},
276+
{
277+
"cell_type": "markdown",
278+
"metadata": {
279+
"id": "h-2qaXgfyONQ"
280+
},
281+
"source": [
282+
"This model puts L2 regularization terms on the weights of each layer, so that the custom training loop below can show how you pick them up from `Model.losses`."
264283
]
265284
},
266285
{
@@ -442,9 +461,13 @@
442461
" images, labels = inputs\n",
443462
" with tf.GradientTape() as tape:\n",
444463
" logits = model(images, training=True)\n",
445-
" loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
464+
" per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
446465
" labels, logits, from_logits=True)\n",
447-
" loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)\n",
466+
" loss = tf.nn.compute_average_loss(per_example_loss)\n",
467+
" model_losses = model.losses\n",
468+
" if model_losses:\n",
469+
" loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
470+
"\n",
448471
" grads = tape.gradient(loss, model.trainable_variables)\n",
449472
" optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))\n",
450473
" training_loss.update_state(loss * strategy.num_replicas_in_sync)\n",
@@ -478,7 +501,7 @@
478501
"\n",
479502
" for step in range(steps_per_epoch):\n",
480503
" train_step(train_iterator)\n",
481-
" print('Current step: {}, training loss: {}, accuracy: {}%'.format(\n",
504+
" print('Current step: {}, training loss: {}, training accuracy: {}%'.format(\n",
482505
" optimizer.iterations.numpy(),\n",
483506
" round(float(training_loss.result()), 4),\n",
484507
" round(float(training_accuracy.result()) * 100, 2)))\n",
@@ -516,9 +539,12 @@
516539
" images, labels = inputs\n",
517540
" with tf.GradientTape() as tape:\n",
518541
" logits = model(images, training=True)\n",
519-
" loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
542+
" per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
520543
" labels, logits, from_logits=True)\n",
521-
" loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)\n",
544+
" loss = tf.nn.compute_average_loss(per_example_loss)\n",
545+
" model_losses = model.losses\n",
546+
" if model_losses:\n",
547+
" loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n",
522548
" grads = tape.gradient(loss, model.trainable_variables)\n",
523549
" optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))\n",
524550
" training_loss.update_state(loss * strategy.num_replicas_in_sync)\n",
@@ -531,7 +557,7 @@
531557
"# retraced if the value changes.\n",
532558
"train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))\n",
533559
"\n",
534-
"print('Current step: {}, training loss: {}, accuracy: {}%'.format(\n",
560+
"print('Current step: {}, training loss: {}, training accuracy: {}%'.format(\n",
535561
" optimizer.iterations.numpy(),\n",
536562
" round(float(training_loss.result()), 4),\n",
537563
" round(float(training_accuracy.result()) * 100, 2)))"

0 commit comments

Comments
 (0)