|
| 1 | +# Copyright 2018 Yahoo Inc. |
| 2 | +# Licensed under the terms of the Apache 2.0 license. |
| 3 | +# Please see LICENSE file in the project root for terms. |
| 4 | + |
| 5 | +# This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel. |
| 6 | +# |
| 7 | +# Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing. |
| 8 | +# However, in some situations, you may have a SavedModel without the original code for defining the inferencing |
| 9 | +# graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor, |
| 10 | +# where each executor can independently load the model and inference on input data. |
| 11 | +# |
| 12 | +# Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing, |
| 13 | +# but it could also be adapted to just use an RDD of TFRecords from Spark. |
| 14 | + |
| 15 | +from __future__ import absolute_import |
| 16 | +from __future__ import division |
| 17 | +from __future__ import print_function |
| 18 | + |
| 19 | +import argparse |
| 20 | +import numpy as np |
| 21 | +import tensorflow as tf |
| 22 | + |
| 23 | +IMAGE_PIXELS = 28 |
| 24 | + |
| 25 | + |
| 26 | +def inference(it, num_workers, args): |
| 27 | + from tensorflowonspark import util |
| 28 | + |
| 29 | + # consume worker number from RDD partition iterator |
| 30 | + for i in it: |
| 31 | + worker_num = i |
| 32 | + print("worker_num: {}".format(i)) |
| 33 | + |
| 34 | + # setup env for single-node TF |
| 35 | + util.single_node_env() |
| 36 | + |
| 37 | + # load saved_model using default tag and signature |
| 38 | + sess = tf.Session() |
| 39 | + tf.saved_model.loader.load(sess, ['serve'], args.export) |
| 40 | + |
| 41 | + # parse function for TFRecords |
| 42 | + def parse_tfr(example_proto): |
| 43 | + feature_def = {"label": tf.FixedLenFeature(10, tf.int64), |
| 44 | + "image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)} |
| 45 | + features = tf.parse_single_example(example_proto, feature_def) |
| 46 | + norm = tf.constant(255, dtype=tf.float32, shape=(784,)) |
| 47 | + image = tf.div(tf.to_float(features['image']), norm) |
| 48 | + label = tf.to_float(features['label']) |
| 49 | + return (image, label) |
| 50 | + |
| 51 | + # define a new tf.data.Dataset (for inferencing) |
| 52 | + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) |
| 53 | + ds = ds.shard(num_workers, worker_num) |
| 54 | + ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1) |
| 55 | + ds = ds.map(parse_tfr).batch(10) |
| 56 | + iterator = ds.make_one_shot_iterator() |
| 57 | + image_label = iterator.get_next(name='inf_image') |
| 58 | + |
| 59 | + # create an output file per spark worker for the predictions |
| 60 | + tf.gfile.MakeDirs(args.output) |
| 61 | + output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w') |
| 62 | + |
| 63 | + while True: |
| 64 | + try: |
| 65 | + # get images and labels from tf.data.Dataset |
| 66 | + img, lbl = sess.run(['inf_image:0', 'inf_image:1']) |
| 67 | + |
| 68 | + # inference by feeding these images and labels into the input tensors |
| 69 | + # you can view the exported model signatures via: |
| 70 | + # saved_model_cli show --dir <export_dir> --all |
| 71 | + |
| 72 | + # note that we feed directly into the graph tensors (bypassing the exported signatures) |
| 73 | + # these tensors will be shown in the "name" field of the signature definitions |
| 74 | + |
| 75 | + outputs = sess.run(['dense_2/Softmax:0'], feed_dict={'Placeholder:0': img}) |
| 76 | + for p in outputs[0]: |
| 77 | + output_file.write("{}\n".format(np.argmax(p))) |
| 78 | + except tf.errors.OutOfRangeError: |
| 79 | + break |
| 80 | + |
| 81 | + output_file.close() |
| 82 | + |
| 83 | + |
| 84 | +if __name__ == '__main__': |
| 85 | + from pyspark.context import SparkContext |
| 86 | + from pyspark.conf import SparkConf |
| 87 | + |
| 88 | + sc = SparkContext(conf=SparkConf().setAppName("mnist_inference")) |
| 89 | + executors = sc._conf.get("spark.executor.instances") |
| 90 | + num_executors = int(executors) if executors is not None else 1 |
| 91 | + |
| 92 | + parser = argparse.ArgumentParser() |
| 93 | + parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors) |
| 94 | + parser.add_argument('--images_labels', type=str, help='Directory for input images with labels') |
| 95 | + parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export") |
| 96 | + parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") |
| 97 | + args, _ = parser.parse_known_args() |
| 98 | + print("args: {}".format(args)) |
| 99 | + |
| 100 | + # Not using TFCluster... just running single-node TF instances on each executor |
| 101 | + nodes = list(range(args.cluster_size)) |
| 102 | + nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size) |
| 103 | + nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args)) |
0 commit comments