@@ -1012,7 +1012,16 @@ def convert_args_to_laid_out_tensors(xs):
10121012class Tensor (object ):
10131013 """A Distributed Tensor."""
10141014
1015- def __init__ (self , operation , shape , dtype , name = None ):
1015+ def __init__ (self , operation , shape , dtype , name = None , index = 0 ):
1016+ """Create a Tensor.
1017+
1018+ Args:
1019+ operation: the Operation that outputs this tensor
1020+ shape: a Shape
1021+ dtype: a tf.DType
1022+ name: an optional string
1023+ index: optional integer, the index among operation's output tensors
1024+ """
10161025 if not isinstance (shape , Shape ):
10171026 raise ValueError ("shape must be a Shape got %s" % shape .to_string )
10181027 if not isinstance (dtype , tf .DType ):
@@ -1022,7 +1031,7 @@ def __init__(self, operation, shape, dtype, name=None):
10221031 self ._shape = shape
10231032 self ._dtype = dtype
10241033 if name is None :
1025- name = self .operation .name
1034+ name = self .operation .name + ":" + str ( index )
10261035 self ._name = name
10271036 self ._mesh .graph .tensors .append (self )
10281037
@@ -1380,7 +1389,8 @@ def __init__(self, forward_op, grad_ys, name=None):
13801389 name = name or "generic_grad" )
13811390 self ._grad_ys = grad_ys
13821391 self ._forward_op = forward_op
1383- self ._outputs = [Tensor (self , x .shape , x .dtype ) for x in forward_op .inputs ]
1392+ self ._outputs = [Tensor (self , x .shape , x .dtype , index = i )
1393+ for i , x in enumerate (forward_op .inputs )]
13841394
13851395 def lower (self , lowering ):
13861396 # lists of lists of tf.Tensor
@@ -1809,7 +1819,8 @@ def __init__(self, x, split_dim, num_or_size_splits, name=None):
18091819
18101820 self ._outputs = [
18111821 Tensor (self , x .shape .resize_dimension (split_dim .name , output_size ),
1812- x .dtype ) for output_size in self ._output_sizes ]
1822+ x .dtype , index = i )
1823+ for i , output_size in enumerate (self ._output_sizes )]
18131824
18141825 def gradient (self , grad_ys ):
18151826 return [concat (grad_ys , self ._split_dim .name )]
@@ -1898,7 +1909,7 @@ def __init__(self, x, dim, name=None):
18981909 self ._axis = x .shape .dims .index (dim )
18991910 output_shape = x .shape - dim
19001911 self ._outputs = [
1901- Tensor (self , output_shape , x .dtype ) for _ in xrange (dim .size )]
1912+ Tensor (self , output_shape , x .dtype , index = i ) for i in xrange (dim .size )]
19021913
19031914 def gradient (self , grad_ys ):
19041915 return [stack (grad_ys , self ._dim .name , self ._axis )]
@@ -3797,7 +3808,7 @@ def __init__(self, cond_fn, body_fn, inputs,
37973808 ops = self .graph .operations
37983809 before = len (ops )
37993810 def make_placeholders (name ):
3800- return [Tensor (self , t .shape , t .dtype , name = "%s_ %d" % (name , i ))
3811+ return [Tensor (self , t .shape , t .dtype , name = "%s: %d" % (name , i ))
38013812 for i , t in enumerate (inputs )]
38023813 self ._cond_inputs = make_placeholders ("cond_input" )
38033814 self ._cond_output = self ._cond_fn (* self ._cond_inputs )
0 commit comments