@@ -424,6 +424,19 @@ def transformer_moe_layer_v1(
424424 variable_dtype = variable_dtype ,
425425 importance = nonpadding ,
426426 num_microbatches = num_microbatches )
427+ elif hparams .moe_gating == "expert_selection" :
428+ dispatch_tensor , combine_tensor , loss = _expert_selection_gating (
429+ inputs = inputs ,
430+ outer_expert_dims = None ,
431+ experts_dim = experts_dim_unsplit ,
432+ group_size_dim = group_size_dim ,
433+ expert_capacity_dim = expert_capacity_dim ,
434+ hparams = hparams ,
435+ train = train ,
436+ variable_dtype = variable_dtype ,
437+ importance = nonpadding ,
438+ name = "expert_selection_gating" ,
439+ num_microbatches = num_microbatches )
427440 else :
428441 raise ValueError ("unknown hparams.moe_gating=%s" % hparams .moe_gating )
429442
@@ -1038,6 +1051,124 @@ def _switch_max_gating(
10381051 return dispatch_tensor , combine_tensor , loss
10391052
10401053
1054+ def _expert_selection_gating (
1055+ inputs , outer_expert_dims , experts_dim , group_size_dim ,
1056+ expert_capacity_dim , hparams , train , variable_dtype , importance = None ,
1057+ name = "expert_selection_gating" , num_microbatches = None ,
1058+ normalize_by_num_experts_routed = True ):
1059+ """Compute gating where each expert chooses what tokens it wants."""
1060+ # Select the randomization policy.
1061+ if train :
1062+ policy = hparams .moe_switch_policy_train
1063+ else :
1064+ policy = hparams .moe_switch_policy_eval
1065+
1066+ # The internals of this function run in float32 otherwise instabilities
1067+ # can occur.
1068+ gate_inputs = mtf .to_float (inputs )
1069+
1070+ # Input perturbations for exploration.
1071+ if train and policy == "input_dropout" :
1072+ gate_inputs = mtf .dropout (gate_inputs , 1.0 - hparams .moe_switch_dropout )
1073+ elif train and policy == "input_jitter" :
1074+ gate_inputs = mtf .layers .multiplicative_jitter (gate_inputs ,
1075+ hparams .moe_switch_jitter )
1076+
1077+ # Compute expert logits for each token.
1078+ # gate_logits shape: [outer_batch, batch, group, expert_unsplit]
1079+ gate_logits = mtf .layers .dense (
1080+ gate_inputs ,
1081+ experts_dim ,
1082+ use_bias = False ,
1083+ expert_dims = outer_expert_dims ,
1084+ variable_dtype = variable_dtype ,
1085+ name = name )
1086+
1087+ # Set tokens to -inf before softmax if importance is zero as softmax is
1088+ # normalized over all tokens in the group.
1089+ if importance is not None :
1090+ gate_logits += mtf .cast (
1091+ mtf .equal (importance , 0.0 ), dtype = gate_logits .dtype ) * - 1e9
1092+ raw_gates = mtf .softmax (gate_logits , reduced_dim = group_size_dim )
1093+
1094+ # expert_gate_probs shape:
1095+ # [outer_batch, batch, expert_unsplit, expert_capacity]
1096+ # expert_gate_indices shape:
1097+ # [outer_batch, batch, expert_unsplit, expert_capacity]
1098+ expert_gate_probs , expert_gate_indices = mtf .top_k (
1099+ raw_gates , reduced_dim = group_size_dim , k_dim = expert_capacity_dim )
1100+
1101+ # dispatch_tensor shape:
1102+ # [outer_batch, batch, expert_unsplit, expert_capacity, group]
1103+ dispatch_tensor = mtf .one_hot (
1104+ expert_gate_indices , group_size_dim , dtype = raw_gates .dtype )
1105+
1106+ # combine_tensor shape:
1107+ # [outer_batch, batch, expert_unsplit, expert_capacity, group]
1108+ combine_tensor = dispatch_tensor * expert_gate_probs
1109+
1110+ # Tokens will be aggregated across many experts and will not
1111+ # be normalized. This could be an issue, so might want to normalize by the
1112+ # number of experts each token is sent to.
1113+ if normalize_by_num_experts_routed :
1114+ num_experts_routed = mtf .reduce_sum (
1115+ dispatch_tensor ,
1116+ output_shape = (dispatch_tensor .shape [:2 ] + [group_size_dim ]))
1117+ combine_tensor /= mtf .maximum (num_experts_routed , 1.0 )
1118+
1119+ ################### Compute the load balancing loss ###################
1120+ # Push `aggregated_group_probs` of size `group` (which sums to num_experts)
1121+ # to be uniform.
1122+ # aggregated_group_probs shape: [outer_batch, batch, group]
1123+ # importance shape: [outer_batch, batch, group]
1124+ aggregated_group_probs = mtf .reduce_mean (raw_gates , reduced_dim = experts_dim )
1125+ if importance is not None :
1126+ aggregated_group_probs *= mtf .cast (
1127+ mtf .equal (importance , 1.0 ), dtype = raw_gates .dtype )
1128+
1129+ # Scale loss by group_size to keep loss constant across different group_sizes.
1130+ # true_group_size is number of tokens per group that are not masked out.
1131+ true_group_size = mtf .cast (
1132+ mtf .reduce_sum (importance , reduced_dim = group_size_dim ),
1133+ dtype = raw_gates .dtype )
1134+ loss = (mtf .reduce_mean (
1135+ aggregated_group_probs * aggregated_group_probs * true_group_size ) *
1136+ float (group_size_dim .size ))
1137+
1138+ if num_microbatches and num_microbatches > 1 :
1139+ tf .logging .info ("Dividing load-balance loss by num_microbatches={}" .format (
1140+ num_microbatches ))
1141+ loss /= num_microbatches
1142+
1143+ ################### Logging ###################
1144+ if train :
1145+ entropy = mtf .reduce_sum (- raw_gates * mtf .log (raw_gates + 1e-9 ),
1146+ reduced_dim = group_size_dim )
1147+ batch_entropy = mtf .reduce_mean (entropy )
1148+ mtf .scalar_summary (name + "/entropy" , batch_entropy )
1149+
1150+ # Log for each token in the group how many experts it gets sent to.
1151+ num_experts_sent_per_token = (
1152+ mtf .reduce_sum (dispatch_tensor , output_shape = [group_size_dim ]) *
1153+ float (experts_dim .size * expert_capacity_dim .size ))
1154+ split_fractions = mtf .split (
1155+ num_experts_sent_per_token ,
1156+ split_dim = group_size_dim ,
1157+ num_or_size_splits = group_size_dim .size )
1158+ for fraction in split_fractions :
1159+ mtf .scalar_summary ("group_token/" + fraction .name .replace (":" , "/" ),
1160+ mtf .reduce_sum (fraction ))
1161+ mtf .scalar_summary ("aux_loss" , mtf .reduce_mean (loss ))
1162+
1163+ #################### Match the inputs dtype ###################
1164+ combine_tensor = mtf .cast (combine_tensor , inputs .dtype )
1165+ loss = mtf .cast (loss , inputs .dtype )
1166+ dispatch_tensor = mtf .cast (
1167+ mtf .cast (dispatch_tensor , tf .bool ), combine_tensor .dtype )
1168+
1169+ return dispatch_tensor , combine_tensor , loss
1170+
1171+
10411172def _switch_gating (
10421173 inputs , outer_expert_dims , experts_dim , expert_capacity_dim ,
10431174 hparams , train , variable_dtype , importance = None , name = "switch_gating" ,
0 commit comments