Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 13db970

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Option to use mtf.Print to log which tokens are sent to which experts when run on CPU.
PiperOrigin-RevId: 368137313
1 parent a54f5cf commit 13db970

File tree

3 files changed

+85
-11
lines changed

3 files changed

+85
-11
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(self,
6161
ntlb_top_k=4,
6262
output_dim=None,
6363
use_experts_attention=False,
64-
z_loss=None):
64+
z_loss=None,
65+
token_logging=False):
6566
self._hparams = HParams(
6667
moe_gating=moe_gating,
6768
moe_num_experts=num_experts,
@@ -87,6 +88,7 @@ def __init__(self,
8788
moe_use_experts_attention=use_experts_attention,
8889
moe_z_loss=z_loss)
8990
self._activation = activation
91+
self.token_logging = token_logging
9092

9193
def call(self, context, x, losses=None):
9294
"""Call the layer."""
@@ -106,7 +108,13 @@ def call(self, context, x, losses=None):
106108
output_dim = self._hparams.moe_output_dim
107109
else:
108110
output_dim = context.model.model_dim
109-
y, loss = transformer_moe_layer_v1(
111+
if self.token_logging:
112+
tokens = _detokenize(context.inputs, context.model.vocabulary)
113+
x = mtf.Print(x, [tokens], "tokens", summarize=1000)
114+
extras = _windows(context.inputs, context.length_dim)
115+
else:
116+
extras = None
117+
y, loss, extras = transformer_moe_layer_v1(
110118
x,
111119
output_dim,
112120
self._hparams,
@@ -116,7 +124,16 @@ def call(self, context, x, losses=None):
116124
mesh_shape=context.model.mesh_shape,
117125
nonpadding=context.nonpadding,
118126
activation=self._activation,
119-
num_microbatches=context.num_microbatches)
127+
num_microbatches=context.num_microbatches,
128+
extras=extras)
129+
130+
if extras:
131+
extras = _detokenize(extras, context.model.vocabulary)
132+
experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts)
133+
extras = mtf.unstack(extras, experts_dim)
134+
for i, t in enumerate(extras):
135+
y = mtf.Print(y, [t], "EXPERT %s" % i, summarize=1000)
136+
120137
if context.losses is not None:
121138
context.losses.append(loss)
122139
if not has_length_dim:
@@ -128,6 +145,23 @@ def call(self, context, x, losses=None):
128145
return y
129146

130147

148+
@gin.configurable
149+
def _windows(ids, length_dim, window_start=0, window_end=0):
150+
to_stack = []
151+
for offset in range(window_start, window_end + 1):
152+
to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False))
153+
return mtf.stack(to_stack, "window", axis=ids.shape.ndims)
154+
155+
156+
def _detokenize(ids, vocabulary):
157+
return mtf.slicewise(
158+
vocabulary.decode_tf,
159+
[ids],
160+
output_shape=mtf.Shape(ids.shape.dims[:-1]),
161+
output_dtype=tf.string,
162+
splittable_dims=ids.shape.dims[:-1])
163+
164+
131165
class MoE2D(transformer.TransformerLayer):
132166
"""Mixture of Experts Layer."""
133167

