Skip to content

Commit d6acd8b

Browse files
authored
Merge pull request #379 from yahoo/leewyang_more_keras
Default keras example to tf.estimator
2 parents e392a02 + 5fc70a4 commit d6acd8b

File tree

5 files changed

+236
-97
lines changed

5 files changed

+236
-97
lines changed

examples/mnist/keras/README.md

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
Original Source: https://github.com/fchollet/keras/blob/master/examples/mnist_mlp.py
44

5-
This is the MNIST Multi Layer Perceptron example from the [Keras examples](https://github.com/fchollet/keras/blob/master/examples), adapted for TensorFlowOnSpark.
5+
This is the MNIST Multi Layer Perceptron example from the [Keras examples](https://github.com/fchollet/keras/blob/master/examples), adapted for the `tf.estimator` API and TensorFlowOnSpark.
66

77
Notes:
88
- This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
9-
- Keras currently saves model checkpoints as [HDF5](https://support.hdfgroup.org/HDF5/) using the [h5py package](http://www.h5py.org/). Unfortunately, this is not currently supported on HDFS. Consequently, this example demonstrates how to save standard TensorFlow model checkpoints on HDFS via a Keras LambdaCallback. If you don't need HDFS support, you can use the standard ModelCheckpoint instead.
109
- InputMode.SPARK only supports feeding data from a single RDD, so the validation dataset/code is disabled in the corresponding example.
1110

1211
#### Launch the Spark Standalone cluster
@@ -24,19 +23,18 @@ Notes:
2423
In this mode, each worker will load the entire MNIST dataset into memory (automatically downloading the dataset if needed).
2524

2625
# remove any old artifacts
27-
rm -rf ${TFoS_HOME}/mnist_model ${TFoS_HOME}/mnist_export
26+
rm -rf ${TFoS_HOME}/mnist_model
2827

2928
# train and validate
3029
${SPARK_HOME}/bin/spark-submit \
3130
--master ${MASTER} \
3231
--conf spark.cores.max=${TOTAL_CORES} \
3332
--conf spark.task.cpus=${CORES_PER_WORKER} \
3433
--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \
35-
${TFoS_HOME}/examples/mnist/keras/mnist_mlp.py \
34+
${TFoS_HOME}/examples/mnist/keras/mnist_mlp_estimator.py \
3635
--cluster_size ${SPARK_WORKER_INSTANCES} \
3736
--input_mode tf \
3837
--model_dir ${TFoS_HOME}/mnist_model \
39-
--export_dir ${TFoS_HOME}/mnist_export \
4038
--epochs 5 \
4139
--tensorboard
4240

@@ -56,25 +54,75 @@ In this mode, Spark will distribute the MNIST dataset (as CSV) across the worker
5654
ls -lR ${TFoS_HOME}/mnist/csv
5755

5856
# remove any old artifacts
59-
rm -rf ${TFoS_HOME}/mnist_model ${TFoS_HOME}/mnist_export
57+
rm -rf ${TFoS_HOME}/mnist_model
6058

61-
# train and validate
59+
# train
6260
${SPARK_HOME}/bin/spark-submit \
6361
--master ${MASTER} \
6462
--conf spark.cores.max=${TOTAL_CORES} \
6563
--conf spark.task.cpus=${CORES_PER_WORKER} \
6664
--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \
67-
${TFoS_HOME}/examples/mnist/keras/mnist_mlp.py \
65+
${TFoS_HOME}/examples/mnist/keras/mnist_mlp_estimator.py \
6866
--cluster_size ${SPARK_WORKER_INSTANCES} \
6967
--input_mode spark \
7068
--images ${TFoS_HOME}/mnist/csv/train/images \
7169
--labels ${TFoS_HOME}/mnist/csv/train/labels \
7270
--epochs 5 \
7371
--model_dir ${TFoS_HOME}/mnist_model \
74-
--export_dir ${TFoS_HOME}/mnist_export \
7572
--tensorboard
7673

74+
#### Inference via saved_model_cli
75+
76+
The training code will automatically export a TensorFlow SavedModel, which can be used with the `saved_model_cli` from the command line, as follows:
77+
78+
# path to the SavedModel export
79+
export SAVED_MODEL=${TFoS_HOME}/mnist_model/export/serving/*
80+
81+
# use a CSV formatted test example
82+
IMG=$(head -n 1 $TFoS_HOME/examples/mnist/csv/test/images/part-00000)
83+
84+
# introspect model
85+
saved_model_cli show --dir $SAVED_MODEL --all
86+
87+
# inference via saved_model_cli
88+
saved_model_cli run --dir $SAVED_MODEL --tag_set serve --signature_def serving_default --input_exp "dense_input=[[$IMG]]"
89+
90+
#### Inference via TF-Serving
91+
92+
For online inferencing use cases, you can serve the SavedModel via a TensorFlow Serving instance as follows. Note that TF-Serving provides both GRPC and REST APIs, but we will only
93+
demonstrate the use of the REST API. Also, [per the TensorFlow Serving instructions](https://www.tensorflow.org/serving/), we will run the serving instance inside a Docker container.
94+
95+
# Start the TF-Serving instance in a docker container
96+
docker pull tensorflow/serving
97+
docker run -t --rm -p 8501:8501 -v "${TFoS_HOME}/mnist_model/export/serving:/models/mnist" -e MODEL_NAME=mnist tensorflow/serving &
98+
99+
# GET model status
100+
curl http://localhost:8501/v1/models/mnist
101+
102+
# GET model metadata
103+
curl http://localhost:8501/v1/models/mnist/metadata
104+
105+
# POST example for inferencing
106+
curl -v -d "{\"instances\": [ {\"dense_input\": [$IMG] } ]}" -X POST http://localhost:8501/v1/models/mnist:predict
107+
108+
# Stop the TF-Serving container
109+
docker stop $(docker ps -q)
110+
111+
#### Run Parallel Inferencing via Spark
112+
113+
For batch inferencing use cases, you can use Spark to run multiple single-node TensorFlow instances in parallel (on the Spark executors). Each executor/instance will operate independently on a shard of the dataset. Note that this requires that the model fits in the memory of each executor.
114+
115+
# remove any old artifacts
116+
rm -Rf ${TFoS_HOME}/predictions
117+
118+
# inference
119+
${SPARK_HOME}/bin/spark-submit \
120+
--master $MASTER ${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \
121+
--cluster_size 3 \
122+
--images_labels ${TFoS_HOME}/mnist/tfr/test \
123+
--export ${TFoS_HOME}/mnist_model/export/serving/* \
124+
--output ${TFoS_HOME}/predictions
125+
77126
#### Shutdown the Spark Standalone cluster
78127

79128
${SPARK_HOME}/sbin/stop-slave.sh; ${SPARK_HOME}/sbin/stop-master.sh
80-
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2018 Yahoo Inc.
2+
# Licensed under the terms of the Apache 2.0 license.
3+
# Please see LICENSE file in the project root for terms.
4+
5+
# This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel.
6+
#
7+
# Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing.
8+
# However, in some situations, you may have a SavedModel without the original code for defining the inferencing
9+
# graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor,
10+
# where each executor can independently load the model and inference on input data.
11+
#
12+
# Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing,
13+
# but it could also be adapted to just use an RDD of TFRecords from Spark.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import argparse
20+
import numpy as np
21+
import tensorflow as tf
22+
23+
IMAGE_PIXELS = 28
24+
25+
26+
def inference(it, num_workers, args):
27+
from tensorflowonspark import util
28+
29+
# consume worker number from RDD partition iterator
30+
for i in it:
31+
worker_num = i
32+
print("worker_num: {}".format(i))
33+
34+
# setup env for single-node TF
35+
util.single_node_env()
36+
37+
# load saved_model using default tag and signature
38+
sess = tf.Session()
39+
tf.saved_model.loader.load(sess, ['serve'], args.export)
40+
41+
# parse function for TFRecords
42+
def parse_tfr(example_proto):
43+
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
44+
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
45+
features = tf.parse_single_example(example_proto, feature_def)
46+
norm = tf.constant(255, dtype=tf.float32, shape=(784,))
47+
image = tf.div(tf.to_float(features['image']), norm)
48+
label = tf.to_float(features['label'])
49+
return (image, label)
50+
51+
# define a new tf.data.Dataset (for inferencing)
52+
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
53+
ds = ds.shard(num_workers, worker_num)
54+
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1)
55+
ds = ds.map(parse_tfr).batch(10)
56+
iterator = ds.make_one_shot_iterator()
57+
image_label = iterator.get_next(name='inf_image')
58+
59+
# create an output file per spark worker for the predictions
60+
tf.gfile.MakeDirs(args.output)
61+
output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w')
62+
63+
while True:
64+
try:
65+
# get images and labels from tf.data.Dataset
66+
img, lbl = sess.run(['inf_image:0', 'inf_image:1'])
67+
68+
# inference by feeding these images and labels into the input tensors
69+
# you can view the exported model signatures via:
70+
# saved_model_cli show --dir <export_dir> --all
71+
72+
# note that we feed directly into the graph tensors (bypassing the exported signatures)
73+
# these tensors will be shown in the "name" field of the signature definitions
74+
75+
outputs = sess.run(['dense_2/Softmax:0'], feed_dict={'Placeholder:0': img})
76+
for p in outputs[0]:
77+
output_file.write("{}\n".format(np.argmax(p)))
78+
except tf.errors.OutOfRangeError:
79+
break
80+
81+
output_file.close()
82+
83+
84+
if __name__ == '__main__':
85+
from pyspark.context import SparkContext
86+
from pyspark.conf import SparkConf
87+
88+
sc = SparkContext(conf=SparkConf().setAppName("mnist_inference"))
89+
executors = sc._conf.get("spark.executor.instances")
90+
num_executors = int(executors) if executors is not None else 1
91+
92+
parser = argparse.ArgumentParser()
93+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors)
94+
parser.add_argument('--images_labels', type=str, help='Directory for input images with labels')
95+
parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export")
96+
parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions")
97+
args, _ = parser.parse_known_args()
98+
print("args: {}".format(args))
99+
100+
# Not using TFCluster... just running single-node TF instances on each executor
101+
nodes = list(range(args.cluster_size))
102+
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
103+
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))

0 commit comments

Comments
 (0)