Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 8563574

Browse files
authored
Address comments
1 parent 635e22d commit 8563574

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

rfcs/20200929-keras-mixed-precision.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ print(tf.keras.mixed_precision.global_policy()) # float32
224224

225225
These rules for tracking floatx are relatively unimportant and exist primarily for backwards compatibility. For the most part, users do not have to be aware of floatx and can simply assume the global policy defaults to float32.
226226

227+
Users must know which device their model runs on so they can choose between setting the global policy to "mixed_float16" or "mixed_bfloat16". Alternatively, Keras could choose for the user by allowing the global policy to be set to a special value, say "mixed_auto", which would select between "mixed_float16" and "mixed_bfloat16" depending on the hardware available. The reason the API does not have "mixed_auto" is that users should be aware of which policy they are using, as it affects the checkpoint/SavedModel format as well as whether loss scaling is required (Users must explicitly use a LossScaleOptimizer in a custom training loop).
228+
227229
`set_global_policy` requires the policy to be floating-point. That is, the policy’s name must be one of "float16", "bfloat16", "float32", "float64", "mixed_float16", or "mixed_bfloat16". The reason is that most layers do not support non-floating-point policies.
228230

229231
A warning will be issued if the global policy is set to "float16" or "bfloat16", as these policies typically result in substantially worse training results. Also, such policies are typically not useful for inference, as a model with float16 variables cannot load training checkpoints with float32 variables.
@@ -808,6 +810,8 @@ To get the loss scale as a float32 tensor, a user calls `optimizer.loss_scale()`
808810

809811
Like all other optimizers, `LossScaleOptimizer.apply_gradients` expects gradients to have the same dtype as the variables. During mixed precision training, variables are float32 (although they are casted to float16), so `LossScaleOptimizer.apply_gradients` expects float32 gradients and will raise an error otherwise. `tf.GradientTape.gradient` will return float32 gradients when passed float32 AutoCastVariables, so no problems will typically occur due to this requirement.
810812

813+
Instead of skipping steps when there are NaNs in the gradients, LossScaleOptimizer could alternatively lower the loss scale and recompute gradients in a loop until they were all finite. However, this can only be done in `minimize`, not `apply_gradients`, as the latter does not compute gradients but only applies them. In the future, an option may be added to LossScaleOptimizer to compute gradients in a loop when `minimize` is used.
814+
811815
## LossScaleOptimizer `__getattribute__` and `__setattr__` delegation
812816

813817
Optimizer subclasses such as SGD add hyperparameters by calling the `_set_hyper()` method. The base Optimizer overrides `__getattribute__` and `__setattr__` so that hyperparmeters can be accessed and set as attributes.

0 commit comments

Comments
 (0)