@@ -39,6 +39,7 @@ def __init__(self, shape, layout, devices, device_assignment):
3939 self ._device_assignment = device_assignment
4040 tf .logging .info ("SimdMeshImpl init: {0} {1}" .format (shape , layout ))
4141 self ._pnum_tensor = None
42+ self .graph_device_function_stacks = []
4243
4344 @property
4445 def pnum_tensor (self ):
@@ -103,16 +104,37 @@ def __init__(self, variable, mesh_impl):
103104 with tf .device (variable .master .device ), utils .outside_all_rewrites ():
104105 zero_tensor = tf .zeros (slice_shape )
105106
107+ # pylint: disable=protected-access
108+ init_device_stack = tf .get_default_graph ()._device_function_stack
109+
110+ if not mesh_impl .graph_device_function_stacks :
111+ for pnum in xrange (mesh_impl .size ):
112+ tpu_device = mesh_impl .device_assignment .tpu_device (replica = pnum )
113+ with ops .device (tpu_device ):
114+ mesh_impl .graph_device_function_stacks .append (
115+ tf .get_default_graph ()._device_function_stack .copy ())
116+
106117 for pnum in xrange (mesh_impl .size ):
107118 slice_var_name = base_name + "_slice_%d" % pnum
108- tpu_device = mesh_impl .device_assignment .tpu_device (replica = pnum )
109119 # Use tf.Variable instead of tf.get_variable since latter adds lots of
110120 # useless operations to the TF graph.
111- with ops .device (tpu_device ):
112- slices .append (tf .Variable (
113- initial_value = zero_tensor ,
114- trainable = True , collections = [], dtype = variable .slice_dtype ,
115- name = slice_var_name , expected_shape = slice_shape ))
121+ # Note: Repeatedly 'with tf.device():' slows down the graph
122+ # construction. Therefore we directly use the cached device_stack here.
123+ tf .get_default_graph (
124+ )._device_function_stack = mesh_impl .graph_device_function_stacks [pnum ]
125+
126+ slices .append (
127+ tf .Variable (
128+ initial_value = zero_tensor ,
129+ trainable = True ,
130+ collections = [],
131+ dtype = variable .slice_dtype ,
132+ name = slice_var_name ,
133+ expected_shape = slice_shape ))
134+
135+ # Restore the initial stack
136+ tf .get_default_graph ()._device_function_stack = init_device_stack
137+ # pylint: enable=protected-access
116138
117139 self ._laid_out_tensor = mesh_impl .LaidOutTensor (
118140 [tpu_variables .ReplicatedVariable (base_name , slices )])
0 commit comments