Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 274ed90

Browse files
Mesh TensorFlow TeamCopybara-Service
authored andcommitted
Append output index to tensor names.
PiperOrigin-RevId: 218705140
1 parent 4b4a180 commit 274ed90

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

mesh_tensorflow/ops.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,16 @@ def convert_args_to_laid_out_tensors(xs):
10121012
class Tensor(object):
10131013
"""A Distributed Tensor."""
10141014

1015-
def __init__(self, operation, shape, dtype, name=None):
1015+
def __init__(self, operation, shape, dtype, name=None, index=0):
1016+
"""Create a Tensor.
1017+
1018+
Args:
1019+
operation: the Operation that outputs this tensor
1020+
shape: a Shape
1021+
dtype: a tf.DType
1022+
name: an optional string
1023+
index: optional integer, the index among operation's output tensors
1024+
"""
10161025
if not isinstance(shape, Shape):
10171026
raise ValueError("shape must be a Shape got %s" % shape.to_string)
10181027
if not isinstance(dtype, tf.DType):
@@ -1022,7 +1031,7 @@ def __init__(self, operation, shape, dtype, name=None):
10221031
self._shape = shape
10231032
self._dtype = dtype
10241033
if name is None:
1025-
name = self.operation.name
1034+
name = self.operation.name + ":" + str(index)
10261035
self._name = name
10271036
self._mesh.graph.tensors.append(self)
10281037

@@ -1380,7 +1389,8 @@ def __init__(self, forward_op, grad_ys, name=None):
13801389
name=name or "generic_grad")
13811390
self._grad_ys = grad_ys
13821391
self._forward_op = forward_op
1383-
self._outputs = [Tensor(self, x.shape, x.dtype) for x in forward_op.inputs]
1392+
self._outputs = [Tensor(self, x.shape, x.dtype, index=i)
1393+
for i, x in enumerate(forward_op.inputs)]
13841394

13851395
def lower(self, lowering):
13861396
# lists of lists of tf.Tensor
@@ -1809,7 +1819,8 @@ def __init__(self, x, split_dim, num_or_size_splits, name=None):
18091819

18101820
self._outputs = [
18111821
Tensor(self, x.shape.resize_dimension(split_dim.name, output_size),
1812-
x.dtype) for output_size in self._output_sizes]
1822+
x.dtype, index=i)
1823+
for i, output_size in enumerate(self._output_sizes)]
18131824

18141825
def gradient(self, grad_ys):
18151826
return [concat(grad_ys, self._split_dim.name)]
@@ -1898,7 +1909,7 @@ def __init__(self, x, dim, name=None):
18981909
self._axis = x.shape.dims.index(dim)
18991910
output_shape = x.shape - dim
19001911
self._outputs = [
1901-
Tensor(self, output_shape, x.dtype) for _ in xrange(dim.size)]
1912+
Tensor(self, output_shape, x.dtype, index=i) for i in xrange(dim.size)]
19021913

19031914
def gradient(self, grad_ys):
19041915
return [stack(grad_ys, self._dim.name, self._axis)]
@@ -3797,7 +3808,7 @@ def __init__(self, cond_fn, body_fn, inputs,
37973808
ops = self.graph.operations
37983809
before = len(ops)
37993810
def make_placeholders(name):
3800-
return [Tensor(self, t.shape, t.dtype, name="%s_%d" % (name, i))
3811+
return [Tensor(self, t.shape, t.dtype, name="%s:%d" % (name, i))
38013812
for i, t in enumerate(inputs)]
38023813
self._cond_inputs = make_placeholders("cond_input")
38033814
self._cond_output = self._cond_fn(*self._cond_inputs)

0 commit comments

Comments
 (0)