Skip to content

Commit e7027fe

Browse files
Internal change
PiperOrigin-RevId: 481251614
1 parent 6f521d9 commit e7027fe

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

official/core/file_writers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ def write_small_dataset(examples: Sequence[Union[tf.train.Example,
3333
examples: List of tf.train.Example or tf.train.SequenceExample.
3434
output_path: Output path for the dataset.
3535
file_type: A string indicating the file format, could be: 'tfrecord',
36-
'tfrecord_compressed', 'riegeli'.
36+
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
37+
string is case insensitive.
3738
"""
3839
file_type = file_type.lower()
3940

40-
if file_type == 'tfrecord':
41+
if file_type == 'tfrecord' or file_type == 'tfrecords':
4142
_write_tfrecord(examples, output_path)
42-
elif file_type == 'tfrecord_compressed':
43+
elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
4344
_write_tfrecord(examples, output_path,
4445
tf.io.TFRecordOptions(compression_type='GZIP'))
4546
elif file_type == 'riegeli':

official/core/file_writers_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ def setUp(self):
3030
example_builder.add_bytes_feature('foo', 'Hello World!')
3131
self._example = example_builder.example
3232

33-
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecord_compressed',
34-
'TFRecord_Compressed')
33+
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
34+
'tfrecord_compressed', 'TFRecord_Compressed',
35+
'tfrecords_gzip')
3536
def test_write_small_dataset_success(self, file_type):
3637
temp_dir = self.create_tempdir()
3738
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')

0 commit comments

Comments
 (0)