Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 4a3e81f

Browse files
nshazeerCopybara-Service
authored andcommitted
New mesh_tensorflow transformer now handles encoder-decoder models and beam search.
Fix bug in variable stacking code (operations in while loops did not get their inputs redirected) Fix datatype bug in simd_mesh_impl.py PiperOrigin-RevId: 224369334
1 parent 8f257a3 commit 4a3e81f

File tree

5 files changed

+603
-89
lines changed

5 files changed

+603
-89
lines changed

mesh_tensorflow/layers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,12 +1157,15 @@ def compress_mean(x, dim, compression_factor):
11571157
return x
11581158

11591159

1160-
def embedding(indices, vocab_dim, output_dim, variable_dtype, name=None):
1160+
def embedding_weights(
1161+
mesh, vocab_dim, output_dim, variable_dtype, name="embedding"):
1162+
return mtf.get_variable(
1163+
mesh, name, mtf.Shape([vocab_dim, output_dim]),
1164+
dtype=variable_dtype, initializer=tf.random_normal_initializer())
1165+
1166+
1167+
def embedding(indices, vocab_dim, output_dim, variable_dtype, name="embedding"):
11611168
"""Embedding layer."""
1162-
with tf.variable_scope(name, default_name="embedding"):
1163-
weights = mtf.get_variable(
1164-
indices.mesh, "w",
1165-
mtf.Shape([vocab_dim, output_dim]),
1166-
dtype=variable_dtype,
1167-
initializer=tf.random_normal_initializer())
1168-
return mtf.gather(weights, indices, vocab_dim)
1169+
weights = embedding_weights(
1170+
indices.mesh, vocab_dim, output_dim, variable_dtype, name)
1171+
return mtf.gather(weights, indices, vocab_dim)

mesh_tensorflow/ops.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,9 @@ def mesh_axis_to_tensor_axis(self, mesh_ndims):
336336
Returns:
337337
Tuple of optional integers, with length mesh_ndims.
338338
"""
339+
ta2ma = self._tensor_axis_to_mesh_axis
339340
return tuple(
340-
[self._tensor_axis_to_mesh_axis.index(mesh_axis)
341-
if mesh_axis in self._tensor_axis_to_mesh_axis else None
341+
[ta2ma.index(mesh_axis) if mesh_axis in ta2ma else None
342342
for mesh_axis in xrange(mesh_ndims)])
343343

344344

@@ -459,11 +459,6 @@ def var_key(v):
459459
key_to_vars = collections.defaultdict(collections.deque)
460460
for v in all_variables:
461461
key_to_vars[var_key(v)].append(v)
462-
# We need to point the inputs of other operations at the outputs of unstack
463-
# instead of the outputs of the deleted Variables. We construct this
464-
# mapping from old input tensors to new input tensors.
465-
tensor_mapping = {}
466-
# maps from old variable name to (stacked_variable, position)
467462
individual_to_stacked = {}
468463
for op in operations:
469464
if isinstance(op, StackedVariable):
@@ -489,8 +484,13 @@ def var_key(v):
489484
stacked_var = StackedVariable(to_stack)
490485
stack_dim = stacked_var.shape.dims[0]
491486
unstacked = unstack(stacked_var.outputs[0], stack_dim)
492-
for v, t in zip(to_stack, unstacked):
493-
tensor_mapping[v.outputs[0]] = t
487+
unstack_op = unstacked[0].operation
488+
# replace the output Tensors of the unstack operation with the
489+
# Tensors which were the outputs of the original variable operations.
490+
# Later operations use these Tensors as inputs.
491+
unstack_op._outputs = [v.outputs[0] for v in to_stack]
492+
for t in unstack_op._outputs:
493+
t._operation = unstack_op
494494
for idx, v in enumerate(to_stack):
495495
individual_to_stacked[v.name] = stacked_var, idx
496496
else:
@@ -500,10 +500,6 @@ def var_key(v):
500500
if op.trainable:
501501
self._trainable_variables.append(op)
502502
else:
503-
# Point inputs of other operations to the outputs of unstack.
504-
for i in xrange(len(op._inputs)):
505-
if op._inputs[i] in tensor_mapping:
506-
op._inputs[i] = tensor_mapping[op._inputs[i]]
507503
if isinstance(op, Assign):
508504
# Rewrite the grouped assignment to stack up the values and then
509505
# assign to the stacked variables.
@@ -930,11 +926,11 @@ def allsplit(self, x, mesh_axis, split_axis, which=None):
930926
num_splits = self.shape[mesh_axis].size
931927
def my_fn(x, which):
932928
slice_begin = [
933-
dimsize // num_splits * which if i == split_axis
934-
else 0 for i, dimsize in enumerate(x.shape.as_list())]
929+
dimsize // num_splits * which if i == split_axis else 0
930+
for i, dimsize in enumerate(x.shape.as_list())]
935931
slice_size = [
936-
dimsize // num_splits if i == split_axis
937-
else dimsize for i, dimsize in enumerate(x.shape.as_list())]
932+
dimsize // num_splits if i == split_axis else dimsize
933+
for i, dimsize in enumerate(x.shape.as_list())]
938934
return tf.slice(x, slice_begin, slice_size)
939935
return self.slicewise(my_fn, x, which)
940936

@@ -3657,6 +3653,8 @@ def add(x1, x2, output_shape=None, name=None):
36573653

36583654

36593655
def add_n(xs):
3656+
if not xs:
3657+
return 0
36603658
return functools.reduce(add, xs)
36613659

36623660

@@ -4037,10 +4035,9 @@ def reduce_logsumexp(x, reduced_dim, extra_logit=None, name=None):
40374035
reduced_shape = x.shape - reduced_dim
40384036
max_logit = reduce_max(stop_gradient(x), output_shape=reduced_shape)
40394037
if extra_logit is not None:
4040-
max_logit = maximum(
4041-
max_logit,
4042-
stop_gradient(extra_logit) if isinstance(extra_logit, Tensor)
4043-
else extra_logit)
4038+
if isinstance(extra_logit, Tensor):
4039+
extra_logit = stop_gradient(extra_logit)
4040+
max_logit = maximum(max_logit, extra_logit)
40444041
x -= max_logit
40454042
exp_x = exp(x)
40464043
sum_exp_x = reduce_sum(exp_x, output_shape=reduced_shape)

mesh_tensorflow/simd_mesh_impl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self, variable, mesh_impl):
102102
slices = []
103103
slices_with_master_dtype = []
104104
with tf.device(variable.master_device), utils.outside_all_rewrites():
105-
zero_tensor = tf.zeros(slice_shape)
105+
zero_tensor = tf.zeros(slice_shape, dtype=variable.slice_dtype)
106106

107107
# pylint: disable=protected-access
108108
init_device_stack = tf.get_default_graph()._device_function_stack
@@ -398,8 +398,9 @@ def slicewise(self, fn, *inputs):
398398
return inputs[0] + inputs[1]
399399
# convert all inputs to LaidOutTensor where possible
400400
inputs = mtf.convert_args_to_laid_out_tensors(inputs)
401-
ret = fn(*[x.one_slice if isinstance(x, self.LaidOutTensor)
402-
else x for x in inputs])
401+
ret = fn(*[
402+
x.one_slice if isinstance(x, self.LaidOutTensor) else x
403+
for x in inputs])
403404
if isinstance(ret, tuple):
404405
return tuple([self.LaidOutTensor([t]) for t in ret])
405406
else:

0 commit comments

Comments
 (0)