Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 982bcb1

Browse files
nshazeerCopybara-Service
authored andcommitted
MTF rewrite_stack_variables() now handles assignments as well. Fully checkpoint-compatible and mathematically-compatible. Turned on by default via the "autostack" option in Lowering. Client code needs to change to call optimizer.apply_grads() instead of optimizer.apply_grad().
PiperOrigin-RevId: 223220539
1 parent e1da243 commit 982bcb1

File tree

4 files changed

+205
-59
lines changed

4 files changed

+205
-59
lines changed

examples/mnist.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def model_fn(features, labels, mode, params):
137137
var_grads = mtf.gradients(
138138
[loss], [v.outputs[0] for v in graph.trainable_variables])
139139
optimizer = mtf.optimize.AdafactorOptimizer()
140-
update_ops = []
141-
for grad, var in zip(var_grads, graph.trainable_variables):
142-
update_ops.extend(optimizer.apply_grad(grad, var))
140+
update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
143141

144142
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
145143
restore_hook = mtf.MtfRestoreHook(lowering)

examples/toy_model_tpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ def model_fn(features, labels, mode, params):
173173
else:
174174
assert FLAGS.optimizer == 'SGD'
175175
optimizer = mtf.optimize.SgdOptimizer(lr=1e-4)
176-
update_ops = []
177-
for grad, var in zip(var_grads, graph.trainable_variables):
178-
update_ops.extend(optimizer.apply_grad(grad, var))
176+
update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
179177
else:
180178
# for now, we can only export fully-replicated tensors.
181179
fully_replicated_logits = mtf.anonymize(logits)

mesh_tensorflow/ops.py

Lines changed: 168 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -402,90 +402,165 @@ def unique_name(self, name, mark_as_used=True):
402402

403403
return name
404404

