Skip to content

Commit 3fa6560

Browse files
Bert fp16 perf improvements, do the matmul in intermediate later in fp16, and also remove explicit casting to fp32 for layerNorm.
PiperOrigin-RevId: 273379063
1 parent 4129326 commit 3fa6560

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

official/nlp/bert_modeling.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,14 @@ def __init__(self,
637637
kernel_initializer=None,
638638
bias_initializer="zeros",
639639
activation=None,
640+
fp32_activation=False,
640641
**kwargs):
641642
super(Dense2DProjection, self).__init__(**kwargs)
642643
self.output_size = output_size
643644
self.kernel_initializer = kernel_initializer
644645
self.bias_initializer = bias_initializer
645646
self.activation = activation
647+
self.fp32_activation = fp32_activation
646648

647649
def build(self, input_shape):
648650
"""Implements build() for the layer."""
@@ -685,6 +687,8 @@ def call(self, inputs):
685687
ret = tf.einsum("abc,cd->abd", inputs, self.kernel)
686688
ret += self.bias
687689
if self.activation is not None:
690+
if self.dtype == tf.float16 and self.fp32_activation:
691+
ret = tf.cast(ret, tf.float32)
688692
return self.activation(ret)
689693
return ret
690694

@@ -753,7 +757,7 @@ def build(self, unused_input_shapes):
753757
kernel_initializer=get_initializer(self.initializer_range),
754758
activation=self.intermediate_activation,
755759
# Uses float32 so that gelu activation is done in float32.
756-
dtype=tf.float32,
760+
fp32_activation=True,
757761
name="intermediate")
758762
self.output_dense = Dense2DProjection(
759763
output_size=self.hidden_size,
@@ -788,23 +792,16 @@ def call(self, inputs):
788792
attention_output = self.attention_dropout(attention_output)
789793
# Use float32 in keras layer norm and the gelu activation in the
790794
# intermediate dense layer for numeric stability
791-
# TODO(reedwm): These casts are probably unnecessary, as we passed
792-
# dtype=tf.float32 to the layer norm constructor, so it will cast its inputs
793-
# to float32 automatically. These manual casts additionally do the "+"
794-
# operator in float32, but "+" is numerically stable in float16.
795-
if self.float_type == tf.float16:
796-
input_tensor = tf.cast(input_tensor, tf.float32)
797-
attention_output = tf.cast(attention_output, tf.float32)
798795
attention_output = self.attention_layer_norm(input_tensor +
799796
attention_output)
797+
if self.float_type == tf.float16:
798+
attention_output = tf.cast(attention_output, tf.float16)
800799
intermediate_output = self.intermediate_dense(attention_output)
801800
if self.float_type == tf.float16:
802801
intermediate_output = tf.cast(intermediate_output, tf.float16)
803802
layer_output = self.output_dense(intermediate_output)
804803
layer_output = self.output_dropout(layer_output)
805804
# Use float32 in keras layer norm for numeric stability
806-
if self.float_type == tf.float16:
807-
layer_output = tf.cast(layer_output, tf.float32)
808805
layer_output = self.output_layer_norm(layer_output + attention_output)
809806
if self.float_type == tf.float16:
810807
layer_output = tf.cast(layer_output, tf.float16)

0 commit comments

Comments
 (0)