@@ -16,6 +16,7 @@ def print_log(worker_num, arg):
1616
1717def map_fun (args , ctx ):
1818 from datetime import datetime
19+ from tensorflowonspark import TFNode
1920 import math
2021 import os
2122 import tensorflow as tf
@@ -54,6 +55,27 @@ def _parse_tfr(example_proto):
5455 label = tf .to_float (features ['label' ])
5556 return (image , label )
5657
58+ def build_model (graph , x ):
59+ with graph .as_default ():
60+ # Variables of the hidden layer
61+ hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ],
62+ stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
63+ hid_b = tf .Variable (tf .zeros ([hidden_units ]), name = "hid_b" )
64+ tf .summary .histogram ("hidden_weights" , hid_w )
65+
66+ # Variables of the softmax layer
67+ sm_w = tf .Variable (tf .truncated_normal ([hidden_units , 10 ],
68+ stddev = 1.0 / math .sqrt (hidden_units )), name = "sm_w" )
69+ sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
70+ tf .summary .histogram ("softmax_weights" , sm_w )
71+
72+ hid_lin = tf .nn .xw_plus_b (x , hid_w , hid_b )
73+ hid = tf .nn .relu (hid_lin )
74+
75+ y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
76+ prediction = tf .argmax (y , 1 , name = "prediction" )
77+ return y , prediction
78+
5779 if job_name == "ps" :
5880 server .join ()
5981 elif job_name == "worker" :
@@ -78,36 +100,21 @@ def _parse_tfr(example_proto):
78100 iterator = ds .make_one_shot_iterator ()
79101 x , y_ = iterator .get_next ()
80102
81- # Variables of the hidden layer
82- hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ],
83- stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
84- hid_b = tf .Variable (tf .zeros ([hidden_units ]), name = "hid_b" )
85- tf .summary .histogram ("hidden_weights" , hid_w )
86-
87- # Variables of the softmax layer
88- sm_w = tf .Variable (tf .truncated_normal ([hidden_units , 10 ],
89- stddev = 1.0 / math .sqrt (hidden_units )), name = "sm_w" )
90- sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
91- tf .summary .histogram ("softmax_weights" , sm_w )
103+ # Build core model
104+ y , prediction = build_model (tf .get_default_graph (), x )
92105
106+ # Add training bits
93107 x_img = tf .reshape (x , [- 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
94108 tf .summary .image ("x_img" , x_img )
95109
96- hid_lin = tf .nn .xw_plus_b (x , hid_w , hid_b )
97- hid = tf .nn .relu (hid_lin )
98-
99- y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
100-
101110 global_step = tf .train .get_or_create_global_step ()
102111
103112 loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
104113 tf .summary .scalar ("loss" , loss )
105114 train_op = tf .train .AdagradOptimizer (0.01 ).minimize (
106115 loss , global_step = global_step )
107116
108- # Test trained model
109117 label = tf .argmax (y_ , 1 , name = "label" )
110- prediction = tf .argmax (y , 1 , name = "prediction" )
111118 correct_prediction = tf .equal (prediction , label )
112119 accuracy = tf .reduce_mean (tf .cast (correct_prediction , tf .float32 ), name = "accuracy" )
113120 tf .summary .scalar ("acc" , accuracy )
@@ -117,8 +124,10 @@ def _parse_tfr(example_proto):
117124 init_op = tf .global_variables_initializer ()
118125
119126 # Create a "supervisor", which oversees the training process and stores model state into HDFS
120- logdir = ctx .absolute_path (args .model )
121- print ("tensorflow model path: {0}" .format (logdir ))
127+ model_dir = ctx .absolute_path (args .model )
128+ export_dir = ctx .absolute_path (args .export )
129+ print ("tensorflow model path: {0}" .format (model_dir ))
130+ print ("tensorflow export path: {0}" .format (export_dir ))
122131 summary_writer = tf .summary .FileWriter ("tensorboard_%d" % worker_num , graph = tf .get_default_graph ())
123132
124133 if args .mode == 'inference' :
@@ -130,7 +139,7 @@ def _parse_tfr(example_proto):
130139 with tf .train .MonitoredTrainingSession (master = server .target ,
131140 is_chief = (task_index == 0 ),
132141 scaffold = tf .train .Scaffold (init_op = init_op , summary_op = summary_op , saver = saver ),
133- checkpoint_dir = logdir ,
142+ checkpoint_dir = model_dir ,
134143 hooks = [tf .train .StopAtStepHook (last_step = args .steps )]) as sess :
135144 print ("{} session ready" .format (datetime .now ().isoformat ()))
136145
@@ -163,6 +172,41 @@ def _parse_tfr(example_proto):
163172
164173 print ("{} stopping MonitoredTrainingSession" .format (datetime .now ().isoformat ()))
165174
175+ # export model (on chief worker only)
176+ if args .mode == "train" and task_index == 0 :
177+ tf .reset_default_graph ()
178+
179+ # add placeholders for input images (and optional labels)
180+ x = tf .placeholder (tf .float32 , [None , IMAGE_PIXELS * IMAGE_PIXELS ], name = 'x' )
181+ y_ = tf .placeholder (tf .float32 , [None , 10 ], name = 'y_' )
182+ label = tf .argmax (y_ , 1 , name = "label" )
183+
184+ # add core model
185+ y , prediction = build_model (tf .get_default_graph (), x )
186+
187+ # restore from last checkpoint
188+ saver = tf .train .Saver ()
189+ with tf .Session () as sess :
190+ ckpt = tf .train .get_checkpoint_state (model_dir )
191+ print ("ckpt: {}" .format (ckpt ))
192+ assert ckpt , "Invalid model checkpoint path: {}" .format (model_dir )
193+ saver .restore (sess , ckpt .model_checkpoint_path )
194+
195+ print ("Exporting saved_model to: {}" .format (export_dir ))
196+ # exported signatures defined in code
197+ signatures = {
198+ tf .saved_model .signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY : {
199+ 'inputs' : { 'image' : x },
200+ 'outputs' : { 'prediction' : prediction },
201+ 'method_name' : tf .saved_model .signature_constants .PREDICT_METHOD_NAME
202+ }
203+ }
204+ TFNode .export_saved_model (sess ,
205+ export_dir ,
206+ tf .saved_model .tag_constants .SERVING ,
207+ signatures )
208+ print ("Exported saved_model" )
209+
166210 # WORKAROUND for https://github.com/tensorflow/tensorflow/issues/21745
167211 # wait for all other nodes to complete (via done files)
168212 done_dir = "{}/{}/done" .format (ctx .absolute_path (args .model ), args .mode )
0 commit comments