Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 2d983bd

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Distill large top-1 sparse models into small dense models
PiperOrigin-RevId: 325864101
1 parent 20034bf commit 2d983bd

File tree

1 file changed

+23
-9
lines changed
  • mesh_tensorflow/transformer

1 file changed

+23
-9
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def transformer_moe_layer_v1(
345345
group_size_dim.size,
346346
int((group_size_dim.size * capacity_factor) / experts_dim.size))
347347
expert_capacity = max(expert_capacity, 4)
348+
tf.logging.info("expert_capacity: %d" % expert_capacity)
348349
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
349350

350351
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
@@ -704,8 +705,6 @@ def _rand_1_gating(
704705
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
705706
hparams, train, variable_dtype, importance=None, name="rand_1_gating"):
706707
"""Compute a random top-1 gating."""
707-
del importance
708-
709708
# SELECT EXPERT
710709
if train:
711710
policy = hparams.moe_rand_1_policy_train
@@ -724,6 +723,10 @@ def _rand_1_gating(
724723
name=name)
725724
raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)
726725

726+
# The internals of this function run in float32.
727+
# bfloat16 seems to reduce quality.
728+
raw_gates = mtf.to_float(raw_gates)
729+
727730
if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
728731
expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
729732
elif policy == "sample":
@@ -740,8 +743,14 @@ def _rand_1_gating(
740743
group_size_dim = inputs.shape[-2]
741744
density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
742745
density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
743-
loss = (mtf.reduce_mean(density_1_proxy * density_1)
744-
* float(experts_dim.size * experts_dim.size))
746+
if importance is not None:
747+
expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
748+
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
749+
density_1_proxy *= mtf.cast(
750+
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
751+
loss = (
752+
mtf.reduce_mean(density_1_proxy * density_1) *
753+
float(experts_dim.size * experts_dim.size))
745754

746755
# Logging
747756
if train:
@@ -767,20 +776,25 @@ def _rand_1_gating(
767776
# the batch indices, to each expert, with position_in_expert
768777
position_in_expert = mtf.cumsum(
769778
expert_mask, group_size_dim, exclusive=True) * expert_mask
779+
position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
770780
# Keep only tokens that fit within expert_capacity.
771781
expert_capacity_float = float(expert_capacity_dim.size)
772-
expert_mask *= mtf.to_float(mtf.less(position_in_expert,
773-
expert_capacity_float))
782+
expert_mask *= mtf.cast(
783+
mtf.less(position_in_expert, expert_capacity_float),
784+
dtype=raw_gates.dtype)
774785
expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)
775786

776787
# Mask out the experts that have overflowed expert capacity. Sparsify the
777788
# expert_gate.
778789
expert_gate *= expert_mask_flat
779790

780791
combine_tensor = (
781-
expert_gate * expert_mask_flat
782-
* mtf.one_hot(expert_index, experts_dim)
783-
* mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
792+
expert_gate * expert_mask_flat *
793+
mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
794+
mtf.one_hot(
795+
mtf.to_int32(position_in_expert),
796+
expert_capacity_dim,
797+
dtype=raw_gates.dtype))
784798

785799
# Match the inputs dtype.
786800
combine_tensor = mtf.cast(combine_tensor, inputs.dtype)

0 commit comments

Comments
 (0)