Skip to content

Commit a86d561

Browse files
committed
support tf.estimator.train_and_evaluate; use local file instead of ppid; surface GPU allocation errors; add reservation timeout
1 parent f685adb commit a86d561

File tree

9 files changed

+285
-143
lines changed

9 files changed

+285
-143
lines changed

tensorflowonspark/TFCluster.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
# status of TF background job
3737
tf_status = {}
3838

39+
3940
class InputMode(object):
4041
"""Enum for the input modes of data feeding."""
4142
TENSORFLOW = 0 #: TensorFlow application is responsible for reading any data.
4243
SPARK = 1 #: Spark is responsible for feeding data to the TensorFlow application via an RDD.
4344

45+
4446
class TFCluster(object):
4547

4648
sc = None #: SparkContext
@@ -197,8 +199,9 @@ def tensorboard_url(self):
197199
tb_url = "http://{0}:{1}".format(node['host'], node['tb_port'])
198200
return tb_url
199201

202+
200203
def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW,
201-
log_dir=None, driver_ps_nodes=False, reservation_timeout=600, queues=['input', 'output', 'error']):
204+
log_dir=None, driver_ps_nodes=False, master_node=None, reservation_timeout=600, queues=['input', 'output', 'error']):
202205
"""Starts the TensorFlowOnSpark cluster and Runs the TensorFlow "main" function on the Spark executors
203206
204207
Args:
@@ -211,6 +214,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
211214
:input_mode: TFCluster.InputMode
212215
:log_dir: directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
213216
:driver_ps_nodes: run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
217+
:master_node: name of the "master" or "chief" node in the cluster_template, used for `tf.estimator` applications.
214218
:reservation_timeout: number of seconds after which cluster reservation times out (600 sec default)
215219
:queues: *INTERNAL_USE*
216220
@@ -226,8 +230,13 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
226230
# build a cluster_spec template using worker_nums
227231
cluster_template = {}
228232
cluster_template['ps'] = range(num_ps)
229-
cluster_template['worker'] = range(num_ps, num_executors)
230-
logging.info("worker node range %s, ps node range %s" % (cluster_template['worker'], cluster_template['ps']))
233+
if master_node is None:
234+
cluster_template['worker'] = range(num_ps, num_executors)
235+
else:
236+
cluster_template[master_node] = range(num_ps, num_ps + 1)
237+
if num_executors > num_ps + 1:
238+
cluster_template['worker'] = range(num_ps + 1, num_executors)
239+
logging.info("cluster_template: {}".format(cluster_template))
231240

232241
# get default filesystem from spark
233242
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
@@ -311,13 +320,17 @@ def _start(status):
311320
logging.info("")
312321
logging.info("========================================================================================")
313322

314-
# since our "primary key" for each executor's TFManager is (host, ppid), sanity check for duplicates
323+
# since our "primary key" for each executor's TFManager is (host, executor_id), sanity check for duplicates
315324
# Note: this may occur if Spark retries failed Python tasks on the same executor.
316325
tb_nodes = set()
317326
for node in cluster_info:
318-
node_id = (node['host'],node['ppid'])
327+
node_id = (node['host'], node['executor_id'])
319328
if node_id in tb_nodes:
320-
raise Exception("Duplicate cluster node id detected (host={0}, ppid={1}). Please ensure that (1) the number of executors >= number of TensorFlow nodes, (2) the number of tasks per executors == 1, and (3) TFCluster.shutdown() is successfully invoked when done.".format(node_id[0], node_id[1]))
329+
raise Exception("Duplicate cluster node id detected (host={0}, executor_id={1})".format(node_id[0], node_id[1]) +
330+
"Please ensure that:\n" +
331+
"1. Number of executors >= number of TensorFlow nodes\n" +
332+
"2. Number of tasks per executors is 1\n" +
333+
"3, TFCluster.shutdown() is successfully invoked when done.")
321334
else:
322335
tb_nodes.add(node_id)
323336

tensorflowonspark/TFNode.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from six.moves.queue import Empty
2222
from . import marker
2323

24+
2425
def hdfs_path(ctx, path):
2526
"""Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths
2627
@@ -47,6 +48,7 @@ def hdfs_path(ctx, path):
4748
logging.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS, path))
4849
return "{0}/{1}".format(ctx.defaultFS, path)
4950

