Skip to content

Commit 0257b27

Browse files
reedwmtensorflower-gardener
authored andcommitted
Simply LayerNorm mixed precision logic.
Instead of needing to ensure variables are float32, casting inputs to float32, etc, instead dtype="float32" is passed to the layer constructor, which will do all that logic automatically. The only difference is the output of LayerNorm is now float32 instead of float16, so an extra cast is needed elsewhere. PiperOrigin-RevId: 273833286
1 parent 3980d2a commit 0257b27

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

official/transformer/v2/transformer.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def symbols_to_logits_fn(ids, i, cache):
290290

291291
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
292292
"""Return predicted sequence."""
293+
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
293294
if self.params["padded_decode"]:
294295
batch_size = encoder_outputs.shape.as_list()[0]
295296
input_length = encoder_outputs.shape.as_list()[1]
@@ -356,27 +357,21 @@ class LayerNormalization(tf.keras.layers.Layer):
356357
"""Applies layer normalization."""
357358

358359
def __init__(self, hidden_size):
359-
super(LayerNormalization, self).__init__()
360+
# Pass dtype=float32, as we have not yet tested if layer norm is numerically
361+
# stable in float16 and bfloat16.
362+
super(LayerNormalization, self).__init__(dtype="float32")
360363
self.hidden_size = hidden_size
361364

362365
def build(self, input_shape):
363366
"""Builds the layer."""
364-
# Passing experimental_autocast=False causes these variables to not be
365-
# automatically casted to fp16 when mixed precision is used. Since we use
366-
# float32 in call() for numeric stability, we do not want variables to be
367-
# casted to fp16.
368367
self.scale = self.add_weight(
369368
"layer_norm_scale",
370369
shape=[self.hidden_size],
371-
dtype="float32",
372-
initializer=tf.ones_initializer(),
373-
experimental_autocast=False)
370+
initializer=tf.ones_initializer())
374371
self.bias = self.add_weight(
375372
"layer_norm_bias",
376373
shape=[self.hidden_size],
377-
dtype="float32",
378-
initializer=tf.zeros_initializer(),
379-
experimental_autocast=False)
374+
initializer=tf.zeros_initializer())
380375
super(LayerNormalization, self).build(input_shape)
381376

382377
def get_config(self):
@@ -385,13 +380,10 @@ def get_config(self):
385380
}
386381

387382
def call(self, x, epsilon=1e-6):
388-
input_dtype = x.dtype
389-
if input_dtype == tf.float16 or input_dtype == tf.bfloat16:
390-
x = tf.cast(x, tf.float32)
391383
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
392384
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
393385
norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon)
394-
return tf.cast(norm_x * self.scale + self.bias, input_dtype)
386+
return norm_x * self.scale + self.bias
395387

396388

397389
class PrePostProcessingWrapper(tf.keras.layers.Layer):

0 commit comments

Comments
 (0)