@@ -637,12 +637,14 @@ def __init__(self,
637
637
kernel_initializer = None ,
638
638
bias_initializer = "zeros" ,
639
639
activation = None ,
640
+ fp32_activation = False ,
640
641
** kwargs ):
641
642
super (Dense2DProjection , self ).__init__ (** kwargs )
642
643
self .output_size = output_size
643
644
self .kernel_initializer = kernel_initializer
644
645
self .bias_initializer = bias_initializer
645
646
self .activation = activation
647
+ self .fp32_activation = fp32_activation
646
648
647
649
def build (self , input_shape ):
648
650
"""Implements build() for the layer."""
@@ -685,6 +687,8 @@ def call(self, inputs):
685
687
ret = tf .einsum ("abc,cd->abd" , inputs , self .kernel )
686
688
ret += self .bias
687
689
if self .activation is not None :
690
+ if self .dtype == tf .float16 and self .fp32_activation :
691
+ ret = tf .cast (ret , tf .float32 )
688
692
return self .activation (ret )
689
693
return ret
690
694
@@ -753,7 +757,7 @@ def build(self, unused_input_shapes):
753
757
kernel_initializer = get_initializer (self .initializer_range ),
754
758
activation = self .intermediate_activation ,
755
759
# Uses float32 so that gelu activation is done in float32.
756
- dtype = tf . float32 ,
760
+ fp32_activation = True ,
757
761
name = "intermediate" )
758
762
self .output_dense = Dense2DProjection (
759
763
output_size = self .hidden_size ,
@@ -788,23 +792,16 @@ def call(self, inputs):
788
792
attention_output = self .attention_dropout (attention_output )
789
793
# Use float32 in keras layer norm and the gelu activation in the
790
794
# intermediate dense layer for numeric stability
791
- # TODO(reedwm): These casts are probably unnecessary, as we passed
792
- # dtype=tf.float32 to the layer norm constructor, so it will cast its inputs
793
- # to float32 automatically. These manual casts additionally do the "+"
794
- # operator in float32, but "+" is numerically stable in float16.
795
- if self .float_type == tf .float16 :
796
- input_tensor = tf .cast (input_tensor , tf .float32 )
797
- attention_output = tf .cast (attention_output , tf .float32 )
798
795
attention_output = self .attention_layer_norm (input_tensor +
799
796
attention_output )
797
+ if self .float_type == tf .float16 :
798
+ attention_output = tf .cast (attention_output , tf .float16 )
800
799
intermediate_output = self .intermediate_dense (attention_output )
801
800
if self .float_type == tf .float16 :
802
801
intermediate_output = tf .cast (intermediate_output , tf .float16 )
803
802
layer_output = self .output_dense (intermediate_output )
804
803
layer_output = self .output_dropout (layer_output )
805
804
# Use float32 in keras layer norm for numeric stability
806
- if self .float_type == tf .float16 :
807
- layer_output = tf .cast (layer_output , tf .float32 )
808
805
layer_output = self .output_layer_norm (layer_output + attention_output )
809
806
if self .float_type == tf .float16 :
810
807
layer_output = tf .cast (layer_output , tf .float16 )
0 commit comments