Skip to content

Commit 313f5cf

Browse files
committed
sync w/ internal
1 parent bc8bddd commit 313f5cf

File tree

10 files changed

+152
-51
lines changed

10 files changed

+152
-51
lines changed

examples/cifar10/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ Also, you will need to download the CIFAR-10 dataset per the [original example](
1818
# set environment variables (if not already done)
1919
export PYTHON_ROOT=~/Python
2020
export PYSPARK_PYTHON=${PYTHON_ROOT}/bin/python
21-
export SPARK_YARN_USER_ENV="PYSPARK_PYTHON=Python/bin/python"
2221
export PATH=${PYTHON_ROOT}/bin/:$PATH
2322
export QUEUE=gpu
2423
export CIFAR10_DATA=<HDFS path to your downloaded files>

examples/imagenet/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Also, you will need to [download the Imagenet dataset per the original example](
1919
# set environment variables (if not already done)
2020
export PYTHON_ROOT=~/Python
2121
export PYSPARK_PYTHON=${PYTHON_ROOT}/bin/python
22-
export SPARK_YARN_USER_ENV="PYSPARK_PYTHON=Python/bin/python"
2322
export PATH=${PYTHON_ROOT}/bin/:$PATH
2423
export QUEUE=gpu
2524
export IMAGENET_DATA=<HDFS path to your downloaded files>

examples/imagenet/inception/imagenet_distributed_train_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorflowonspark.pipeline import TFEstimator
1414
from datetime import datetime
1515

16-
import inception_export
16+
from inception import inception_export
1717

1818
import sys
1919
import tensorflow as tf

examples/imagenet/inception/inception_export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
tf.app.flags.DEFINE_string('subset', 'validation',
2828
"""Either 'validation' or 'train'.""")
2929

30-
def export(args):
30+
def export(_):
3131
FLAGS = tf.app.flags.FLAGS
3232

3333
"""Evaluate model on Dataset for a number of steps."""
@@ -99,7 +99,7 @@ def preprocess_image(image_buffer):
9999
print('Successfully loaded model from %s at step=%s.' %
100100
(ckpt.model_checkpoint_path, global_step))
101101

102-
print("Exporting saved_model to: {}".format(args.export_dir))
102+
print("Exporting saved_model to: {}".format(FLAGS.export_dir))
103103
# exported signatures defined in code
104104
signatures = {
105105
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
@@ -109,7 +109,7 @@ def preprocess_image(image_buffer):
109109
}
110110
}
111111
TFNode.export_saved_model(sess,
112-
args.export_dir,
112+
FLAGS.export_dir,
113113
tf.saved_model.tag_constants.SERVING,
114114
signatures)
115115
print("Exported saved_model")

examples/slim/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ And, you will need to [download an image dataset](https://github.com/tensorflow/
1717
# set environment variables (if not already done)
1818
export PYTHON_ROOT=~/Python
1919
export PYSPARK_PYTHON=${PYTHON_ROOT}/bin/python
20-
export SPARK_YARN_USER_ENV="PYSPARK_PYTHON=Python/bin/python"
2120
export PATH=${PYTHON_ROOT}/bin/:$PATH
2221
export QUEUE=gpu
2322
export DATASET_DIR=<HDFS path to your downloaded files>
@@ -63,7 +62,6 @@ And, you will need to [download an image dataset](https://github.com/tensorflow/
6362
--conf spark.dynamicAllocation.enabled=false \
6463
--conf spark.yarn.maxAppAttempts=1 \
6564
--conf spark.ui.view.acls=* \
66-
--conf spark.task.maxFailures=1 \
6765
--archives hdfs:///user/${USER}/Python.zip#Python \
6866
--conf spark.executorEnv.LD_LIBRARY_PATH="/usr/local/cuda-7.5/lib64:$JAVA_HOME/jre/lib/amd64/server" \
6967
--driver-library-path="/usr/local/cuda-7.5/lib64" \

tensorflowonspark/TFCluster.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@
2525
import logging
2626
import os
2727
import random
28+
import sys
2829
import threading
2930
import time
3031
from pyspark.streaming import DStream
3132
from . import reservation
3233
from . import TFManager
3334
from . import TFSparkNode
3435

36+
# status of TF background job
37+
tf_status = {}
38+
3539
class InputMode(object):
3640
"""Enum for the input modes of data feeding."""
3741
TENSORFLOW = 0 #: TensorFlow application is responsible for reading any data.
@@ -158,8 +162,15 @@ def shutdown(self, ssc=None):
158162
workerRDD = self.sc.parallelize(range(workers), workers)
159163
workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues))
160164

165+
# exit Spark application w/ err status if TF job had any errors
166+
if 'error' in tf_status:
167+
logging.error("Exiting Spark application with error status.")
168+
self.sc.cancelAllJobs()
169+
self.sc.stop()
170+
sys.exit(1)
171+
161172
logging.info("Shutting down cluster")
162-
# shutdown queues and manageres for "PS" executors.
173+
# shutdown queues and managers for "PS" executors.
163174
# note: we have to connect/shutdown from the spark driver, because these executors are "busy" and won't accept any other tasks.
164175
for node in ps_list:
165176
addr = node['addr']
@@ -187,7 +198,7 @@ def tensorboard_url(self):
187198
return tb_url
188199

189200
def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW,
190-
log_dir=None, driver_ps_nodes=False, queues=['input', 'output']):
201+
log_dir=None, driver_ps_nodes=False, reservation_timeout=600, queues=['input', 'output', 'error']):
191202
"""Starts the TensorFlowOnSpark cluster and Runs the TensorFlow "main" function on the Spark executors
192203
193204
Args:
@@ -200,6 +211,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
200211
:input_mode: TFCluster.InputMode
201212
:log_dir: directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
202213
:driver_ps_nodes: run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
214+
:reservation_timeout: number of seconds after which cluster reservation times out (600 sec default)
203215
:queues: *INTERNAL_USE*
204216
205217
Returns:
@@ -261,20 +273,28 @@ def _start_ps(node_index):
261273
ps_thread.start()
262274

263275
# start TF on a background thread (on Spark driver) to allow for feeding job
264-
def _start():
265-
nodeRDD.foreachPartition(TFSparkNode.run(map_fun,
266-
tf_args,
267-
cluster_meta,
268-
tensorboard,
269-
log_dir,
270-
queues,
271-
background=(input_mode == InputMode.SPARK)))
272-
t = threading.Thread(target=_start)
276+
def _start(status):
277+
try:
278+
nodeRDD.foreachPartition(TFSparkNode.run(map_fun,
279+
tf_args,
280+
cluster_meta,
281+
tensorboard,
282+
log_dir,
283+
queues,
284+
background=(input_mode == InputMode.SPARK)))
285+
except Exception as e:
286+
logging.error("Exception in TF background thread")
287+
status['error'] = str(e)
288+
289+
t = threading.Thread(target=_start, args=(tf_status,))
290+
# run as daemon thread so that in spark mode main thread can exit
291+
# if feeder spark stage fails and main thread can't do explicit shutdown
292+
t.daemon = True
273293
t.start()
274294

275295
# wait for executors to register and start TFNodes before continuing
276296
logging.info("Waiting for TFSparkNodes to start")
277-
cluster_info = server.await_reservations()
297+
cluster_info = server.await_reservations(sc, tf_status, reservation_timeout)
278298
logging.info("All TFSparkNodes started")
279299

280300
# print cluster_info and extract TensorBoard URL

tensorflowonspark/TFManager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def _get(key):
2626
def _set(key, value):
2727
kdict[key] = value
2828

29+
def _get_queue(qname):
30+
try:
31+
return qdict[qname]
32+
except KeyError:
33+
return None
34+
2935
def start(authkey, queues, mode='local'):
3036
"""Create a new multiprocess.Manager (or return existing one).
3137
@@ -42,7 +48,8 @@ def start(authkey, queues, mode='local'):
4248
kdict.clear()
4349
for q in queues:
4450
qdict[q] = JoinableQueue()
45-
TFManager.register('get_queue', callable=lambda qname: qdict[qname])
51+
52+
TFManager.register('get_queue', callable=lambda qname: _get_queue(qname))
4653
TFManager.register('get', callable=lambda key: _get(key))
4754
TFManager.register('set', callable=lambda key, value: _set(key, value))
4855
if mode == 'remote':

