@@ -82,7 +82,7 @@ def feed_dict(batch):
8282
8383 y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
8484
85- global_step = tf .Variable ( 0 )
85+ global_step = tf .train . get_or_create_global_step ( )
8686
8787 loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
8888 tf .summary .scalar ("loss" , loss )
@@ -98,27 +98,22 @@ def feed_dict(batch):
9898 accuracy = tf .reduce_mean (tf .cast (correct_prediction , tf .float32 ), name = "accuracy" )
9999 tf .summary .scalar ("acc" , accuracy )
100100
101- saver = tf .train .Saver ()
102101 summary_op = tf .summary .merge_all ()
103- init_op = tf .global_variables_initializer ()
104102
105- # Create a "supervisor ", which oversees the training process and stores model state into HDFS
103+ # Create a "MonitoredTrainingSession ", which oversees the training process and stores model state into HDFS
106104 logdir = ctx .absolute_path (args .model )
107105 print ("tensorflow model path: {0}" .format (logdir ))
108-
106+ hooks = [tf .train .StopAtStepHook (last_step = 100000 )]
107+
109108 if job_name == "worker" and task_index == 0 :
110109 summary_writer = tf .summary .FileWriter (logdir , graph = tf .get_default_graph ())
111110
112111 if args .mode == "train" :
113- sv = tf .train .Supervisor (is_chief = (task_index == 0 ),
114- logdir = logdir ,
115- init_op = init_op ,
116- summary_op = None ,
117- summary_writer = None ,
118- saver = saver ,
119- global_step = global_step ,
120- stop_grace_secs = 300 ,
121- save_model_secs = 10 )
112+ with tf .train .MonitoredTrainingSession (master = server .target ,
113+ is_chief = (task_index == 0 ),
114+ checkpoint_dir = logdir ,
115+ hooks = hooks ,
116+ ) as mon_sess :
122117 else :
123118 sv = tf .train .Supervisor (is_chief = (task_index == 0 ),
124119 logdir = logdir ,
@@ -128,15 +123,13 @@ def feed_dict(batch):
128123 stop_grace_secs = 300 ,
129124 save_model_secs = 0 )
130125
131- # The supervisor takes care of session initialization, restoring from
132- # a checkpoint, and closing when done or an error occurs.
133- with sv .managed_session (server .target ) as sess :
134- print ("{0} session ready" .format (datetime .now ().isoformat ()))
126+ # The MonitoredTrainingSession takes care of session initialization, restoring from
127+ # a checkpoint, and closing when done or an error occurs
135128
136129 # Loop until the supervisor shuts down or 1000000 steps have completed.
137130 step = 0
138131 tf_feed = ctx .get_data_feed (args .mode == "train" )
139- while not sv .should_stop () and not tf_feed .should_stop () and step < args .steps :
132+ while not mon_sess .should_stop () and not tf_feed .should_stop () and step < args .steps :
140133 # Run a training step asynchronously.
141134 # See `tf.train.SyncReplicasOptimizer` for additional details on how to
142135 # perform *synchronous* training.
@@ -147,24 +140,24 @@ def feed_dict(batch):
147140
148141 if len (batch_xs ) > 0 :
149142 if args .mode == "train" :
150- _ , summary , step = sess .run ([train_op , summary_op , global_step ], feed_dict = feed )
143+ _ , summary , step = mon_sess .run ([train_op , summary_op , global_step ], feed_dict = feed )
151144 # print accuracy and save model checkpoint to HDFS every 100 steps
152145 if (step % 100 == 0 ):
153- print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , sess .run (accuracy ,{x : batch_xs , y_ : batch_ys })))
146+ print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , mon_sess .run (accuracy ,{x : batch_xs , y_ : batch_ys })))
154147
155- if sv . is_chief :
148+ if task_index == 0 :
156149 summary_writer .add_summary (summary , step )
157150 else : # args.mode == "inference"
158- labels , preds , acc = sess .run ([label , prediction , accuracy ], feed_dict = feed )
151+ labels , preds , acc = mon_sess .run ([label , prediction , accuracy ], feed_dict = feed )
159152
160153 results = ["{0} Label: {1}, Prediction: {2}" .format (datetime .now ().isoformat (), l , p ) for l ,p in zip (labels ,preds )]
161154 tf_feed .batch_results (results )
162155 print ("acc: {0}" .format (acc ))
163156
164- if sv .should_stop () or step >= args .steps :
157+ if mon_sess .should_stop () or step >= args .steps :
165158 tf_feed .terminate ()
166159
167160 # Ask for all the services to stop.
168- print ("{0} stopping supervisor " .format (datetime .now ().isoformat ()))
169- sv . stop ()
161+ print ("{0} stopping MonitoredTrainingSession " .format (datetime .now ().isoformat ()))
162+
170163
0 commit comments