@@ -191,7 +225,7 @@ def call(self, context, x, losses=None):
191225
def transformer_moe_layer_v1(
192226
inputs, output_dim, hparams, train, variable_dtype,
193227
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
194-
num_microbatches=None):
228+
num_microbatches=None, extras=None):
195229
"""Local mixture of experts that works well on TPU.
196230
197231
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -266,6 +300,7 @@ def transformer_moe_layer_v1(
266300
and zeros(padding).
267301
activation: a function.
268302
num_microbatches: number of microbatches.
303+
extras: a tensor to dispatch (for debugging purposes)
269304
270305
Returns:
271306
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -329,6 +364,10 @@ def transformer_moe_layer_v1(
329364
# over which those groups are split.
330365
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
331366
orig_inputs.shape.dims[-1])
367+
368+
if extras:
369+
extras_dims = extras.shape.dims[len(batch_and_length_dims):]
370+
332371
# Hack: we assume that
333372
# "outer_batch" == replication of experts
334373
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -360,6 +399,11 @@ def transformer_moe_layer_v1(
360399
# OGSM Tensor
361400
inputs = mtf.reshape(inputs, moe_input_dims)
362401

402+
if extras:
403+
extras = mtf.reshape(
404+
extras,
405+
[outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims)
406+
363407
# Each sequence sends expert_capacity positions to each expert.
364408
if train:
365409
capacity_factor = hparams.moe_capacity_factor_train
@@ -465,6 +509,17 @@ def transformer_moe_layer_v1(
465509
input_dim
466510
]))
467511

512+
if extras:
513+
extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)],
514+
mtf.Shape([
515+
outer_batch_dim, experts_dim_unsplit,
516+
num_groups_dim, expert_capacity_dim] + extras_dims))
517+
extras = mtf.reshape(
518+
extras,
519+
mtf.Shape([
520+
outer_batch_dim, experts_dim, batch_dim_unsplit,
521+
expert_capacity_dim] + extras_dims))
522+
468523
# Now feed the expert inputs through the experts.
469524
h = mtf.layers.dense_product(
470525
expert_inputs,
@@ -519,10 +574,15 @@ def _compute_output(hidden, layer_name):
519574
k = _compute_output(k_h, layer_name="k_wo")
520575
outputs.append(q)
521576
outputs.append(k)
522-
return outputs, loss * hparams.moe_loss_coef
577+
return outputs, loss * hparams.moe_loss_coef, None
523578
else:
524579
output = _compute_output(h, layer_name="wo")
525-
return output, loss * hparams.moe_loss_coef
580+
loss *= hparams.moe_loss_coef
581+
582+
if extras:
583+
return output, loss, extras
584+
else:
585+
return output, loss, None
526586

527587

528588
def transformer_moe_layer_v2(

mesh_tensorflow/transformer/transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ def __init__(self,
721721
input_full_attention=False,
722722
loss_on_targets_only=False,
723723
loss_denominator=None,
724-
token_dropout_rate=0.0):
724+
token_dropout_rate=0.0,
725+
vocabulary=None):
725726
"""Create a Unitransformer.
726727
727728
Args:
@@ -766,6 +767,7 @@ def __init__(self,
766767
same denominator as was used for the pretraining. This complication
767768
might be avoided by always using loss_denominator = 1.0.
768769
token_dropout_rate: an optional floating point value
770+
vocabulary: an optional vocabularies.Vocabulary
769771
"""
770772
self.layer_stack = layer_stack
771773
self.model_dim = mtf.Dimension("d_model", d_model)
@@ -806,6 +808,7 @@ def __init__(self,
806808
raise ValueError(
807809
"input_full_attention only makes sense with autoregressive")
808810
self.token_dropout_rate = token_dropout_rate
811+
self.vocabulary = vocabulary
809812

810813
@property
811814
def fully_autoregressive(self):

mesh_tensorflow/transformer/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def build_model(model_type="bitransformer",
170170
input_vocab_size=gin.REQUIRED,
171171
output_vocab_size=gin.REQUIRED,
172172
layout_rules=None,
173-
mesh_shape=None):
173+
mesh_shape=None,
174+
input_vocabulary=None,
175+
target_vocabulary=None):
174176
"""Build a transformer model.
175177
176178
Currently, four types of models are supported:
@@ -212,15 +214,21 @@ def build_model(model_type="bitransformer",
212214
output_vocab_size: an integer
213215
layout_rules: optional, input to mtf.convert_to_layout_rules
214216
mesh_shape: optional, an input to mtf.convert_to_shape()
217+
input_vocabulary: optional, a vocubalaries.Vocabulary
218+
target_vocabulary: optional, a vocubalaries.Vocabulary
219+
215220
Returns:
216221
a Unitransformer or Bitransformer
217222
"""
218223
if model_type == "bitransformer":
219-
return transformer.make_bitransformer(
224+
ret = transformer.make_bitransformer(
220225
input_vocab_size=input_vocab_size,
221226
output_vocab_size=output_vocab_size,
222227
mesh_shape=mesh_shape,
223228
layout=layout_rules)
229+
ret.encoder.vocabulary = input_vocabulary
230+
ret.decoder.vocabulary = target_vocabulary
231+
return ret
224232
elif model_type == "bi_student_teacher":
225233
return transformer.make_bi_student_teacher(
226234
input_vocab_size=input_vocab_size,
@@ -234,7 +242,8 @@ def build_model(model_type="bitransformer",
234242
input_vocab_size=input_vocab_size,
235243
output_vocab_size=output_vocab_size,
236244
mesh_shape=mesh_shape,
237-
layout=layout_rules)
245+
layout=layout_rules,
246+
vocabulary=input_vocabulary)
238247
else:
239248
raise ValueError("unknown model_type")
240249

@@ -1928,7 +1937,9 @@ def get_estimator(model_type, vocabulary, mesh_shape,
19281937
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
19291938
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
19301939
layout_rules=layout_rules,
1931-
mesh_shape=mesh_shape)
1940+
mesh_shape=mesh_shape,
1941+
input_vocabulary=inputs_vocabulary(vocabulary),
1942+
target_vocabulary=targets_vocabulary(vocabulary))
19321943

19331944
model_fn = tpu_estimator_model_fn(
19341945
model_type=model_type,

0 commit comments

Comments
 (0)