Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 07417c9

Browse files
Youlong ChengCopybara-Service
authored andcommitted
reduce TPU variable creation time.
PiperOrigin-RevId: 222707060
1 parent e45bbd5 commit 07417c9

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

mesh_tensorflow/simd_mesh_impl.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, shape, layout, devices, device_assignment):
3939
self._device_assignment = device_assignment
4040
tf.logging.info("SimdMeshImpl init: {0} {1}".format(shape, layout))
4141
self._pnum_tensor = None
42+
self.graph_device_function_stacks = []
4243

4344
@property
4445
def pnum_tensor(self):
@@ -103,16 +104,37 @@ def __init__(self, variable, mesh_impl):
103104
with tf.device(variable.master.device), utils.outside_all_rewrites():
104105
zero_tensor = tf.zeros(slice_shape)
105106

107+
# pylint: disable=protected-access
108+
init_device_stack = tf.get_default_graph()._device_function_stack
109+
110+
if not mesh_impl.graph_device_function_stacks:
111+
for pnum in xrange(mesh_impl.size):
112+
tpu_device = mesh_impl.device_assignment.tpu_device(replica=pnum)
113+
with ops.device(tpu_device):
114+
mesh_impl.graph_device_function_stacks.append(
115+
tf.get_default_graph()._device_function_stack.copy())
116+
106117
for pnum in xrange(mesh_impl.size):
107118
slice_var_name = base_name + "_slice_%d" % pnum
108-
tpu_device = mesh_impl.device_assignment.tpu_device(replica=pnum)
109119
# Use tf.Variable instead of tf.get_variable since latter adds lots of
110120
# useless operations to the TF graph.
111-
with ops.device(tpu_device):
112-
slices.append(tf.Variable(
113-
initial_value=zero_tensor,
114-
trainable=True, collections=[], dtype=variable.slice_dtype,
115-
name=slice_var_name, expected_shape=slice_shape))
121+
# Note: Repeatedly 'with tf.device():' slows down the graph
122+
# construction. Therefore we directly use the cached device_stack here.
123+
tf.get_default_graph(
124+
)._device_function_stack = mesh_impl.graph_device_function_stacks[pnum]
125+
126+
slices.append(
127+
tf.Variable(
128+
initial_value=zero_tensor,
129+
trainable=True,
130+
collections=[],
131+
dtype=variable.slice_dtype,
132+
name=slice_var_name,
133+
expected_shape=slice_shape))
134+
135+
# Restore the initial stack
136+
tf.get_default_graph()._device_function_stack = init_device_stack
137+
# pylint: enable=protected-access
116138

117139
self._laid_out_tensor = mesh_impl.LaidOutTensor(
118140
[tpu_variables.ReplicatedVariable(base_name, slices)])

0 commit comments

Comments
 (0)