@@ -402,90 +402,165 @@ def unique_name(self, name, mark_as_used=True):
402402
403403 return name
404404
405- def rewrite_stack_variables (self , max_combined_size = 2 ** 30 ):
405+ def rewrite_stack_variables (self ,
406+ max_combined_variable_size = 2 ** 30 ,
407+ max_combined_slice_size = 2 ** 27 ,
408+ mesh_to_impl = None ):
406409 """Rewrite the current graph to combine variables.
407410
408411 This helps speed up graph construction times in the case of large meshes
409412 and large numbers of variables.
410413
411- This function should be called after the forward pass, (before any variable
412- assignemnts). Some similar variables are stacked to form larger variables.
413-
414- Variables created prior to this call are checkpointed as separate variables,
415- even though they are combined internally. So the checkpoints are
416- compatible for inference purposes with/without this call. However, the
417- optimizer accumulators, which are created after this call are checkpointed
418- as combined variables.
414+ This function should be called after graph construction (it is called by
415+ default in the Lowering constuctor).
419416
420417 When we find a set of variables with the same shape/dtype/etc, we replace
421418 them with one StackedVariable and an "unstack" operation. The
422- StackedVariable has multiple master variables (so as to maintain
423- checkpiont compatibility), but only one slice variable per device. We
424- point the inputs of later operations to the outputs of the
425- "unstack" operations, instead of the outputs of the defunct single
426- variables.
419+ StackedVariable has multiple master variables (so as to maintain checkpiont
420+ compatibility), but only one slice variable per device. We point the inputs
421+ of later operations to the outputs of the "unstack" operations, instead of
422+ the outputs of the defunct single variables.
423+
424+ In order for variables to be combinable, they must be set in the same Assign
425+ operation(s) - so it is necessary to call mtf.grouped_assign() from the
426+ optimizer instead of many separate calls to mtf.assign(). The assign
427+ operations get rewritten to set the appropriate stacked variables.
427428
428- TODO(noam, ylc ): Rewrite assignments as well, so that this can be applied at
429- the end of graph construction and be fully checkpoint-compatible.
430- Alternatively, find another solution for speeding up graph construction .
429+ TODO(noam): Combining to larger sizes seems to cause errors on TPU.
430+ debug this. Perhaps we should try to keep the combined master variables
431+ on the same device .
431432
432433 Args:
433- max_combined_size: an integer - maximum size for combined variables.
434+ max_combined_variable_size: an integer
435+ max_combined_slice_size: an integer
436+ mesh_to_impl: an optional dictionary from Mesh to MeshImpl
434437 """
438+ # pylint: disable=protected-access
435439 all_variables = self ._all_variables
436440 operations = self ._operations
437441 self ._operations = []
438442 self ._all_variables = []
439443 self ._trainable_variables = []
444+ # We can only stack varaibles which share the same set of assignment
445+ # operations.
446+ var_to_assign_ops = collections .defaultdict (str )
447+ for op in operations :
448+ if isinstance (op , Assign ):
449+ for v in op ._variables :
450+ var_to_assign_ops [v .name ] += op .name + ", "
451+ # Two variables with the same "key" can be stacked together.
440452 def var_key (v ):
441- return str ([v .shape ,
453+ return str ([v .mesh ,
454+ v .shape ,
442455 v .master_dtype ,
443456 v .slice_dtype ,
444457 v .activation_dtype ,
445- v .trainable ])
446- key_to_vars = collections .defaultdict (list )
458+ v .trainable ,
459+ var_to_assign_ops [v ]])
460+ key_to_vars = collections .defaultdict (collections .deque )
447461 for v in all_variables :
448462 key_to_vars [var_key (v )].append (v )
449- deleted_vars = set ()
450463 # We need to point the inputs of other operations at the outputs of unstack
451464 # instead of the outputs of the deleted Variables. We construct this
452465 # mapping from old input tensors to new input tensors.
453466 tensor_mapping = {}
467+ # maps from old variable name to (stacked_variable, position)
468+ individual_to_stacked = {}
454469 for op in operations :
455- if isinstance (op , Assign ):
456- raise ValueError ("stack_variables() should be called before any "
457- "variable assignment." )
458470 if isinstance (op , StackedVariable ):
459471 raise ValueError ("stack_variables() should not be called twice." )
460- if isinstance (op , Variable ):
461- if op in deleted_vars :
472+ elif isinstance (op , Variable ):
473+ if op . name in individual_to_stacked :
462474 continue
463475 similar_vars = key_to_vars [var_key (op )]
464- num_to_stack = max (1 , min (
465- len (similar_vars ),
466- max_combined_size // op .shape .size ))
467- to_stack = similar_vars [:num_to_stack ]
468- key_to_vars [var_key (op )] = similar_vars [num_to_stack :]
476+ num_to_stack = len (similar_vars )
477+ if max_combined_variable_size is not None :
478+ num_to_stack = min (
479+ num_to_stack , max_combined_variable_size // op .shape .size )
480+ if mesh_to_impl is not None :
481+ mesh_impl = mesh_to_impl [op .mesh ]
482+ if mesh_impl .size == 1 :
483+ num_to_stack = 1 # no point in stacking for single processors.
484+ slice_size = mesh_impl .slice_size (op .shape )
485+ num_to_stack = min (
486+ num_to_stack , max_combined_slice_size // slice_size )
487+ num_to_stack = max (1 , num_to_stack )
488+ to_stack = [similar_vars .popleft () for _ in xrange (num_to_stack )]
469489 if num_to_stack > 1 :
470490 stacked_var = StackedVariable (to_stack )
471491 stack_dim = stacked_var .shape .dims [0 ]
472- deleted_vars .update (to_stack )
473492 unstacked = unstack (stacked_var .outputs [0 ], stack_dim )
474493 for v , t in zip (to_stack , unstacked ):
475494 tensor_mapping [v .outputs [0 ]] = t
495+ for idx , v in enumerate (to_stack ):
496+ individual_to_stacked [v .name ] = stacked_var , idx
476497 else :
498+ assert op == to_stack [0 ]
477499 self ._operations .append (op )
478500 self ._all_variables .append (op )
479501 if op .trainable :
480- self .trainable_variables .append (op )
502+ self ._trainable_variables .append (op )
481503 else :
482- self ._operations .append (op )
483504 # Point inputs of other operations to the outputs of unstack.
484- # pylint: disable=protected-access
485505 for i in xrange (len (op ._inputs )):
486506 if op ._inputs [i ] in tensor_mapping :
487507 op ._inputs [i ] = tensor_mapping [op ._inputs [i ]]
488- # pylint: enable=protected-access
508+ if isinstance (op , Assign ):
509+ # Rewrite the grouped assignment to stack up the values and then
510+ # assign to the stacked variables.
511+ new_variables = []
512+ new_values = []
513+ var_to_val = dict (zip ([v .name for v in op ._variables ], op ._inputs ))
514+ for var , val in zip (op ._variables , op ._inputs ):
515+ if var .name in individual_to_stacked :
516+ stacked_var , pos = individual_to_stacked [var .name ]
517+ if pos == 0 :
518+ vals = [var_to_val [n ] for n in stacked_var .original_names ]
519+ new_variables .append (stacked_var )
520+ new_values .append (
521+ stack (vals , stacked_var .shape .dims [0 ].name , 0 ))
522+ else :
523+ new_variables .append (var )
524+ new_values .append (val )
525+ op ._variables = new_variables
526+ op ._inputs = new_values
527+ self ._operations .append (op )
528+ # pylint: enable=protected-access
529+
530+ def combine_assignments (self , assignments ):
531+ """Rewrite the current graph to combine "Assign" operations.
532+
533+ Combine similar Assign operations into grouped Assign operations.
534+ This is useful when using the rewrite_stack_variables() optimization,
535+ since variables can only be stacked if they are present in the same set
536+ of Assign operations.
537+
538+ This function takes a list of Assign operations and returns a possibly
539+ shorter list of Assign operations. The input Assignment operations
540+ are removed from the graph and become invalid.
541+
542+ Args:
543+ assignments: a list of Assign objects
544+ Returns:
545+ a list of Assign objects
546+ """
547+ group_by_fn = collections .defaultdict (list )
548+ for a in assignments :
549+ if not isinstance (a , Assign ):
550+ raise ValueError ("ops should be instances of mtf.Assign" )
551+ group_by_fn [a .assign_fn ].append (a )
552+ assignments_set = set (assignments )
553+ self ._operations = [
554+ op for op in self ._operations if op not in assignments_set ]
555+ ret = []
556+ for fn , ops in six .iteritems (group_by_fn ):
557+ variables = []
558+ values = []
559+ for a in ops :
560+ variables .extend (a .variables )
561+ values .extend (a .inputs )
562+ ret .append (Assign (variables , values , fn ))
563+ return ret
489564
490565
491566class Lowering (object ):
@@ -512,18 +587,23 @@ class Lowering(object):
512587 ```
513588 """
514589
515- def __init__ (self , graph , mesh_to_impl ):
590+ def __init__ (self , graph , mesh_to_impl , autostack = True ):
516591 """Creates a Lowering of a Graph.
517592
518593 Args:
519594 graph: Graph.
520595 mesh_to_impl: {Mesh: MeshImpl}. Keys are the Mesh's in the graph and
521596 their values are MeshImpl's, which map Tensor Dimension names to
522597 Mesh Dimension names.
598+ autostack: a boolean. If True, then the graph gets rewritten to
599+ reduce the number of variables (see rewrite_stack_variables()).
600+ This is a helpful performance optimization for large meshes.
523601 """
524602 # tf.logging.info("LOWERING GRAPH:\n%s" % graph.to_string)
525603 self .mesh_to_impl = mesh_to_impl # {Mesh: MeshImpl}
526604 self .graph = graph
605+ if autostack :
606+ self .autostack ()
527607 self ._counters = []
528608 self .tensors = {} # {Tensor: Mesh.LaidOutTensor}
529609 self .operations = {} # {Operation: tf.Operation}
@@ -537,7 +617,11 @@ def __init__(self, graph, mesh_to_impl):
537617 "output/%s" % type (op ).__name__ , self .laid_out_size (out ))
538618 self .add_counter ("output_unique/%s" % type (op ).__name__ , out .size )
539619 log_variable_sizes (
540- graph .trainable_variables , "Trainable Variables" , verbose = True )
620+ graph .trainable_variables , "Trainable Variables" , verbose = True ,
621+ mesh_to_impl = self .mesh_to_impl )
622+ log_variable_sizes (
623+ graph .all_variables , "All Variables" , verbose = False ,
624+ mesh_to_impl = self .mesh_to_impl )
541625 tf .logging .info ("Counters:\n " + pretty_print_counters (self ._counters ))
542626
543627 def mesh_impl (self , m ):
@@ -601,6 +685,10 @@ def verify_slice_shapes(self, tensor, laid_out_tensor):
601685 "Wrong slice shape: correct_shape = %s actual shape = %s"
602686 % (correct_shape , actual_shape ))
603687
688+ def autostack (self ):
689+ """Rewrite graph to stack similar variables (performance optimization)."""
690+ self .graph .rewrite_stack_variables (mesh_to_impl = self .mesh_to_impl )
691+
604692
605693class Mesh (object ):
606694 """A placeholder with no functionality.
@@ -766,6 +854,9 @@ def slice_begin(self, tensor_shape, pnum):
766854 dim_size // self .shape [mesh_axis ].size * coordinates [mesh_axis ])
767855 return ret
768856
857+ def slice_size (self , tensor_shape ):
858+ return list_product (self .slice_shape (tensor_shape ))
859+
769860 def laid_out_size (self , tensor_shape ):
770861 """Total size of all slices.
771862
@@ -2850,18 +2941,30 @@ def assign_sub_slice(variable, slice_var, val):
28502941
28512942
28522943class Assign (Operation ):
2853- """Assign to a variable ."""
2944+ """Assign to one or more variables ."""
28542945
2855- def __init__ (self , var , new_val , assign_fn = assign_slice , name = None ):
2856- super (Assign , self ).__init__ ([new_val ], var .mesh , name = name or "assign" )
2857- self ._var = var
2946+ def __init__ (self , variables , new_values , assign_fn = assign_slice , name = None ):
2947+ super (Assign , self ).__init__ (
2948+ new_values , variables [0 ].mesh , name = name or "assign" )
2949+ self ._variables = variables
28582950 self ._assign_fn = assign_fn
28592951 self ._outputs = []
28602952
28612953 def lower (self , lowering ):
2862- lowering .operations [self ] = lowering .variables [self ._var ].assign_to_slices (
2863- self ._assign_fn ,
2864- lowering .tensors [self .inputs [0 ]].to_laid_out_tensor ().all_slices )
2954+ ops = []
2955+ for var , val in zip (self ._variables , self .inputs ):
2956+ ops .append (lowering .variables [var ].assign_to_slices (
2957+ self ._assign_fn ,
2958+ lowering .tensors [val ].to_laid_out_tensor ().all_slices ))
2959+ lowering .operations [self ] = tf .group (ops )
2960+
2961+ @property
2962+ def assign_fn (self ):
2963+ return self ._assign_fn
2964+
2965+ @property
2966+ def variables (self ):
2967+ return self ._variables
28652968
28662969
28672970def assign (var , new_val , assign_fn = assign_slice ):
@@ -2881,7 +2984,7 @@ def assign(var, new_val, assign_fn=assign_slice):
28812984 var = var .operation
28822985 if not isinstance (var , Variable ):
28832986 raise ValueError ("var must be a mtf.Variable or its output Tensor." )
2884- return Assign (var , new_val , assign_fn = assign_fn )
2987+ return Assign ([ var ], [ new_val ] , assign_fn = assign_fn )
28852988
28862989
28872990def assign_add (var , new_val ):
@@ -4129,31 +4232,44 @@ def _cumprod(l):
41294232 return ret
41304233
41314234
4132- def log_variable_sizes (var_list , tag , verbose = True ):
4235+ def log_variable_sizes (var_list , tag , verbose = True , mesh_to_impl = None ):
41334236 """Log the sizes and shapes of variables, and the total size.
41344237
41354238 Args:
41364239 var_list: a list of variables; defaults to trainable_variables
41374240 tag: a string; defaults to "Trainable Variables"
41384241 verbose: bool, if True, log every weight; otherwise, log total size only.
4242+ mesh_to_impl: an optional map from Mesh to MeshImpl
41394243 """
41404244 if not var_list :
41414245 return
41424246
41434247 name_to_var = {v .name : v for v in var_list }
41444248 total_size = 0
4249+ total_slice_size = 0
41454250 for v_name in sorted (list (name_to_var )):
41464251 v = name_to_var [v_name ]
41474252 v_size = v .shape .size
4253+ if mesh_to_impl is not None :
4254+ slice_size = mesh_to_impl [v .mesh ].slice_size (v .shape )
4255+ else :
4256+ slice_size = 0
4257+ total_slice_size += slice_size
41484258 if verbose :
4149- tf .logging .info ("Weight %s\t shape %s\t size %d" ,
4150- v .name .ljust (80 ),
4151- str (v .shape ).ljust (30 ), v_size )
4259+ tf .logging .info (
4260+ "Variable %s size %s slice_size %s %s" ,
4261+ v .name .ljust (60 ),
4262+ str (v_size ).ljust (12 ),
4263+ str (slice_size ).ljust (12 ),
4264+ str (v .shape ).ljust (60 ))
41524265 if isinstance (v , StackedVariable ):
41534266 for n in v .original_names :
41544267 tf .logging .info (" " + n )
41554268 total_size += v_size
4156- tf .logging .info ("%s Total size: %d" , tag , total_size )
4269+ tf .logging .info ("%s count: %s Total size: %s Total slice_size: %s" ,
4270+ tag .ljust (30 ), str (len (var_list )).ljust (6 ),
4271+ str (total_size ).ljust (15 ),
4272+ str (total_slice_size ).ljust (15 ))
41574273
41584274
41594275class WhileLoopOperation (Operation ):
0 commit comments