@@ -347,7 +347,6 @@ class Graph(object):
347347
348348 def __init__ (self ):
349349 self ._operations = []
350- self ._tensors = []
351350 self ._trainable_variables = []
352351 self ._all_variables = []
353352 # Maps a name used in the graph to the next id to use for that name.
@@ -360,10 +359,6 @@ def __repr__(self):
360359 def operations (self ):
361360 return self ._operations
362361
363- @property
364- def tensors (self ):
365- return self ._tensors
366-
367362 @property
368363 def trainable_variables (self ):
369364 return self ._trainable_variables
@@ -407,6 +402,91 @@ def unique_name(self, name, mark_as_used=True):
407402
408403 return name
409404
405+ def rewrite_stack_variables (self , max_combined_size = 2 ** 30 ):
406+ """Rewrite the current graph to combine variables.
407+
408+ This helps speed up graph construction times in the case of large meshes
409+ and large numbers of variables.
410+
411+ This function should be called after the forward pass, (before any variable
412+ assignemnts). Some similar variables are stacked to form larger variables.
413+
414+ Variables created prior to this call are checkpointed as separate variables,
415+ even though they are combined internally. So the checkpoints are
416+ compatible for inference purposes with/without this call. However, the
417+ optimizer accumulators, which are created after this call are checkpointed
418+ as combined variables.
419+
420+ When we find a set of variables with the same shape/dtype/etc, we replace
421+ them with one StackedVariable and an "unstack" operation. The
422+ StackedVariable has multiple master variables (so as to maintain
423+ checkpiont compatibility), but only one slice variable per device. We
424+ point the inputs of later operations to the outputs of the
425+ "unstack" operations, instead of the outputs of the defunct single
426+ variables.
427+
428+ TODO(noam, ylc): Rewrite assignments as well, so that this can be applied at
429+ the end of graph construction and be fully checkpoint-compatible.
430+ Alternatively, find another solution for speeding up graph construction.
431+
432+ Args:
433+ max_combined_size: an integer - maximum size for combined variables.
434+ """
435+ all_variables = self ._all_variables
436+ operations = self ._operations
437+ self ._operations = []
438+ self ._all_variables = []
439+ self ._trainable_variables = []
440+ def var_key (v ):
441+ return str ([v .shape ,
442+ v .master_dtype ,
443+ v .slice_dtype ,
444+ v .activation_dtype ,
445+ v .trainable ])
446+ key_to_vars = collections .defaultdict (list )
447+ for v in all_variables :
448+ key_to_vars [var_key (v )].append (v )
449+ deleted_vars = set ()
450+ # We need to point the inputs of other operations at the outputs of unstack
451+ # instead of the outputs of the deleted Variables. We construct this
452+ # mapping from old input tensors to new input tensors.
453+ tensor_mapping = {}
454+ for op in operations :
455+ if isinstance (op , Assign ):
456+ raise ValueError ("stack_variables() should be called before any "
457+ "variable assignment." )
458+ if isinstance (op , StackedVariable ):
459+ raise ValueError ("stack_variables() should not be called twice." )
460+ if isinstance (op , Variable ):
461+ if op in deleted_vars :
462+ continue
463+ similar_vars = key_to_vars [var_key (op )]
464+ num_to_stack = max (1 , min (
465+ len (similar_vars ),
466+ max_combined_size // op .shape .size ))
467+ to_stack = similar_vars [:num_to_stack ]
468+ key_to_vars [var_key (op )] = similar_vars [num_to_stack :]
469+ if num_to_stack > 1 :
470+ stacked_var = StackedVariable (to_stack )
471+ stack_dim = stacked_var .shape .dims [0 ]
472+ deleted_vars .update (to_stack )
473+ unstacked = unstack (stacked_var .outputs [0 ], stack_dim )
474+ for v , t in zip (to_stack , unstacked ):
475+ tensor_mapping [v .outputs [0 ]] = t
476+ else :
477+ self ._operations .append (op )
478+ self ._all_variables .append (op )
479+ if op .trainable :
480+ self .trainable_variables .append (op )
481+ else :
482+ self ._operations .append (op )
483+ # Point inputs of other operations to the outputs of unstack.
484+ # pylint: disable=protected-access
485+ for i in xrange (len (op ._inputs )):
486+ if op ._inputs [i ] in tensor_mapping :
487+ op ._inputs [i ] = tensor_mapping [op ._inputs [i ]]
488+ # pylint: enable=protected-access
489+
410490
411491class Lowering (object ):
412492 """Lowering of a Graph from Mesh-TensorFlow to TensorFlow.
@@ -1087,7 +1167,6 @@ def __init__(self, operation, shape, dtype, name=None, index=0):
10871167 if name is None :
10881168 name = self .operation .name + ":" + str (index )
10891169 self ._name = name
1090- self ._mesh .graph .tensors .append (self )
10911170
10921171 @property
10931172 def shape (self ):
@@ -2204,7 +2283,7 @@ def conv2d(conv_input, conv_filter, strides, padding, name=None):
22042283
22052284
22062285class Conv2dBackpropInputOperation (Operation ):
2207- """like tf.nn.conv2d_backprop_input"""
2286+ """like tf.nn.conv2d_backprop_input. """
22082287
22092288 def __init__ (self , input_shape , conv_filter , dy , strides , padding , name = None ):
22102289 super (Conv2dBackpropInputOperation , self ).__init__ (
@@ -2618,11 +2697,12 @@ def __init__(self, mesh, name, shape, master_dtype, slice_dtype,
26182697 self ._slice_dtype = slice_dtype
26192698 self ._activation_dtype = activation_dtype
26202699 self ._trainable = trainable
2621- with tf .device (mesh .variable_placer_fn ), utils .outside_all_rewrites ():
2622- self .master = tf .get_variable (
2623- name , shape .to_integer_list , dtype = master_dtype ,
2624- initializer = initializer , ** kwargs )
2625- self ._name = self .master .name [:self .master .name .find (":" )]
2700+ if not isinstance (self , StackedVariable ):
2701+ with tf .device (mesh .variable_placer_fn ), utils .outside_all_rewrites ():
2702+ self ._master = tf .get_variable (
2703+ name , shape .to_integer_list , dtype = master_dtype ,
2704+ initializer = initializer , ** kwargs )
2705+ self ._name = self ._master .name [:self ._master .name .find (":" )]
26262706 self ._outputs = [Tensor (self , shape , activation_dtype )]
26272707 self .graph .all_variables .append (self )
26282708 if trainable :
@@ -2667,6 +2747,61 @@ def slice_dtype(self):
26672747 def activation_dtype (self ):
26682748 return self ._activation_dtype
26692749
2750+ @property
2751+ def trainable (self ):
2752+ return self ._trainable
2753+
2754+ @property
2755+ def master_device (self ):
2756+ return self ._master .device
2757+
2758+ def get_master (self ):
2759+ return self ._master
2760+
2761+ def assign_to_master (self , val ):
2762+ return tf .assign (self ._master , val )
2763+
2764+
2765+ class StackedVariable (Variable ):
2766+ """A Variable which combines many variables into one.
2767+
2768+ This is a performance optimization to reduce the time associated with large
2769+ numbers of slice variables. See Graph.rewrite_stack_variables() for usage.
2770+ """
2771+
2772+ def __init__ (self , vs ):
2773+ """Create a StackedVariable.
2774+
2775+ Args:
2776+ vs: a list of Variables
2777+ """
2778+ shape = Shape ([Dimension ("stacked" , len (vs ))] + vs [0 ].shape .dims )
2779+ name = "stacked/" + vs [0 ].name
2780+ # TODO(noam): verify that vs are the same shape, etc.
2781+ super (StackedVariable , self ).__init__ (
2782+ vs [0 ].mesh , name , shape , vs [0 ].master_dtype , vs [0 ].slice_dtype ,
2783+ vs [0 ].activation_dtype , None , vs [0 ].trainable )
2784+ self ._name = name
2785+ self ._masters = [v .get_master () for v in vs ]
2786+ self ._original_names = [v .name for v in vs ]
2787+
2788+ @property
2789+ def original_names (self ):
2790+ return self ._original_names
2791+
2792+ @property
2793+ def master_device (self ):
2794+ return self ._masters [0 ].device
2795+
2796+ def get_master (self ):
2797+ with tf .device (self .master_device ):
2798+ return tf .stack (self ._masters )
2799+
2800+ def assign_to_master (self , val ):
2801+ return tf .group ([
2802+ tf .assign (var_slice , val_slice ) for var_slice , val_slice
2803+ in zip (self ._masters , tf .unstack (val ))])
2804+
26702805
26712806class ReadVariable (Operation ):
26722807 """Read a variable."""
@@ -4014,6 +4149,9 @@ def log_variable_sizes(var_list, tag, verbose=True):
40144149 tf .logging .info ("Weight %s\t shape %s\t size %d" ,
40154150 v .name .ljust (80 ),
40164151 str (v .shape ).ljust (30 ), v_size )
4152+ if isinstance (v , StackedVariable ):
4153+ for n in v .original_names :
4154+ tf .logging .info (" " + n )
40174155 total_size += v_size
40184156 tf .logging .info ("%s Total size: %d" , tag , total_size )
40194157
0 commit comments