Skip to content

Commit a9d661f

Browse files
authored
Merge pull request #249 from yahoo/leewyang_mnist_pep8
pep8 for mnist examples; minor fix for latest spark/mnist_dist.py
2 parents db0a726 + b133b8e commit a9d661f

14 files changed

+186
-185
lines changed

examples/mnist/keras/mnist_mlp.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import print_function
88

9+
910
def main_fun(args, ctx):
1011
import numpy
1112
import os
@@ -16,12 +17,9 @@ def main_fun(args, ctx):
1617
from tensorflow.contrib.keras.api.keras.layers import Dense, Dropout
1718
from tensorflow.contrib.keras.api.keras.optimizers import RMSprop
1819
from tensorflow.contrib.keras.python.keras.callbacks import LambdaCallback, TensorBoard
19-
2020
from tensorflow.python.saved_model import builder as saved_model_builder
2121
from tensorflow.python.saved_model import tag_constants
2222
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
23-
24-
2523
from tensorflowonspark import TFNode
2624

2725
cluster, server = TFNode.start_cluster_server(ctx)
@@ -44,8 +42,8 @@ def generate_rdd_data(tf_feed, batch_size):
4442
yield (images, labels)
4543

4644
with tf.device(tf.train.replica_device_setter(
47-
worker_device="/job:worker/task:%d" % ctx.task_index,
48-
cluster=cluster)):
45+
worker_device="/job:worker/task:%d" % ctx.task_index,
46+
cluster=cluster)):
4947

5048
IMAGE_PIXELS = 28
5149
batch_size = 100
@@ -98,21 +96,20 @@ def save_checkpoint(epoch, logs=None):
9896

9997
if args.input_mode == 'tf':
10098
# train & validate on in-memory data
101-
history = model.fit(x_train, y_train,
102-
batch_size=batch_size,
103-
epochs=args.epochs,
104-
verbose=1,
105-
validation_data=(x_test, y_test),
106-
callbacks=callbacks)
99+
model.fit(x_train, y_train,
100+
batch_size=batch_size,
101+
epochs=args.epochs,
102+
verbose=1,
103+
validation_data=(x_test, y_test),
104+
callbacks=callbacks)
107105
else: # args.input_mode == 'spark':
108106
# train on data read from a generator which is producing data from a Spark RDD
109107
tf_feed = TFNode.DataFeed(ctx.mgr)
110-
history = model.fit_generator(
111-
generator=generate_rdd_data(tf_feed, batch_size),
112-
steps_per_epoch=args.steps_per_epoch,
113-
epochs=args.epochs,
114-
verbose=1,
115-
callbacks=callbacks)
108+
model.fit_generator(generator=generate_rdd_data(tf_feed, batch_size),
109+
steps_per_epoch=args.steps_per_epoch,
110+
epochs=args.epochs,
111+
verbose=1,
112+
callbacks=callbacks)
116113

117114
if args.export_dir and ctx.job_name == 'worker' and ctx.task_index == 0:
118115
# save a local Keras model, so we can reload it with an inferencing learning_phase
@@ -125,11 +122,11 @@ def save_checkpoint(epoch, logs=None):
125122
# export a saved_model for inferencing
126123
builder = saved_model_builder.SavedModelBuilder(args.export_dir)
127124
signature = predict_signature_def(inputs={'images': new_model.input},
128-
outputs={'scores': new_model.output})
125+
outputs={'scores': new_model.output})
129126
builder.add_meta_graph_and_variables(sess=sess,
130-
tags=[tag_constants.SERVING],
131-
signature_def_map={'predict': signature},
132-
clear_devices=True)
127+
tags=[tag_constants.SERVING],
128+
signature_def_map={'predict': signature},
129+
clear_devices=True)
133130
builder.save()
134131

135132
if args.input_mode == 'spark':
@@ -160,7 +157,7 @@ def save_checkpoint(epoch, logs=None):
160157
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
161158

162159
args = parser.parse_args()
163-
print("args:",args)
160+
print("args:", args)
164161

165162
if args.input_mode == 'tf':
166163
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir)

examples/mnist/spark/mnist_dist.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#Copyright 2018 Yahoo Inc.
1+
# Copyright 2018 Yahoo Inc.
22
# Licensed under the terms of the Apache 2.0 license.
33
# Please see LICENSE file in the project root for terms.
44

@@ -9,9 +9,11 @@
99
from __future__ import nested_scopes
1010
from __future__ import print_function
1111

12+
1213
def print_log(worker_num, arg):
1314
print("{0}: {1}".format(worker_num, arg))
1415

