@@ -52,17 +52,20 @@ def __init__(self,
5252 dropout_rate = 0.0 ,
5353 activation = "relu" ,
5454 moe_gating = "top_2" ,
55+ min_expert_capacity = 4 ,
5556 rand_1_policy_train = "input_jitter" ,
5657 rand_1_policy_eval = "input_jitter" ,
5758 rand_1_dropout = 0.1 ,
5859 rand_1_temperature = 1.0 ,
59- rand_1_jitter = 1e-2 ):
60+ rand_1_jitter = 1e-2 ,
61+ switch_top_k = 4 ):
6062 self ._hparams = HParams (
6163 moe_gating = moe_gating ,
6264 moe_num_experts = num_experts ,
6365 moe_loss_coef = loss_coef ,
6466 moe_hidden_size = hidden_size ,
6567 moe_group_size = group_size ,
68+ moe_min_expert_capacity = min_expert_capacity ,
6669 moe_capacity_factor_train = capacity_factor_train ,
6770 moe_capacity_factor_eval = capacity_factor_eval ,
6871 moe_use_second_place_loss = use_second_place_loss ,
@@ -75,7 +78,8 @@ def __init__(self,
7578 moe_rand_1_policy_eval = rand_1_policy_eval ,
7679 moe_rand_1_dropout = rand_1_dropout ,
7780 moe_rand_1_temperature = rand_1_temperature ,
78- moe_rand_1_jitter = rand_1_jitter )
81+ moe_rand_1_jitter = rand_1_jitter ,
82+ moe_switch_top_k = switch_top_k )
7983 self ._activation = activation
8084
8185 def call (self , context , x , losses = None ):
@@ -344,10 +348,9 @@ def transformer_moe_layer_v1(
344348 expert_capacity = min (
345349 group_size_dim .size ,
346350 int ((group_size_dim .size * capacity_factor ) / experts_dim .size ))
347- expert_capacity = max (expert_capacity , 4 )
351+ expert_capacity = max (expert_capacity , hparams . moe_min_expert_capacity )
348352 tf .logging .info ("expert_capacity: %d" % expert_capacity )
349353 expert_capacity_dim = mtf .Dimension ("expert_capacity" , expert_capacity )
350-
351354 experts_dim_unsplit = mtf .Dimension ("expert_unsplit" , experts_dim .size )
352355 batch_dim_unsplit = mtf .Dimension ("batch_unsplit" , num_groups_dim .size )
353356 if nonpadding is not None :
@@ -377,6 +380,16 @@ def transformer_moe_layer_v1(
377380 train = train ,
378381 variable_dtype = variable_dtype ,
379382 importance = nonpadding )
383+ elif hparams .moe_gating == "switch" :
384+ dispatch_tensor , combine_tensor , loss = _switch_gating (
385+ inputs = inputs ,
386+ outer_expert_dims = None ,
387+ experts_dim = experts_dim_unsplit ,
388+ expert_capacity_dim = expert_capacity_dim ,
389+ hparams = hparams ,
390+ train = train ,
391+ variable_dtype = variable_dtype ,
392+ importance = nonpadding )
380393 else :
381394 raise ValueError ("unknown hparams.moe_gating=%s" % hparams .moe_gating )
382395
@@ -571,7 +584,7 @@ def transformer_moe_layer_v2(
571584 else :
572585 capacity_factor = hparams .moe_capacity_factor_eval
573586 expert_capacity = min (s .size , int ((s .size * capacity_factor ) / x .size ))
574- expert_capacity = max (expert_capacity , 4 )
587+ expert_capacity = max (expert_capacity , hparams . moe_min_expert_capacity )
575588 c = mtf .Dimension ("expert_capacity_x" , expert_capacity )
576589
577590 # We "cheat" here and look at the mesh shape and layout. This is to ensure
@@ -588,7 +601,7 @@ def transformer_moe_layer_v2(
588601 expert_capacity = min (
589602 t .size ,
590603 int ((t .size * hparams .moe_capacity_factor_second_level ) / y .size ))
591- expert_capacity = max (expert_capacity , 4 )
604+ expert_capacity = max (expert_capacity , hparams . moe_min_expert_capacity )
592605 d = mtf .Dimension ("expert_capacity_y" , expert_capacity )
593606
594607 # First level of expert routing
@@ -701,6 +714,140 @@ def transformer_moe_layer_v2(
701714 return output , (loss_outer + loss_inner ) * hparams .moe_loss_coef
702715
703716
717+ def _switch_gating (inputs ,
718+ outer_expert_dims ,
719+ experts_dim ,
720+ expert_capacity_dim ,
721+ hparams ,
722+ train ,
723+ variable_dtype ,
724+ importance = None ,
725+ name = "switch_gating" ):
726+ """Compute a switch top-1 gating with no-token-left behind behavior."""
727+ # SELECT EXPERT
728+ if train :
729+ policy = hparams .moe_rand_1_policy_train
730+ else :
731+ policy = hparams .moe_rand_1_policy_eval
732+
733+ # Input perturbations
734+ if train and policy == "input_jitter" :
735+ inputs = mtf .layers .multiplicative_jitter (inputs , hparams .moe_rand_1_jitter )
736+
737+ gate_logits = mtf .layers .dense (
738+ inputs ,
739+ experts_dim ,
740+ use_bias = False ,
741+ expert_dims = outer_expert_dims ,
742+ variable_dtype = variable_dtype ,
743+ name = name )
744+ raw_gates = mtf .softmax (gate_logits , reduced_dim = experts_dim )
745+
746+ # The internals of this function run in float32.
747+ # bfloat16 seems to reduce quality.
748+ raw_gates = mtf .to_float (raw_gates )
749+
750+ # Top-k operation
751+ k_dim = mtf .Dimension ("k" , hparams .moe_switch_top_k )
752+ expert_gate , expert_index = mtf .top_k (
753+ raw_gates , reduced_dim = experts_dim , k_dim = k_dim )
754+ expert_mask = mtf .one_hot (expert_index , experts_dim )
755+
756+ # LOAD BALANCING LOSS
757+ outer_batch_dim = inputs .shape [0 ]
758+ batch_dim = inputs .shape [1 ]
759+ group_size_dim = inputs .shape [- 2 ]
760+ density_1 = mtf .reduce_mean (expert_mask , reduced_dim = group_size_dim )
761+ density_1_proxy = mtf .reduce_mean (raw_gates , reduced_dim = group_size_dim )
762+ if importance is not None :
763+ expert_mask *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
764+ expert_gate *= mtf .cast (mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
765+ density_1_proxy *= mtf .cast (
766+ mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
767+ loss = (
768+ mtf .reduce_mean (density_1_proxy * density_1 ) *
769+ float (experts_dim .size * experts_dim .size ))
770+
771+ # Logging
772+ if train :
773+ entropy = mtf .reduce_sum (
774+ - raw_gates * mtf .log (raw_gates + 1e-9 ), reduced_dim = experts_dim )
775+ batch_entropy = mtf .reduce_mean (entropy )
776+ mtf .scalar_summary (name + "/entropy" , batch_entropy )
777+
778+ mask_count_experts = mtf .reduce_sum (expert_mask , output_shape = [experts_dim ])
779+ total_routed = mtf .reduce_sum (mask_count_experts )
780+ expert_fraction = mtf .to_float (mask_count_experts / total_routed )
781+ split_fractions = mtf .split (
782+ expert_fraction ,
783+ split_dim = experts_dim ,
784+ num_or_size_splits = experts_dim .size )
785+ for fraction in split_fractions :
786+ mtf .scalar_summary ("experts/" + fraction .name .replace (":" , "/" ),
787+ mtf .reduce_mean (fraction ))
788+ mtf .scalar_summary ("aux_loss" , mtf .reduce_mean (loss ))
789+
790+ # COMPUTE ASSIGNMENT TO EXPERT
791+ # Iteratively route tokens (no-token-left-behind). The idea is to route as
792+ # many tokens as possible to top-i before then trying top-(i+1).
793+ top_k_masks = mtf .split (
794+ expert_mask , split_dim = k_dim , num_or_size_splits = k_dim .size )
795+ top_k_gates = mtf .split (
796+ expert_gate , split_dim = k_dim , num_or_size_splits = k_dim .size )
797+ top_k_indices = mtf .split (
798+ expert_index , split_dim = k_dim , num_or_size_splits = k_dim .size )
799+
800+ # Tensors cumulative values over the iterative process.
801+ combine_tensor = mtf .constant (
802+ inputs .mesh ,
803+ value = 0 ,
804+ shape = [outer_batch_dim , batch_dim , experts_dim , expert_capacity_dim ])
805+ cum_tokens = mtf .constant (
806+ inputs .mesh , value = 0 , shape = [outer_batch_dim , batch_dim , experts_dim ])
807+ tokens_left_to_route = mtf .constant (
808+ inputs .mesh , value = 1. , shape = [outer_batch_dim , batch_dim , group_size_dim ])
809+
810+ expert_capacity_float = float (expert_capacity_dim .size )
811+ for (top_i_mask , top_i_gate , top_i_index ) in zip (top_k_masks , top_k_gates ,
812+ top_k_indices ):
813+ top_i_mask = mtf .reshape (
814+ top_i_mask ,
815+ new_shape = [outer_batch_dim , batch_dim , group_size_dim , experts_dim ])
816+ # Operate only on the unrouted tokens.
817+ top_i_mask *= tokens_left_to_route
818+
819+ # Record cumulative number of tokens to each expert across iterations.
820+ cumulative_tokens_in_expert = cum_tokens + mtf .cumsum (
821+ top_i_mask , group_size_dim )
822+
823+ expert_overflow = mtf .to_float (
824+ mtf .less_equal (cumulative_tokens_in_expert , expert_capacity_float ))
825+ output_i_tokens = top_i_mask * expert_overflow
826+
827+ # Update the cumulative tokens routed to each expert.
828+ cum_tokens += mtf .reduce_sum (output_i_tokens , reduced_dim = group_size_dim )
829+ tokens_left_to_route -= (
830+ mtf .reduce_sum (output_i_tokens , reduced_dim = experts_dim ))
831+
832+ # Combine-tensor for this iteration
833+ output_i_tokens_flat = mtf .reduce_sum (
834+ output_i_tokens , reduced_dim = experts_dim )
835+ position_in_expert = cumulative_tokens_in_expert - 1
836+ top_i_combine_tensor = (
837+ top_i_gate * output_i_tokens_flat *
838+ mtf .one_hot (top_i_index , experts_dim ) *
839+ mtf .one_hot (mtf .to_int32 (position_in_expert ), expert_capacity_dim ))
840+ combine_tensor += top_i_combine_tensor
841+
842+ # Match the inputs dtype.
843+ combine_tensor = mtf .cast (combine_tensor , inputs .dtype )
844+ loss = mtf .cast (loss , inputs .dtype )
845+ dispatch_tensor = mtf .cast (
846+ mtf .cast (combine_tensor , tf .bool ), combine_tensor .dtype )
847+
848+ return dispatch_tensor , combine_tensor , loss
849+
850+
704851def _rand_1_gating (
705852 inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
706853 hparams , train , variable_dtype , importance = None , name = "rand_1_gating" ):
0 commit comments