Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 3d8ce2b

Browse files
nshazeerCopybara-Service
authored andcommitted
Option to combine variables in mesh-tensorflow to improve graph construction time for models with many variables on many cores. A more transparent solution would still be preferable. To use this feature, call Graph.rewrite_stack_variables() after the forwards pass. The graph gets rewritten to have fewer variables.
PiperOrigin-RevId: 223029564
1 parent 07417c9 commit 3d8ce2b

File tree

4 files changed

+168
-33
lines changed

4 files changed

+168
-33
lines changed

mesh_tensorflow/ops.py

Lines changed: 150 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ class Graph(object):
347347

348348
def __init__(self):
349349
self._operations = []
350-
self._tensors = []
351350
self._trainable_variables = []
352351
self._all_variables = []
353352
# Maps a name used in the graph to the next id to use for that name.
@@ -360,10 +359,6 @@ def __repr__(self):
360359
def operations(self):
361360
return self._operations
362361

363-
@property
364-
def tensors(self):
365-
return self._tensors
366-
367362
@property
368363
def trainable_variables(self):
369364
return self._trainable_variables
@@ -407,6 +402,91 @@ def unique_name(self, name, mark_as_used=True):
407402

408403
return name
409404

405+
def rewrite_stack_variables(self, max_combined_size=2 ** 30):
406+
"""Rewrite the current graph to combine variables.
407+
408+
This helps speed up graph construction times in the case of large meshes
409+
and large numbers of variables.
410+
411+
This function should be called after the forward pass, (before any variable
412+
assignemnts). Some similar variables are stacked to form larger variables.
413+
414+
Variables created prior to this call are checkpointed as separate variables,
415+
even though they are combined internally. So the checkpoints are
416+
compatible for inference purposes with/without this call. However, the
417+
optimizer accumulators, which are created after this call are checkpointed
418+
as combined variables.
419+
420+
When we find a set of variables with the same shape/dtype/etc, we replace
421+
them with one StackedVariable and an "unstack" operation. The
422+
StackedVariable has multiple master variables (so as to maintain
423+
checkpiont compatibility), but only one slice variable per device. We
424+
point the inputs of later operations to the outputs of the
425+
"unstack" operations, instead of the outputs of the defunct single
426+
variables.
427+
428+
TODO(noam, ylc): Rewrite assignments as well, so that this can be applied at
429+
the end of graph construction and be fully checkpoint-compatible.
430+
Alternatively, find another solution for speeding up graph construction.
431+
432+
Args:
433+
max_combined_size: an integer - maximum size for combined variables.
434+
"""
435+
all_variables = self._all_variables
436+
operations = self._operations
437+
self._operations = []
438+
self._all_variables = []
439+
self._trainable_variables = []
440+
def var_key(v):
441+
return str([v.shape,
442+
v.master_dtype,
443+
v.slice_dtype,
444+
v.activation_dtype,
445+
v.trainable])
446+
key_to_vars = collections.defaultdict(list)
447+
for v in all_variables:
448+
key_to_vars[var_key(v)].append(v)
449+
deleted_vars = set()
450+
# We need to point the inputs of other operations at the outputs of unstack
451+
# instead of the outputs of the deleted Variables. We construct this
452+
# mapping from old input tensors to new input tensors.
453+
tensor_mapping = {}
454+
for op in operations:
455+
if isinstance(op, Assign):
456+
raise ValueError("stack_variables() should be called before any "
457+
"variable assignment.")
458+
if isinstance(op, StackedVariable):
459+
raise ValueError("stack_variables() should not be called twice.")
460+
if isinstance(op, Variable):
461+
if op in deleted_vars:
462+
continue
463+
similar_vars = key_to_vars[var_key(op)]
464+
num_to_stack = max(1, min(
465+
len(similar_vars),
466+
max_combined_size // op.shape.size))
467+
to_stack = similar_vars[:num_to_stack]
468+
key_to_vars[var_key(op)] = similar_vars[num_to_stack:]
469+
if num_to_stack > 1:
470+
stacked_var = StackedVariable(to_stack)
471+
stack_dim = stacked_var.shape.dims[0]
472+
deleted_vars.update(to_stack)
473+
unstacked = unstack(stacked_var.outputs[0], stack_dim)
474+
for v, t in zip(to_stack, unstacked):
475+
tensor_mapping[v.outputs[0]] = t
476+
else:
477+
self._operations.append(op)
478+
self._all_variables.append(op)
479+
if op.trainable:
480+
self.trainable_variables.append(op)
481+
else:
482+
self._operations.append(op)
483+
# Point inputs of other operations to the outputs of unstack.
484+
# pylint: disable=protected-access
485+
for i in xrange(len(op._inputs)):
486+
if op._inputs[i] in tensor_mapping:
487+
op._inputs[i] = tensor_mapping[op._inputs[i]]
488+
# pylint: enable=protected-access
489+
410490

411491
class Lowering(object):
412492
"""Lowering of a Graph from Mesh-TensorFlow to TensorFlow.
@@ -1087,7 +1167,6 @@ def __init__(self, operation, shape, dtype, name=None, index=0):
10871167
if name is None:
10881168
name = self.operation.name + ":" + str(index)
10891169
self._name = name
1090-
self._mesh.graph.tensors.append(self)
10911170

10921171
@property
10931172
def shape(self):
@@ -2204,7 +2283,7 @@ def conv2d(conv_input, conv_filter, strides, padding, name=None):
22042283

22052284

22062285
class Conv2dBackpropInputOperation(Operation):
2207-
"""like tf.nn.conv2d_backprop_input"""
2286+
"""like tf.nn.conv2d_backprop_input."""
22082287

22092288
def __init__(self, input_shape, conv_filter, dy, strides, padding, name=None):
22102289
super(Conv2dBackpropInputOperation, self).__init__(
@@ -2618,11 +2697,12 @@ def __init__(self, mesh, name, shape, master_dtype, slice_dtype,
26182697
self._slice_dtype = slice_dtype
26192698
self._activation_dtype = activation_dtype
26202699
self._trainable = trainable
2621-
with tf.device(mesh.variable_placer_fn), utils.outside_all_rewrites():
2622-
self.master = tf.get_variable(
2623-
name, shape.to_integer_list, dtype=master_dtype,
2624-
initializer=initializer, **kwargs)
2625-
self._name = self.master.name[:self.master.name.find(":")]
2700+
if not isinstance(self, StackedVariable):
2701+
with tf.device(mesh.variable_placer_fn), utils.outside_all_rewrites():
2702+
self._master = tf.get_variable(
2703+
name, shape.to_integer_list, dtype=master_dtype,
2704+
initializer=initializer, **kwargs)
2705+
self._name = self._master.name[:self._master.name.find(":")]
26262706
self._outputs = [Tensor(self, shape, activation_dtype)]
26272707
self.graph.all_variables.append(self)
26282708
if trainable:
@@ -2667,6 +2747,61 @@ def slice_dtype(self):
26672747
def activation_dtype(self):
26682748
return self._activation_dtype
26692749

2750+
@property
2751+
def trainable(self):
2752+
return self._trainable
2753+
2754+
@property
2755+
def master_device(self):
2756+
return self._master.device
2757+
2758+
def get_master(self):
2759+
return self._master
2760+
2761+
def assign_to_master(self, val):
2762+
return tf.assign(self._master, val)
2763+
2764+
2765+
class StackedVariable(Variable):
2766+
"""A Variable which combines many variables into one.
2767+
2768+
This is a performance optimization to reduce the time associated with large
2769+
numbers of slice variables. See Graph.rewrite_stack_variables() for usage.
2770+
"""
2771+
2772+
def __init__(self, vs):
2773+
"""Create a StackedVariable.
2774+
2775+
Args:
2776+
vs: a list of Variables
2777+
"""
2778+
shape = Shape([Dimension("stacked", len(vs))] + vs[0].shape.dims)
2779+
name = "stacked/" + vs[0].name
2780+
# TODO(noam): verify that vs are the same shape, etc.
2781+
super(StackedVariable, self).__init__(
2782+
vs[0].mesh, name, shape, vs[0].master_dtype, vs[0].slice_dtype,
2783+
vs[0].activation_dtype, None, vs[0].trainable)
2784+
self._name = name
2785+
self._masters = [v.get_master() for v in vs]
2786+
self._original_names = [v.name for v in vs]
2787+
2788+
@property
2789+
def original_names(self):
2790+
return self._original_names
2791+
2792+
@property
2793+
def master_device(self):
2794+
return self._masters[0].device
2795+
2796+
def get_master(self):
2797+
with tf.device(self.master_device):
2798+
return tf.stack(self._masters)
2799+
2800+
def assign_to_master(self, val):
2801+
return tf.group([
2802+
tf.assign(var_slice, val_slice) for var_slice, val_slice
2803+
in zip(self._masters, tf.unstack(val))])
2804+
26702805

26712806
class ReadVariable(Operation):
26722807
"""Read a variable."""
@@ -4014,6 +4149,9 @@ def log_variable_sizes(var_list, tag, verbose=True):
40144149
tf.logging.info("Weight %s\tshape %s\tsize %d",
40154150
v.name.ljust(80),
40164151
str(v.shape).ljust(30), v_size)
4152+
if isinstance(v, StackedVariable):
4153+
for n in v.original_names:
4154+
tf.logging.info(" " + n)
40174155
total_size += v_size
40184156
tf.logging.info("%s Total size: %d", tag, total_size)
40194157

mesh_tensorflow/ops_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,22 @@ def testTensorLayout(self):
9393

9494
def testGraph(self):
9595
graph = mtf.Graph()
96-
self.assertLen(graph.operations, 0)
97-
self.assertLen(graph.tensors, 0)
98-
self.assertLen(graph.trainable_variables, 0)
99-
self.assertLen(graph.all_variables, 0)
96+
self.assertEmpty(graph.operations)
97+
self.assertEmpty(graph.trainable_variables)
98+
self.assertEmpty(graph.all_variables)
10099
mesh = mtf.Mesh(graph, "mesh_test")
101100
_ = mtf.import_tf_tensor(mesh,
102101
tf_tensor=tf.constant(0.),
103102
shape=mtf.Shape([]))
104103
self.assertLen(graph.operations, 1)
105-
self.assertLen(graph.tensors, 1)
106-
self.assertLen(graph.trainable_variables, 0)
107-
self.assertLen(graph.all_variables, 0)
104+
self.assertEmpty(graph.trainable_variables)
105+
self.assertEmpty(graph.all_variables)
108106
_ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
109107
self.assertLen(graph.operations, 2)
110-
self.assertLen(graph.tensors, 2)
111108
self.assertLen(graph.trainable_variables, 1)
112109
self.assertLen(graph.all_variables, 1)
113110
_ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False)
114111
self.assertLen(graph.operations, 3)
115-
self.assertLen(graph.tensors, 3)
116112
self.assertLen(graph.trainable_variables, 1)
117113
self.assertLen(graph.all_variables, 2)
118114

@@ -172,7 +168,7 @@ def testMeshImpl(self):
172168
("heads", "model")])
173169
mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
174170
self.assertEqual(mesh_impl.shape, shape)
175-
self.assertEqual(mesh_impl.ndims, len(shape))
171+
self.assertLen(shape, mesh_impl.ndims)
176172
self.assertEqual(mesh_impl.layout_rules, layout_rules)
177173
self.assertEqual(mesh_impl.size, shape.size)
178174
self.assertTrue(mesh_impl.supports_control_dependencies)

mesh_tensorflow/placement_mesh_impl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, variable, mesh_impl):
7979
if self.slice_is_master:
8080
tf.logging.info(
8181
"Single slice is indentical to master - avoid creating extra vars.")
82-
slices = [variable.master]
82+
slices = [variable.get_master()]
8383
self._laid_out_tensor = mesh_impl.LaidOutTensor(slices)
8484
self._copy_slices_to_master = tf.group([])
8585
self._copy_master_to_slices = tf.group([])
@@ -96,9 +96,9 @@ def __init__(self, variable, mesh_impl):
9696
tf.cast(slices[-1], variable.master_dtype))
9797
self._laid_out_tensor = mesh_impl.LaidOutTensor(slices)
9898
self._copy_master_to_slices = self.assign_to_slices(
99-
mtf.assign_slice, mesh_impl.make_slices(variable.master, shape))
100-
self._copy_slices_to_master = tf.assign(
101-
variable.master,
99+
mtf.assign_slice, mesh_impl.make_slices(
100+
variable.get_master(), shape))
101+
self._copy_slices_to_master = variable.assign_to_master(
102102
mesh_impl.combine_slices(slices_with_master_dtype, shape))
103103

104104
@property
@@ -108,7 +108,9 @@ def slice_is_master(self):
108108
return False
109109
if self._variable.master_dtype != self._variable.slice_dtype:
110110
return False
111-
master_device = self._variable.master.device
111+
if isinstance(self._variable, mtf.StackedVariable):
112+
return False
113+
master_device = self._variable.master_device
112114
slice_device = self._mesh_impl.devices[0]
113115
return slice_device == master_device or not slice_device
114116

mesh_tensorflow/simd_mesh_impl.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(self, variable, mesh_impl):
101101
base_name = variable.name
102102
slices = []
103103
slices_with_master_dtype = []
104-
with tf.device(variable.master.device), utils.outside_all_rewrites():
104+
with tf.device(variable.master_device), utils.outside_all_rewrites():
105105
zero_tensor = tf.zeros(slice_shape)
106106

107107
# pylint: disable=protected-access
@@ -138,15 +138,14 @@ def __init__(self, variable, mesh_impl):
138138

139139
self._laid_out_tensor = mesh_impl.LaidOutTensor(
140140
[tpu_variables.ReplicatedVariable(base_name, slices)])
141-
with tf.device(variable.master.device), utils.outside_all_rewrites():
141+
with tf.device(variable.master_device), utils.outside_all_rewrites():
142142
self._copy_master_to_slices = self._generate_copy_master_to_slices_op(
143-
variable.master, shape, slices, slice_shape)
143+
variable.get_master(), shape, slices, slice_shape)
144144
slices_with_master_dtype = [
145145
tf.cast(s, variable.master_dtype) for s in slices]
146-
self._copy_slices_to_master = tf.assign(
147-
variable.master,
146+
self._copy_slices_to_master = variable.assign_to_master(
148147
mesh_impl.combine_slices(slices_with_master_dtype, shape,
149-
device=variable.master.device))
148+
device=variable.master_device))
150149

151150
def _generate_copy_master_to_slices_op(self, master_variable, master_shape,
152151
slices, slice_shape):

0 commit comments

Comments
 (0)