16+
1517
def map_fun(args, ctx):
1618
from datetime import datetime
1719
import math
@@ -30,7 +32,7 @@ def map_fun(args, ctx):
3032
# Parameters
3133
IMAGE_PIXELS = 28
3234
hidden_units = 128
33-
batch_size = args.batch_size
35+
batch_size = args.batch_size
3436

3537
# Get TF cluster and server instances
3638
cluster, server = ctx.start_cluster_server(1, args.rdma)
@@ -55,28 +57,28 @@ def feed_dict(batch):
5557

5658
# Assigns ops to the local worker by default.
5759
with tf.device(tf.train.replica_device_setter(
58-
worker_device="/job:worker/task:%d" % task_index,
59-
cluster=cluster)):
60+
worker_device="/job:worker/task:%d" % task_index,
61+
cluster=cluster)):
6062

61-
# Placeholders or QueueRunner/Readers for input data
63+
# Placeholders or QueueRunner/Readers for input data
6264
with tf.name_scope('inputs'):
63-
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS] , name="x")
65+
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
6466
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
65-
67+
6668
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
6769
tf.summary.image("x_img", x_img)
68-
70+
6971
with tf.name_scope('layer'):
7072
# Variables of the hidden layer
7173
with tf.name_scope('hidden_layer'):
72-
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
74+
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
7375
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
7476
tf.summary.histogram("hidden_weights", hid_w)
7577
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
7678
hid = tf.nn.relu(hid_lin)
77-
79+
7880
# Variables of the softmax layer
79-
with tf.name_scope('softmax_layer'):
81+
with tf.name_scope('softmax_layer'):
8082
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
8183
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
8284
tf.summary.histogram("softmax_weights", sm_w)
@@ -93,7 +95,7 @@ def feed_dict(batch):
9395

9496
# Test trained model
9597
label = tf.argmax(y_, 1, name="label")
96-
prediction = tf.argmax(y, 1,name="prediction")
98+
prediction = tf.argmax(y, 1, name="prediction")
9799
correct_prediction = tf.equal(prediction, label)
98100

99101
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
@@ -102,10 +104,9 @@ def feed_dict(batch):
102104
summary_op = tf.summary.merge_all()
103105

104106
logdir = ctx.absolute_path(args.model)
105-
# logdir = args.model
106107
print("tensorflow model path: {0}".format(logdir))
107108
hooks = [tf.train.StopAtStepHook(last_step=100000)]
108-
109+
109110
if job_name == "worker" and task_index == 0:
110111
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
111112

@@ -119,9 +120,9 @@ def feed_dict(batch):
119120
step = 0
120121
tf_feed = ctx.get_data_feed(args.mode == "train")
121122
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
122-
# Run a training step asynchronously
123-
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
124-
# perform *synchronous* training.
123+
# Run a training step asynchronously
124+
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
125+
# perform *synchronous* training.
125126

126127
# using feed_dict
127128
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
@@ -132,14 +133,14 @@ def feed_dict(batch):
132133
_, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed)
133134
# print accuracy and save model checkpoint to HDFS every 100 steps
134135
if (step % 100 == 0):
135-
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
136+
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy, {x: batch_xs, y_: batch_ys})))
136137

137138
if task_index == 0:
138139
summary_writer.add_summary(summary, step)
139140
else: # args.mode == "inference"
140141
labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed)
141142

142-
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
143+
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l, p in zip(labels, preds)]
143144
tf_feed.batch_results(results)
144145
print("results: {0}, acc: {1}".format(results, acc))
145146

@@ -148,4 +149,6 @@ def feed_dict(batch):
148149

149150
# Ask for all the services to stop.
150151
print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
151-
summary_writer.close()
152+
153+
if job_name == "worker" and task_index == 0:
154+
summary_writer.close()

examples/mnist/spark/mnist_dist_dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from __future__ import nested_scopes
1010
from __future__ import print_function
1111

12+
1213
def print_log(worker_num, arg):
1314
print("{0}: {1}".format(worker_num, arg))
1415

16+
1517
def map_fun(args, ctx):
1618
from tensorflowonspark import TFNode
1719
from datetime import datetime
@@ -48,8 +50,8 @@ def rdd_generator():
4850

4951
# Assigns ops to the local worker by default.
5052
with tf.device(tf.train.replica_device_setter(
51-
worker_device="/job:worker/task:%d" % task_index,
52-
cluster=cluster)):
53+
worker_device="/job:worker/task:%d" % task_index,
54+
cluster=cluster)):
5355

