Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 3dc48f5

Browse files
author
Mesh TensorFlow Team
committed
Add new routing method where each expert chooses when tokens it wants. A token can be chosen multiple times across different experts.
PiperOrigin-RevId: 368316075
1 parent 90f9edc commit 3dc48f5

File tree

1 file changed

+131
-0
lines changed
  • mesh_tensorflow/transformer

1 file changed

+131
-0
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,19 @@ def transformer_moe_layer_v1(
424424
variable_dtype=variable_dtype,
425425
importance=nonpadding,
426426
num_microbatches=num_microbatches)
427+
elif hparams.moe_gating == "expert_selection":
428+
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
429+
inputs=inputs,
430+
outer_expert_dims=None,
431+
experts_dim=experts_dim_unsplit,
432+
group_size_dim=group_size_dim,
433+
expert_capacity_dim=expert_capacity_dim,
434+
hparams=hparams,
435+
train=train,
436+
variable_dtype=variable_dtype,
437+
importance=nonpadding,
438+
name="expert_selection_gating",
439+
num_microbatches=num_microbatches)
427440
else:
428441
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
429442

@@ -1038,6 +1051,124 @@ def _switch_max_gating(
10381051
return dispatch_tensor, combine_tensor, loss
10391052

10401053

1054+
def _expert_selection_gating(
1055+
inputs, outer_expert_dims, experts_dim, group_size_dim,
1056+
expert_capacity_dim, hparams, train, variable_dtype, importance=None,
1057+
name="expert_selection_gating", num_microbatches=None,
1058+
normalize_by_num_experts_routed=True):
1059+
"""Compute gating where each expert chooses what tokens it wants."""
1060+
# Select the randomization policy.
1061+
if train:
1062+
policy = hparams.moe_switch_policy_train
1063+
else:
1064+
policy = hparams.moe_switch_policy_eval
1065+
1066+
# The internals of this function run in float32 otherwise instabilities
1067+
# can occur.
1068+
gate_inputs = mtf.to_float(inputs)
1069+
1070+
# Input perturbations for exploration.
1071+
if train and policy == "input_dropout":
1072+
gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_switch_dropout)
1073+
elif train and policy == "input_jitter":
1074+
gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
1075+
hparams.moe_switch_jitter)
1076+
1077+
# Compute expert logits for each token.
1078+
# gate_logits shape: [outer_batch, batch, group, expert_unsplit]
1079+
gate_logits = mtf.layers.dense(
1080+
gate_inputs,
1081+
experts_dim,
1082+
use_bias=False,
1083+
expert_dims=outer_expert_dims,
1084+
variable_dtype=variable_dtype,
1085+
name=name)
1086+
1087+
# Set tokens to -inf before softmax if importance is zero as softmax is
1088+
# normalized over all tokens in the group.
1089+
if importance is not None:
1090+
gate_logits += mtf.cast(
1091+
mtf.equal(importance, 0.0), dtype=gate_logits.dtype) * -1e9
1092+
raw_gates = mtf.softmax(gate_logits, reduced_dim=group_size_dim)
1093+
1094+
# expert_gate_probs shape:
1095+
# [outer_batch, batch, expert_unsplit, expert_capacity]
1096+
# expert_gate_indices shape:
1097+
# [outer_batch, batch, expert_unsplit, expert_capacity]
1098+
expert_gate_probs, expert_gate_indices = mtf.top_k(
1099+
raw_gates, reduced_dim=group_size_dim, k_dim=expert_capacity_dim)
1100+
1101+
# dispatch_tensor shape:
1102+
# [outer_batch, batch, expert_unsplit, expert_capacity, group]
1103+
dispatch_tensor = mtf.one_hot(
1104+
expert_gate_indices, group_size_dim, dtype=raw_gates.dtype)
1105+
1106+
# combine_tensor shape:
1107+
# [outer_batch, batch, expert_unsplit, expert_capacity, group]
1108+
combine_tensor = dispatch_tensor * expert_gate_probs
1109+
1110+
# Tokens will be aggregated across many experts and will not
1111+
# be normalized. This could be an issue, so might want to normalize by the
1112+
# number of experts each token is sent to.
1113+
if normalize_by_num_experts_routed:
1114+
num_experts_routed = mtf.reduce_sum(
1115+
dispatch_tensor,
1116+
output_shape=(dispatch_tensor.shape[:2] + [group_size_dim]))
1117+
combine_tensor /= mtf.maximum(num_experts_routed, 1.0)
1118+
1119+
################### Compute the load balancing loss ###################
1120+
# Push `aggregated_group_probs` of size `group` (which sums to num_experts)
1121+
# to be uniform.
1122+
# aggregated_group_probs shape: [outer_batch, batch, group]
1123+
# importance shape: [outer_batch, batch, group]
1124+
aggregated_group_probs = mtf.reduce_mean(raw_gates, reduced_dim=experts_dim)
1125+
if importance is not None:
1126+
aggregated_group_probs *= mtf.cast(
1127+
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
1128+
1129+
# Scale loss by group_size to keep loss constant across different group_sizes.
1130+
# true_group_size is number of tokens per group that are not masked out.
1131+
true_group_size = mtf.cast(
1132+
mtf.reduce_sum(importance, reduced_dim=group_size_dim),
1133+
dtype=raw_gates.dtype)
1134+
loss = (mtf.reduce_mean(
1135+
aggregated_group_probs * aggregated_group_probs * true_group_size) *
1136+
float(group_size_dim.size))
1137+
1138+
if num_microbatches and num_microbatches > 1:
1139+
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
1140+
num_microbatches))
1141+
loss /= num_microbatches
1142+
1143+
################### Logging ###################
1144+
if train:
1145+
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
1146+
reduced_dim=group_size_dim)
1147+
batch_entropy = mtf.reduce_mean(entropy)
1148+
mtf.scalar_summary(name + "/entropy", batch_entropy)
1149+
1150+
# Log for each token in the group how many experts it gets sent to.
1151+
num_experts_sent_per_token = (
1152+
mtf.reduce_sum(dispatch_tensor, output_shape=[group_size_dim]) *
1153+
float(experts_dim.size * expert_capacity_dim.size))
1154+
split_fractions = mtf.split(
1155+
num_experts_sent_per_token,
1156+
split_dim=group_size_dim,
1157+
num_or_size_splits=group_size_dim.size)
1158+
for fraction in split_fractions:
1159+
mtf.scalar_summary("group_token/" + fraction.name.replace(":", "/"),
1160+
mtf.reduce_sum(fraction))
1161+
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
1162+
1163+
#################### Match the inputs dtype ###################
1164+
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
1165+
loss = mtf.cast(loss, inputs.dtype)
1166+
dispatch_tensor = mtf.cast(
1167+
mtf.cast(dispatch_tensor, tf.bool), combine_tensor.dtype)
1168+
1169+
return dispatch_tensor, combine_tensor, loss
1170+
1171+
10411172
def _switch_gating(
10421173
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
10431174
hparams, train, variable_dtype, importance=None, name="switch_gating",

0 commit comments

Comments
 (0)