Skip to content

Commit 2f673d6

Browse files
authored
Merge pull request #501 from qsbao/list_files_no_shuffle
list_files in inference examples should be deterministically
2 parents 3a0d837 + 2eed5e5 commit 2f673d6

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/mnist/estimator/mnist_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def parse_tfr(example_proto):
4747
return (image, label)
4848

4949
# define a new tf.data.Dataset (for inferencing)
50-
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
50+
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False)
5151
ds = ds.shard(num_workers, worker_num)
5252
ds = ds.interleave(tf.data.TFRecordDataset)
5353
ds = ds.map(parse_tfr)

examples/mnist/keras/mnist_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def parse_tfr(example_proto):
3838
return (image, label)
3939

4040
# define a new tf.data.Dataset (for inferencing)
41-
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
41+
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False)
4242
ds = ds.shard(ctx.num_workers, ctx.worker_num)
4343
ds = ds.interleave(tf.data.TFRecordDataset)
4444
ds = ds.map(parse_tfr)

0 commit comments

Comments
 (0)