2121import tensorflow as tf
2222
2323
24- def inference (it , num_workers , args ):
25- from tensorflowonspark import util
26-
27- # consume worker number from RDD partition iterator
28- for i in it :
29- worker_num = i
30- print ("worker_num: {}" .format (i ))
31-
32- # setup env for single-node TF
33- util .single_node_env ()
24+ def inference (args , ctx ):
3425
3526 # load saved_model
3627 saved_model = tf .saved_model .load (args .export_dir , tags = 'serve' )
@@ -48,14 +39,14 @@ def parse_tfr(example_proto):
4839
4940 # define a new tf.data.Dataset (for inferencing)
5041 ds = tf .data .Dataset .list_files ("{}/part-*" .format (args .images_labels ))
51- ds = ds .shard (num_workers , worker_num )
42+ ds = ds .shard (ctx . num_workers , ctx . worker_num )
5243 ds = ds .interleave (tf .data .TFRecordDataset )
5344 ds = ds .map (parse_tfr )
5445 ds = ds .batch (10 )
5546
5647 # create an output file per spark worker for the predictions
5748 tf .io .gfile .makedirs (args .output )
58- output_file = tf .io .gfile .GFile ("{}/part-{:05d}" .format (args .output , worker_num ), mode = 'w' )
49+ output_file = tf .io .gfile .GFile ("{}/part-{:05d}" .format (args .output , ctx . worker_num ), mode = 'w' )
5950
6051 for batch in ds :
6152 predictions = predict (conv2d_input = batch [0 ])
@@ -70,6 +61,7 @@ def parse_tfr(example_proto):
7061if __name__ == '__main__' :
7162 from pyspark .context import SparkContext
7263 from pyspark .conf import SparkConf
64+ from tensorflowonspark import TFParallel
7365
7466 sc = SparkContext (conf = SparkConf ().setAppName ("mnist_inference" ))
7567 executors = sc ._conf .get ("spark.executor.instances" )
@@ -83,7 +75,5 @@ def parse_tfr(example_proto):
8375 args , _ = parser .parse_known_args ()
8476 print ("args: {}" .format (args ))
8577
86- # Not using TFCluster... just running single-node TF instances on each executor
87- nodes = list (range (args .cluster_size ))
88- nodeRDD = sc .parallelize (list (range (args .cluster_size )), args .cluster_size )
89- nodeRDD .foreachPartition (lambda worker_num : inference (worker_num , args .cluster_size , args ))
78+ # Running single-node TF instances on each executor
79+ TFParallel .run (sc , inference , args , args .cluster_size )
0 commit comments