61
61
'--test_data' , type = str , default = '/tmp/census_data/adult.test' ,
62
62
help = 'Path to the test data.' )
63
63
64
+ _SHUFFLE_BUFFER = 100000
65
+
64
66
65
67
def build_model_columns ():
66
68
"""Builds a set of wide and deep feature columns."""
@@ -167,6 +169,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
167
169
assert tf .gfile .Exists (data_file ), (
168
170
'%s not found. Please make sure you have either run data_download.py or '
169
171
'set both arguments --train_data and --test_data.' % data_file )
172
+
170
173
def parse_csv (value ):
171
174
print ('Parsing' , data_file )
172
175
columns = tf .decode_csv (value , record_defaults = _CSV_COLUMN_DEFAULTS )
@@ -178,49 +181,36 @@ def parse_csv(value):
178
181
dataset = tf .contrib .data .TextLineDataset (data_file )
179
182
dataset = dataset .map (parse_csv , num_threads = 5 )
180
183
181
- # Apply transformations to the Dataset
182
- dataset = dataset .batch (batch_size )
184
+ if shuffle :
185
+ dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
186
+
187
+ # We call repeat after shuffling, rather than before, to prevent separate
188
+ # epochs from blending together.
183
189
dataset = dataset .repeat (num_epochs )
190
+ dataset = dataset .batch (batch_size )
184
191
185
- # Input function that is called by the Estimator
186
- def _input_fn ():
187
- if shuffle :
188
- # Apply shuffle transformation to re-shuffle the dataset in each call.
189
- shuffled_dataset = dataset .shuffle (buffer_size = 100000 )
190
- iterator = shuffled_dataset .make_one_shot_iterator ()
191
- else :
192
- iterator = dataset .make_one_shot_iterator ()
193
- features , labels = iterator .get_next ()
194
- return features , labels
195
- return _input_fn
192
+ iterator = dataset .make_one_shot_iterator ()
193
+ features , labels = iterator .get_next ()
194
+ return features , labels
196
195
197
196
198
197
def main (unused_argv ):
199
198
# Clean up the model directory if present
200
199
shutil .rmtree (FLAGS .model_dir , ignore_errors = True )
201
-
202
200
model = build_estimator (FLAGS .model_dir , FLAGS .model_type )
203
201
204
- # Set up input function generators for the train and test data files.
205
- train_input_fn = input_fn (
206
- data_file = FLAGS .train_data ,
207
- num_epochs = FLAGS .epochs_per_eval ,
208
- shuffle = True ,
209
- batch_size = FLAGS .batch_size )
210
- eval_input_fn = input_fn (
211
- data_file = FLAGS .test_data ,
212
- num_epochs = 1 ,
213
- shuffle = False ,
214
- batch_size = FLAGS .batch_size )
215
-
216
202
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
217
203
for n in range (FLAGS .train_epochs // FLAGS .epochs_per_eval ):
218
- model .train (input_fn = train_input_fn )
219
- results = model .evaluate (input_fn = eval_input_fn )
204
+ model .train (input_fn = lambda : input_fn (
205
+ FLAGS .train_data , FLAGS .epochs_per_eval , True , FLAGS .batch_size ))
206
+
207
+ results = model .evaluate (input_fn = lambda : input_fn (
208
+ FLAGS .test_data , 1 , False , FLAGS .batch_size ))
220
209
221
210
# Display evaluation metrics
222
211
print ('Results at epoch' , (n + 1 ) * FLAGS .epochs_per_eval )
223
212
print ('-' * 30 )
213
+
224
214
for key in sorted (results ):
225
215
print ('%s: %s' % (key , results [key ]))
226
216
0 commit comments