Skip to content

Commit db0a726

Browse files
authored
Merge pull request #247 from wuyifan18/patch-1
Update mnist_dist.py using tensorflow1.6
2 parents 93bec3d + 929adca commit db0a726

File tree

1 file changed

+55
-74
lines changed

1 file changed

+55
-74
lines changed

examples/mnist/spark/mnist_dist.py

Lines changed: 55 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Yahoo Inc.
1+
#Copyright 2018 Yahoo Inc.
22
# Licensed under the terms of the Apache 2.0 license.
33
# Please see LICENSE file in the project root for terms.
44

@@ -58,37 +58,38 @@ def feed_dict(batch):
5858
worker_device="/job:worker/task:%d" % task_index,
5959
cluster=cluster)):
6060

61-
# Variables of the hidden layer
62-
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
63-
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
64-
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
65-
tf.summary.histogram("hidden_weights", hid_w)
66-
67-
# Variables of the softmax layer
68-
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
69-
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
70-
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
71-
tf.summary.histogram("softmax_weights", sm_w)
72-
73-
# Placeholders or QueueRunner/Readers for input data
74-
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
75-
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
76-
77-
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
78-
tf.summary.image("x_img", x_img)
79-
80-
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
81-
hid = tf.nn.relu(hid_lin)
82-
83-
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
84-
85-
global_step = tf.Variable(0)
86-
87-
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
88-
tf.summary.scalar("loss", loss)
89-
90-
train_op = tf.train.AdagradOptimizer(0.01).minimize(
91-
loss, global_step=global_step)
61+
# Placeholders or QueueRunner/Readers for input data
62+
with tf.name_scope('inputs'):
63+
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS] , name="x")
64+
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
65+
66+
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
67+
tf.summary.image("x_img", x_img)
68+
69+
with tf.name_scope('layer'):
70+
# Variables of the hidden layer
71+
with tf.name_scope('hidden_layer'):
72+
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
73+
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
74+
tf.summary.histogram("hidden_weights", hid_w)
75+
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
76+
hid = tf.nn.relu(hid_lin)
77+
78+
# Variables of the softmax layer
79+
with tf.name_scope('softmax_layer'):
80+
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
81+
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
82+
tf.summary.histogram("softmax_weights", sm_w)
83+
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
84+
85+
global_step = tf.train.get_or_create_global_step()
86+
87+
with tf.name_scope('loss'):
88+
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
89+
tf.summary.scalar("loss", loss)
90+
91+
with tf.name_scope('train'):
92+
train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step)
9293

9394
# Test trained model
9495
label = tf.argmax(y_, 1, name="label")
@@ -98,73 +99,53 @@ def feed_dict(batch):
9899
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
99100
tf.summary.scalar("acc", accuracy)
100101

101-
saver = tf.train.Saver()
102102
summary_op = tf.summary.merge_all()
103-
init_op = tf.global_variables_initializer()
104103

105-
# Create a "supervisor", which oversees the training process and stores model state into HDFS
106104
logdir = ctx.absolute_path(args.model)
105+
# logdir = args.model
107106
print("tensorflow model path: {0}".format(logdir))
108-
107+
hooks = [tf.train.StopAtStepHook(last_step=100000)]
108+
109109
if job_name == "worker" and task_index == 0:
110110
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
111111

112-
if args.mode == "train":
113-
sv = tf.train.Supervisor(is_chief=(task_index == 0),
114-
logdir=logdir,
115-
init_op=init_op,
116-
summary_op=None,
117-
summary_writer=None,
118-
saver=saver,
119-
global_step=global_step,
120-
stop_grace_secs=300,
121-
save_model_secs=10)
122-
else:
123-
sv = tf.train.Supervisor(is_chief=(task_index == 0),
124-
logdir=logdir,
125-
summary_op=None,
126-
saver=saver,
127-
global_step=global_step,
128-
stop_grace_secs=300,
129-
save_model_secs=0)
130-
131-
# The supervisor takes care of session initialization, restoring from
132-
# a checkpoint, and closing when done or an error occurs.
133-
with sv.managed_session(server.target) as sess:
134-
print("{0} session ready".format(datetime.now().isoformat()))
135-
136-
# Loop until the supervisor shuts down or 1000000 steps have completed.
112+
# The MonitoredTrainingSession takes care of session initialization, restoring from
113+
# a checkpoint, and closing when done or an error occurs
114+
with tf.train.MonitoredTrainingSession(master=server.target,
115+
is_chief=(task_index == 0),
116+
checkpoint_dir=logdir,
117+
hooks=hooks) as mon_sess:
118+
137119
step = 0
138120
tf_feed = ctx.get_data_feed(args.mode == "train")
139-
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
140-
# Run a training step asynchronously.
141-
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
142-
# perform *synchronous* training.
121+
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
122+
# Run a training step asynchronously
123+
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
124+
# perform *synchronous* training.
143125

144126
# using feed_dict
145127
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
146128
feed = {x: batch_xs, y_: batch_ys}
147129

148130
if len(batch_xs) > 0:
149131
if args.mode == "train":
150-
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
132+
_, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed)
151133
# print accuracy and save model checkpoint to HDFS every 100 steps
152134
if (step % 100 == 0):
153-
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
135+
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
154136

155-
if sv.is_chief:
137+
if task_index == 0:
156138
summary_writer.add_summary(summary, step)
157139
else: # args.mode == "inference"
158-
labels, preds, acc = sess.run([label, prediction, accuracy], feed_dict=feed)
140+
labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed)
159141

160142
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
161143
tf_feed.batch_results(results)
162-
print("acc: {0}".format(acc))
144+
print("results: {0}, acc: {1}".format(results, acc))
163145

164-
if sv.should_stop() or step >= args.steps:
146+
if mon_sess.should_stop() or step >= args.steps:
165147
tf_feed.terminate()
166148

167149
# Ask for all the services to stop.
168-
print("{0} stopping supervisor".format(datetime.now().isoformat()))
169-
sv.stop()
170-
150+
print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
151+
summary_writer.close()

0 commit comments

Comments
 (0)