Skip to content

Commit fc24556

Browse files
authored
Merge pull request #481 from yahoo/leewyang_mnist_ds
fix data_format parsing; remove auto_shard option
2 parents df468ff + 01569e7 commit fc24556

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

examples/mnist/keras/mnist_tf_ds.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def parse_tfos(example_proto):
3939
# tfos: /path/to/mnist/tfr/train/part-r-*
4040
image_pattern = ctx.absolute_path(args.images_labels)
4141

42-
options = tf.data.Options()
43-
options.experimental_distribute.auto_shard = False
44-
4542
ds = tf.data.Dataset.list_files(image_pattern)
46-
ds = ds.with_options(options)
4743
ds = ds.repeat(args.epochs).shuffle(BUFFER_SIZE)
4844
ds = ds.interleave(tf.data.TFRecordDataset)
49-
train_datasets_unbatched = ds.map(parse_tfos)
45+
46+
if args.data_format == 'tfds':
47+
train_datasets_unbatched = ds.map(parse_tfds)
48+
else: # 'tfos'
49+
train_datasets_unbatched = ds.map(parse_tfos)
5050

5151
def build_and_compile_cnn_model():
5252
model = tf.keras.Sequential([

0 commit comments

Comments
 (0)