99from __future__ import nested_scopes
1010from __future__ import print_function
1111
12+ from datetime import datetime
13+ import tensorflow as tf
14+ from tensorflowonspark import TFNode
15+
1216
1317def print_log (worker_num , arg ):
1418 print ("{0}: {1}" .format (worker_num , arg ))
1519
1620
21+ class ExportHook (tf .train .SessionRunHook ):
22+ def __init__ (self , export_dir , input_tensor , output_tensor ):
23+ self .export_dir = export_dir
24+ self .input_tensor = input_tensor
25+ self .output_tensor = output_tensor
26+
27+ def end (self , session ):
28+ print ("{} ======= Exporting to: {}" .format (datetime .now ().isoformat (), self .export_dir ))
29+ signatures = {
30+ tf .saved_model .signature_constants .DEFAULT_SERVING_SIGNATURE_DEF_KEY : {
31+ 'inputs' : {'image' : self .input_tensor },
32+ 'outputs' : {'prediction' : self .output_tensor },
33+ 'method_name' : tf .saved_model .signature_constants .PREDICT_METHOD_NAME
34+ }
35+ }
36+ TFNode .export_saved_model (session ,
37+ self .export_dir ,
38+ tf .saved_model .tag_constants .SERVING ,
39+ signatures )
40+ print ("{} ======= Done exporting" .format (datetime .now ().isoformat ()))
41+
42+
1743def map_fun (args , ctx ):
18- from datetime import datetime
1944 import math
2045 import numpy
21- import tensorflow as tf
2246 import time
2347
2448 worker_num = ctx .worker_num
@@ -105,7 +129,6 @@ def feed_dict(batch):
105129
106130 logdir = ctx .absolute_path (args .model )
107131 print ("tensorflow model path: {0}" .format (logdir ))
108- hooks = [tf .train .StopAtStepHook (last_step = 100000 )]
109132
110133 if job_name == "worker" and task_index == 0 :
111134 summary_writer = tf .summary .FileWriter (logdir , graph = tf .get_default_graph ())
@@ -115,11 +138,11 @@ def feed_dict(batch):
115138 with tf .train .MonitoredTrainingSession (master = server .target ,
116139 is_chief = (task_index == 0 ),
117140 checkpoint_dir = logdir ,
118- hooks = hooks ) as mon_sess :
119-
141+ hooks = [ tf . train . StopAtStepHook ( last_step = args . steps )],
142+ chief_only_hooks = [ ExportHook ( ctx . absolute_path ( args . export_dir ), x , prediction )]) as mon_sess :
120143 step = 0
121144 tf_feed = ctx .get_data_feed (args .mode == "train" )
122- while not mon_sess .should_stop () and not tf_feed .should_stop () and step < args . steps :
145+ while not mon_sess .should_stop () and not tf_feed .should_stop ():
123146 # Run a training step asynchronously
124147 # See `tf.train.SyncReplicasOptimizer` for additional details on how to
125148 # perform *synchronous* training.
0 commit comments