Skip to content

Commit 5fc70a4

Browse files
committed
add parallel inferencing code to mnist/keras example
1 parent 07494c0 commit 5fc70a4

File tree

4 files changed

+152
-45
lines changed

4 files changed

+152
-45
lines changed

examples/mnist/keras/README.md

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,43 +71,58 @@ In this mode, Spark will distribute the MNIST dataset (as CSV) across the worker
7171
--model_dir ${TFoS_HOME}/mnist_model \
7272
--tensorboard
7373

74+
#### Inference via saved_model_cli
7475

75-
#### Shutdown the Spark Standalone cluster
76+
The training code will automatically export a TensorFlow SavedModel, which can be used with the `saved_model_cli` from the command line, as follows:
7677

77-
${SPARK_HOME}/sbin/stop-slave.sh; ${SPARK_HOME}/sbin/stop-master.sh
78+
# path to the SavedModel export
79+
export SAVED_MODEL=${TFoS_HOME}/mnist_model/export/serving/*
80+
81+
# use a CSV formatted test example
82+
IMG=$(head -n 1 $TFoS_HOME/examples/mnist/csv/test/images/part-00000)
83+
84+
# introspect model
85+
saved_model_cli show --dir $SAVED_MODEL --all
86+
87+
# inference via saved_model_cli
88+
saved_model_cli run --dir $SAVED_MODEL --tag_set serve --signature_def serving_default --input_exp "dense_input=[[$IMG]]"
7889

7990
#### Inference via TF-Serving
8091

81-
The training code will automatically export a TensorFlow SavedModel, which can be used with TensorFlow Serving as follows.
92+
For online inferencing use cases, you can serve the SavedModel via a TensorFlow Serving instance as follows. Note that TF-Serving provides both GRPC and REST APIs, but we will only
93+
demonstrate the use of the REST API. Also, [per the TensorFlow Serving instructions](https://www.tensorflow.org/serving/), we will run the serving instance inside a Docker container.
8294

83-
Note: we use Docker to run the TF-Serving instance, per [recommendation](https://www.tensorflow.org/serving/).
84-
```
85-
# path to the SavedModel export
86-
export MODEL=${TFoS_HOME}/mnist_model/export/serving/*
95+
# Start the TF-Serving instance in a docker container
96+
docker pull tensorflow/serving
97+
docker run -t --rm -p 8501:8501 -v "${TFoS_HOME}/mnist_model/export/serving:/models/mnist" -e MODEL_NAME=mnist tensorflow/serving &
8798

88-
# use the CSV formatted data as a single example
89-
IMG=$(head -n 1 $TFoS_HOME/examples/mnist/csv/test/images/part-00000)
99+
# GET model status
100+
curl http://localhost:8501/v1/models/mnist
90101

91-
# introspect model
92-
saved_model_cli show --dir $MODEL --all
102+
# GET model metadata
103+
curl http://localhost:8501/v1/models/mnist/metadata
93104

94-
# inference via saved_model_cli
95-
saved_model_cli run --dir $MODEL --tag_set serve --signature_def serving_default --input_exp "dense_input=[[$IMG]]"
96-
# [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
105+
# POST example for inferencing
106+
curl -v -d "{\"instances\": [ {\"dense_input\": [$IMG] } ]}" -X POST http://localhost:8501/v1/models/mnist:predict
97107

98-
# START the TF-Serving instance in a docker container
99-
docker pull tensorflow/serving
100-
docker run -t --rm -p 8501:8501 -v "${TFoS_HOME}/mnist_model/export/serving:/models/mnist" -e MODEL_NAME=mnist tensorflow/serving &
108+
# Stop the TF-Serving container
109+
docker stop $(docker ps -q)
101110

102-
# GET model status
103-
curl http://localhost:8501/v1/models/mnist
111+
#### Run Parallel Inferencing via Spark
104112

105-
# GET model metadata
106-
curl http://localhost:8501/v1/models/mnist/metadata
113+
For batch inferencing use cases, you can use Spark to run multiple single-node TensorFlow instances in parallel (on the Spark executors). Each executor/instance will operate independently on a shard of the dataset. Note that this requires that the model fits in the memory of each executor.
107114

108-
# POST example for inferencing
109-
curl -v -d "{\"instances\": [ {\"dense_input\": [$IMG] } ]}" -X POST http://localhost:8501/v1/models/mnist:predict
115+
# remove any old artifacts
116+
rm -Rf ${TFoS_HOME}/predictions
110117

111-
# STOP the TF-Serving container
112-
docker stop $(docker ps -q)
113-
```
118+
# inference
119+
${SPARK_HOME}/bin/spark-submit \
120+
--master $MASTER ${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \
121+
--cluster_size 3 \
122+
--images_labels ${TFoS_HOME}/mnist/tfr/test \
123+
--export ${TFoS_HOME}/mnist_model/export/serving/* \
124+
--output ${TFoS_HOME}/predictions
125+
126+
#### Shutdown the Spark Standalone cluster
127+
128+
${SPARK_HOME}/sbin/stop-slave.sh; ${SPARK_HOME}/sbin/stop-master.sh
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2018 Yahoo Inc.
2+
# Licensed under the terms of the Apache 2.0 license.
3+
# Please see LICENSE file in the project root for terms.
4+
5+
# This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel.
6+
#
7+
# Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing.
8+
# However, in some situations, you may have a SavedModel without the original code for defining the inferencing
9+
# graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor,
10+
# where each executor can independently load the model and inference on input data.
11+
#
12+
# Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing,
13+
# but it could also be adapted to just use an RDD of TFRecords from Spark.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import argparse
20+
import numpy as np
21+
import tensorflow as tf
22+
23+
IMAGE_PIXELS = 28
24+
25+
26+
def inference(it, num_workers, args):
27+
from tensorflowonspark import util
28+
29+
# consume worker number from RDD partition iterator
30+
for i in it:
31+
worker_num = i
32+
print("worker_num: {}".format(i))
33+
34+
# setup env for single-node TF
35+
util.single_node_env()
36+
37+
# load saved_model using default tag and signature
38+
sess = tf.Session()
39+
tf.saved_model.loader.load(sess, ['serve'], args.export)
40+
41+
# parse function for TFRecords
42+
def parse_tfr(example_proto):
43+
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
44+
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
45+
features = tf.parse_single_example(example_proto, feature_def)
46+
norm = tf.constant(255, dtype=tf.float32, shape=(784,))
47+
image = tf.div(tf.to_float(features['image']), norm)
48+
label = tf.to_float(features['label'])
49+
return (image, label)
50+
51+
# define a new tf.data.Dataset (for inferencing)
52+
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
53+
ds = ds.shard(num_workers, worker_num)
54+
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1)
55+
ds = ds.map(parse_tfr).batch(10)
56+
iterator = ds.make_one_shot_iterator()
57+
image_label = iterator.get_next(name='inf_image')
58+
59+
# create an output file per spark worker for the predictions
60+
tf.gfile.MakeDirs(args.output)
61+
output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w')
62+
63+
while True:
64+
try:
65+
# get images and labels from tf.data.Dataset
66+
img, lbl = sess.run(['inf_image:0', 'inf_image:1'])
67+
68+
# inference by feeding these images and labels into the input tensors
69+
# you can view the exported model signatures via:
70+
# saved_model_cli show --dir <export_dir> --all
71+
72+
# note that we feed directly into the graph tensors (bypassing the exported signatures)
73+
# these tensors will be shown in the "name" field of the signature definitions
74+
75+
outputs = sess.run(['dense_2/Softmax:0'], feed_dict={'Placeholder:0': img})
76+
for p in outputs[0]:
77+
output_file.write("{}\n".format(np.argmax(p)))
78+
except tf.errors.OutOfRangeError:
79+
break
80+
81+
output_file.close()
82+
83+
84+
if __name__ == '__main__':
85+
from pyspark.context import SparkContext
86+
from pyspark.conf import SparkConf
87+
88+
sc = SparkContext(conf=SparkConf().setAppName("mnist_inference"))
89+
executors = sc._conf.get("spark.executor.instances")
90+
num_executors = int(executors) if executors is not None else 1
91+
92+
parser = argparse.ArgumentParser()
93+
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors)
94+
parser.add_argument('--images_labels', type=str, help='Directory for input images with labels')
95+
parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export")
96+
parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions")
97+
args, _ = parser.parse_known_args()
98+
print("args: {}".format(args))
99+
100+
# Not using TFCluster... just running single-node TF instances on each executor
101+
nodes = list(range(args.cluster_size))
102+
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
103+
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))

examples/mnist/keras/mnist_mlp_estimator.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def train_input_fn():
107107

108108
# WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745
109109
# wait for all other nodes to complete (via done files)
110-
done_dir = "{}/{}/done".format(ctx.absolute_path(args.model_dir), args.mode)
110+
done_dir = "{}/done".format(ctx.absolute_path(args.model_dir))
111111
print("Writing done file to: {}".format(done_dir))
112112
tf.gfile.MakeDirs(done_dir)
113113
with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file:
@@ -157,14 +157,6 @@ def train_input_fn():
157157
images = sc.textFile(args.images).map(lambda ln: [float(x) for x in ln.split(',')])
158158
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
159159
dataRDD = images.zip(labels)
160-
if args.mode == 'train':
161-
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
162-
cluster.train(dataRDD, args.epochs)
163-
cluster.shutdown()
164-
else:
165-
# Note: using "parallel" inferencing, not "cluster"
166-
# each node loads the model and runs independently of others
167-
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, 0, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir)
168-
resultRDD = cluster.inference(dataRDD)
169-
resultRDD.saveAsTextFile(args.output)
170-
cluster.shutdown()
160+
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='master')
161+
cluster.train(dataRDD, args.epochs)
162+
cluster.shutdown()

examples/mnist/tf/mnist_inference.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717
from __future__ import print_function
1818

1919
import argparse
20-
import logging
21-
import sys
2220
import tensorflow as tf
23-
import time
24-
import traceback
2521

2622
IMAGE_PIXELS = 28
2723

24+
2825
def inference(it, num_workers, args):
2926
from tensorflowonspark import util
3027

@@ -69,9 +66,10 @@ def parse_tfr(example_proto):
6966

7067
# inference by feeding these images and labels into the input tensors
7168
# you can view the exported model signatures via:
72-
# saved_model_cli show --dir mnist_export --all
69+
# saved_model_cli show --dir <saved_model> --all
7370

7471
# note that we feed directly into the graph tensors (bypassing the exported signatures)
72+
# these tensors will be shown in the "name" field of the signature definitions
7573
# also note that we can feed/fetch tensors that were not explicitly exported, e.g. `y_` and `label:0`
7674

7775
labels, preds = sess.run(['label:0', 'prediction:0'], feed_dict={'x:0': img, 'y_:0': lbl})
@@ -82,8 +80,8 @@ def parse_tfr(example_proto):
8280

8381
output_file.close()
8482

83+
8584
if __name__ == '__main__':
86-
import os
8785
from pyspark.context import SparkContext
8886
from pyspark.conf import SparkConf
8987

@@ -103,4 +101,3 @@ def parse_tfr(example_proto):
103101
nodes = list(range(args.cluster_size))
104102
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
105103
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))
106-

0 commit comments

Comments
 (0)