Skip to content

Commit 6c0d143

Browse files
committed
add grace period to TFCluster.shutdown; add ExportHook as a SessionRunHook example
1 parent 3856c05 commit 6c0d143

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

examples/mnist/spark/mnist_dist.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,40 @@
99
from __future__ import nested_scopes
1010
from __future__ import print_function
1111

12+
from datetime import datetime
13+
import tensorflow as tf
14+
from tensorflowonspark import TFNode
15+
1216

1317
def print_log(worker_num, arg):
1418
print("{0}: {1}".format(worker_num, arg))
1519

1620

21+
class ExportHook(tf.train.SessionRunHook):
22+
def __init__(self, export_dir, input_tensor, output_tensor):
23+
self.export_dir = export_dir
24+
self.input_tensor = input_tensor
25+
self.output_tensor = output_tensor
26+
27+
def end(self, session):
28+
print("{} ======= Exporting to: {}".format(datetime.now().isoformat(), self.export_dir))
29+
signatures = {
30+
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
31+
'inputs': {'image': self.input_tensor},
32+
'outputs': {'prediction': self.output_tensor},
33+
'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
34+
}
35+
}
36+
TFNode.export_saved_model(session,
37+
self.export_dir,
38+
tf.saved_model.tag_constants.SERVING,
39+
signatures)
40+
print("{} ======= Done exporting".format(datetime.now().isoformat()))
41+
42+
1743
def map_fun(args, ctx):
18-
from datetime import datetime
1944
import math
2045
import numpy
21-
import tensorflow as tf
2246
import time
2347

2448
worker_num = ctx.worker_num
@@ -105,7 +129,6 @@ def feed_dict(batch):
105129

106130
logdir = ctx.absolute_path(args.model)
107131
print("tensorflow model path: {0}".format(logdir))
108-
hooks = [tf.train.StopAtStepHook(last_step=100000)]
109132

110133
if job_name == "worker" and task_index == 0:
111134
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
@@ -115,11 +138,11 @@ def feed_dict(batch):
115138
with tf.train.MonitoredTrainingSession(master=server.target,
116139
is_chief=(task_index == 0),
117140
checkpoint_dir=logdir,
118-
hooks=hooks) as mon_sess:
119-
141+
hooks=[tf.train.StopAtStepHook(last_step=args.steps)],
142+
chief_only_hooks=[ExportHook(ctx.absolute_path(args.export_dir), x, prediction)]) as mon_sess:
120143
step = 0
121144
tf_feed = ctx.get_data_feed(args.mode == "train")
122-
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
145+
while not mon_sess.should_stop() and not tf_feed.should_stop():
123146
# Run a training step asynchronously
124147
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
125148
# perform *synchronous* training.

examples/mnist/spark/mnist_spark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parser = argparse.ArgumentParser()
2626
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
2727
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
28+
parser.add_argument("--export_dir", help="HDFS path to export saved_model", default="mnist_export")
2829
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="csv")
2930
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
3031
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
@@ -71,6 +72,7 @@ def toNumpy(bytestr):
7172
else:
7273
labelRDD = cluster.inference(dataRDD)
7374
labelRDD.saveAsTextFile(args.output)
74-
cluster.shutdown()
75+
76+
cluster.shutdown(grace_secs=30)
7577

7678
print("{0} ===== Stop".format(datetime.now().isoformat()))

tensorflowonspark/TFCluster.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,12 @@ def inference(self, dataRDD, qname='input'):
109109
assert(qname in self.queues)
110110
return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, qname))
111111

112-
def shutdown(self, ssc=None):
112+
def shutdown(self, ssc=None, grace_secs=0):
113113
"""Stops the distributed TensorFlow cluster.
114114
115115
Args:
116116
:ssc: *For Streaming applications only*. Spark StreamingContext
117+
:grace_secs: Grace period to wait before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model.
117118
"""
118119
logging.info("Stopping TensorFlow nodes")
119120

@@ -146,12 +147,13 @@ def shutdown(self, ssc=None):
146147
count += 1
147148
time.sleep(5)
148149

149-
# shutdown queues and managers for "worker" executors.
150-
# note: in SPARK mode, this job will immediately queue up behind the "data feeding" job.
151-
# in TENSORFLOW mode, this will only run after all workers have finished.
152-
workers = len(worker_list)
153-
workerRDD = self.sc.parallelize(range(workers), workers)
154-
workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues))
150+
# shutdown queues and managers for "worker" executors.
151+
# note: in SPARK mode, this job will immediately queue up behind the "data feeding" job.
152+
# in TENSORFLOW mode, this will only run after all workers have finished.
153+
workers = len(worker_list)
154+
workerRDD = self.sc.parallelize(range(workers), workers)
155+
workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues))
156+
time.sleep(grace_secs)
155157

156158
# exit Spark application w/ err status if TF job had any errors
157159
if 'error' in tf_status:

0 commit comments

Comments
 (0)