Skip to content

Commit 4cfa0d3

Browse files
authored
Input improvements (#2706)
1 parent dcc2368 commit 4cfa0d3

File tree

5 files changed

+44
-44
lines changed

5 files changed

+44
-44
lines changed

official/mnist/mnist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,16 @@ def example_parser(serialized_example):
7979
# a small dataset, we can easily shuffle the full epoch.
8080
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
8181

82+
# We call repeat after shuffling, rather than before, to prevent separate
83+
# epochs from blending together.
8284
dataset = dataset.repeat(num_epochs)
8385

8486
# Map example_parser over dataset, and batch results by up to batch_size
8587
dataset = dataset.map(
8688
example_parser, num_threads=1, output_buffer_size=batch_size)
8789
dataset = dataset.batch(batch_size)
88-
images, labels = dataset.make_one_shot_iterator().get_next()
90+
iterator = dataset.make_one_shot_iterator()
91+
images, labels = iterator.get_next()
8992

9093
return images, labels
9194

official/resnet/cifar10_main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
166166
num_threads=1,
167167
output_buffer_size=2 * batch_size)
168168

169+
# We call repeat after shuffling, rather than before, to prevent separate
170+
# epochs from blending together.
169171
dataset = dataset.repeat(num_epochs)
170172

171173
# Batch results by up to batch_size, and then fetch the tuple from the
172174
# iterator.
173-
iterator = dataset.batch(batch_size).make_one_shot_iterator()
175+
dataset = dataset.batch(batch_size)
176+
iterator = dataset.make_one_shot_iterator()
174177
images, labels = iterator.get_next()
175178

176179
return images, labels

official/resnet/imagenet_main.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
'validation': 50000,
7474
}
7575

76+
_FILE_SHUFFLE_BUFFER = 1024
7677
_SHUFFLE_BUFFER = 1500
7778

7879

@@ -81,11 +82,11 @@ def filenames(is_training, data_dir):
8182
if is_training:
8283
return [
8384
os.path.join(data_dir, 'train-%05d-of-01024' % i)
84-
for i in range(0, 1024)]
85+
for i in range(1024)]
8586
else:
8687
return [
8788
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
88-
for i in range(0, 128)]
89+
for i in range(128)]
8990

9091

9192
def dataset_parser(value, is_training):
@@ -137,11 +138,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
137138
filenames(is_training, data_dir))
138139

139140
if is_training:
140-
dataset = dataset.shuffle(buffer_size=1024)
141-
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
141+
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
142142

143-
if is_training:
144-
dataset = dataset.repeat(num_epochs)
143+
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
145144

146145
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
147146
num_threads=5,
@@ -152,7 +151,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
152151
# randomness, while smaller sizes have better performance.
153152
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
154153

155-
iterator = dataset.batch(batch_size).make_one_shot_iterator()
154+
# We call repeat after shuffling, rather than before, to prevent separate
155+
# epochs from blending together.
156+
dataset = dataset.repeat(num_epochs)
157+
dataset = dataset.batch(batch_size)
158+
159+
iterator = dataset.make_one_shot_iterator()
156160
images, labels = iterator.get_next()
157161
return images, labels
158162

@@ -188,8 +192,8 @@ def resnet_model_fn(features, labels, mode, params):
188192
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
189193

190194
if mode == tf.estimator.ModeKeys.TRAIN:
191-
# Scale the learning rate linearly with the batch size. When the batch size is
192-
# 256, the learning rate should be 0.1.
195+
# Scale the learning rate linearly with the batch size. When the batch size
196+
# is 256, the learning rate should be 0.1.
193197
initial_learning_rate = 0.1 * params['batch_size'] / 256
194198
batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size']
195199
global_step = tf.train.get_or_create_global_step()

official/wide_deep/wide_deep.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
'--test_data', type=str, default='/tmp/census_data/adult.test',
6262
help='Path to the test data.')
6363

64+
_SHUFFLE_BUFFER = 100000
65+
6466

