Skip to content

Commit 929adca

Browse files
author
pikachu
authored
Update mnist_dist.py
1 parent 0aac8b2 commit 929adca

File tree

1 file changed

+39
-51
lines changed

1 file changed

+39
-51
lines changed

examples/mnist/spark/mnist_dist.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Yahoo Inc.
1+
#Copyright 2018 Yahoo Inc.
22
# Licensed under the terms of the Apache 2.0 license.
33
# Please see LICENSE file in the project root for terms.
44

@@ -58,37 +58,38 @@ def feed_dict(batch):
5858
worker_device="/job:worker/task:%d" % task_index,
5959
cluster=cluster)):
6060

61-
# Variables of the hidden layer
62-
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
63-
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
64-
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
65-
tf.summary.histogram("hidden_weights", hid_w)
66-
67-
# Variables of the softmax layer
68-
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
69-
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
70-
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
71-
tf.summary.histogram("softmax_weights", sm_w)
72-
73-
# Placeholders or QueueRunner/Readers for input data
74-
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
75-
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
76-
77-
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
78-
tf.summary.image("x_img", x_img)
79-
80-
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
81-
hid = tf.nn.relu(hid_lin)
82-
83-
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
61+
# Placeholders or QueueRunner/Readers for input data
62+
with tf.name_scope('inputs'):
63+
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS] , name="x")
64+
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
65+
66+
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
67+
tf.summary.image("x_img", x_img)
68+
69+
with tf.name_scope('layer'):
70+
# Variables of the hidden layer
71+
with tf.name_scope('hidden_layer'):
72+
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
73+
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
74+
tf.summary.histogram("hidden_weights", hid_w)
75+
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
76+
hid = tf.nn.relu(hid_lin)
77+
78+
# Variables of the softmax layer
79+
with tf.name_scope('softmax_layer'):
80+
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
81+
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
82+
tf.summary.histogram("softmax_weights", sm_w)
83+
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
8484

8585
global_step = tf.train.get_or_create_global_step()
8686

87-
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
88-
tf.summary.scalar("loss", loss)
87+
with tf.name_scope('loss'):
88+
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
89+
tf.summary.scalar("loss", loss)
8990

90-
train_op = tf.train.AdagradOptimizer(0.01).minimize(
91-
loss, global_step=global_step)
91+
with tf.name_scope('train'):
92+
train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step)
9293

9394
# Test trained model
9495
label = tf.argmax(y_, 1, name="label")
@@ -100,39 +101,27 @@ def feed_dict(batch):
100101

101102
summary_op = tf.summary.merge_all()
102103

103-
# Create a "MonitoredTrainingSession", which oversees the training process and stores model state into HDFS
104104
logdir = ctx.absolute_path(args.model)
105+
# logdir = args.model
105106
print("tensorflow model path: {0}".format(logdir))
106107
hooks = [tf.train.StopAtStepHook(last_step=100000)]
107108

108109
if job_name == "worker" and task_index == 0:
109110
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
110111

111-
if args.mode == "train":
112-
with tf.train.MonitoredTrainingSession(master=server.target,
113-
is_chief=(task_index == 0),
114-
checkpoint_dir=logdir,
115-
hooks=hooks,
116-
) as mon_sess:
117-
else:
118-
sv = tf.train.Supervisor(is_chief=(task_index == 0),
119-
logdir=logdir,
120-
summary_op=None,
121-
saver=saver,
122-
global_step=global_step,
123-
stop_grace_secs=300,
124-
save_model_secs=0)
125-
126112
# The MonitoredTrainingSession takes care of session initialization, restoring from
127113
# a checkpoint, and closing when done or an error occurs
114+
with tf.train.MonitoredTrainingSession(master=server.target,
115+
is_chief=(task_index == 0),
116+
checkpoint_dir=logdir,
117+
hooks=hooks) as mon_sess:
128118

129-
# Loop until the supervisor shuts down or 1000000 steps have completed.
130119
step = 0
131120
tf_feed = ctx.get_data_feed(args.mode == "train")
132121
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
133-
# Run a training step asynchronously.
134-
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
135-
# perform *synchronous* training.
122+
# Run a training step asynchronously
123+
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
124+
# perform *synchronous* training.
136125

137126
# using feed_dict
138127
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
@@ -152,12 +141,11 @@ def feed_dict(batch):
152141

153142
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
154143
tf_feed.batch_results(results)
155-
print("acc: {0}".format(acc))
144+
print("results: {0}, acc: {1}".format(results, acc))
156145

157146
if mon_sess.should_stop() or step >= args.steps:
158147
tf_feed.terminate()
159148

160149
# Ask for all the services to stop.
161150
print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
162-
163-
151+
summary_writer.close()

0 commit comments

Comments
 (0)