Skip to content

Commit 51b7912

Browse files
authored
Merge pull request #400 from yahoo/leewyang_terminate_feed
add back code to terminate feed
2 parents df79f63 + a4af831 commit 51b7912

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/mnist/spark/mnist_dist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def rdd_generator():
117117

118118
if args.mode == "train":
119119
_, summary, step = sess.run([train_op, summary_op, global_step])
120-
if (step % 100 == 0):
120+
if (step % 100 == 0) and (not sess.should_stop()):
121121
print("{} step: {} accuracy: {}".format(datetime.now().isoformat(), step, sess.run(accuracy)))
122122
if task_index == 0:
123123
summary_writer.add_summary(summary, step)
@@ -129,6 +129,9 @@ def rdd_generator():
129129

130130
print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
131131

132+
if sess.should_stop() or step >= args.steps:
133+
tf_feed.terminate()
134+
132135
# WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745
133136
# wait for all other nodes to complete (via done files)
134137
done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)

0 commit comments

Comments
 (0)