Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 7e3d57f

Browse files
author
Mesh TensorFlow Team
committed
Fixed some bugs for Synthesizer model.
PiperOrigin-RevId: 323770262
1 parent d46ff87 commit 7e3d57f

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

mesh_tensorflow/transformer/attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,7 @@ def synthetic_attention(q,
209209
tf.logging.info("Using Random Synthesizers")
210210
r_shape = mtf.Shape([mtf.Dimension("length", max_length),
211211
mtf.Dimension("heads", num_heads.size),
212-
mtf.Dimension("memory_length",
213-
num_heads, max_length)])
214-
initializer = tf.random_uniform_initializer()
212+
mtf.Dimension("memory_length", max_length)])
215213
r = mtf.get_variable(context.mesh, "R", r_shape,
216214
initializer=None,
217215
dtype=context.variable_dtype)
@@ -235,12 +233,11 @@ def synthetic_attention(q,
235233
r_shape = mtf.Shape([mtf.Dimension("length", 512),
236234
mtf.Dimension("heads", num_heads.size),
237235
mtf.Dimension("memory_length", 512)])
238-
initializer = tf.random_normal_initializer()
239236
r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
240-
initializer=initializer,
237+
initializer=None,
241238
dtype=context.variable_dtype)
242239
r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
243-
initializer=initializer,
240+
initializer=None,
244241
dtype=context.variable_dtype)
245242
r = mtf.einsum([r1, r2], r_shape)
246243
r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
@@ -324,6 +321,7 @@ def synthetic_attention(q,
324321
outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
325322
else:
326323
outputs_shape = q.shape - [key_dim] + value_dim
324+
327325
outputs = mtf.einsum([weights, v], outputs_shape)
328326
return outputs
329327

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def __init__(self,
438438
self.no_query = False
439439
else:
440440
self.shared_kv = True
441-
self.shared_q = True
441+
self.no_query = True
442442

443443
def make_params(self, context):
444444
return attention_params(context=context,
@@ -451,7 +451,6 @@ def make_params(self, context):
451451
def call(self, context, x, losses=None):
452452
"""Call the layer."""
453453
params = self.make_params(context)
454-
q = params.compute_q(x)
455454
memory_length = self.memory_length(context)
456455
if context.mode == "incremental":
457456
m = x
@@ -467,26 +466,23 @@ def call(self, context, x, losses=None):
467466
q = x
468467
else:
469468
q = params.compute_q(x)
469+
if self.shared_kv:
470+
k = kv
471+
v = kv
470472
if context.mode == "incremental":
471473
one_hot = mtf.one_hot(
472474
context.position, memory_length, dtype=context.activation_dtype)
473475
inv_one_hot = 1.0 - one_hot
474-
if self.shared_kv:
475-
old_kv = context.get_states(1)
476-
kv = old_kv * inv_one_hot + kv * one_hot
477-
else:
478-
old_k, old_v = context.get_states(2)
479-
k = old_k * inv_one_hot + k * one_hot
480-
v = old_v * inv_one_hot + v * one_hot
476+
old_k, old_v = context.get_states(2)
477+
k = old_k * inv_one_hot + k * one_hot
478+
v = old_v * inv_one_hot + v * one_hot
481479
memory_position = mtf.range(context.mesh, memory_length, tf.int32)
482480
else:
483481
memory_position = self.rename_length_to_memory_length(
484482
context.position, context)
485483
if context.mode == "incremental" or context.mode == "first_part":
486-
context.record_new_states([kv] if self.shared_kv else [k, v])
487-
if self.shared_kv:
488-
k = kv
489-
v = kv
484+
context.record_new_states([k, v])
485+
490486
o = attention.synthetic_attention(q, k, v, memory_length,
491487
self.kv_dim, self.kv_dim,
492488
self.compute_bias(context,

0 commit comments

Comments
 (0)