tensorflowonspark/TFSparkNode.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99
from __future__ import print_function
1010

1111
import logging
12+
import multiprocessing
1213
import os
13-
import sys
1414
import platform
1515
import socket
1616
import subprocess
17-
import multiprocessing
17+
import sys
1818
import uuid
19+
import time
20+
import traceback
21+
from threading import Thread
1922

2023
from . import TFManager
2124
from . import TFNode
22-
from . import reservation
2325
from . import marker
26+
from . import reservation
2427
from . import util
2528

2629
class TFNodeContext:
@@ -97,6 +100,14 @@ def _get_manager(cluster_info, host, ppid):
97100
authkey = node['authkey']
98101
TFSparkNode.mgr = TFManager.connect(addr,authkey)
99102
break
103+
104+
if TFSparkNode.mgr is None:
105+
msg = "No TFManager found on this node, please ensure that:\n" + \
106+
"1. Spark num_executors matches TensorFlow cluster_size\n" + \
107+
"2. Spark cores/tasks per executor is 1.\n" + \
108+
"3. Spark dynamic allocation is disabled."
109+
raise Exception(msg)
110+
100111
logging.info("Connected to TFSparkNode.mgr on {0}, ppid={1}, state={2}".format(host, ppid, str(TFSparkNode.mgr.get('state'))))
101112
return TFSparkNode.mgr
102113

