@@ -53,12 +53,12 @@ def __init__(self,
5353 activation = "relu" ,
5454 moe_gating = "top_2" ,
5555 min_expert_capacity = 4 ,
56- rand_1_policy_train = "input_jitter" ,
57- rand_1_policy_eval = "input_jitter" ,
58- rand_1_dropout = 0.1 ,
59- rand_1_temperature = 1.0 ,
60- rand_1_jitter = 1e-2 ,
61- switch_top_k = 4 ,
56+ switch_policy_train = "input_jitter" ,
57+ switch_policy_eval = "input_jitter" ,
58+ switch_dropout = 0.1 ,
59+ switch_temperature = 1.0 ,
60+ switch_jitter = 1e-2 ,
61+ ntlb_top_k = 4 ,
6262 output_dim = None ,
6363 use_experts_attention = False ):
6464 self ._hparams = HParams (
@@ -76,13 +76,13 @@ def __init__(self,
7676 moe_second_threshold_train = second_threshold_train ,
7777 moe_second_threshold_eval = second_threshold_eval ,
7878 moe_dropout_rate = dropout_rate ,
79- moe_rand_1_policy_train = rand_1_policy_train ,
80- moe_rand_1_policy_eval = rand_1_policy_eval ,
81- moe_rand_1_dropout = rand_1_dropout ,
82- moe_rand_1_temperature = rand_1_temperature ,
83- moe_rand_1_jitter = rand_1_jitter ,
79+ moe_switch_policy_train = switch_policy_train ,
80+ moe_switch_policy_eval = switch_policy_eval ,
81+ moe_switch_dropout = switch_dropout ,
82+ moe_switch_temperature = switch_temperature ,
83+ moe_switch_jitter = switch_jitter ,
8484 moe_output_dim = output_dim ,
85- moe_switch_top_k = switch_top_k ,
85+ moe_ntlb_top_k = ntlb_top_k ,
8686 moe_use_experts_attention = use_experts_attention )
8787 self ._activation = activation
8888
@@ -389,8 +389,8 @@ def transformer_moe_layer_v1(
389389 variable_dtype = variable_dtype ,
390390 importance = nonpadding ,
391391 num_microbatches = num_microbatches )
392- elif hparams .moe_gating == "rand_1 " :
393- dispatch_tensor , combine_tensor , loss = _rand_1_gating (
392+ elif hparams .moe_gating == "switch " :
393+ dispatch_tensor , combine_tensor , loss = _switch_gating (
394394 inputs = inputs ,
395395 outer_expert_dims = None ,
396396 experts_dim = experts_dim_unsplit ,
@@ -400,8 +400,8 @@ def transformer_moe_layer_v1(
400400 variable_dtype = variable_dtype ,
401401 importance = nonpadding ,
402402 num_microbatches = num_microbatches )
403- elif hparams .moe_gating == "switch " :
404- dispatch_tensor , combine_tensor , loss = _switch_gating (
403+ elif hparams .moe_gating == "ntlb " :
404+ dispatch_tensor , combine_tensor , loss = _ntlb_gating (
405405 inputs = inputs ,
406406 outer_expert_dims = None ,
407407 experts_dim = experts_dim_unsplit ,
@@ -774,26 +774,26 @@ def transformer_moe_layer_v2(
774774 return output , (loss_outer + loss_inner ) * hparams .moe_loss_coef
775775
776776
777- def _switch_gating (inputs ,
778- outer_expert_dims ,
779- experts_dim ,
780- expert_capacity_dim ,
781- hparams ,
782- train ,
783- variable_dtype ,
784- importance = None ,
785- name = "switch_gating " ,
786- num_microbatches = None ):
787- """Compute a switch top-1 gating with no-token-left behind behavior."""
777+ def _ntlb_gating (inputs ,
778+ outer_expert_dims ,
779+ experts_dim ,
780+ expert_capacity_dim ,
781+ hparams ,
782+ train ,
783+ variable_dtype ,
784+ importance = None ,
785+ name = "ntlb_gating " ,
786+ num_microbatches = None ):
787+ """Compute Switch gating with no-token-left behind (NTLB) behavior."""
788788 # SELECT EXPERT
789789 if train :
790- policy = hparams .moe_rand_1_policy_train
790+ policy = hparams .moe_switch_policy_train
791791 else :
792- policy = hparams .moe_rand_1_policy_eval
792+ policy = hparams .moe_switch_policy_eval
793793
794794 # Input perturbations
795795 if train and policy == "input_jitter" :
796- inputs = mtf .layers .multiplicative_jitter (inputs , hparams .moe_rand_1_jitter )
796+ inputs = mtf .layers .multiplicative_jitter (inputs , hparams .moe_switch_jitter )
797797
798798 gate_logits = mtf .layers .dense (
799799 inputs ,
@@ -809,7 +809,7 @@ def _switch_gating(inputs,
809809 raw_gates = mtf .to_float (raw_gates )
810810
811811 # Top-k operation
812- k_dim = mtf .Dimension ("k" , hparams .moe_switch_top_k )
812+ k_dim = mtf .Dimension ("k" , hparams .moe_ntlb_top_k )
813813 expert_gate , expert_index = mtf .top_k (
814814 raw_gates , reduced_dim = experts_dim , k_dim = k_dim )
815815 expert_mask = mtf .one_hot (expert_index , experts_dim )
@@ -913,27 +913,27 @@ def _switch_gating(inputs,
913913 return dispatch_tensor , combine_tensor , loss
914914
915915
916- def _rand_1_gating (
916+ def _switch_gating (
917917 inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
918- hparams , train , variable_dtype , importance = None , name = "rand_1_gating " ,
918+ hparams , train , variable_dtype , importance = None , name = "switch_gating " ,
919919 num_microbatches = None ):
920- """Compute a random top-1 gating."""
920+ """Compute Switch gating."""
921921 # SELECT EXPERT
922922 if train :
923- policy = hparams .moe_rand_1_policy_train
923+ policy = hparams .moe_switch_policy_train
924924 else :
925- policy = hparams .moe_rand_1_policy_eval
925+ policy = hparams .moe_switch_policy_eval
926926
927927 # The internals of this function run in float32.
928928 # bfloat16 seems to reduce quality.
929929 gate_inputs = mtf .to_float (inputs )
930930
931931 # Input perturbations
932932 if train and policy == "input_dropout" :
933- gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_rand_1_dropout )
933+ gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_switch_dropout )
934934 elif train and policy == "input_jitter" :
935935 gate_inputs = mtf .layers .multiplicative_jitter (gate_inputs ,
936- hparams .moe_rand_1_jitter )
936+ hparams .moe_switch_jitter )
937937
938938 gate_logits = mtf .layers .dense (
939939 gate_inputs ,
@@ -950,15 +950,14 @@ def _rand_1_gating(
950950 mtf .scalar_summary ("expert_gate" , mtf .reduce_mean (expert_gate ))
951951 elif policy == "sample" :
952952 expert_index = mtf .sample_with_temperature (
953- gate_logits , experts_dim , temperature = hparams .moe_rand_1_temperature )
953+ gate_logits , experts_dim , temperature = hparams .moe_switch_temperature )
954954 expert_gate = mtf .gather (raw_gates , expert_index , dim = experts_dim )
955955 else :
956- raise ValueError ("Unknown rand_1 policy %s" % policy )
956+ raise ValueError ("Unknown Switch gating policy %s" % policy )
957957
958958 expert_mask = mtf .one_hot (expert_index , experts_dim , dtype = raw_gates .dtype )
959959
960960 # LOAD BALANCING LOSS
961- # TODO(liamfedus): Check entropy loss.
962961 group_size_dim = inputs .shape [- 2 ]
963962 density_1 = mtf .reduce_mean (expert_mask , reduced_dim = group_size_dim )
964963 density_1_proxy = mtf .reduce_mean (raw_gates , reduced_dim = group_size_dim )
0 commit comments