1- # Copyright 2017 Yahoo Inc.
1+ #Copyright 2018 Yahoo Inc.
22# Licensed under the terms of the Apache 2.0 license.
33# Please see LICENSE file in the project root for terms.
44
@@ -58,37 +58,38 @@ def feed_dict(batch):
5858 worker_device = "/job:worker/task:%d" % task_index ,
5959 cluster = cluster )):
6060
61- # Variables of the hidden layer
62- hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ],
63- stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
64- hid_b = tf .Variable (tf .zeros ([hidden_units ]), name = "hid_b" )
65- tf .summary .histogram ("hidden_weights" , hid_w )
66-
67- # Variables of the softmax layer
68- sm_w = tf .Variable (tf .truncated_normal ([hidden_units , 10 ],
69- stddev = 1.0 / math .sqrt (hidden_units )), name = "sm_w" )
70- sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
71- tf .summary .histogram ("softmax_weights" , sm_w )
72-
73- # Placeholders or QueueRunner/Readers for input data
74- x = tf .placeholder (tf .float32 , [None , IMAGE_PIXELS * IMAGE_PIXELS ], name = "x" )
75- y_ = tf .placeholder (tf .float32 , [None , 10 ], name = "y_" )
76-
77- x_img = tf .reshape (x , [- 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
78- tf .summary .image ("x_img" , x_img )
79-
80- hid_lin = tf .nn .xw_plus_b (x , hid_w , hid_b )
81- hid = tf .nn .relu (hid_lin )
82-
83- y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
84-
85- global_step = tf .Variable (0 )
86-
87- loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
88- tf .summary .scalar ("loss" , loss )
89-
90- train_op = tf .train .AdagradOptimizer (0.01 ).minimize (
91- loss , global_step = global_step )
61+ # Placeholders or QueueRunner/Readers for input data
62+ with tf .name_scope ('inputs' ):
63+ x = tf .placeholder (tf .float32 , [None , IMAGE_PIXELS * IMAGE_PIXELS ] , name = "x" )
64+ y_ = tf .placeholder (tf .float32 , [None , 10 ], name = "y_" )
65+
66+ x_img = tf .reshape (x , [- 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
67+ tf .summary .image ("x_img" , x_img )
68+
69+ with tf .name_scope ('layer' ):
70+ # Variables of the hidden layer
71+ with tf .name_scope ('hidden_layer' ):
72+ hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ], stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
73+ hid_b = tf .Variable (tf .zeros ([hidden_units ]), name = "hid_b" )
74+ tf .summary .histogram ("hidden_weights" , hid_w )
75+ hid_lin = tf .nn .xw_plus_b (x , hid_w , hid_b )
76+ hid = tf .nn .relu (hid_lin )
77+
78+ # Variables of the softmax layer
79+ with tf .name_scope ('softmax_layer' ):
80+ sm_w = tf .Variable (tf .truncated_normal ([hidden_units , 10 ], stddev = 1.0 / math .sqrt (hidden_units )), name = "sm_w" )
81+ sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
82+ tf .summary .histogram ("softmax_weights" , sm_w )
83+ y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
84+
85+ global_step = tf .train .get_or_create_global_step ()
86+
87+ with tf .name_scope ('loss' ):
88+ loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
89+ tf .summary .scalar ("loss" , loss )
90+
91+ with tf .name_scope ('train' ):
92+ train_op = tf .train .AdagradOptimizer (0.01 ).minimize (loss , global_step = global_step )
9293
9394 # Test trained model
9495 label = tf .argmax (y_ , 1 , name = "label" )
@@ -98,73 +99,53 @@ def feed_dict(batch):
9899 accuracy = tf .reduce_mean (tf .cast (correct_prediction , tf .float32 ), name = "accuracy" )
99100 tf .summary .scalar ("acc" , accuracy )
100101
101- saver = tf .train .Saver ()
102102 summary_op = tf .summary .merge_all ()
103- init_op = tf .global_variables_initializer ()
104103
105- # Create a "supervisor", which oversees the training process and stores model state into HDFS
106104 logdir = ctx .absolute_path (args .model )
105+ # logdir = args.model
107106 print ("tensorflow model path: {0}" .format (logdir ))
108-
107+ hooks = [tf .train .StopAtStepHook (last_step = 100000 )]
108+
109109 if job_name == "worker" and task_index == 0 :
110110 summary_writer = tf .summary .FileWriter (logdir , graph = tf .get_default_graph ())
111111
112- if args .mode == "train" :
113- sv = tf .train .Supervisor (is_chief = (task_index == 0 ),
114- logdir = logdir ,
115- init_op = init_op ,
116- summary_op = None ,
117- summary_writer = None ,
118- saver = saver ,
119- global_step = global_step ,
120- stop_grace_secs = 300 ,
121- save_model_secs = 10 )
122- else :
123- sv = tf .train .Supervisor (is_chief = (task_index == 0 ),
124- logdir = logdir ,
125- summary_op = None ,
126- saver = saver ,
127- global_step = global_step ,
128- stop_grace_secs = 300 ,
129- save_model_secs = 0 )
130-
131- # The supervisor takes care of session initialization, restoring from
132- # a checkpoint, and closing when done or an error occurs.
133- with sv .managed_session (server .target ) as sess :
134- print ("{0} session ready" .format (datetime .now ().isoformat ()))
135-
136- # Loop until the supervisor shuts down or 1000000 steps have completed.
112+ # The MonitoredTrainingSession takes care of session initialization, restoring from
113+ # a checkpoint, and closing when done or an error occurs
114+ with tf .train .MonitoredTrainingSession (master = server .target ,
115+ is_chief = (task_index == 0 ),
116+ checkpoint_dir = logdir ,
117+ hooks = hooks ) as mon_sess :
118+
137119 step = 0
138120 tf_feed = ctx .get_data_feed (args .mode == "train" )
139- while not sv .should_stop () and not tf_feed .should_stop () and step < args .steps :
140- # Run a training step asynchronously.
141- # See `tf.train.SyncReplicasOptimizer` for additional details on how to
142- # perform *synchronous* training.
121+ while not mon_sess .should_stop () and not tf_feed .should_stop () and step < args .steps :
122+ # Run a training step asynchronously
123+ # See `tf.train.SyncReplicasOptimizer` for additional details on how to
124+ # perform *synchronous* training.
143125
144126 # using feed_dict
145127 batch_xs , batch_ys = feed_dict (tf_feed .next_batch (batch_size ))
146128 feed = {x : batch_xs , y_ : batch_ys }
147129
148130 if len (batch_xs ) > 0 :
149131 if args .mode == "train" :
150- _ , summary , step = sess .run ([train_op , summary_op , global_step ], feed_dict = feed )
132+ _ , summary , step = mon_sess .run ([train_op , summary_op , global_step ], feed_dict = feed )
151133 # print accuracy and save model checkpoint to HDFS every 100 steps
152134 if (step % 100 == 0 ):
153- print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , sess .run (accuracy ,{x : batch_xs , y_ : batch_ys })))
135+ print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , mon_sess .run (accuracy ,{x : batch_xs , y_ : batch_ys })))
154136
155- if sv . is_chief :
137+ if task_index == 0 :
156138 summary_writer .add_summary (summary , step )
157139 else : # args.mode == "inference"
158- labels , preds , acc = sess .run ([label , prediction , accuracy ], feed_dict = feed )
140+ labels , preds , acc = mon_sess .run ([label , prediction , accuracy ], feed_dict = feed )
159141
160142 results = ["{0} Label: {1}, Prediction: {2}" .format (datetime .now ().isoformat (), l , p ) for l ,p in zip (labels ,preds )]
161143 tf_feed .batch_results (results )
162- print ("acc : {0}" .format (acc ))
144+ print ("results : {0}, acc: {1} " .format (results , acc ))
163145
164- if sv .should_stop () or step >= args .steps :
146+ if mon_sess .should_stop () or step >= args .steps :
165147 tf_feed .terminate ()
166148
167149 # Ask for all the services to stop.
168- print ("{0} stopping supervisor" .format (datetime .now ().isoformat ()))
169- sv .stop ()
170-
150+ print ("{0} stopping MonitoredTrainingSession" .format (datetime .now ().isoformat ()))
151+ summary_writer .close ()
0 commit comments