1- # Copyright 2017 Yahoo Inc.
1+ #Copyright 2018 Yahoo Inc.
22# Licensed under the terms of the Apache 2.0 license.
33# Please see LICENSE file in the project root for terms.
44
@@ -58,37 +58,38 @@ def feed_dict(batch):
5858 worker_device = "/job:worker/task:%d" % task_index ,
5959 cluster = cluster )):
6060
61- # Variables of the hidden layer
62- hid_w = tf .Variable ( tf . truncated_normal ([ IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ],
63- stddev = 1.0 / IMAGE_PIXELS ) , name = "hid_w " )
64- hid_b = tf .Variable (tf .zeros ([ hidden_units ]) , name = "hid_b " )
65- tf . summary . histogram ( "hidden_weights" , hid_w )
66-
67- # Variables of the softmax layer
68- sm_w = tf . Variable ( tf . truncated_normal ([ hidden_units , 10 ],
69- stddev = 1.0 / math . sqrt ( hidden_units )), name = "sm_w" )
70- sm_b = tf . Variable ( tf . zeros ([ 10 ]), name = "sm_b" )
71- tf .summary . histogram ( "softmax_weights" , sm_w )
72-
73- # Placeholders or QueueRunner/Readers for input data
74- x = tf .placeholder ( tf . float32 , [ None , IMAGE_PIXELS * IMAGE_PIXELS ], name = "x" )
75- y_ = tf .placeholder ( tf . float32 , [ None , 10 ], name = "y_" )
76-
77- x_img = tf . reshape ( x , [ - 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
78- tf . summary . image ( "x_img" , x_img )
79-
80- hid_lin = tf .nn . xw_plus_b ( x , hid_w , hid_b )
81- hid = tf .nn . relu ( hid_lin )
82-
83- y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
61+ # Placeholders or QueueRunner/Readers for input data
62+ with tf .name_scope ( 'inputs' ):
63+ x = tf . placeholder ( tf . float32 , [ None , IMAGE_PIXELS * IMAGE_PIXELS ] , name = "x " )
64+ y_ = tf .placeholder (tf .float32 , [ None , 10 ] , name = "y_ " )
65+
66+ x_img = tf . reshape ( x , [ - 1 , IMAGE_PIXELS , IMAGE_PIXELS , 1 ])
67+ tf . summary . image ( "x_img" , x_img )
68+
69+ with tf . name_scope ( 'layer' ):
70+ # Variables of the hidden layer
71+ with tf .name_scope ( 'hidden_layer' ):
72+ hid_w = tf . Variable ( tf . truncated_normal ([ IMAGE_PIXELS * IMAGE_PIXELS , hidden_units ], stddev = 1.0 / IMAGE_PIXELS ), name = "hid_w" )
73+ hid_b = tf . Variable ( tf . zeros ([ hidden_units ]), name = "hid_b" )
74+ tf .summary . histogram ( "hidden_weights" , hid_w )
75+ hid_lin = tf .nn . xw_plus_b ( x , hid_w , hid_b )
76+ hid = tf . nn . relu ( hid_lin )
77+
78+ # Variables of the softmax layer
79+ with tf . name_scope ( 'softmax_layer' ):
80+ sm_w = tf .Variable ( tf . truncated_normal ([ hidden_units , 10 ], stddev = 1.0 / math . sqrt ( hidden_units )), name = "sm_w" )
81+ sm_b = tf .Variable ( tf . zeros ([ 10 ]), name = "sm_b" )
82+ tf . summary . histogram ( "softmax_weights" , sm_w )
83+ y = tf .nn .softmax (tf .nn .xw_plus_b (hid , sm_w , sm_b ))
8484
8585 global_step = tf .train .get_or_create_global_step ()
8686
87- loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
88- tf .summary .scalar ("loss" , loss )
87+ with tf .name_scope ('loss' ):
88+ loss = - tf .reduce_sum (y_ * tf .log (tf .clip_by_value (y , 1e-10 , 1.0 )))
89+ tf .summary .scalar ("loss" , loss )
8990
90- train_op = tf .train . AdagradOptimizer ( 0.01 ). minimize (
91- loss , global_step = global_step )
91+ with tf .name_scope ( 'train' ):
92+ train_op = tf . train . AdagradOptimizer ( 0.01 ). minimize ( loss , global_step = global_step )
9293
9394 # Test trained model
9495 label = tf .argmax (y_ , 1 , name = "label" )
@@ -100,39 +101,27 @@ def feed_dict(batch):
100101
101102 summary_op = tf .summary .merge_all ()
102103
103- # Create a "MonitoredTrainingSession", which oversees the training process and stores model state into HDFS
104104 logdir = ctx .absolute_path (args .model )
105+ # logdir = args.model
105106 print ("tensorflow model path: {0}" .format (logdir ))
106107 hooks = [tf .train .StopAtStepHook (last_step = 100000 )]
107108
108109 if job_name == "worker" and task_index == 0 :
109110 summary_writer = tf .summary .FileWriter (logdir , graph = tf .get_default_graph ())
110111
111- if args .mode == "train" :
112- with tf .train .MonitoredTrainingSession (master = server .target ,
113- is_chief = (task_index == 0 ),
114- checkpoint_dir = logdir ,
115- hooks = hooks ,
116- ) as mon_sess :
117- else :
118- sv = tf .train .Supervisor (is_chief = (task_index == 0 ),
119- logdir = logdir ,
120- summary_op = None ,
121- saver = saver ,
122- global_step = global_step ,
123- stop_grace_secs = 300 ,
124- save_model_secs = 0 )
125-
126112 # The MonitoredTrainingSession takes care of session initialization, restoring from
127113 # a checkpoint, and closing when done or an error occurs
114+ with tf .train .MonitoredTrainingSession (master = server .target ,
115+ is_chief = (task_index == 0 ),
116+ checkpoint_dir = logdir ,
117+ hooks = hooks ) as mon_sess :
128118
129- # Loop until the supervisor shuts down or 1000000 steps have completed.
130119 step = 0
131120 tf_feed = ctx .get_data_feed (args .mode == "train" )
132121 while not mon_sess .should_stop () and not tf_feed .should_stop () and step < args .steps :
133- # Run a training step asynchronously.
134- # See `tf.train.SyncReplicasOptimizer` for additional details on how to
135- # perform *synchronous* training.
122+ # Run a training step asynchronously
123+ # See `tf.train.SyncReplicasOptimizer` for additional details on how to
124+ # perform *synchronous* training.
136125
137126 # using feed_dict
138127 batch_xs , batch_ys = feed_dict (tf_feed .next_batch (batch_size ))
@@ -152,12 +141,11 @@ def feed_dict(batch):
152141
153142 results = ["{0} Label: {1}, Prediction: {2}" .format (datetime .now ().isoformat (), l , p ) for l ,p in zip (labels ,preds )]
154143 tf_feed .batch_results (results )
155- print ("acc : {0}" .format (acc ))
144+ print ("results : {0}, acc: {1} " .format (results , acc ))
156145
157146 if mon_sess .should_stop () or step >= args .steps :
158147 tf_feed .terminate ()
159148
160149 # Ask for all the services to stop.
161150 print ("{0} stopping MonitoredTrainingSession" .format (datetime .now ().isoformat ()))
162-
163-
151+ summary_writer .close ()
0 commit comments