@@ -39,46 +39,43 @@ def main_fun(args, ctx):
3939 estimator = tf .keras .estimator .model_to_estimator (model , model_dir = args .model_dir )
4040
4141 # setup train_input_fn for InputMode.TENSORFLOW or InputMode.SPARK
42- if args .input_mode == 'tf' :
43- train_input_fn = tf .estimator .inputs .numpy_input_fn (
44- x = {"dense_1_input" : x_train },
45- y = y_train ,
46- batch_size = 128 ,
47- num_epochs = None ,
48- shuffle = True )
49- else : # 'spark'
50- tf_feed = TFNode .DataFeed (ctx .mgr )
51-
52- def rdd_generator ():
53- while not tf_feed .should_stop ():
54- batch = tf_feed .next_batch (1 )
55- if len (batch ) > 0 :
56- record = batch [0 ]
57- image = numpy .array (record [0 ]).astype (numpy .float32 ) / 255.0
58- label = numpy .array (record [1 ]).astype (numpy .float32 )
59- yield (image , label )
60-
61- def train_input_fn ():
62- ds = tf .data .Dataset .from_generator (rdd_generator ,
63- (tf .float32 , tf .float32 ),
64- (tf .TensorShape ([IMAGE_PIXELS * IMAGE_PIXELS ]), tf .TensorShape ([10 ])))
65- ds = ds .batch (args .batch_size )
66- return ds
67-
68- # eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
69- eval_input_fn = tf .estimator .inputs .numpy_input_fn (
70- x = {"dense_1_input" : x_test },
71- y = y_test ,
72- num_epochs = args .epochs ,
73- shuffle = False )
74-
75- # setup tf.estimator.train_and_evaluate()
76- train_spec = tf .estimator .TrainSpec (input_fn = train_input_fn , max_steps = args .steps )
77- eval_spec = tf .estimator .EvalSpec (input_fn = eval_input_fn )
78- tf .estimator .train_and_evaluate (estimator , train_spec , eval_spec )
79-
80- # export a saved_model, if export_dir provided
81- if args .export_dir :
42+ if args .mode == 'train' :
43+ if args .input_mode == 'tf' :
44+ # For InputMode.TENSORFLOW, just use data in memory
45+ train_input_fn = tf .estimator .inputs .numpy_input_fn (
46+ x = {"dense_1_input" : x_train },
47+ y = y_train ,
48+ batch_size = 128 ,
49+ num_epochs = None ,
50+ shuffle = True )
51+ else : # 'spark'
52+ # For InputMode.SPARK, read data from RDD
53+ tf_feed = TFNode .DataFeed (ctx .mgr )
54+
55+ def rdd_generator ():
56+ while not tf_feed .should_stop ():
57+ batch = tf_feed .next_batch (1 )
58+ if len (batch ) > 0 :
59+ record = batch [0 ]
60+ image = numpy .array (record [0 ]).astype (numpy .float32 ) / 255.0
61+ label = numpy .array (record [1 ]).astype (numpy .float32 )
62+ yield (image , label )
63+
64+ def train_input_fn ():
65+ ds = tf .data .Dataset .from_generator (rdd_generator ,
66+ (tf .float32 , tf .float32 ),
67+ (tf .TensorShape ([IMAGE_PIXELS * IMAGE_PIXELS ]), tf .TensorShape ([10 ])))
68+ ds = ds .batch (args .batch_size )
69+ return ds
70+
71+ # eval_input_fn ALWAYS uses data loaded in memory, since InputMode.SPARK can only feed one RDD at a time
72+ eval_input_fn = tf .estimator .inputs .numpy_input_fn (
73+ x = {"dense_1_input" : x_test },
74+ y = y_test ,
75+ num_epochs = args .epochs ,
76+ shuffle = False )
77+
78+ # serving_input_receiver_fn ALWAYS expects serialized TFExamples in a placeholder.
8279 def serving_input_receiver_fn ():
8380 """An input receiver that expects a serialized tf.Example."""
8481 serialized_tf_example = tf .placeholder (dtype = tf .string ,
@@ -89,7 +86,35 @@ def serving_input_receiver_fn():
8986 features = tf .parse_example (serialized_tf_example , feature_spec )
9087 return tf .estimator .export .ServingInputReceiver (features , receiver_tensors )
9188
92- estimator .export_savedmodel (args .export_dir , serving_input_receiver_fn )
89+ # setup tf.estimator.train_and_evaluate() w/ FinalExporter
90+ exporter = tf .estimator .FinalExporter ("serving" , serving_input_receiver_fn = serving_input_receiver_fn )
91+ train_spec = tf .estimator .TrainSpec (input_fn = train_input_fn , max_steps = args .steps )
92+ eval_spec = tf .estimator .EvalSpec (input_fn = eval_input_fn , exporters = exporter )
93+ tf .estimator .train_and_evaluate (estimator , train_spec , eval_spec )
94+
95+ else : # mode == 'inference'
96+ if args .input_mode == 'spark' :
97+ tf_feed = TFNode .DataFeed (ctx .mgr )
98+
99+ def rdd_generator ():
100+ while not tf_feed .should_stop ():
101+ batch = tf_feed .next_batch (1 )
102+ if len (batch ) > 0 :
103+ record = batch [0 ]
104+ image = numpy .array (record [0 ]).astype (numpy .float32 ) / 255.0
105+ label = numpy .array (record [1 ]).astype (numpy .float32 )
106+ yield (image , label )
107+
108+ def predict_input_fn ():
109+ ds = tf .data .Dataset .from_generator (rdd_generator ,
110+ (tf .float32 , tf .float32 ),
111+ (tf .TensorShape ([IMAGE_PIXELS * IMAGE_PIXELS ]), tf .TensorShape ([10 ])))
112+ ds = ds .batch (args .batch_size )
113+ return ds
114+
115+ predictions = estimator .predict (predict_input_fn )
116+ for result in predictions :
117+ tf_feed .batch_results ([result ])
93118
94119
95120if __name__ == '__main__' :
@@ -112,6 +137,8 @@ def serving_input_receiver_fn():
112137 parser .add_argument ("--input_mode" , help = "input mode (tf|spark)" , default = "tf" )
113138 parser .add_argument ("--labels" , help = "HDFS path to MNIST labels in parallelized CSV format" )
114139 parser .add_argument ("--model_dir" , help = "directory to write model checkpoints" )
140+ parser .add_argument ("--mode" , help = "(train|inference" )
141+ parser .add_argument ("--output" , help = "HDFS path to save test/inference output" , default = "predictions" )
115142 parser .add_argument ("--num_ps" , help = "number of ps nodes" , type = int , default = 1 )
116143 parser .add_argument ("--steps" , help = "max number of steps to train" , type = int , default = 2000 )
117144 parser .add_argument ("--tensorboard" , help = "launch tensorboard process" , action = "store_true" )
@@ -120,14 +147,22 @@ def serving_input_receiver_fn():
120147 print ("args:" , args )
121148
122149 if args .input_mode == 'tf' :
123- # for TENSORFLOW mode, each node will load/train entire dataset in memory per original example
150+ # for TENSORFLOW mode, each node will load/train/infer entire dataset in memory per original example
124151 cluster = TFCluster .run (sc , main_fun , args , args .cluster_size , args .num_ps , args .tensorboard , TFCluster .InputMode .TENSORFLOW , log_dir = args .model_dir , master_node = 'master' )
125152 cluster .shutdown ()
126153 else : # 'spark'
127154 # for SPARK mode, just use CSV format as an example
128155 images = sc .textFile (args .images ).map (lambda ln : [float (x ) for x in ln .split (',' )])
129156 labels = sc .textFile (args .labels ).map (lambda ln : [float (x ) for x in ln .split (',' )])
130157 dataRDD = images .zip (labels )
131- cluster = TFCluster .run (sc , main_fun , args , args .cluster_size , args .num_ps , args .tensorboard , TFCluster .InputMode .SPARK , log_dir = args .model_dir , master_node = 'master' )
132- cluster .train (dataRDD , args .epochs )
133- cluster .shutdown ()
158+ if args .mode == 'train' :
159+ cluster = TFCluster .run (sc , main_fun , args , args .cluster_size , args .num_ps , args .tensorboard , TFCluster .InputMode .SPARK , log_dir = args .model_dir , master_node = 'master' )
160+ cluster .train (dataRDD , args .epochs )
161+ cluster .shutdown ()
162+ else :
163+ # Note: using "parallel" inferencing, not "cluster"
164+ # each node loads the model and runs independently of others
165+ cluster = TFCluster .run (sc , main_fun , args , args .cluster_size , 0 , args .tensorboard , TFCluster .InputMode .SPARK , log_dir = args .model_dir )
166+ resultRDD = cluster .inference (dataRDD )
167+ resultRDD .saveAsTextFile (args .output )
168+ cluster .shutdown ()
0 commit comments