@@ -152,7 +163,7 @@ def _mapfn(iter):
152163
addr = None
153164
if job_name == 'ps':
154165
# PS nodes must be remotely accessible in order to shutdown from Spark driver.
155-
TFSparkNode.mgr = TFManager.start(authkey, ['control'], 'remote')
166+
TFSparkNode.mgr = TFManager.start(authkey, ['control', 'error'], 'remote')
156167
addr = (host, TFSparkNode.mgr.address[1])
157168
else:
158169
# worker nodes only need to be locally accessible within the executor for data feeding
@@ -238,7 +249,11 @@ def _mapfn(iter):
238249
# construct a TensorFlow clusterspec from cluster_info
239250
sorted_cluster_info = sorted(cluster_info, key=lambda k: k['worker_num'])
240251
spec = {}
252+
last_worker_num = -1
241253
for node in sorted_cluster_info:
254+
if (node['worker_num'] == last_worker_num):
255+
raise Exception("Duplicate worker/task in cluster_info")
256+
last_worker_num = node['worker_num']
242257
logging.info("node: {0}".format(node))
243258
(njob, nhost, nport) = (node['job_name'], node['host'], node['port'])
244259
hosts = [] if njob not in spec else spec[njob]
@@ -268,20 +283,37 @@ def wrapper_fn(args, context):
268283
sys.argv = args
269284
fn(args, context)
270285

286+
def wrapper_fn_background(args, context):
287+
"""Wrapper function that signals exceptions to foreground process."""
288+
errq = TFSparkNode.mgr.get_queue('error')
289+
try:
290+
wrapper_fn(args, context)
291+
except Exception:
292+
errq.put(traceback.format_exc())
293+
errq.join()
294+
271295
if job_name == 'ps' or background:
272296
# invoke the TensorFlow main function in a background thread
273297
logging.info("Starting TensorFlow {0}:{1} as {2} on cluster node {3} on background process".format(
274298
job_name, task_index, job_name, worker_num))
275-
p = multiprocessing.Process(target=wrapper_fn, args=(tf_args, ctx))
299+
300+
p = multiprocessing.Process(target=wrapper_fn_background, args=(tf_args, ctx))
276301
if job_name == 'ps':
277302
p.daemon = True
278303
p.start()
279304