51+
5052
def start_cluster_server(ctx, num_gpus=1, rdma=False):
5153
"""Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster.
5254
@@ -71,7 +73,8 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
7173
if tf.test.is_built_with_cuda():
7274
# GPU
7375
gpu_initialized = False
74-
while not gpu_initialized:
76+
retries = 3
77+
while not gpu_initialized and retries > 0:
7578
try:
7679
# override PS jobs to only reserve one GPU
7780
if ctx.job_name == 'ps':
@@ -97,7 +100,10 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
97100
except Exception as e:
98101
print(e)
99102
logging.error("{0}: Failed to allocate GPU, trying again...".format(ctx.worker_num))
103+
retries -= 1
100104
time.sleep(10)
105+
if not gpu_initialized:
106+
raise Exception("Failed to allocate GPU")
101107
else:
102108
# CPU
103109
os.environ['CUDA_VISIBLE_DEVICES'] = ''
@@ -111,10 +117,12 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
111117

112118
return (cluster, server)
113119

120+
114121
def next_batch(mgr, batch_size, qname='input'):
115122
"""*DEPRECATED*. Use TFNode.DataFeed class instead."""
116123
raise Exception("DEPRECATED: Use TFNode.DataFeed class instead")
117124

125+
118126
def export_saved_model(sess, export_dir, tag_set, signatures):
119127
"""Convenience function to export a saved_model using provided arguments
120128
@@ -148,25 +156,29 @@ def export_saved_model(sess, export_dir, tag_set, signatures):
148156
signature_def_map = {}
149157
for key, sig in signatures.items():
150158
signature_def_map[key] = tf.saved_model.signature_def_utils.build_signature_def(
151-
inputs={ name:tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['inputs'].items() },
152-
outputs={ name:tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['outputs'].items() },
153-
method_name=sig['method_name'] if 'method_name' in sig else key)
159+
inputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['inputs'].items()},
160+
outputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['outputs'].items()},
161+
method_name=sig['method_name'] if 'method_name' in sig else key)
154162
logging.info("===== signature_def_map: {}".format(signature_def_map))
155-
builder.add_meta_graph_and_variables(sess,
156-
tag_set.split(','),
157-
signature_def_map=signature_def_map,
158-
clear_devices=True)
163+
builder.add_meta_graph_and_variables(
164+
sess,
165+
tag_set.split(','),
166+
signature_def_map=signature_def_map,
167+
clear_devices=True)
159168
g.finalize()
160169
builder.save()
161170

171+
162172
def batch_results(mgr, results, qname='output'):
163173
"""*DEPRECATED*. Use TFNode.DataFeed class instead."""
164174
raise Exception("DEPRECATED: Use TFNode.DataFeed class instead")
165175

176+
166177
def terminate(mgr, qname='input'):
167178
"""*DEPRECATED*. Use TFNode.DataFeed class instead."""
168179
raise Exception("DEPRECATED: Use TFNode.DataFeed class instead")
169180

181+
170182
class DataFeed(object):
171183
"""This class manages the *InputMode.SPARK* data feeding process from the perspective of the TensorFlow application.
172184
@@ -184,7 +196,7 @@ def __init__(self, mgr, train_mode=True, qname_in='input', qname_out='output', i
184196
self.qname_in = qname_in
185197
self.qname_out = qname_out
186198
self.done_feeding = False
187-
self.input_tensors = [ tensor for col, tensor in sorted(input_mapping.items()) ] if input_mapping is not None else None
199+
self.input_tensors = [tensor for col, tensor in sorted(input_mapping.items())] if input_mapping is not None else None
188200

189201
def next_batch(self, batch_size):
190202
"""Gets a batch of items from the input RDD.
@@ -206,7 +218,7 @@ def next_batch(self, batch_size):
206218
"""
207219
logging.debug("next_batch() invoked")
208220
queue = self.mgr.get_queue(self.qname_in)
209-
tensors = [] if self.input_tensors is None else { tensor:[] for tensor in self.input_tensors }
221+
tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors}
210222
count = 0
211223
while count < batch_size:
212224
item = queue.get(block=True)
@@ -276,4 +288,3 @@ def terminate(self):
276288
except Empty:
277289
logging.info("dropped {0} items from queue".format(count))
278290
done = True
279-

0 commit comments

Comments
 (0)