Skip to content

Commit dca5538

Browse files
authored
Merge pull request #466 from yahoo/leewyang_loggers
use module loggers
2 parents 286fd65 + e228ebd commit dca5538

File tree

8 files changed

+102
-89
lines changed

8 files changed

+102
-89
lines changed

tensorflowonspark/TFCluster.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from . import TFManager
3535
from . import TFSparkNode
3636

37+
logger = logging.getLogger(__name__)
38+
3739
# status of TF background job
3840
tf_status = {}
3941

@@ -73,7 +75,7 @@ def train(self, dataRDD, num_epochs=0, feed_timeout=600, qname='input'):
7375
:feed_timeout: number of seconds after which data feeding times out (600 sec default)
7476
:qname: *INTERNAL USE*.
7577
"""
76-
logging.info("Feeding training data")
78+
logger.info("Feeding training data")
7779
assert self.input_mode == InputMode.SPARK, "TFCluster.train() requires InputMode.SPARK"
7880
assert qname in self.queues, "Unknown queue: {}".format(qname)
7981
assert num_epochs >= 0, "num_epochs cannot be negative"
@@ -107,7 +109,7 @@ def inference(self, dataRDD, feed_timeout=600, qname='input'):
107109
Returns:
108110
A Spark RDD representing the output of the TensorFlow inferencing
109111
"""
110-
logging.info("Feeding inference data")
112+
logger.info("Feeding inference data")
111113
assert self.input_mode == InputMode.SPARK, "TFCluster.inference() requires InputMode.SPARK"
112114
assert qname in self.queues, "Unknown queue: {}".format(qname)
113115
return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, feed_timeout=feed_timeout, qname=qname))
@@ -123,7 +125,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
123125
:grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0.
124126
:timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. Use -1 to disable timeout.
125127
"""
126-
logging.info("Stopping TensorFlow nodes")
128+
logger.info("Waiting for TensorFlow nodes to complete...")
127129

128130
# identify ps/workers
129131
ps_list, worker_list, eval_list = [], [], []
@@ -133,7 +135,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
133135
# setup execution timeout
134136
if timeout > 0:
135137
def timeout_handler(signum, frame):
136-
logging.error("TensorFlow execution timed out, exiting Spark application with error status")
138+
logger.error("TensorFlow execution timed out, exiting Spark application with error status")
137139
self.sc.cancelAllJobs()
138140
self.sc.stop()
139141
sys.exit(1)
@@ -146,7 +148,7 @@ def timeout_handler(signum, frame):
146148
# Spark Streaming
147149
while not ssc.awaitTerminationOrTimeout(1):
148150
if self.server.done:
149-
logging.info("Server done, stopping StreamingContext")
151+
logger.info("Server done, stopping StreamingContext")
150152
ssc.stop(stopSparkContext=False, stopGraceFully=True)
151153
break
152154
elif self.input_mode == InputMode.TENSORFLOW:
@@ -175,12 +177,12 @@ def timeout_handler(signum, frame):
175177

176178
# exit Spark application w/ err status if TF job had any errors
177179
if 'error' in tf_status:
178-
logging.error("Exiting Spark application with error status.")
180+
logger.error("Exiting Spark application with error status.")
179181
self.sc.cancelAllJobs()
180182
self.sc.stop()
181183
sys.exit(1)
182184

183-
logging.info("Shutting down cluster")
185+
logger.info("Shutting down cluster")
184186
# shutdown queues and managers for "PS" executors.
185187
# note: we have to connect/shutdown from the spark driver, because these executors are "busy" and won't accept any other tasks.
186188
for node in ps_list + eval_list:
@@ -230,7 +232,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
230232
Returns:
231233
A TFCluster object representing the started cluster.
232234
"""
233-
logging.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))
235+
logger.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))
234236

235237
if driver_ps_nodes and input_mode != InputMode.TENSORFLOW:
236238
raise Exception('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW')
@@ -263,7 +265,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
263265
if num_workers > 0:
264266
cluster_template['worker'] = executors[:num_workers]
265267

266-
logging.info("cluster_template: {}".format(cluster_template))
268+
logger.info("cluster_template: {}".format(cluster_template))
267269

268270
# get default filesystem from spark
269271
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
@@ -279,7 +281,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
279281
server_addr = server.start()
280282

281283
# start TF nodes on all executors
282-
logging.info("Starting TensorFlow on executors")
284+
logger.info("Starting TensorFlow on executors")
283285
cluster_meta = {
284286
'id': random.getrandbits(64),
285287
'cluster_template': cluster_template,
@@ -295,7 +297,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
295297

296298
if driver_ps_nodes:
297299
def _start_ps(node_index):
298-
logging.info("starting ps node locally %d" % node_index)
300+
logger.info("starting ps node locally %d" % node_index)
299301
TFSparkNode.run(map_fun,
300302
tf_args,
301303
cluster_meta,
@@ -319,7 +321,7 @@ def _start(status):
319321
queues,
320322
background=(input_mode == InputMode.SPARK)))
321323
except Exception as e:
322-
logging.error("Exception in TF background thread")
324+
logger.error("Exception in TF background thread")
323325
status['error'] = str(e)
324326