405-
def rewrite_stack_variables(self, max_combined_size=2 ** 30):
405+
def rewrite_stack_variables(self,
406+
max_combined_variable_size=2 ** 30,
407+
max_combined_slice_size=2 ** 27,
408+
mesh_to_impl=None):
406409
"""Rewrite the current graph to combine variables.
407410
408411
This helps speed up graph construction times in the case of large meshes
409412
and large numbers of variables.
410413
411-
This function should be called after the forward pass, (before any variable
412-
assignemnts). Some similar variables are stacked to form larger variables.
413-
414-
Variables created prior to this call are checkpointed as separate variables,
415-
even though they are combined internally. So the checkpoints are
416-
compatible for inference purposes with/without this call. However, the
417-
optimizer accumulators, which are created after this call are checkpointed
418-
as combined variables.
414+
This function should be called after graph construction (it is called by
415+
default in the Lowering constuctor).
419416
420417
When we find a set of variables with the same shape/dtype/etc, we replace
421418
them with one StackedVariable and an "unstack" operation. The
422-
StackedVariable has multiple master variables (so as to maintain
423-
checkpiont compatibility), but only one slice variable per device. We
424-
point the inputs of later operations to the outputs of the
425-
"unstack" operations, instead of the outputs of the defunct single
426-
variables.
419+
StackedVariable has multiple master variables (so as to maintain checkpiont
420+
compatibility), but only one slice variable per device. We point the inputs
421+
of later operations to the outputs of the "unstack" operations, instead of
422+
the outputs of the defunct single variables.
423+
424+
In order for variables to be combinable, they must be set in the same Assign
425+
operation(s) - so it is necessary to call mtf.grouped_assign() from the
426+
optimizer instead of many separate calls to mtf.assign(). The assign
427+
operations get rewritten to set the appropriate stacked variables.
427428
428-
TODO(noam, ylc): Rewrite assignments as well, so that this can be applied at
429-
the end of graph construction and be fully checkpoint-compatible.
430-
Alternatively, find another solution for speeding up graph construction.
429+
TODO(noam): Combining to larger sizes seems to cause errors on TPU.
430+
debug this. Perhaps we should try to keep the combined master variables
431+
on the same device.
431432
432433
Args:
433-
max_combined_size: an integer - maximum size for combined variables.
434+
max_combined_variable_size: an integer
435+
max_combined_slice_size: an integer
436+
mesh_to_impl: an optional dictionary from Mesh to MeshImpl
434437
"""
438+
# pylint: disable=protected-access
435439
all_variables = self._all_variables
436440
operations = self._operations
437441
self._operations = []
438442
self._all_variables = []
439443
self._trainable_variables = []
444+
# We can only stack varaibles which share the same set of assignment
445+
# operations.
446+
var_to_assign_ops = collections.defaultdict(str)
447+
for op in operations:
448+
if isinstance(op, Assign):
449+
for v in op._variables:
450+
var_to_assign_ops[v.name] += op.name + ", "
451+
# Two variables with the same "key" can be stacked together.
440452
def var_key(v):
441-
return str([v.shape,
453+
return str([v.mesh,
454+
v.shape,
442455
v.master_dtype,
443456
v.slice_dtype,
444457
v.activation_dtype,
445-
v.trainable])
446-
key_to_vars = collections.defaultdict(list)
458+
v.trainable,
459+
var_to_assign_ops[v]])
460+
key_to_vars = collections.defaultdict(collections.deque)
447461
for v in all_variables:
448462
key_to_vars[var_key(v)].append(v)
449-
deleted_vars = set()
450463
# We need to point the inputs of other operations at the outputs of unstack
451464
# instead of the outputs of the deleted Variables. We construct this
452465
# mapping from old input tensors to new input tensors.
453466
tensor_mapping = {}
467+
# maps from old variable name to (stacked_variable, position)
468+
individual_to_stacked = {}
454469
for op in operations:
455-
if isinstance(op, Assign):
456-
raise ValueError("stack_variables() should be called before any "
457-
"variable assignment.")
458470
if isinstance(op, StackedVariable):
459471
raise ValueError("stack_variables() should not be called twice.")
460-
if isinstance(op, Variable):
461-
if op in deleted_vars:
472+
elif isinstance(op, Variable):
473+
if op.name in individual_to_stacked:
462474
continue
463475
similar_vars = key_to_vars[var_key(op)]
464-
num_to_stack = max(1, min(
465-
len(similar_vars),
466-
max_combined_size // op.shape.size))
467-
to_stack = similar_vars[:num_to_stack]
468-
key_to_vars[var_key(op)] = similar_vars[num_to_stack:]
476+
num_to_stack = len(similar_vars)
477+
if max_combined_variable_size is not None:
478+
num_to_stack = min(
479+
num_to_stack, max_combined_variable_size // op.shape.size)
480+
if mesh_to_impl is not None:
481+
mesh_impl = mesh_to_impl[op.mesh]
482+
if mesh_impl.size == 1:
483+
num_to_stack = 1 # no point in stacking for single processors.
484+
slice_size = mesh_impl.slice_size(op.shape)
485+
num_to_stack = min(
486+
num_to_stack, max_combined_slice_size // slice_size)
487+
num_to_stack = max(1, num_to_stack)
488+
to_stack = [similar_vars.popleft() for _ in xrange(num_to_stack)]
469489
if num_to_stack > 1:
470490
stacked_var = StackedVariable(to_stack)
471491
stack_dim = stacked_var.shape.dims[0]
472-
deleted_vars.update(to_stack)
473492
unstacked = unstack(stacked_var.outputs[0], stack_dim)
474493
for v, t in zip(to_stack, unstacked):
475494
tensor_mapping[v.outputs[0]] = t
495+
for idx, v in enumerate(to_stack):
496+
individual_to_stacked[v.name] = stacked_var, idx
476497
else:
498+
assert op == to_stack[0]
477499
self._operations.append(op)
478500
self._all_variables.append(op)
479501
if op.trainable:
480-
self.trainable_variables.append(op)
502+
self._trainable_variables.append(op)
481503
else:
482-
self._operations.append(op)
483504
# Point inputs of other operations to the outputs of unstack.
484-
# pylint: disable=protected-access
485505
for i in xrange(len(op._inputs)):
486506
if op._inputs[i] in tensor_mapping:
487507
op._inputs[i] = tensor_mapping[op._inputs[i]]
488-
# pylint: enable=protected-access
508+
if isinstance(op, Assign):
509+
# Rewrite the grouped assignment to stack up the values and then
510+
# assign to the stacked variables.
511+
new_variables = []
512+
new_values = []
513+
var_to_val = dict(zip([v.name for v in op._variables], op._inputs))
514+
for var, val in zip(op._variables, op._inputs):
515+
if var.name in individual_to_stacked:
516+
stacked_var, pos = individual_to_stacked[var.name]
517+
if pos == 0:
518+
vals = [var_to_val[n] for n in stacked_var.original_names]
519+
new_variables.append(stacked_var)
520+
new_values.append(
521+
stack(vals, stacked_var.shape.dims[0].name, 0))
522+
else:
523+
new_variables.append(var)
524+
new_values.append(val)
525+
op._variables = new_variables
526+
op._inputs = new_values
527+
self._operations.append(op)
528+
# pylint: enable=protected-access
529+
530+
def combine_assignments(self, assignments):
531+
"""Rewrite the current graph to combine "Assign" operations.
532+
533+
Combine similar Assign operations into grouped Assign operations.
534+
This is useful when using the rewrite_stack_variables() optimization,
535+
since variables can only be stacked if they are present in the same set
536+
of Assign operations.
537+
538+
This function takes a list of Assign operations and returns a possibly
539+
shorter list of Assign operations. The input Assignment operations
540+
are removed from the graph and become invalid.
541+
542+
Args:
543+
assignments: a list of Assign objects
544+
Returns:
545+
a list of Assign objects
546+
"""
547+
group_by_fn = collections.defaultdict(list)
548+
for a in assignments:
549+
if not isinstance(a, Assign):
550+
raise ValueError("ops should be instances of mtf.Assign")
551+
group_by_fn[a.assign_fn].append(a)
552+
assignments_set = set(assignments)
553+
self._operations = [
554+
op for op in self._operations if op not in assignments_set]
555+
ret = []
556+
for fn, ops in six.iteritems(group_by_fn):
557+
variables = []
558+
values = []
559+
for a in ops:
560+
variables.extend(a.variables)
561+
values.extend(a.inputs)
562+
ret.append(Assign(variables, values, fn))
563+
return ret
489564

490565

491566
class Lowering(object):
@@ -512,18 +587,23 @@ class Lowering(object):
512587
```
513588
"""
514589

515-
def __init__(self, graph, mesh_to_impl):
590+
def __init__(self, graph, mesh_to_impl, autostack=True):
516591
"""Creates a Lowering of a Graph.
517592
518593
Args:
519594
graph: Graph.
520595
mesh_to_impl: {Mesh: MeshImpl}. Keys are the Mesh's in the graph and
521596
their values are MeshImpl's, which map Tensor Dimension names to
522597
Mesh Dimension names.
598+
autostack: a boolean. If True, then the graph gets rewritten to
599+
reduce the number of variables (see rewrite_stack_variables()).
600+
This is a helpful performance optimization for large meshes.
523601
"""
524602
# tf.logging.info("LOWERING GRAPH:\n%s" % graph.to_string)
525603
self.mesh_to_impl = mesh_to_impl # {Mesh: MeshImpl}
526604
self.graph = graph
605+
if autostack:
606+
self.autostack()
527607
self._counters = []
528608
self.tensors = {} # {Tensor: Mesh.LaidOutTensor}
529609
self.operations = {} # {Operation: tf.Operation}
@@ -537,7 +617,11 @@ def __init__(self, graph, mesh_to_impl):
537617
"output/%s" % type(op).__name__, self.laid_out_size(out))
538618
self.add_counter("output_unique/%s" % type(op).__name__, out.size)
539619
log_variable_sizes(
540-
graph.trainable_variables, "Trainable Variables", verbose=True)
620+
graph.trainable_variables, "Trainable Variables", verbose=True,
621+
mesh_to_impl=self.mesh_to_impl)
622+
log_variable_sizes(
623+
graph.all_variables, "All Variables", verbose=False,
624+
mesh_to_impl=self.mesh_to_impl)
541625
tf.logging.info("Counters:\n" + pretty_print_counters(self._counters))
542626

543627
def mesh_impl(self, m):
@@ -601,6 +685,10 @@ def verify_slice_shapes(self, tensor, laid_out_tensor):
601685
"Wrong slice shape: correct_shape = %s actual shape = %s"
602686
% (correct_shape, actual_shape))
603687

688+
def autostack(self):
689+
"""Rewrite graph to stack similar variables (performance optimization)."""
690+
self.graph.rewrite_stack_variables(mesh_to_impl=self.mesh_to_impl)
691+
604692

605693
class Mesh(object):
606694
"""A placeholder with no functionality.
@@ -766,6 +854,9 @@ def slice_begin(self, tensor_shape, pnum):
766854
dim_size // self.shape[mesh_axis].size * coordinates[mesh_axis])
767855
return ret
768856

857+
def slice_size(self, tensor_shape):
858+
return list_product(self.slice_shape(tensor_shape))
859+
769860
def laid_out_size(self, tensor_shape):
770861
"""Total size of all slices.
771862
@@ -2850,18 +2941,30 @@ def assign_sub_slice(variable, slice_var, val):
28502941

28512942

28522943
class Assign(Operation):
2853-
"""Assign to a variable."""
2944+
"""Assign to one or more variables."""
28542945

2855-
def __init__(self, var, new_val, assign_fn=assign_slice, name=None):
2856-
super(Assign, self).__init__([new_val], var.mesh, name=name or "assign")
2857-
self._var = var
2946+
def __init__(self, variables, new_values, assign_fn=assign_slice, name=None):
2947+
super(Assign, self).__init__(
2948+
new_values, variables[0].mesh, name=name or "assign")
2949+
self._variables = variables
28582950
self._assign_fn = assign_fn
28592951
self._outputs = []
28602952

28612953
def lower(self, lowering):
2862-
lowering.operations[self] = lowering.variables[self._var].assign_to_slices(
2863-
self._assign_fn,
2864-
lowering.tensors[self.inputs[0]].to_laid_out_tensor().all_slices)
2954+
ops = []
2955+
for var, val in zip(self._variables, self.inputs):
2956+
ops.append(lowering.variables[var].assign_to_slices(
2957+
self._assign_fn,
2958+
lowering.tensors[val].to_laid_out_tensor().all_slices))
2959+
lowering.operations[self] = tf.group(ops)
2960+
2961+
@property
2962+
def assign_fn(self):
2963+
return self._assign_fn
2964+
2965+
@property
2966+
def variables(self):
2967+
return self._variables
28652968

28662969

28672970
def assign(var, new_val, assign_fn=assign_slice):
@@ -2881,7 +2984,7 @@ def assign(var, new_val, assign_fn=assign_slice):
28812984
var = var.operation
28822985
if not isinstance(var, Variable):
28832986
raise ValueError("var must be a mtf.Variable or its output Tensor.")
2884-
return Assign(var, new_val, assign_fn=assign_fn)
2987+
return Assign([var], [new_val], assign_fn=assign_fn)
28852988

28862989

28872990
def assign_add(var, new_val):
@@ -4129,31 +4232,44 @@ def _cumprod(l):
41294232
return ret
41304233

41314234

4132-
def log_variable_sizes(var_list, tag, verbose=True):
4235+
def log_variable_sizes(var_list, tag, verbose=True, mesh_to_impl=None):
41334236
"""Log the sizes and shapes of variables, and the total size.
41344237
41354238
Args:
41364239
var_list: a list of variables; defaults to trainable_variables
41374240
tag: a string; defaults to "Trainable Variables"
41384241
verbose: bool, if True, log every weight; otherwise, log total size only.
4242+
mesh_to_impl: an optional map from Mesh to MeshImpl
41394243
"""
41404244
if not var_list:
41414245
return
41424246

41434247
name_to_var = {v.name: v for v in var_list}
41444248
total_size = 0
4249+
total_slice_size = 0
41454250
for v_name in sorted(list(name_to_var)):
41464251
v = name_to_var[v_name]
41474252
v_size = v.shape.size
4253+
if mesh_to_impl is not None:
4254+
slice_size = mesh_to_impl[v.mesh].slice_size(v.shape)
4255+
else:
4256+
slice_size = 0
4257+
total_slice_size += slice_size
41484258
if verbose:
4149-
tf.logging.info("Weight %s\tshape %s\tsize %d",
4150-
v.name.ljust(80),
4151-
str(v.shape).ljust(30), v_size)
4259+
tf.logging.info(
4260+
"Variable %s size %s slice_size %s %s",
4261+
v.name.ljust(60),
4262+
str(v_size).ljust(12),
4263+
str(slice_size).ljust(12),
4264+
str(v.shape).ljust(60))
41524265
if isinstance(v, StackedVariable):
41534266
for n in v.original_names:
41544267
tf.logging.info(" " + n)
41554268
total_size += v_size
4156-
tf.logging.info("%s Total size: %d", tag, total_size)
4269+
tf.logging.info("%s count: %s Total size: %s Total slice_size: %s",
4270+
tag.ljust(30), str(len(var_list)).ljust(6),
4271+
str(total_size).ljust(15),
4272+
str(total_slice_size).ljust(15))
41574273

41584274

41594275
class WhileLoopOperation(Operation):

0 commit comments

Comments
 (0)