280305
# for ps nodes only, wait indefinitely in foreground thread for a "control" event (None == "stop")
281306
if job_name == 'ps':
282307
queue = TFSparkNode.mgr.get_queue('control')
308+
equeue = TFSparkNode.mgr.get_queue('error')
283309
done = False
284310
while not done:
311+
while (queue.empty() and equeue.empty()):
312+
time.sleep(1)
313+
if (not equeue.empty()):
314+
e_str = equeue.get()
315+
equeue.task_done()
316+
raise Exception("exception in ps:\n" + e_str)
285317
msg = queue.get(block=True)
286318
logging.info("Got msg: {0}".format(msg))
287319
if msg is None:
@@ -311,7 +343,13 @@ def train(cluster_info, cluster_meta, qname='input'):
311343
def _train(iter):
312344
# get shared queue, reconnecting if necessary
313345
mgr = _get_manager(cluster_info, util.get_ip_address(), os.getppid())
314-
queue = mgr.get_queue(qname)
346+
try:
347+
queue = mgr.get_queue(qname)
348+
equeue = mgr.get_queue('error')
349+
except (AttributeError, KeyError):
350+
msg = "Queue '{}' not found on this node, check for exceptions on other nodes.".format(qname)
351+
raise Exception(msg)
352+
315353
state = str(mgr.get('state'))
316354
logging.info("mgr.state={0}".format(state))
317355
terminating = state == "'terminating'"
@@ -321,15 +359,23 @@ def _train(iter):
321359
for item in iter:
322360
count += 1
323361
logging.info("Skipped {0} items from partition".format(count))
324-
325362
else:
326363
logging.info("Feeding partition {0} into {1} queue {2}".format(iter, qname, queue))
327364
count = 0
328365
for item in iter:
329366
count += 1
330367
queue.put(item, block=True)
368+
331369
# wait for consumers to finish processing all items in queue before "finishing" this iterator
332-
queue.join()
370+
joinThr = Thread(target=queue.join)
371+
joinThr.start()
372+
while (joinThr.isAlive()):
373+
if (not equeue.empty()):
374+
e_str = equeue.get()
375+
equeue.task_done()
376+
raise Exception("exception in worker:\n" + e_str)
377+
time.sleep(1)
378+
# queue.join()
333379
logging.info("Processed {0} items in partition".format(count))
334380

335381
# check if TF is terminating feed after this partition
@@ -361,7 +407,12 @@ def inference(cluster_info, qname='input'):
361407
def _inference(iter):
362408
# get shared queue, reconnecting if necessary
363409
mgr = _get_manager(cluster_info, util.get_ip_address(), os.getppid())
364-
queue_in = mgr.get_queue(qname)
410+
try:
411+
queue_in = mgr.get_queue(qname)
412+
equeue = mgr.get_queue('error')
413+
except (AttributeError, KeyError):
414+
msg = "Queue '{}' not found on this node, check for exceptions on other nodes.".format(qname)
415+
raise Exception(msg)
365416

366417
logging.info("Feeding partition {0} into {1} queue {2}".format(iter, qname, queue_in))
367418
count = 0
@@ -377,7 +428,15 @@ def _inference(iter):
377428
return []
378429

379430
# wait for consumers to finish processing all items in queue before "finishing" this iterator
380-
queue_in.join()
431+
joinThr = Thread(target=queue_in.join)
432+
joinThr.start()
433+
while (joinThr.isAlive()):
434+
if (not equeue.empty()):
435+
e_str = equeue.get()
436+
equeue.task_done()
437+
raise Exception("exception in worker:\n" + e_str)
438+
time.sleep(1)
439+
381440
logging.info("Processed {0} items in partition".format(count))
382441

383442
# read result queue
@@ -422,9 +481,13 @@ def _shutdown(iter):
422481
# terminate any listening queues
423482
logging.info("Stopping all queues")
424483
for q in queues:
425-
queue = mgr.get_queue(q)
426-
logging.info("Feeding None into {0} queue".format(q))
427-
queue.put(None, block=True)
484+
try:
485+
queue = mgr.get_queue(q)
486+
logging.info("Feeding None into {0} queue".format(q))
487+
queue.put(None, block=True)
488+
except (AttributeError, KeyError):
489+
msg = "Queue '{}' not found on this node, check for exceptions on other nodes.".format(q)
490+
raise Exception(msg)
428491

429492
logging.info("Setting mgr.state to 'stopped'")
430493
mgr.set('state', 'stopped')

0 commit comments

Comments
 (0)