1- #Copyright 2018 Yahoo Inc.
1+ # Copyright 2018 Yahoo Inc.
22# Licensed under the terms of the Apache 2.0 license.
33# Please see LICENSE file in the project root for terms.
44
99from __future__ import nested_scopes
1010from __future__ import print_function
1111
12+
1213def print_log (worker_num , arg ):
1314 print ("{0}: {1}" .format (worker_num , arg ))
1415
16+
1517def map_fun (args , ctx ):
1618 from datetime import datetime
1719 import math
@@ -30,7 +32,7 @@ def map_fun(args, ctx):
3032 # Parameters
3133 IMAGE_PIXELS = 28
3234 hidden_units = 128
33- batch_size = args .batch_size
35+ batch_size = args .batch_size
3436
3537 # Get TF cluster and server instances
3638 cluster , server = ctx .start_cluster_server (1 , args .rdma )
@@ -55,28 +57,28 @@ def feed_dict(batch):
5557
5658 # Assigns ops to the local worker by default.
5759 with tf .device (tf .train .replica_device_setter (
58- worker_device = "/job:worker/task:%d" % task_index ,
59- cluster = cluster )):
60+ worker_device = "/job:worker/task:%d" % task_index ,
61+ cluster = cluster )):
6062
61- # Placeholders or QueueRunner/Readers for input data
63+ # Placeholders or QueueRunner/Readers for input data
6264 with tf .name_scope ('inputs' ):
63- x = tf .placeholder (tf .float32 , [None , IMAGE_PIXELS * IMAGE_PIXELS ] , name = "x" )
65+ x = tf .placeholder (tf .float32 , [None , IMAGE_PIXELS * IMAGE_PIXELS ], name = "x" )
6466 y_ = tf .placeholder (tf .float32 , [None , 10 ], name = "y_" )
65-
67+
6668 x_img = tf .reshape (x , [- 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
6769 tf .summary .image ("x_img" , x_img )
68-
70+
6971 with tf .name_scope ('layer' ):
7072 # Variables of the hidden layer
7173 with tf .name_scope ('hidden_layer' ):
72- hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ], stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
74+ hid_w = tf .Variable (tf .truncated_normal ([IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ], stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
7375 hid_b = tf .Variable (tf .zeros ([hidden_units ]), name = "hid_b" )
7476 tf .summary .histogram ("hidden_weights" , hid_w )
7577 hid_lin = tf .nn .xw_plus_b (x , hid_w , hid_b )
7678 hid = tf .nn .relu (hid_lin )
77-
79+
7880 # Variables of the softmax layer
79- with tf .name_scope ('softmax_layer' ):
81+ with tf .name_scope ('softmax_layer' ):
8082 sm_w = tf .Variable (tf .truncated_normal ([hidden_units , 10 ], stddev = 1.0 / math .sqrt (hidden_units )), name = "sm_w" )
8183 sm_b = tf .Variable (tf .zeros ([10 ]), name = "sm_b" )
8284 tf .summary .histogram ("softmax_weights" , sm_w )
@@ -93,7 +95,7 @@ def feed_dict(batch):
9395
9496 # Test trained model
9597 label = tf .argmax (y_ , 1 , name = "label" )
96- prediction = tf .argmax (y , 1 ,name = "prediction" )
98+ prediction = tf .argmax (y , 1 , name = "prediction" )
9799 correct_prediction = tf .equal (prediction , label )
98100
99101 accuracy = tf .reduce_mean (tf .cast (correct_prediction , tf .float32 ), name = "accuracy" )
@@ -102,10 +104,9 @@ def feed_dict(batch):
102104 summary_op = tf .summary .merge_all ()
103105
104106 logdir = ctx .absolute_path (args .model )
105- # logdir = args.model
106107 print ("tensorflow model path: {0}" .format (logdir ))
107108 hooks = [tf .train .StopAtStepHook (last_step = 100000 )]
108-
109+
109110 if job_name == "worker" and task_index == 0 :
110111 summary_writer = tf .summary .FileWriter (logdir , graph = tf .get_default_graph ())
111112
@@ -119,9 +120,9 @@ def feed_dict(batch):
119120 step = 0
120121 tf_feed = ctx .get_data_feed (args .mode == "train" )
121122 while not mon_sess .should_stop () and not tf_feed .should_stop () and step < args .steps :
122- # Run a training step asynchronously
123- # See `tf.train.SyncReplicasOptimizer` for additional details on how to
124- # perform *synchronous* training.
123+ # Run a training step asynchronously
124+ # See `tf.train.SyncReplicasOptimizer` for additional details on how to
125+ # perform *synchronous* training.
125126
126127 # using feed_dict
127128 batch_xs , batch_ys = feed_dict (tf_feed .next_batch (batch_size ))
@@ -132,14 +133,14 @@ def feed_dict(batch):
132133 _ , summary , step = mon_sess .run ([train_op , summary_op , global_step ], feed_dict = feed )
133134 # print accuracy and save model checkpoint to HDFS every 100 steps
134135 if (step % 100 == 0 ):
135- print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , mon_sess .run (accuracy ,{x : batch_xs , y_ : batch_ys })))
136+ print ("{0} step: {1} accuracy: {2}" .format (datetime .now ().isoformat (), step , mon_sess .run (accuracy , {x : batch_xs , y_ : batch_ys })))
136137
137138 if task_index == 0 :
138139 summary_writer .add_summary (summary , step )
139140 else : # args.mode == "inference"
140141 labels , preds , acc = mon_sess .run ([label , prediction , accuracy ], feed_dict = feed )
141142
142- results = ["{0} Label: {1}, Prediction: {2}" .format (datetime .now ().isoformat (), l , p ) for l ,p in zip (labels ,preds )]
143+ results = ["{0} Label: {1}, Prediction: {2}" .format (datetime .now ().isoformat (), l , p ) for l , p in zip (labels , preds )]
143144 tf_feed .batch_results (results )
144145 print ("results: {0}, acc: {1}" .format (results , acc ))
145146
@@ -148,4 +149,6 @@ def feed_dict(batch):
148149
149150 # Ask for all the services to stop.
150151 print ("{0} stopping MonitoredTrainingSession" .format (datetime .now ().isoformat ()))
151- summary_writer .close ()
152+
153+ if job_name == "worker" and task_index == 0 :
154+ summary_writer .close ()
0 commit comments