Skip to content

Commit b9f906d

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add example writer that writes to multiple file formats
This can be used when you want to generate the dataset in multiple formats, e.g. tfrecord, riegeli, and parquet. PiperOrigin-RevId: 646800706
1 parent 82c0cc8 commit b9f906d

File tree

2 files changed

+122
-13
lines changed

2 files changed

+122
-13
lines changed

tensorflow_datasets/core/writer.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,28 @@
2323
import itertools
2424
import json
2525
import os
26+
import re
2627
from typing import Any
2728

28-
from absl import logging
29-
from etils import epath
30-
from tensorflow_datasets.core import example_parser
31-
from tensorflow_datasets.core import example_serializer
32-
from tensorflow_datasets.core import file_adapters
33-
from tensorflow_datasets.core import hashing
34-
from tensorflow_datasets.core import lazy_imports_lib
35-
from tensorflow_datasets.core import naming
36-
from tensorflow_datasets.core import shuffle
37-
from tensorflow_datasets.core import utils
38-
from tensorflow_datasets.core.utils import file_utils
39-
from tensorflow_datasets.core.utils import shard_utils
40-
from tensorflow_datasets.core.utils import type_utils
29+
from etils import epy
4130
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
4231

32+
with epy.lazy_imports():
33+
# pylint: disable=g-import-not-at-top
34+
from absl import logging
35+
from etils import epath
36+
from tensorflow_datasets.core import example_parser
37+
from tensorflow_datasets.core import example_serializer
38+
from tensorflow_datasets.core import file_adapters
39+
from tensorflow_datasets.core import hashing
40+
from tensorflow_datasets.core import naming
41+
from tensorflow_datasets.core import shuffle
42+
from tensorflow_datasets.core import utils
43+
from tensorflow_datasets.core.utils import file_utils
44+
from tensorflow_datasets.core.utils import shard_utils
45+
from tensorflow_datasets.core.utils import type_utils
46+
47+
# pylint: enable=g-import-not-at-top
4348

4449
# TODO(tfds): Should be `TreeDict[FeatureValue]`
4550
Example = Any
@@ -186,6 +191,63 @@ def write(
186191
return adapter.write_examples(path, examples)
187192

188193

194+
def _convert_path_to_file_format(
195+
path: epath.PathLike, file_format: file_adapters.FileFormat
196+
) -> epath.Path:
197+
"""Returns the path to a specific shard in a different file format.
198+
199+
TFDS typically stores the file format in the filename. For example,
200+
`dataset-train.tfrecord-00000-of-00001` is a TFRecord file and
201+
`dataset-train-00000-of-00001.bagz` is a Bagz file. This function converts
202+
the filename to the desired file format.
203+
204+
Args:
205+
path: The path of a specific to convert. Can be the path for different file
206+
formats.
207+
file_format: The file format to which the shard path should be converted.
208+
"""
209+
path = epath.Path(path)
210+
211+
infix_formats = [
212+
f.value
213+
for f in file_adapters.FileFormat
214+
]
215+
infix_format_concat = "|".join(infix_formats)
216+
217+
file_name = re.sub(
218+
rf"\.({infix_format_concat})", f".{file_format.value}", path.name
219+
)
220+
return path.parent / file_name
221+
222+
223+
class MultiOutputExampleWriter(ExampleWriter):
224+
"""Example writer that can write multiple outputs."""
225+
226+
def __init__(self, writers: Sequence[ExampleWriter]): # pylint: disable=super-init-not-called
227+
self._writers = writers
228+
229+
def write(
230+
self,
231+
path: epath.PathLike,
232+
examples: Iterable[type_utils.KeySerializedExample],
233+
) -> file_adapters.ExamplePositions | None:
234+
"""Writes examples to multiple outputs."""
235+
write_fns = []
236+
for writer, my_iter in zip(
237+
self._writers, itertools.tee(examples, len(self._writers))
238+
):
239+
if file_format := writer.file_format:
240+
shard_path = os.fspath(
241+
_convert_path_to_file_format(path=path, file_format=file_format)
242+
)
243+
write_fns.append(functools.partial(writer.write, shard_path, my_iter))
244+
else:
245+
write_fns.append(functools.partial(writer.write, path, my_iter))
246+
247+
for write_fn in write_fns:
248+
write_fn()
249+
250+
189251
class Writer:
190252
"""Shuffles and writes Examples to sharded files.
191253

tensorflow_datasets/core/writer_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Optional
2121
from unittest import mock
2222

23+
from absl.testing import parameterized
2324
from etils import epath
2425
import tensorflow as tf
2526
from tensorflow_datasets import testing
@@ -591,5 +592,51 @@ def write(self, path, examples) -> file_adapters.ExamplePositions | None:
591592
epath.Path(path).touch()
592593

593594

595+
class ExampleWriterTest(parameterized.TestCase):
596+
597+
@parameterized.parameters(
598+
dict(
599+
path='/tmp/dataset-train.tfrecord-00000-of-00001',
600+
file_format=file_adapters.FileFormat.TFRECORD,
601+
expected_path='/tmp/dataset-train.tfrecord-00000-of-00001',
602+
),
603+
dict(
604+
path='/tmp/dataset-train.riegeli-00000-of-00001',
605+
file_format=file_adapters.FileFormat.TFRECORD,
606+
expected_path='/tmp/dataset-train.tfrecord-00000-of-00001',
607+
),
608+
dict(
609+
path='/tmp/dataset-train.tfrecord-00000-of-00001',
610+
file_format=file_adapters.FileFormat.RIEGELI,
611+
expected_path='/tmp/dataset-train.riegeli-00000-of-00001',
612+
),
613+
)
614+
def test_convert_path_to_file_format(self, path, file_format, expected_path):
615+
converted_path = writer_lib._convert_path_to_file_format(path, file_format)
616+
self.assertEqual(os.fspath(converted_path), expected_path)
617+
618+
def test_multi_output_example_writer(self):
619+
tfrecord_writer = mock.create_autospec(writer_lib.ExampleWriter)
620+
tfrecord_writer.file_format = file_adapters.FileFormat.TFRECORD
621+
622+
riegeli_writer = mock.create_autospec(writer_lib.ExampleWriter)
623+
riegeli_writer.file_format = file_adapters.FileFormat.RIEGELI
624+
625+
path = '/tmp/dataset-train.tfrecord-00000-of-00001'
626+
iterator = [
627+
('key1', b'value1'),
628+
('key2', b'value2'),
629+
]
630+
writer = writer_lib.MultiOutputExampleWriter([
631+
tfrecord_writer,
632+
riegeli_writer,
633+
])
634+
writer.write(path=path, examples=iterator)
635+
tfrecord_writer.write.assert_called_once_with(path, mock.ANY)
636+
riegeli_writer.write.assert_called_once_with(
637+
'/tmp/dataset-train.riegeli-00000-of-00001', mock.ANY
638+
)
639+
640+
594641
if __name__ == '__main__':
595642
testing.test_main()

0 commit comments

Comments
 (0)