Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 4758510

Browse files
author
Mesh TensorFlow Team
committed
Expert Attention Fixes:
- Allow moe.py to work with a tensor of "memory_length" dimension - Fix Experts Attention bug in moe.py where it would break during decoding if the input dimension was different than the output dimension. - Fix bug in ExpertsEncDecAttention where it was only doing Self-Attention on the decoder side. - Factorize expert_computation code to easily allow for using different query and memory antecedents PiperOrigin-RevId: 395552483
1 parent acf6247 commit 4758510

File tree

4 files changed

+48
-99
lines changed

4 files changed

+48
-99
lines changed

mesh_tensorflow/transformer/attention.py

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,10 @@ def __init__(self,
631631
combine_dims=True,
632632
ensemble_dim=None,
633633
keep_query_heads_dims=False,
634-
fold_scaling_into_initializer=False,
634+
fold_scaling_into_initializer=True,
635635
context=None,
636636
experts_hparams=None,
637-
expert_computation="qkv",
638-
is_encdec=False):
637+
expert_computation="qkv"):
639638
super(ExpertsAttentionParams, self).__init__(
640639
mesh=mesh,
641640
query_input_dim=query_input_dim,
@@ -656,7 +655,6 @@ def __init__(self,
656655

657656
self.context = context
658657
self.expert_computation = expert_computation
659-
self.is_encdec = is_encdec
660658

661659
# Unless we want to compute both q and kv, we can use the normal MoE
662660
# settings.
@@ -698,6 +696,9 @@ def __init__(self,
698696
# we want to partition both "experts_hidden" and "heads".
699697
moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size)
700698

699+
tf.logging.info("ExpertsAttention moe_hidden_size: {}".format(
700+
experts_hparams.hidden_size))
701+
tf.logging.info("moe_output_dims: {}".format(moe_output_dims))
701702
self.moe_layer = mtf.transformer.moe.MoE1D(
702703
moe_gating=experts_hparams.moe_gating,
703704
num_experts=experts_hparams.num_experts,
@@ -718,70 +719,55 @@ def __init__(self,
718719
activation=experts_hparams.activation,
719720
z_loss=experts_hparams.z_loss)
720721

721-
def _replace_d_model_dim(self, t):
722-
"""Used to replace the `d_model` dim with `heads`."""
723-
new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
724-
return mtf.reshape(t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))
725-
726-
def _compute_q_with_experts(self, antecedent):
727-
q = self.moe_layer.call(self.context, antecedent)
728-
q = self._replace_d_model_dim(q)
729-
return q
730-
731-
def _compute_kv_with_experts(self, antecedent):
732-
kv = self.moe_layer.call(
733-
self.context, antecedent, use_enc_nonpadding=self.is_encdec)
734-
kv = self._replace_d_model_dim(kv)
735-
return kv
736-
737722
def _compute_merge_qkv(self, antecedent):
738723
"""Computes qkv all in one call using MoE layer."""
739-
# This mode assumes query and memory antecedent are the same.
740-
qkv = self.moe_layer.call(self.context, antecedent)
741-
q, kv = qkv
742-
q = self._replace_d_model_dim(q)
743-
kv = self._replace_d_model_dim(kv)
744-
self._q = q
745-
self._kv = kv
746-
747-
def compute_q(self, query_antecedent):
724+
def _replace_d_model_dim(t):
725+
"""Used to replace the `d_model` dim with `heads`."""
726+
new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
727+
return mtf.reshape(
728+
t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))
748729
if self.expert_computation == "qkv":
749-
self._compute_merge_qkv(query_antecedent)
750-
q = self._q
730+
# NOTE: This assumes querty and memory antecedent are the same
731+
qk = self.moe_layer.call(self.context, antecedent)
732+
# Split qk here since they went through experts-layers
733+
q, k = qk
734+
q = _replace_d_model_dim(q)
735+
k = _replace_d_model_dim(k)
751736
elif self.expert_computation == "q":
752-
q = self._compute_q_with_experts(query_antecedent)
753-
# If computing "kv" with experts, then compute q normally.
737+
q = self.moe_layer.call(self.context, antecedent)
738+
q = _replace_d_model_dim(q)
739+
# Compute key/value normally
740+
k = mtf.layers.us_einsum(
741+
[antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
754742
elif self.expert_computation == "kv":
743+
k = self.moe_layer.call(self.context, antecedent)
744+
k = _replace_d_model_dim(k)
745+
# Compute query normally
755746
q = mtf.layers.us_einsum(
756-
[query_antecedent, self.wq], reduced_dims=[self.query_input_dim])
747+
[antecedent, self.wq], reduced_dims=[self.query_input_dim])
748+
else:
749+
raise ValueError("Invalid expert computation mode: {}".format(
750+
self.expert_computation))
751+
752+
# Scale query
757753
q *= self.key_dim.size ** -0.5
758-
return mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
754+
self._q = mtf.replace_dimensions(q, q.shape.dims[-1], self.q_dims)
755+
self._k = mtf.replace_dimensions(k, k.shape.dims[-1], self.k_dims)
756+
757+
def compute_q(self, query_antecedent):
758+
self._compute_merge_qkv(query_antecedent)
759+
return self._q
759760

760761
def compute_k(self, memory_antecedent):
761-
raise NotImplementedError("ExpertsAttention uses shared_kv = True.")
762+
del memory_antecedent
763+
return self._k
762764

763765
def compute_kv(self, memory_antecedent):
764-
if self.expert_computation == "qkv":
765-
# We have already computing "kv" with "q", so just return its value.
766-
kv = self._kv
767-
# Check if the "length" dimension should be "memory_length" since both
768-
# q and kv were computed using the same antecedent. This is why we must
769-
# always have the same query and memory antecedent for the qkv mode.
770-
if self.context.length_dim in kv.shape.dims:
771-
memory_length = mtf.Dimension(
772-
"memory_length", self.context.length_dim.size)
773-
kv = mtf.replace_dimensions(
774-
kv, self.context.length_dim, memory_length)
775-
# If computing "q" with experts, then compute "kv" normally.
776-
elif self.expert_computation == "q":
777-
kv = mtf.layers.us_einsum(
778-
[memory_antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
779-
elif self.expert_computation == "kv":
780-
kv = self._compute_kv_with_experts(memory_antecedent)
781-
kv = mtf.replace_dimensions(kv, kv.shape.dims[-1], self.k_dims)
782-
return kv
766+
del memory_antecedent
767+
return self._k
783768

784769
def compute_v(self, memory_antecedent):
770+
del memory_antecedent
785771
raise NotImplementedError("ExpertsAttention uses shared_kv = True.")
786772

787773

mesh_tensorflow/transformer/moe.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -98,30 +98,13 @@ def __init__(self,
9898
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
9999
self._activation = activation
100100

101-
def call(self, context, x, losses=None, use_enc_nonpadding=False):
101+
def call(self, context, x, losses=None):
102102
"""Call the layer."""
103103
if context.model.ensemble_dim:
104104
raise NotImplementedError("MoE not yet implemented with ensembles")
105105

106106
has_length_dim = context.length_dim in x.shape.dims
107-
has_memory_length_dim = "memory_length" in x.shape.dimension_names
108-
# Used for EncDec attention if we have the MoE layer produce the kv.
109-
if use_enc_nonpadding:
110-
nonpadding = context.nonpadding_encoder
111-
else:
112-
nonpadding = context.nonpadding
113-
# If a memory_length dimension exists, then we make sure the
114-
# length dimension of the nonpadding tensor matches it.
115-
if (has_memory_length_dim and isinstance(nonpadding, mtf.Tensor)
116-
and "length" in nonpadding.shape.dimension_names):
117-
old_length_dim = nonpadding.shape.get_dim_by_name("length")
118-
new_length_dim = mtf.Dimension("memory_length", old_length_dim.size)
119-
nonpadding = mtf.replace_dimensions(
120-
nonpadding, old_length_dim, new_length_dim)
121-
# Insert a length dimension if one does not exist.
122-
# Typically no length dims will occur on the decoder during autoregressive
123-
# decoding.
124-
if not has_length_dim and not has_memory_length_dim:
107+
if not has_length_dim:
125108
x_shape = x.shape
126109
shape_with_length = mtf.Shape(
127110
x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
@@ -141,21 +124,18 @@ def call(self, context, x, losses=None, use_enc_nonpadding=False):
141124
context.variable_dtype,
142125
layout=context.model.layout,
143126
mesh_shape=context.model.mesh_shape,
144-
nonpadding=nonpadding,
127+
nonpadding=context.nonpadding,
145128
activation=self._activation,
146129
num_microbatches=context.num_microbatches,
147130
token_embeddings=context.input_embeddings)
148131
if context.losses is not None:
149132
context.losses.append(loss)
150-
if not has_length_dim and not has_memory_length_dim:
151-
# Shapes will differ if the input and output dimension of the layer do not
152-
# match.
153-
new_y_shape = mtf.Shape(x_shape.dims[:-1] + [output_dim])
133+
if not has_length_dim:
154134
if self._hparams.moe_use_experts_attention:
155-
y_reshape = [mtf.reshape(y_out, new_y_shape) for y_out in y]
135+
y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y]
156136
y = y_reshape
157137
else:
158-
y = mtf.reshape(y, new_y_shape)
138+
y = mtf.reshape(y, x_shape)
159139
return y
160140

161141

mesh_tensorflow/transformer/transformer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,6 @@ def nonpadding(self):
304304
return mtf.cast(
305305
mtf.not_equal(self.sequence_id, 0), self.activation_dtype)
306306

307-
@property
308-
def nonpadding_encoder(self):
309-
"""Tensor with zeros in padding positions and ones elsewhere for encoder."""
310-
if self.encoder_sequence_id is None:
311-
return None
312-
if self.encoder_sequence_id == 1:
313-
return 1
314-
else:
315-
return mtf.cast(
316-
mtf.not_equal(self.encoder_sequence_id, 0), self.activation_dtype)
317-
318307
def get_position(self):
319308
if self.position_is_default:
320309
return mtf.range(self.mesh, self.length_dim, tf.int32)

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def __init__(self,
410410
**kwargs):
411411
super(ExpertsSelfAttention, self).__init__(**kwargs)
412412
self.expert_computation = expert_computation
413-
self.is_encdec = False # Overrided in ExpertsEncDecAttention
414413
self._hparams = mtf.transformer.moe.HParams(
415414
moe_gating=moe_gating,
416415
num_experts=num_experts,
@@ -466,8 +465,7 @@ def make_params(self, context):
466465
fold_scaling_into_initializer=self.fold_scaling_into_initializer,
467466
context=context,
468467
experts_hparams=self._hparams,
469-
expert_computation=self.expert_computation,
470-
is_encdec=self.is_encdec)
468+
expert_computation=self.expert_computation)
471469

472470

473471
@gin.configurable
@@ -477,10 +475,6 @@ class ExpertsEncDecAttention(ExpertsSelfAttention):
477475
def __init__(self, relative_attention_type=None, **kwargs):
478476
super(ExpertsEncDecAttention, self).__init__(
479477
relative_attention_type=relative_attention_type, **kwargs)
480-
self.is_encdec = True
481-
if self.expert_computation == "qkv":
482-
raise ValueError("ExpertsEncDecAttention must use expert_computation of "
483-
"q or kv.")
484478

485479
def _get_memory_antecedent(self, context):
486480
return context.encoder_output

0 commit comments

Comments
 (0)