Skip to content

Commit 31d6464

Browse files
committed
add tensorboard arg to estimator example; add keras/estimator example
1 parent 3c8b44d commit 31d6464

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

examples/mnist/estimator/mnist_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,10 @@ def main(args, ctx):
179179
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
180180
parser.add_argument("--num_ps", help="number of PS nodes in cluster", type=int, default=1)
181181
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
182+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
183+
182184
args = parser.parse_args()
183185
print("args:", args)
184186

185-
cluster = TFCluster.run(sc, main, args, args.cluster_size, args.num_ps, tensorboard=False, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model, master_node='master')
187+
cluster = TFCluster.run(sc, main, args, args.cluster_size, args.num_ps, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model, master_node='master')
186188
cluster.shutdown()
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy
2+
import tensorflow as tf
3+
from tensorflow.python import keras
4+
from tensorflow.python.keras.models import Sequential
5+
from tensorflow.python.keras.layers import Dense, Dropout
6+
from tensorflow.python.keras.optimizers import RMSprop
7+
from tensorflowonspark import TFNode
8+
9+
10+
def main_fun(args, ctx):
11+
IMAGE_PIXELS = 28
12+
num_classes = 10
13+
14+
# use Keras API to load data
15+
from tensorflow.python.keras.datasets import mnist
16+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
17+
x_train = x_train.reshape(60000, 784)
18+
x_test = x_test.reshape(10000, 784)
19+
x_train = x_train.astype('float32') / 255
20+
x_test = x_test.astype('float32') / 255
21+
22+
# convert class vectors to binary class matrices
23+
y_train = keras.utils.to_categorical(y_train, num_classes)
24+
y_test = keras.utils.to_categorical(y_test, num_classes)
25+
26+
# setup a Keras model
27+
model = Sequential()
28+
model.add(Dense(512, activation='relu', input_shape=(784,)))
29+
model.add(Dropout(0.2))
30+
model.add(Dense(512, activation='relu'))
31+
model.add(Dropout(0.2))
32+
model.add(Dense(10, activation='softmax'))
33+
model.compile(loss='categorical_crossentropy',
34+
optimizer=RMSprop(),
35+
metrics=['accuracy'])
36+
model.summary()
37+
38+
# convert Keras model to tf.estimator
39+
estimator = tf.keras.estimator.model_to_estimator(model, model_dir=args.model_dir)
40+
41+
# setup train_input_fn for InputMode.TENSORFLOW or InputMode.SPARK
42+
if args.input_mode == 'tf':
43+
train_input_fn = tf.estimator.inputs.numpy_input_fn(
44+
x={"dense_1_input": x_train},
45+
y=y_train,
46+
batch_size=128,
47+
num_epochs=None,
48+
shuffle=True)
49+
else: # 'spark'
50+
tf_feed = TFNode.DataFeed(ctx.mgr)
51+
52+
def rdd_generator():
53+
while not tf_feed.should_stop():
54+
batch = tf_feed.next_batch(1)
55+
if len(batch) > 0:
56+
record = batch[0]
57+
image = numpy.array(record[0]).astype(numpy.float32) / 255.0
58+
label = numpy.array(record[1]).astype(numpy.float32)
59+
yield (image, label)
60+
61+
def train_input_fn():
62+
ds = tf.data.Dataset.from_generator(rdd_generator,
63+
(tf.float32, tf.float32),
64+
(tf.TensorShape([IMAGE_PIXELS * IMAGE_PIXELS]), tf.TensorShape([10])))
65+
ds = ds.batch(args.batch_size)
66+
return ds
67+
68+
# eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
69+
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
70+
x={"dense_1_input": x_test},
71+
y=y_test,
72+
num_epochs=args.epochs,
73+
shuffle=False)
74+
75+
# setup tf.estimator.train_and_evaluate()
76+
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=args.steps)
77+
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
78+
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
79+
80+
# export a saved_model, if export_dir provided
81+
if args.export_dir:
82+
def serving_input_receiver_fn():
83+
"""An input receiver that expects a serialized tf.Example."""
84+
serialized_tf_example = tf.placeholder(dtype=tf.string,
85+
shape=[args.batch_size],
86+
name='input_example_tensor')
87+
receiver_tensors = {'dense_1_input': serialized_tf_example}
88+
feature_spec = {'dense_1_input': tf.FixedLenFeature(784, tf.string)}
89+
features = tf.parse_example(serialized_tf_example, feature_spec)
90+
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
91+
92+
estimator.export_savedmodel(args.export_dir, serving_input_receiver_fn)
93+
94+
95+
if __name__ == '__main__':
96+
import argparse
97+
from pyspark.context import SparkContext
98+
from pyspark.conf import SparkConf
99+
from tensorflowonspark import TFCluster
100+
101+
sc = SparkContext(conf=SparkConf().setAppName("mnist_mlp"))
102+
executors = sc._conf.get("spark.executor.instances")
103+
num_executors = int(executors) if executors is not None else 1
104+
num_ps = 1
105+
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
108+
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
109+
parser.add_argument("--epochs", help="number of epochs of training data", type=int, default=1)
110+
parser.add_argument("--export_dir", help="directory to export saved_model")
111+
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized CSV format")
112+
parser.add_argument("--input_mode", help="input mode (tf|spark)", default="tf")
113+
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized CSV format")
114+
parser.add_argument("--model_dir", help="directory to write model checkpoints")
115+
parser.add_argument("--num_ps", help="number of ps nodes", type=int, default=1)
116+
parser.add_argument("--steps", help="max number of steps to train", type=int, default=2000)
117+
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
118+
119+
args = parser.parse_args()
120+
print("args:", args)
121+
122+
if args.input_mode == 'tf':
123+
# for TENSORFLOW mode, each node will load/train entire dataset in memory per original example
124+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='master')
125+
cluster.shutdown()
126+
else: # 'spark'
127+
# for SPARK mode, just use CSV format as an example
128+
images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
129+
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
130+
dataRDD = images.zip(labels)
131+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
132+
cluster.train(dataRDD, args.epochs)
133+
cluster.shutdown()

0 commit comments

Comments
 (0)