325327
t = threading.Thread(target=_start, args=(tf_status,))
@@ -329,23 +331,23 @@ def _start(status):
329331
t.start()
330332

331333
# wait for executors to register and start TFNodes before continuing
332-
logging.info("Waiting for TFSparkNodes to start")
334+
logger.info("Waiting for TFSparkNodes to start")
333335
cluster_info = server.await_reservations(sc, tf_status, reservation_timeout)
334-
logging.info("All TFSparkNodes started")
336+
logger.info("All TFSparkNodes started")
335337

336338
# print cluster_info and extract TensorBoard URL
337339
tb_url = None
338340
for node in cluster_info:
339-
logging.info(node)
341+
logger.info(node)
340342
if node['tb_port'] != 0:
341343
tb_url = "http://{0}:{1}".format(node['host'], node['tb_port'])
342344

343345
if tb_url is not None:
344-
logging.info("========================================================================================")
345-
logging.info("")
346-
logging.info("TensorBoard running at: {0}".format(tb_url))
347-
logging.info("")
348-
logging.info("========================================================================================")
346+
logger.info("========================================================================================")
347+
logger.info("")
348+
logger.info("TensorBoard running at: {0}".format(tb_url))
349+
logger.info("")
350+
logger.info("========================================================================================")
349351

350352
# since our "primary key" for each executor's TFManager is (host, executor_id), sanity check for duplicates
351353
# Note: this may occur if Spark retries failed Python tasks on the same executor.

tensorflowonspark/TFNode.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from six.moves.queue import Empty
2020
from . import marker
2121

22+
logger = logging.getLogger(__name__)
2223

2324
def hdfs_path(ctx, path):
2425
"""Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths
@@ -54,7 +55,7 @@ def hdfs_path(ctx, path):
5455
elif ctx.defaultFS.startswith("file://"):
5556
return "{0}/{1}/{2}".format(ctx.defaultFS, ctx.working_dir[1:], path)
5657
else:
57-
logging.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS, path))
58+
logger.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS, path))
5859
return "{0}/{1}".format(ctx.defaultFS, path)
5960

6061

@@ -120,21 +121,21 @@ def next_batch(self, batch_size):
120121
Returns:
121122
A batch of items or a dictionary of tensors.
122123
"""
123-
logging.debug("next_batch() invoked")
124+
logger.debug("next_batch() invoked")
124125
queue = self.mgr.get_queue(self.qname_in)
125126
tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors}
126127
count = 0
127128
while count < batch_size:
128129
item = queue.get(block=True)
129130
if item is None:
130131
# End of Feed
131-
logging.info("next_batch() got None")
132+
logger.info("next_batch() got None")
132133
queue.task_done()
133134
self.done_feeding = True
134135
break
135136
elif type(item) is marker.EndPartition:
136137
# End of Partition
137-
logging.info("next_batch() got EndPartition")
138+
logger.info("next_batch() got EndPartition")
138139
queue.task_done()
139140
if not self.train_mode and count > 0:
140141
break
@@ -147,7 +148,7 @@ def next_batch(self, batch_size):
147148
tensors[self.input_tensors[i]].append(item[i])
148149
count += 1
149150
queue.task_done()
150-
logging.debug("next_batch() returning {0} items".format(count))
151+
logger.debug("next_batch() returning {0} items".format(count))
151152
return tensors
152153

153154
def should_stop(self):
@@ -163,11 +164,11 @@ def batch_results(self, results):
163164
Args:
164165
:results: array of output data for the equivalent batch of input data.
165166
"""
166-
logging.debug("batch_results() invoked")
167+
logger.debug("batch_results() invoked")
167168
queue = self.mgr.get_queue(self.qname_out)
168169
for item in results:
169170
queue.put(item, block=True)
170-
logging.debug("batch_results() returning data")
171+
logger.debug("batch_results() returning data")
171172

172173
def terminate(self):
173174
"""Terminate data feeding early.
@@ -177,7 +178,7 @@ def terminate(self):
177178
to terminate an RDD operation early, so the extra partitions will still be sent to the executors (but will be ignored). Because
178179
of this, you should size your input data accordingly to avoid excessive overhead.
179180
"""
180-
logging.info("terminate() invoked")
181+
logger.info("terminate() invoked")
181182
self.mgr.set('state', 'terminating')
182183

183184
# drop remaining items in the queue
@@ -190,5 +191,5 @@ def terminate(self):
190191
queue.task_done()
191192
count += 1
192193
except Empty:
193-
logging.info("dropped {0} items from queue".format(count))
194+
logger.info("dropped {0} items from queue".format(count))
194195
done = True

0 commit comments

Comments
 (0)