@@ -345,6 +345,7 @@ def transformer_moe_layer_v1(
345345 group_size_dim .size ,
346346 int ((group_size_dim .size * capacity_factor ) / experts_dim .size ))
347347 expert_capacity = max (expert_capacity , 4 )
348+ tf .logging .info ("expert_capacity: %d" % expert_capacity )
348349 expert_capacity_dim = mtf .Dimension ("expert_capacity" , expert_capacity )
349350
350351 experts_dim_unsplit = mtf .Dimension ("expert_unsplit" , experts_dim .size )
@@ -704,8 +705,6 @@ def _rand_1_gating(
704705 inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
705706 hparams , train , variable_dtype , importance = None , name = "rand_1_gating" ):
706707 """Compute a random top-1 gating."""
707- del importance
708-
709708 # SELECT EXPERT
710709 if train :
711710 policy = hparams .moe_rand_1_policy_train
@@ -724,6 +723,10 @@ def _rand_1_gating(
724723 name = name )
725724 raw_gates = mtf .softmax (gate_logits , reduced_dim = experts_dim )
726725
726+ # The internals of this function run in float32.
727+ # bfloat16 seems to reduce quality.
728+ raw_gates = mtf .to_float (raw_gates )
729+
727730 if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter" :
728731 expert_gate , expert_index = mtf .top_1 (raw_gates , reduced_dim = experts_dim )
729732 elif policy == "sample" :
@@ -740,8 +743,14 @@ def _rand_1_gating(
740743 group_size_dim = inputs .shape [- 2 ]
741744 density_1 = mtf .reduce_mean (expert_mask , reduced_dim = group_size_dim )
742745 density_1_proxy = mtf .reduce_mean (raw_gates , reduced_dim = group_size_dim )
743- loss = (mtf .reduce_mean (density_1_proxy * density_1 )
744- * float (experts_dim .size * experts_dim .size ))
746+ if importance is not None :
747+ expert_mask *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
748+ expert_gate *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
749+ density_1_proxy *= mtf .cast (
750+ mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
751+ loss = (
752+ mtf .reduce_mean (density_1_proxy * density_1 ) *
753+ float (experts_dim .size * experts_dim .size ))
745754
746755 # Logging
747756 if train :
@@ -767,20 +776,25 @@ def _rand_1_gating(
767776 # the batch indices, to each expert, with position_in_expert
768777 position_in_expert = mtf .cumsum (
769778 expert_mask , group_size_dim , exclusive = True ) * expert_mask
779+ position_in_expert = mtf .cast (position_in_expert , dtype = raw_gates .dtype )
770780 # Keep only tokens that fit within expert_capacity.
771781 expert_capacity_float = float (expert_capacity_dim .size )
772- expert_mask *= mtf .to_float (mtf .less (position_in_expert ,
773- expert_capacity_float ))
782+ expert_mask *= mtf .cast (
783+ mtf .less (position_in_expert , expert_capacity_float ),
784+ dtype = raw_gates .dtype )
774785 expert_mask_flat = mtf .reduce_sum (expert_mask , reduced_dim = experts_dim )
775786
776787 # Mask out the experts that have overflowed expert capacity. Sparsify the
777788 # expert_gate.
778789 expert_gate *= expert_mask_flat
779790
780791 combine_tensor = (
781- expert_gate * expert_mask_flat
782- * mtf .one_hot (expert_index , experts_dim )
783- * mtf .one_hot (mtf .to_int32 (position_in_expert ), expert_capacity_dim ))
792+ expert_gate * expert_mask_flat *
793+ mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype ) *
794+ mtf .one_hot (
795+ mtf .to_int32 (position_in_expert ),
796+ expert_capacity_dim ,
797+ dtype = raw_gates .dtype ))
784798
785799 # Match the inputs dtype.
786800 combine_tensor = mtf .cast (combine_tensor , inputs .dtype )
0 commit comments