@@ -336,9 +336,9 @@ def mesh_axis_to_tensor_axis(self, mesh_ndims):
336336 Returns:
337337 Tuple of optional integers, with length mesh_ndims.
338338 """
339+ ta2ma = self ._tensor_axis_to_mesh_axis
339340 return tuple (
340- [self ._tensor_axis_to_mesh_axis .index (mesh_axis )
341- if mesh_axis in self ._tensor_axis_to_mesh_axis else None
341+ [ta2ma .index (mesh_axis ) if mesh_axis in ta2ma else None
342342 for mesh_axis in xrange (mesh_ndims )])
343343
344344
@@ -459,11 +459,6 @@ def var_key(v):
459459 key_to_vars = collections .defaultdict (collections .deque )
460460 for v in all_variables :
461461 key_to_vars [var_key (v )].append (v )
462- # We need to point the inputs of other operations at the outputs of unstack
463- # instead of the outputs of the deleted Variables. We construct this
464- # mapping from old input tensors to new input tensors.
465- tensor_mapping = {}
466- # maps from old variable name to (stacked_variable, position)
467462 individual_to_stacked = {}
468463 for op in operations :
469464 if isinstance (op , StackedVariable ):
@@ -489,8 +484,13 @@ def var_key(v):
489484 stacked_var = StackedVariable (to_stack )
490485 stack_dim = stacked_var .shape .dims [0 ]
491486 unstacked = unstack (stacked_var .outputs [0 ], stack_dim )
492- for v , t in zip (to_stack , unstacked ):
493- tensor_mapping [v .outputs [0 ]] = t
487+ unstack_op = unstacked [0 ].operation
488+ # replace the output Tensors of the unstack operation with the
489+ # Tensors which were the outputs of the original variable operations.
490+ # Later operations use these Tensors as inputs.
491+ unstack_op ._outputs = [v .outputs [0 ] for v in to_stack ]
492+ for t in unstack_op ._outputs :
493+ t ._operation = unstack_op
494494 for idx , v in enumerate (to_stack ):
495495 individual_to_stacked [v .name ] = stacked_var , idx
496496 else :
@@ -500,10 +500,6 @@ def var_key(v):
500500 if op .trainable :
501501 self ._trainable_variables .append (op )
502502 else :
503- # Point inputs of other operations to the outputs of unstack.
504- for i in xrange (len (op ._inputs )):
505- if op ._inputs [i ] in tensor_mapping :
506- op ._inputs [i ] = tensor_mapping [op ._inputs [i ]]
507503 if isinstance (op , Assign ):
508504 # Rewrite the grouped assignment to stack up the values and then
509505 # assign to the stacked variables.
@@ -930,11 +926,11 @@ def allsplit(self, x, mesh_axis, split_axis, which=None):
930926 num_splits = self .shape [mesh_axis ].size
931927 def my_fn (x , which ):
932928 slice_begin = [
933- dimsize // num_splits * which if i == split_axis
934- else 0 for i , dimsize in enumerate (x .shape .as_list ())]
929+ dimsize // num_splits * which if i == split_axis else 0
930+ for i , dimsize in enumerate (x .shape .as_list ())]
935931 slice_size = [
936- dimsize // num_splits if i == split_axis
937- else dimsize for i , dimsize in enumerate (x .shape .as_list ())]
932+ dimsize // num_splits if i == split_axis else dimsize
933+ for i , dimsize in enumerate (x .shape .as_list ())]
938934 return tf .slice (x , slice_begin , slice_size )
939935 return self .slicewise (my_fn , x , which )
940936
@@ -3657,6 +3653,8 @@ def add(x1, x2, output_shape=None, name=None):
36573653
36583654
36593655def add_n (xs ):
3656+ if not xs :
3657+ return 0
36603658 return functools .reduce (add , xs )
36613659
36623660
@@ -4037,10 +4035,9 @@ def reduce_logsumexp(x, reduced_dim, extra_logit=None, name=None):
40374035 reduced_shape = x .shape - reduced_dim
40384036 max_logit = reduce_max (stop_gradient (x ), output_shape = reduced_shape )
40394037 if extra_logit is not None :
4040- max_logit = maximum (
4041- max_logit ,
4042- stop_gradient (extra_logit ) if isinstance (extra_logit , Tensor )
4043- else extra_logit )
4038+ if isinstance (extra_logit , Tensor ):
4039+ extra_logit = stop_gradient (extra_logit )
4040+ max_logit = maximum (max_logit , extra_logit )
40444041 x -= max_logit
40454042 exp_x = exp (x )
40464043 sum_exp_x = reduce_sum (exp_x , output_shape = reduced_shape )
0 commit comments