99from __future__ import print_function
1010
1111import logging
12+ import multiprocessing
1213import os
13- import sys
1414import platform
1515import socket
1616import subprocess
17- import multiprocessing
17+ import sys
1818import uuid
19+ import time
20+ import traceback
21+ from threading import Thread
1922
2023from . import TFManager
2124from . import TFNode
22- from . import reservation
2325from . import marker
26+ from . import reservation
2427from . import util
2528
2629class TFNodeContext :
@@ -97,6 +100,14 @@ def _get_manager(cluster_info, host, ppid):
97100 authkey = node ['authkey' ]
98101 TFSparkNode .mgr = TFManager .connect (addr ,authkey )
99102 break
103+
104+ if TFSparkNode .mgr is None :
105+ msg = "No TFManager found on this node, please ensure that:\n " + \
106+ "1. Spark num_executors matches TensorFlow cluster_size\n " + \
107+ "2. Spark cores/tasks per executor is 1.\n " + \
108+ "3. Spark dynamic allocation is disabled."
109+ raise Exception (msg )
110+
100111 logging .info ("Connected to TFSparkNode.mgr on {0}, ppid={1}, state={2}" .format (host , ppid , str (TFSparkNode .mgr .get ('state' ))))
101112 return TFSparkNode .mgr
102113
@@ -152,7 +163,7 @@ def _mapfn(iter):
152163 addr = None
153164 if job_name == 'ps' :
154165 # PS nodes must be remotely accessible in order to shutdown from Spark driver.
155- TFSparkNode .mgr = TFManager .start (authkey , ['control' ], 'remote' )
166+ TFSparkNode .mgr = TFManager .start (authkey , ['control' , 'error' ], 'remote' )
156167 addr = (host , TFSparkNode .mgr .address [1 ])
157168 else :
158169 # worker nodes only need to be locally accessible within the executor for data feeding
@@ -238,7 +249,11 @@ def _mapfn(iter):
238249 # construct a TensorFlow clusterspec from cluster_info
239250 sorted_cluster_info = sorted (cluster_info , key = lambda k : k ['worker_num' ])
240251 spec = {}
252+ last_worker_num = - 1
241253 for node in sorted_cluster_info :
254+ if (node ['worker_num' ] == last_worker_num ):
255+ raise Exception ("Duplicate worker/task in cluster_info" )
256+ last_worker_num = node ['worker_num' ]
242257 logging .info ("node: {0}" .format (node ))
243258 (njob , nhost , nport ) = (node ['job_name' ], node ['host' ], node ['port' ])
244259 hosts = [] if njob not in spec else spec [njob ]
@@ -268,20 +283,37 @@ def wrapper_fn(args, context):
268283 sys .argv = args
269284 fn (args , context )
270285
286+ def wrapper_fn_background (args , context ):
287+ """Wrapper function that signals exceptions to foreground process."""
288+ errq = TFSparkNode .mgr .get_queue ('error' )
289+ try :
290+ wrapper_fn (args , context )
291+ except Exception :
292+ errq .put (traceback .format_exc ())
293+ errq .join ()
294+
271295 if job_name == 'ps' or background :
272296 # invoke the TensorFlow main function in a background thread
273297 logging .info ("Starting TensorFlow {0}:{1} as {2} on cluster node {3} on background process" .format (
274298 job_name , task_index , job_name , worker_num ))
275- p = multiprocessing .Process (target = wrapper_fn , args = (tf_args , ctx ))
299+
300+ p = multiprocessing .Process (target = wrapper_fn_background , args = (tf_args , ctx ))
276301 if job_name == 'ps' :
277302 p .daemon = True
278303 p .start ()
279304
280305 # for ps nodes only, wait indefinitely in foreground thread for a "control" event (None == "stop")
281306 if job_name == 'ps' :
282307 queue = TFSparkNode .mgr .get_queue ('control' )
308+ equeue = TFSparkNode .mgr .get_queue ('error' )
283309 done = False
284310 while not done :
311+ while (queue .empty () and equeue .empty ()):
312+ time .sleep (1 )
313+ if (not equeue .empty ()):
314+ e_str = equeue .get ()
315+ equeue .task_done ()
316+ raise Exception ("exception in ps:\n " + e_str )
285317 msg = queue .get (block = True )
286318 logging .info ("Got msg: {0}" .format (msg ))
287319 if msg is None :
@@ -311,7 +343,13 @@ def train(cluster_info, cluster_meta, qname='input'):
311343 def _train (iter ):
312344 # get shared queue, reconnecting if necessary
313345 mgr = _get_manager (cluster_info , util .get_ip_address (), os .getppid ())
314- queue = mgr .get_queue (qname )
346+ try :
347+ queue = mgr .get_queue (qname )
348+ equeue = mgr .get_queue ('error' )
349+ except (AttributeError , KeyError ):
350+ msg = "Queue '{}' not found on this node, check for exceptions on other nodes." .format (qname )
351+ raise Exception (msg )
352+
315353 state = str (mgr .get ('state' ))
316354 logging .info ("mgr.state={0}" .format (state ))
317355 terminating = state == "'terminating'"
@@ -321,15 +359,23 @@ def _train(iter):
321359 for item in iter :
322360 count += 1
323361 logging .info ("Skipped {0} items from partition" .format (count ))
324-
325362 else :
326363 logging .info ("Feeding partition {0} into {1} queue {2}" .format (iter , qname , queue ))
327364 count = 0
328365 for item in iter :
329366 count += 1
330367 queue .put (item , block = True )
368+
331369 # wait for consumers to finish processing all items in queue before "finishing" this iterator
332- queue .join ()
370+ joinThr = Thread (target = queue .join )
371+ joinThr .start ()
372+ while (joinThr .isAlive ()):
373+ if (not equeue .empty ()):
374+ e_str = equeue .get ()
375+ equeue .task_done ()
376+ raise Exception ("exception in worker:\n " + e_str )
377+ time .sleep (1 )
378+ # queue.join()
333379 logging .info ("Processed {0} items in partition" .format (count ))
334380
335381 # check if TF is terminating feed after this partition
@@ -361,7 +407,12 @@ def inference(cluster_info, qname='input'):
361407 def _inference (iter ):
362408 # get shared queue, reconnecting if necessary
363409 mgr = _get_manager (cluster_info , util .get_ip_address (), os .getppid ())
364- queue_in = mgr .get_queue (qname )
410+ try :
411+ queue_in = mgr .get_queue (qname )
412+ equeue = mgr .get_queue ('error' )
413+ except (AttributeError , KeyError ):
414+ msg = "Queue '{}' not found on this node, check for exceptions on other nodes." .format (qname )
415+ raise Exception (msg )
365416
366417 logging .info ("Feeding partition {0} into {1} queue {2}" .format (iter , qname , queue_in ))
367418 count = 0
@@ -377,7 +428,15 @@ def _inference(iter):
377428 return []
378429
379430 # wait for consumers to finish processing all items in queue before "finishing" this iterator
380- queue_in .join ()
431+ joinThr = Thread (target = queue_in .join )
432+ joinThr .start ()
433+ while (joinThr .isAlive ()):
434+ if (not equeue .empty ()):
435+ e_str = equeue .get ()
436+ equeue .task_done ()
437+ raise Exception ("exception in worker:\n " + e_str )
438+ time .sleep (1 )
439+
381440 logging .info ("Processed {0} items in partition" .format (count ))
382441
383442 # read result queue
@@ -422,9 +481,13 @@ def _shutdown(iter):
422481 # terminate any listening queues
423482 logging .info ("Stopping all queues" )
424483 for q in queues :
425- queue = mgr .get_queue (q )
426- logging .info ("Feeding None into {0} queue" .format (q ))
427- queue .put (None , block = True )
484+ try :
485+ queue = mgr .get_queue (q )
486+ logging .info ("Feeding None into {0} queue" .format (q ))
487+ queue .put (None , block = True )
488+ except (AttributeError , KeyError ):
489+ msg = "Queue '{}' not found on this node, check for exceptions on other nodes." .format (q )
490+ raise Exception (msg )
428491
429492 logging .info ("Setting mgr.state to 'stopped'" )
430493 mgr .set ('state' , 'stopped' )
0 commit comments