Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 5a9d503

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
No-token-left-behind routing by passing token to next best expert
PiperOrigin-RevId: 326073610
1 parent 2d983bd commit 5a9d503

File tree

1 file changed

+153
-6
lines changed
  • mesh_tensorflow/transformer

1 file changed

+153
-6
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,20 @@ def __init__(self,
5252
dropout_rate=0.0,
5353
activation="relu",
5454
moe_gating="top_2",
55+
min_expert_capacity=4,
5556
rand_1_policy_train="input_jitter",
5657
rand_1_policy_eval="input_jitter",
5758
rand_1_dropout=0.1,
5859
rand_1_temperature=1.0,
59-
rand_1_jitter=1e-2):
60+
rand_1_jitter=1e-2,
61+
switch_top_k=4):
6062
self._hparams = HParams(
6163
moe_gating=moe_gating,
6264
moe_num_experts=num_experts,
6365
moe_loss_coef=loss_coef,
6466
moe_hidden_size=hidden_size,
6567
moe_group_size=group_size,
68+
moe_min_expert_capacity=min_expert_capacity,
6669
moe_capacity_factor_train=capacity_factor_train,
6770
moe_capacity_factor_eval=capacity_factor_eval,
6871
moe_use_second_place_loss=use_second_place_loss,
@@ -75,7 +78,8 @@ def __init__(self,
7578
moe_rand_1_policy_eval=rand_1_policy_eval,
7679
moe_rand_1_dropout=rand_1_dropout,
7780
moe_rand_1_temperature=rand_1_temperature,
78-
moe_rand_1_jitter=rand_1_jitter)
81+
moe_rand_1_jitter=rand_1_jitter,
82+
moe_switch_top_k=switch_top_k)
7983
self._activation = activation
8084

8185
def call(self, context, x, losses=None):
@@ -344,10 +348,9 @@ def transformer_moe_layer_v1(
344348
expert_capacity = min(
345349
group_size_dim.size,
346350
int((group_size_dim.size * capacity_factor) / experts_dim.size))
347-
expert_capacity = max(expert_capacity, 4)
351+
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
348352
tf.logging.info("expert_capacity: %d" % expert_capacity)
349353
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
350-
351354
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
352355
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
353356
if nonpadding is not None:
@@ -377,6 +380,16 @@ def transformer_moe_layer_v1(
377380
train=train,
378381
variable_dtype=variable_dtype,
379382
importance=nonpadding)
383+
elif hparams.moe_gating == "switch":
384+
dispatch_tensor, combine_tensor, loss = _switch_gating(
385+
inputs=inputs,
386+
outer_expert_dims=None,
387+
experts_dim=experts_dim_unsplit,
388+
expert_capacity_dim=expert_capacity_dim,
389+
hparams=hparams,
390+
train=train,
391+
variable_dtype=variable_dtype,
392+
importance=nonpadding)
380393
else:
381394
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
382395

@@ -571,7 +584,7 @@ def transformer_moe_layer_v2(
571584
else:
572585
capacity_factor = hparams.moe_capacity_factor_eval
573586
expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size))
574-
expert_capacity = max(expert_capacity, 4)
587+
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
575588
c = mtf.Dimension("expert_capacity_x", expert_capacity)
576589

