3434from . import TFManager
3535from . import TFSparkNode
3636
37+ logger = logging .getLogger (__name__ )
38+
3739# status of TF background job
3840tf_status = {}
3941
@@ -73,7 +75,7 @@ def train(self, dataRDD, num_epochs=0, feed_timeout=600, qname='input'):
7375 :feed_timeout: number of seconds after which data feeding times out (600 sec default)
7476 :qname: *INTERNAL USE*.
7577 """
76- logging .info ("Feeding training data" )
78+ logger .info ("Feeding training data" )
7779 assert self .input_mode == InputMode .SPARK , "TFCluster.train() requires InputMode.SPARK"
7880 assert qname in self .queues , "Unknown queue: {}" .format (qname )
7981 assert num_epochs >= 0 , "num_epochs cannot be negative"
@@ -107,7 +109,7 @@ def inference(self, dataRDD, feed_timeout=600, qname='input'):
107109 Returns:
108110 A Spark RDD representing the output of the TensorFlow inferencing
109111 """
110- logging .info ("Feeding inference data" )
112+ logger .info ("Feeding inference data" )
111113 assert self .input_mode == InputMode .SPARK , "TFCluster.inference() requires InputMode.SPARK"
112114 assert qname in self .queues , "Unknown queue: {}" .format (qname )
113115 return dataRDD .mapPartitions (TFSparkNode .inference (self .cluster_info , feed_timeout = feed_timeout , qname = qname ))
@@ -123,7 +125,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
123125 :grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0.
124126 :timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. Use -1 to disable timeout.
125127 """
126- logging .info ("Stopping TensorFlow nodes" )
128+ logger .info ("Waiting for TensorFlow nodes to complete... " )
127129
128130 # identify ps/workers
129131 ps_list , worker_list , eval_list = [], [], []
@@ -133,7 +135,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
133135 # setup execution timeout
134136 if timeout > 0 :
135137 def timeout_handler (signum , frame ):
136- logging .error ("TensorFlow execution timed out, exiting Spark application with error status" )
138+ logger .error ("TensorFlow execution timed out, exiting Spark application with error status" )
137139 self .sc .cancelAllJobs ()
138140 self .sc .stop ()
139141 sys .exit (1 )
@@ -146,7 +148,7 @@ def timeout_handler(signum, frame):
146148 # Spark Streaming
147149 while not ssc .awaitTerminationOrTimeout (1 ):
148150 if self .server .done :
149- logging .info ("Server done, stopping StreamingContext" )
151+ logger .info ("Server done, stopping StreamingContext" )
150152 ssc .stop (stopSparkContext = False , stopGraceFully = True )
151153 break
152154 elif self .input_mode == InputMode .TENSORFLOW :
@@ -175,12 +177,12 @@ def timeout_handler(signum, frame):
175177
176178 # exit Spark application w/ err status if TF job had any errors
177179 if 'error' in tf_status :
178- logging .error ("Exiting Spark application with error status." )
180+ logger .error ("Exiting Spark application with error status." )
179181 self .sc .cancelAllJobs ()
180182 self .sc .stop ()
181183 sys .exit (1 )
182184
183- logging .info ("Shutting down cluster" )
185+ logger .info ("Shutting down cluster" )
184186 # shutdown queues and managers for "PS" executors.
185187 # note: we have to connect/shutdown from the spark driver, because these executors are "busy" and won't accept any other tasks.
186188 for node in ps_list + eval_list :
@@ -230,7 +232,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
230232 Returns:
231233 A TFCluster object representing the started cluster.
232234 """
233- logging .info ("Reserving TFSparkNodes {0}" .format ("w/ TensorBoard" if tensorboard else "" ))
235+ logger .info ("Reserving TFSparkNodes {0}" .format ("w/ TensorBoard" if tensorboard else "" ))
234236
235237 if driver_ps_nodes and input_mode != InputMode .TENSORFLOW :
236238 raise Exception ('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW' )
@@ -263,7 +265,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
263265 if num_workers > 0 :
264266 cluster_template ['worker' ] = executors [:num_workers ]
265267
266- logging .info ("cluster_template: {}" .format (cluster_template ))
268+ logger .info ("cluster_template: {}" .format (cluster_template ))
267269
268270 # get default filesystem from spark
269271 defaultFS = sc ._jsc .hadoopConfiguration ().get ("fs.defaultFS" )
@@ -279,7 +281,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
279281 server_addr = server .start ()
280282
281283 # start TF nodes on all executors
282- logging .info ("Starting TensorFlow on executors" )
284+ logger .info ("Starting TensorFlow on executors" )
283285 cluster_meta = {
284286 'id' : random .getrandbits (64 ),
285287 'cluster_template' : cluster_template ,
@@ -295,7 +297,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
295297
296298 if driver_ps_nodes :
297299 def _start_ps (node_index ):
298- logging .info ("starting ps node locally %d" % node_index )
300+ logger .info ("starting ps node locally %d" % node_index )
299301 TFSparkNode .run (map_fun ,
300302 tf_args ,
301303 cluster_meta ,
@@ -319,7 +321,7 @@ def _start(status):
319321 queues ,
320322 background = (input_mode == InputMode .SPARK )))
321323 except Exception as e :
322- logging .error ("Exception in TF background thread" )
324+ logger .error ("Exception in TF background thread" )
323325 status ['error' ] = str (e )
324326
325327 t = threading .Thread (target = _start , args = (tf_status ,))
@@ -329,23 +331,23 @@ def _start(status):
329331 t .start ()
330332
331333 # wait for executors to register and start TFNodes before continuing
332- logging .info ("Waiting for TFSparkNodes to start" )
334+ logger .info ("Waiting for TFSparkNodes to start" )
333335 cluster_info = server .await_reservations (sc , tf_status , reservation_timeout )
334- logging .info ("All TFSparkNodes started" )
336+ logger .info ("All TFSparkNodes started" )
335337
336338 # print cluster_info and extract TensorBoard URL
337339 tb_url = None
338340 for node in cluster_info :
339- logging .info (node )
341+ logger .info (node )
340342 if node ['tb_port' ] != 0 :
341343 tb_url = "http://{0}:{1}" .format (node ['host' ], node ['tb_port' ])
342344
343345 if tb_url is not None :
344- logging .info ("========================================================================================" )
345- logging .info ("" )
346- logging .info ("TensorBoard running at: {0}" .format (tb_url ))
347- logging .info ("" )
348- logging .info ("========================================================================================" )
346+ logger .info ("========================================================================================" )
347+ logger .info ("" )
348+ logger .info ("TensorBoard running at: {0}" .format (tb_url ))
349+ logger .info ("" )
350+ logger .info ("========================================================================================" )
349351
350352 # since our "primary key" for each executor's TFManager is (host, executor_id), sanity check for duplicates
351353 # Note: this may occur if Spark retries failed Python tasks on the same executor.
0 commit comments