Skip to content

Commit 0aac8b2

Browse files
author
pikachu
authored
Update mnist_dist.py
tf.train.Supervisor is deprecated in tensorflow1.6, so I update it using tf.train.MonitoredTrainingSession.
1 parent 93bec3d commit 0aac8b2

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

examples/mnist/spark/mnist_dist.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def feed_dict(batch):
8282

8383
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
8484

85-
global_step = tf.Variable(0)
85+
global_step = tf.train.get_or_create_global_step()
8686

8787
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
8888
tf.summary.scalar("loss", loss)
@@ -98,27 +98,22 @@ def feed_dict(batch):
9898
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
9999
tf.summary.scalar("acc", accuracy)
100100

101-
saver = tf.train.Saver()
102101
summary_op = tf.summary.merge_all()
103-
init_op = tf.global_variables_initializer()
104102

105-
# Create a "supervisor", which oversees the training process and stores model state into HDFS
103+
# Create a "MonitoredTrainingSession", which oversees the training process and stores model state into HDFS
106104
logdir = ctx.absolute_path(args.model)
107105
print("tensorflow model path: {0}".format(logdir))
108-
106+
hooks = [tf.train.StopAtStepHook(last_step=100000)]
107+
109108
if job_name == "worker" and task_index == 0:
110109
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
111110

112111
if args.mode == "train":
113-
sv = tf.train.Supervisor(is_chief=(task_index == 0),
114-
logdir=logdir,
115-
init_op=init_op,
116-
summary_op=None,
117-
summary_writer=None,
118-
saver=saver,
119-
global_step=global_step,
120-
stop_grace_secs=300,
121-
save_model_secs=10)
112+
with tf.train.MonitoredTrainingSession(master=server.target,
113+
is_chief=(task_index == 0),
114+
checkpoint_dir=logdir,
115+
hooks=hooks,
116+
) as mon_sess:
122117
else:
123118
sv = tf.train.Supervisor(is_chief=(task_index == 0),
124119
logdir=logdir,
@@ -128,15 +123,13 @@ def feed_dict(batch):
128123
stop_grace_secs=300,
129124
save_model_secs=0)
130125

131-
# The supervisor takes care of session initialization, restoring from
132-
# a checkpoint, and closing when done or an error occurs.
133-
with sv.managed_session(server.target) as sess:
134-
print("{0} session ready".format(datetime.now().isoformat()))
126+
# The MonitoredTrainingSession takes care of session initialization, restoring from
127+
# a checkpoint, and closing when done or an error occurs
135128

136129
# Loop until the supervisor shuts down or 1000000 steps have completed.
137130
step = 0
138131
tf_feed = ctx.get_data_feed(args.mode == "train")
139-
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
132+
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
140133
# Run a training step asynchronously.
141134
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
142135
# perform *synchronous* training.
@@ -147,24 +140,24 @@ def feed_dict(batch):
147140

148141
if len(batch_xs) > 0:
149142
if args.mode == "train":
150-
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
143+
_, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed)
151144
# print accuracy and save model checkpoint to HDFS every 100 steps
152145
if (step % 100 == 0):
153-
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
146+
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
154147

155-
if sv.is_chief:
148+
if task_index == 0:
156149
summary_writer.add_summary(summary, step)
157150
else: # args.mode == "inference"
158-
labels, preds, acc = sess.run([label, prediction, accuracy], feed_dict=feed)
151+
labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed)
159152

160153
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
161154
tf_feed.batch_results(results)
162155
print("acc: {0}".format(acc))
163156

164-
if sv.should_stop() or step >= args.steps:
157+
if mon_sess.should_stop() or step >= args.steps:
165158
tf_feed.terminate()
166159

167160
# Ask for all the services to stop.
168-
print("{0} stopping supervisor".format(datetime.now().isoformat()))
169-
sv.stop()
161+
print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
162+
170163

0 commit comments

Comments
 (0)