5456
# Dataset for input data
5557
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10]))).batch(args.batch_size)
@@ -58,13 +60,13 @@ def rdd_generator():
5860

5961
# Variables of the hidden layer
6062
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
61-
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
63+
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
6264
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
6365
tf.summary.histogram("hidden_weights", hid_w)
6466

6567
# Variables of the softmax layer
6668
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
67-
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
69+
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
6870
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
6971
tf.summary.histogram("softmax_weights", sm_w)
7072

@@ -90,7 +92,7 @@ def rdd_generator():
9092

9193
# Test trained model
9294
label = tf.argmax(y_, 1, name="label")
93-
prediction = tf.argmax(y, 1,name="prediction")
95+
prediction = tf.argmax(y, 1, name="prediction")
9496
correct_prediction = tf.equal(prediction, label)
9597

9698
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
@@ -146,7 +148,7 @@ def rdd_generator():
146148
else: # args.mode == "inference"
147149
labels, preds, acc = sess.run([label, prediction, accuracy])
148150

149-
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
151+
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l, p in zip(labels, preds)]
150152
tf_feed.batch_results(results)
151153
print("acc: {0}".format(acc))
152154

@@ -156,4 +158,3 @@ def rdd_generator():
156158
# Ask for all the services to stop.
157159
print("{0} stopping supervisor".format(datetime.now().isoformat()))
158160
sv.stop()
159-

examples/mnist/spark/mnist_dist_pipeline.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from __future__ import nested_scopes
1010
from __future__ import print_function
1111

12+
1213
def print_log(worker_num, arg):
1314
print("{0}: {1}".format(worker_num, arg))
1415

16+
1517
def map_fun(args, ctx):
1618
from tensorflowonspark import TFNode
1719
from datetime import datetime
@@ -32,7 +34,7 @@ def map_fun(args, ctx):
3234

3335
# Parameters
3436
hidden_units = 128
35-
batch_size = args.batch_size
37+
batch_size = args.batch_size
3638

3739
# Get TF cluster and server instances
3840
cluster, server = TFNode.start_cluster_server(ctx, 1, args.protocol == 'rdma')
@@ -54,18 +56,18 @@ def feed_dict(batch):
5456

5557
# Assigns ops to the local worker by default.
5658
with tf.device(tf.train.replica_device_setter(
57-
worker_device="/job:worker/task:%d" % task_index,
58-
cluster=cluster)):
59+
worker_device="/job:worker/task:%d" % task_index,
60+
cluster=cluster)):
5961

6062
# Variables of the hidden layer
6163
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
62-
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
64+
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
6365
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
6466
tf.summary.histogram("hidden_weights", hid_w)
6567

6668
# Variables of the softmax layer
6769
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
68-
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
70+
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
6971
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
7072
tf.summary.histogram("softmax_weights", sm_w)
7173

@@ -91,7 +93,7 @@ def feed_dict(batch):
9193

9294
# Test trained model
9395
label = tf.argmax(y_, 1, name="label")
94-
prediction = tf.argmax(y, 1,name="prediction")
96+
prediction = tf.argmax(y, 1, name="prediction")
9597
correct_prediction = tf.equal(prediction, label)
9698

9799
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
@@ -122,7 +124,6 @@ def feed_dict(batch):
122124

123125
# Loop until the supervisor shuts down or 1000000 steps have completed.
124126
step = 0
125-
#tf_feed = TFNode.DataFeed(ctx.mgr)
126127
tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping)
127128
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
128129
# Run a training step asynchronously.
@@ -137,7 +138,7 @@ def feed_dict(batch):
137138
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
138139
# print accuracy and save model checkpoint to HDFS every 100 steps
139140
if (step % 100 == 0):
140-
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
141+
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy, {x: batch_xs, y_: batch_ys})))
141142

142143
if sv.is_chief:
143144
summary_writer.add_summary(summary, step)
@@ -150,13 +151,13 @@ def feed_dict(batch):
150151
# exported signatures defined in code
151152
signatures = {
152153
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
153-
'inputs': { 'image': x },
154-
'outputs': { 'prediction': prediction },
154+
'inputs': {'image': x},
155+
'outputs': {'prediction': prediction},
155156
'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
156157
},
157158
'featurize': {
158-
'inputs': { 'image': x },
159-
'outputs': { 'features': hid },
159+
'inputs': {'image': x},
160+
'outputs': {'features': hid},
160161
'method_name': 'featurize'
161162
}
162163
}

0 commit comments

Comments
 (0)