6567
def build_model_columns():
6668
"""Builds a set of wide and deep feature columns."""
@@ -167,6 +169,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
167169
assert tf.gfile.Exists(data_file), (
168170
'%s not found. Please make sure you have either run data_download.py or '
169171
'set both arguments --train_data and --test_data.' % data_file)
172+
170173
def parse_csv(value):
171174
print('Parsing', data_file)
172175
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
@@ -178,49 +181,36 @@ def parse_csv(value):
178181
dataset = tf.contrib.data.TextLineDataset(data_file)
179182
dataset = dataset.map(parse_csv, num_threads=5)
180183

181-
# Apply transformations to the Dataset
182-
dataset = dataset.batch(batch_size)
184+
if shuffle:
185+
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
186+
187+
# We call repeat after shuffling, rather than before, to prevent separate
188+
# epochs from blending together.
183189
dataset = dataset.repeat(num_epochs)
190+
dataset = dataset.batch(batch_size)
184191

185-
# Input function that is called by the Estimator
186-
def _input_fn():
187-
if shuffle:
188-
# Apply shuffle transformation to re-shuffle the dataset in each call.
189-
shuffled_dataset = dataset.shuffle(buffer_size=100000)
190-
iterator = shuffled_dataset.make_one_shot_iterator()
191-
else:
192-
iterator = dataset.make_one_shot_iterator()
193-
features, labels = iterator.get_next()
194-
return features, labels
195-
return _input_fn
192+
iterator = dataset.make_one_shot_iterator()
193+
features, labels = iterator.get_next()
194+
return features, labels
196195

197196

198197
def main(unused_argv):
199198
# Clean up the model directory if present
200199
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
201-
202200
model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
203201

204-
# Set up input function generators for the train and test data files.
205-
train_input_fn = input_fn(
206-
data_file=FLAGS.train_data,
207-
num_epochs=FLAGS.epochs_per_eval,
208-
shuffle=True,
209-
batch_size=FLAGS.batch_size)
210-
eval_input_fn = input_fn(
211-
data_file=FLAGS.test_data,
212-
num_epochs=1,
213-
shuffle=False,
214-
batch_size=FLAGS.batch_size)
215-
216202
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
217203
for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
218-
model.train(input_fn=train_input_fn)
219-
results = model.evaluate(input_fn=eval_input_fn)
204+
model.train(input_fn=lambda: input_fn(
205+
FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size))
206+
207+
results = model.evaluate(input_fn=lambda: input_fn(
208+
FLAGS.test_data, 1, False, FLAGS.batch_size))
220209

221210
# Display evaluation metrics
222211
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
223212
print('-' * 30)
213+
224214
for key in sorted(results):
225215
print('%s: %s' % (key, results[key]))
226216

official/wide_deep/wide_deep_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def setUp(self):
5454
temp_csv.write(TEST_INPUT)
5555

5656
def test_input_fn(self):
57-
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)()
57+
features, labels = wide_deep.input_fn(self.input_csv, 1, False, 1)
5858
with tf.Session() as sess:
5959
features, labels = sess.run((features, labels))
6060

@@ -78,20 +78,20 @@ def build_and_test_estimator(self, model_type):
7878

7979
# Train for 1 step to initialize model and evaluate initial loss
8080
model.train(
81-
input_fn=wide_deep.input_fn(
81+
input_fn=lambda: wide_deep.input_fn(
8282
TEST_CSV, num_epochs=1, shuffle=True, batch_size=1),
8383
steps=1)
8484
initial_results = model.evaluate(
85-
input_fn=wide_deep.input_fn(
85+
input_fn=lambda: wide_deep.input_fn(
8686
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
8787

8888
# Train for 40 steps at batch size 2 and evaluate final loss
8989
model.train(
90-
input_fn=wide_deep.input_fn(
90+
input_fn=lambda: wide_deep.input_fn(
9191
TEST_CSV, num_epochs=None, shuffle=True, batch_size=2),
9292
steps=40)
9393
final_results = model.evaluate(
94-
input_fn=wide_deep.input_fn(
94+
input_fn=lambda: wide_deep.input_fn(
9595
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
9696

9797
print('%s initial results:' % model_type, initial_results)

0 commit comments

Comments
 (0)