18
18
from __future__ import print_function
19
19
20
20
import argparse
21
- import os
22
21
import sys
23
22
24
23
import tensorflow as tf
24
+
25
25
from official .mnist import dataset
26
+ from official .utils .arg_parsers import parsers
27
+ from official .utils .logging import hooks_helper
26
28
29
+ LEARNING_RATE = 1e-4
27
30
28
31
class Model (tf .keras .Model ):
29
32
"""Model to recognize digits in the MNIST dataset.
@@ -104,7 +107,7 @@ def model_fn(features, labels, mode, params):
104
107
'classify' : tf .estimator .export .PredictOutput (predictions )
105
108
})
106
109
if mode == tf .estimator .ModeKeys .TRAIN :
107
- optimizer = tf .train .AdamOptimizer (learning_rate = 1e-4 )
110
+ optimizer = tf .train .AdamOptimizer (learning_rate = LEARNING_RATE )
108
111
109
112
# If we are running multi-GPU, we need to wrap the optimizer.
110
113
if params .get ('multi_gpu' ):
@@ -114,10 +117,15 @@ def model_fn(features, labels, mode, params):
114
117
loss = tf .losses .sparse_softmax_cross_entropy (labels = labels , logits = logits )
115
118
accuracy = tf .metrics .accuracy (
116
119
labels = labels , predictions = tf .argmax (logits , axis = 1 ))
117
- # Name the accuracy tensor 'train_accuracy' to demonstrate the
118
- # LoggingTensorHook.
120
+
121
+ # Name tensors to be logged with LoggingTensorHook.
122
+ tf .identity (LEARNING_RATE , 'learning_rate' )
123
+ tf .identity (loss , 'cross_entropy' )
119
124
tf .identity (accuracy [1 ], name = 'train_accuracy' )
125
+
126
+ # Save accuracy scalar to Tensorboard output.
120
127
tf .summary .scalar ('train_accuracy' , accuracy [1 ])
128
+
121
129
return tf .estimator .EstimatorSpec (
122
130
mode = tf .estimator .ModeKeys .TRAIN ,
123
131
loss = loss ,
@@ -185,30 +193,32 @@ def main(unused_argv):
185
193
'multi_gpu' : FLAGS .multi_gpu
186
194
})
187
195
188
- # Train the model
196
+ # Set up training and evaluation input functions.
189
197
def train_input_fn ():
190
198
# When choosing shuffle buffer sizes, larger sizes result in better
191
199
# randomness, while smaller sizes use less memory. MNIST is a small
192
200
# enough dataset that we can easily shuffle the full epoch.
193
201
ds = dataset .train (FLAGS .data_dir )
194
- ds = ds .cache ().shuffle (buffer_size = 50000 ).batch (FLAGS .batch_size ).repeat (
195
- FLAGS .train_epochs )
196
- return ds
202
+ ds = ds .cache ().shuffle (buffer_size = 50000 ).batch (FLAGS .batch_size )
197
203
198
- # Set up training hook that logs the training accuracy every 100 steps.
199
- tensors_to_log = {'train_accuracy' : 'train_accuracy' }
200
- logging_hook = tf .train .LoggingTensorHook (
201
- tensors = tensors_to_log , every_n_iter = 100 )
202
- mnist_classifier .train (input_fn = train_input_fn , hooks = [logging_hook ])
204
+ # Iterate through the dataset a set number (`epochs_between_evals`) of times
205
+ # during each training session.
206
+ ds = ds .repeat (FLAGS .epochs_between_evals )
207
+ return ds
203
208
204
- # Evaluate the model and print results
205
209
def eval_input_fn ():
206
210
return dataset .test (FLAGS .data_dir ).batch (
207
211
FLAGS .batch_size ).make_one_shot_iterator ().get_next ()
208
212
209
- eval_results = mnist_classifier .evaluate (input_fn = eval_input_fn )
210
- print ()
211
- print ('Evaluation results:\n \t %s' % eval_results )
213
+ # Set up hook that outputs training logs every 100 steps.
214
+ train_hooks = hooks_helper .get_train_hooks (
215
+ FLAGS .hooks , batch_size = FLAGS .batch_size )
216
+
217
+ # Train and evaluate model.
218
+ for n in range (FLAGS .train_epochs // FLAGS .epochs_between_evals ):
219
+ mnist_classifier .train (input_fn = train_input_fn , hooks = train_hooks )
220
+ eval_results = mnist_classifier .evaluate (input_fn = eval_input_fn )
221
+ print ('\n Evaluation results:\n \t %s\n ' % eval_results )
212
222
213
223
# Export the model
214
224
if FLAGS .export_dir is not None :
@@ -220,51 +230,28 @@ def eval_input_fn():
220
230
221
231
222
232
class MNISTArgParser (argparse .ArgumentParser ):
223
-
233
+ """Argument parser for running MNIST model."""
224
234
def __init__ (self ):
225
- super (MNISTArgParser , self ).__init__ ()
235
+ super (MNISTArgParser , self ).__init__ (parents = [
236
+ parsers .BaseParser (),
237
+ parsers .ImageModelParser ()])
226
238
227
- self .add_argument (
228
- '--multi_gpu' , action = 'store_true' ,
229
- help = 'If set, run across all available GPUs.' )
230
- self .add_argument (
231
- '--batch_size' ,
232
- type = int ,
233
- default = 100 ,
234
- help = 'Number of images to process in a batch' )
235
- self .add_argument (
236
- '--data_dir' ,
237
- type = str ,
238
- default = '/tmp/mnist_data' ,
239
- help = 'Path to directory containing the MNIST dataset' )
240
- self .add_argument (
241
- '--model_dir' ,
242
- type = str ,
243
- default = '/tmp/mnist_model' ,
244
- help = 'The directory where the model will be stored.' )
245
- self .add_argument (
246
- '--train_epochs' ,
247
- type = int ,
248
- default = 40 ,
249
- help = 'Number of epochs to train.' )
250
- self .add_argument (
251
- '--data_format' ,
252
- type = str ,
253
- default = None ,
254
- choices = ['channels_first' , 'channels_last' ],
255
- help = 'A flag to override the data format used in the model. '
256
- 'channels_first provides a performance boost on GPU but is not always '
257
- 'compatible with CPU. If left unspecified, the data format will be '
258
- 'chosen automatically based on whether TensorFlow was built for CPU or '
259
- 'GPU.' )
260
239
self .add_argument (
261
240
'--export_dir' ,
262
241
type = str ,
263
- help = 'The directory where the exported SavedModel will be stored.' )
242
+ help = '[default: %(default)s] If set, a SavedModel serialization of the '
243
+ 'model will be exported to this directory at the end of training. '
244
+ 'See the README for more details and relevant links.' )
245
+
246
+ self .set_defaults (
247
+ data_dir = '/tmp/mnist_data' ,
248
+ model_dir = '/tmp/mnist_model' ,
249
+ batch_size = 100 ,
250
+ train_epochs = 40 )
264
251
265
252
266
253
if __name__ == '__main__' :
267
- parser = MNISTArgParser ()
268
254
tf .logging .set_verbosity (tf .logging .INFO )
255
+ parser = MNISTArgParser ()
269
256
FLAGS , unparsed = parser .parse_known_args ()
270
257
tf .app .run (main = main , argv = [sys .argv [0 ]] + unparsed )
0 commit comments