577590
# We "cheat" here and look at the mesh shape and layout. This is to ensure
@@ -588,7 +601,7 @@ def transformer_moe_layer_v2(
588601
expert_capacity = min(
589602
t.size,
590603
int((t.size * hparams.moe_capacity_factor_second_level) / y.size))
591-
expert_capacity = max(expert_capacity, 4)
604+
expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
592605
d = mtf.Dimension("expert_capacity_y", expert_capacity)
593606

594607
# First level of expert routing
@@ -701,6 +714,140 @@ def transformer_moe_layer_v2(
701714
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
702715

703716

717+
def _switch_gating(inputs,
718+
outer_expert_dims,
719+
experts_dim,
720+
expert_capacity_dim,
721+
hparams,
722+
train,
723+
variable_dtype,
724+
importance=None,
725+
name="switch_gating"):
726+
"""Compute a switch top-1 gating with no-token-left behind behavior."""
727+
# SELECT EXPERT
728+
if train:
729+
policy = hparams.moe_rand_1_policy_train
730+
else:
731+
policy = hparams.moe_rand_1_policy_eval
732+
733+
# Input perturbations
734+
if train and policy == "input_jitter":
735+
inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)
736+
737+
gate_logits = mtf.layers.dense(
738+
inputs,
739+
experts_dim,
740+
use_bias=False,
741+
expert_dims=outer_expert_dims,
742+
variable_dtype=variable_dtype,
743+
name=name)
744+
raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)
745+
746+
# The internals of this function run in float32.
747+
# bfloat16 seems to reduce quality.
748+
raw_gates = mtf.to_float(raw_gates)
749+
750+
# Top-k operation
751+
k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
752+
expert_gate, expert_index = mtf.top_k(
753+
raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
754+
expert_mask = mtf.one_hot(expert_index, experts_dim)
755+
756+
# LOAD BALANCING LOSS
757+
outer_batch_dim = inputs.shape[0]
758+
batch_dim = inputs.shape[1]
759+
group_size_dim = inputs.shape[-2]
760+
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
761+
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
762+
if importance is not None:
763+
expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
764+
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
765+
density_1_proxy *= mtf.cast(
766+
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
767+
loss = (
768+
mtf.reduce_mean(density_1_proxy * density_1) *
769+
float(experts_dim.size * experts_dim.size))
770+
771+
# Logging
772+
if train:
773+
entropy = mtf.reduce_sum(
774+
-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
775+
batch_entropy = mtf.reduce_mean(entropy)
776+
mtf.scalar_summary(name + "/entropy", batch_entropy)
777+
778+
mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
779+
total_routed = mtf.reduce_sum(mask_count_experts)
780+
expert_fraction = mtf.to_float(mask_count_experts / total_routed)
781+
split_fractions = mtf.split(
782+
expert_fraction,
783+
split_dim=experts_dim,
784+
num_or_size_splits=experts_dim.size)
785+
for fraction in split_fractions:
786+
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
787+
mtf.reduce_mean(fraction))
788+
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
789+
790+
# COMPUTE ASSIGNMENT TO EXPERT
791+
# Iteratively route tokens (no-token-left-behind). The idea is to route as
792+
# many tokens as possible to top-i before then trying top-(i+1).
793+
top_k_masks = mtf.split(
794+
expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
795+
top_k_gates = mtf.split(
796+
expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
797+
top_k_indices = mtf.split(
798+
expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)
799+
800+
# Tensors cumulative values over the iterative process.
801+
combine_tensor = mtf.constant(
802+
inputs.mesh,
803+
value=0,
804+
shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
805+
cum_tokens = mtf.constant(
806+
inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
807+
tokens_left_to_route = mtf.constant(
808+
inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])
809+
810+
expert_capacity_float = float(expert_capacity_dim.size)
811+
for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
812+
top_k_indices):
813+
top_i_mask = mtf.reshape(
814+
top_i_mask,
815+
new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
816+
# Operate only on the unrouted tokens.
817+
top_i_mask *= tokens_left_to_route
818+
819+
# Record cumulative number of tokens to each expert across iterations.
820+
cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
821+
top_i_mask, group_size_dim)
822+
823+
expert_overflow = mtf.to_float(
824+
mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
825+
output_i_tokens = top_i_mask * expert_overflow
826+
827+
# Update the cumulative tokens routed to each expert.
828+
cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
829+
tokens_left_to_route -= (
830+
mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))
831+
832+
# Combine-tensor for this iteration
833+
output_i_tokens_flat = mtf.reduce_sum(
834+
output_i_tokens, reduced_dim=experts_dim)
835+
position_in_expert = cumulative_tokens_in_expert - 1
836+
top_i_combine_tensor = (
837+
top_i_gate * output_i_tokens_flat *
838+
mtf.one_hot(top_i_index, experts_dim) *
839+
mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
840+
combine_tensor += top_i_combine_tensor
841+
842+
# Match the inputs dtype.
843+
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
844+
loss = mtf.cast(loss, inputs.dtype)
845+
dispatch_tensor = mtf.cast(
846+
mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)
847+
848+
return dispatch_tensor, combine_tensor, loss
849+
850+
704851
def _rand_1_gating(
705852
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
706853
hparams, train, variable_dtype, importance=None, name="rand_1_gating"):

0 commit comments

Comments
 (0)