|
23 | 23 | import itertools
|
24 | 24 | import json
|
25 | 25 | import os
|
| 26 | +import re |
26 | 27 | from typing import Any
|
27 | 28 |
|
28 |
| -from absl import logging |
29 |
| -from etils import epath |
30 |
| -from tensorflow_datasets.core import example_parser |
31 |
| -from tensorflow_datasets.core import example_serializer |
32 |
| -from tensorflow_datasets.core import file_adapters |
33 |
| -from tensorflow_datasets.core import hashing |
34 |
| -from tensorflow_datasets.core import lazy_imports_lib |
35 |
| -from tensorflow_datasets.core import naming |
36 |
| -from tensorflow_datasets.core import shuffle |
37 |
| -from tensorflow_datasets.core import utils |
38 |
| -from tensorflow_datasets.core.utils import file_utils |
39 |
| -from tensorflow_datasets.core.utils import shard_utils |
40 |
| -from tensorflow_datasets.core.utils import type_utils |
| 29 | +from etils import epy |
41 | 30 | from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
|
42 | 31 |
|
| 32 | +with epy.lazy_imports(): |
| 33 | + # pylint: disable=g-import-not-at-top |
| 34 | + from absl import logging |
| 35 | + from etils import epath |
| 36 | + from tensorflow_datasets.core import example_parser |
| 37 | + from tensorflow_datasets.core import example_serializer |
| 38 | + from tensorflow_datasets.core import file_adapters |
| 39 | + from tensorflow_datasets.core import hashing |
| 40 | + from tensorflow_datasets.core import naming |
| 41 | + from tensorflow_datasets.core import shuffle |
| 42 | + from tensorflow_datasets.core import utils |
| 43 | + from tensorflow_datasets.core.utils import file_utils |
| 44 | + from tensorflow_datasets.core.utils import shard_utils |
| 45 | + from tensorflow_datasets.core.utils import type_utils |
| 46 | + |
| 47 | + # pylint: enable=g-import-not-at-top |
43 | 48 |
|
44 | 49 | # TODO(tfds): Should be `TreeDict[FeatureValue]`
|
45 | 50 | Example = Any
|
@@ -186,6 +191,63 @@ def write(
|
186 | 191 | return adapter.write_examples(path, examples)
|
187 | 192 |
|
188 | 193 |
|
| 194 | +def _convert_path_to_file_format( |
| 195 | + path: epath.PathLike, file_format: file_adapters.FileFormat |
| 196 | +) -> epath.Path: |
| 197 | + """Returns the path to a specific shard in a different file format. |
| 198 | +
|
| 199 | + TFDS typically stores the file format in the filename. For example, |
| 200 | + `dataset-train.tfrecord-00000-of-00001` is a TFRecord file and |
| 201 | + `dataset-train-00000-of-00001.bagz` is a Bagz file. This function converts |
| 202 | + the filename to the desired file format. |
| 203 | +
|
| 204 | + Args: |
| 205 | + path: The path of a specific to convert. Can be the path for different file |
| 206 | + formats. |
| 207 | + file_format: The file format to which the shard path should be converted. |
| 208 | + """ |
| 209 | + path = epath.Path(path) |
| 210 | + |
| 211 | + infix_formats = [ |
| 212 | + f.value |
| 213 | + for f in file_adapters.FileFormat |
| 214 | + ] |
| 215 | + infix_format_concat = "|".join(infix_formats) |
| 216 | + |
| 217 | + file_name = re.sub( |
| 218 | + rf"\.({infix_format_concat})", f".{file_format.value}", path.name |
| 219 | + ) |
| 220 | + return path.parent / file_name |
| 221 | + |
| 222 | + |
| 223 | +class MultiOutputExampleWriter(ExampleWriter): |
| 224 | + """Example writer that can write multiple outputs.""" |
| 225 | + |
| 226 | + def __init__(self, writers: Sequence[ExampleWriter]): # pylint: disable=super-init-not-called |
| 227 | + self._writers = writers |
| 228 | + |
| 229 | + def write( |
| 230 | + self, |
| 231 | + path: epath.PathLike, |
| 232 | + examples: Iterable[type_utils.KeySerializedExample], |
| 233 | + ) -> file_adapters.ExamplePositions | None: |
| 234 | + """Writes examples to multiple outputs.""" |
| 235 | + write_fns = [] |
| 236 | + for writer, my_iter in zip( |
| 237 | + self._writers, itertools.tee(examples, len(self._writers)) |
| 238 | + ): |
| 239 | + if file_format := writer.file_format: |
| 240 | + shard_path = os.fspath( |
| 241 | + _convert_path_to_file_format(path=path, file_format=file_format) |
| 242 | + ) |
| 243 | + write_fns.append(functools.partial(writer.write, shard_path, my_iter)) |
| 244 | + else: |
| 245 | + write_fns.append(functools.partial(writer.write, path, my_iter)) |
| 246 | + |
| 247 | + for write_fn in write_fns: |
| 248 | + write_fn() |
| 249 | + |
| 250 | + |
189 | 251 | class Writer:
|
190 | 252 | """Shuffles and writes Examples to sharded files.
|
191 | 253 |
|
|
0 commit comments