Skip to content

Commit 5dee51e

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add annotation to file adapters whether tf.data is supported
PiperOrigin-RevId: 648368969
1 parent 839e49e commit 5dee51e

File tree

5 files changed

+164
-150
lines changed

5 files changed

+164
-150
lines changed

tensorflow_datasets/core/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,29 @@
1818
# Allow to use `tfds.core.Path` in dataset implementation which seems more
1919
# natural than having to import a third party module.
2020
from etils.epath import Path
21-
2221
from tensorflow_datasets.core import community
2322
from tensorflow_datasets.core.dataset_builder import BeamBasedBuilder
2423
from tensorflow_datasets.core.dataset_builder import BuilderConfig
2524
from tensorflow_datasets.core.dataset_builder import DatasetBuilder
2625
from tensorflow_datasets.core.dataset_builder import GeneratorBasedBuilder
27-
2826
from tensorflow_datasets.core.dataset_info import BeamMetadataDict
2927
from tensorflow_datasets.core.dataset_info import DatasetIdentity
3028
from tensorflow_datasets.core.dataset_info import DatasetInfo
3129
from tensorflow_datasets.core.dataset_info import Metadata
3230
from tensorflow_datasets.core.dataset_info import MetadataDict
33-
3431
from tensorflow_datasets.core.example_serializer import ExampleSerializer
35-
3632
from tensorflow_datasets.core.file_adapters import FileFormat
37-
3833
from tensorflow_datasets.core.lazy_imports_lib import lazy_imports
39-
4034
from tensorflow_datasets.core.load import DatasetCollectionLoader
41-
4235
from tensorflow_datasets.core.naming import ShardedFileTemplate
43-
4436
from tensorflow_datasets.core.registered import DatasetNotFoundError
45-
4637
from tensorflow_datasets.core.sequential_writer import SequentialWriter
47-
4838
from tensorflow_datasets.core.split_builder import SplitGeneratorLegacy as SplitGenerator
49-
5039
from tensorflow_datasets.core.splits import ReadInstruction
5140
from tensorflow_datasets.core.splits import Split
5241
from tensorflow_datasets.core.splits import SplitDict
5342
from tensorflow_datasets.core.splits import SplitInfo
5443
from tensorflow_datasets.core.splits import SubSplitInfo
55-
5644
from tensorflow_datasets.core.utils import Experiment
5745
from tensorflow_datasets.core.utils import gcs_path
5846
from tensorflow_datasets.core.utils import lazy_imports_utils
@@ -61,6 +49,7 @@
6149
from tensorflow_datasets.core.utils.benchmark import BenchmarkResult
6250
from tensorflow_datasets.core.utils.file_utils import add_data_dir
6351
from tensorflow_datasets.core.utils.file_utils import as_path
52+
from tensorflow_datasets.core.writer import ExampleWriter
6453

6554

6655
def benchmark(*args, **kwargs):
@@ -81,6 +70,7 @@ def benchmark(*args, **kwargs):
8170
"DatasetInfo",
8271
"DatasetIdentity",
8372
"DatasetNotFoundError",
73+
"ExampleWriter",
8474
"Experiment",
8575
"FileFormat",
8676
"GeneratorBasedBuilder",

tensorflow_datasets/core/file_adapters.py

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from __future__ import annotations
1919

2020
import abc
21-
from collections.abc import Iterator
21+
from collections.abc import Iterable, Iterator
2222
import enum
2323
import itertools
2424
import os
25-
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Type, TypeVar, Union
25+
import re
26+
from typing import Any, ClassVar, Type, TypeVar
2627

2728
from etils import epath
2829
from tensorflow_datasets.core.utils import file_utils
@@ -32,7 +33,7 @@
3233
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
3334
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3435

35-
ExamplePositions = List[Any]
36+
ExamplePositions = list[Any]
3637
T = TypeVar('T')
3738

3839

@@ -61,7 +62,34 @@ def with_random_access(cls) -> set[FileFormat]:
6162
}
6263

6364
@classmethod
64-
def from_value(cls, file_format: Union[str, 'FileFormat']) -> 'FileFormat':
65+
def with_tf_data(cls) -> set[FileFormat]:
66+
"""File formats with tf.data support."""
67+
return {
68+
file_format
69+
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
70+
if adapter.SUPPORTS_TF_DATA
71+
}
72+
73+
@classmethod
74+
def with_suffix_before_shard_spec(cls) -> set[FileFormat]:
75+
"""File formats with suffix before shard spec."""
76+
return {
77+
file_format
78+
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
79+
if adapter.SUFFIX_BEFORE_SHARD_SPEC
80+
}
81+
82+
@classmethod
83+
def with_suffix_after_shard_spec(cls) -> set[FileFormat]:
84+
"""File formats with suffix after shard spec."""
85+
return {
86+
file_format
87+
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
88+
if not adapter.SUFFIX_BEFORE_SHARD_SPEC
89+
}
90+
91+
@classmethod
92+
def from_value(cls, file_format: str | FileFormat) -> FileFormat:
6593
try:
6694
return cls(file_format)
6795
except ValueError as e:
@@ -79,15 +107,22 @@ class FileAdapter(abc.ABC):
79107
"""Interface for Adapter objects which read and write examples in a format."""
80108

81109
FILE_SUFFIX: ClassVar[str]
110+
111+
# Whether the file format suffix should go before the shard spec.
112+
# For example, `dataset-train.tfrecord-00000-of-00001` if `True`,
113+
# otherwise `dataset-train-00000-of-00001.tfrecord`.
114+
SUFFIX_BEFORE_SHARD_SPEC: ClassVar[bool] = True
115+
82116
SUPPORTS_RANDOM_ACCESS: ClassVar[bool]
117+
SUPPORTS_TF_DATA: ClassVar[bool]
83118
BUFFER_SIZE = 8 << 20 # 8 MiB per file.
84119

85120
@classmethod
86121
@abc.abstractmethod
87122
def make_tf_data(
88123
cls,
89124
filename: epath.PathLike,
90-
buffer_size: Optional[int] = None,
125+
buffer_size: int | None = None,
91126
) -> tf.data.Dataset:
92127
"""Returns TensorFlow Dataset comprising given record file."""
93128
raise NotImplementedError()
@@ -98,7 +133,7 @@ def write_examples(
98133
cls,
99134
path: epath.PathLike,
100135
iterator: Iterable[type_utils.KeySerializedExample],
101-
) -> Optional[ExamplePositions]:
136+
) -> ExamplePositions | None:
102137
"""Write examples from given iterator in given path.
103138
104139
Args:
@@ -117,12 +152,13 @@ class TfRecordFileAdapter(FileAdapter):
117152

118153
FILE_SUFFIX = 'tfrecord'
119154
SUPPORTS_RANDOM_ACCESS = False
155+
SUPPORTS_TF_DATA = True
120156

121157
@classmethod
122158
def make_tf_data(
123159
cls,
124160
filename: epath.PathLike,
125-
buffer_size: Optional[int] = None,
161+
buffer_size: int | None = None,
126162
) -> tf.data.Dataset:
127163
"""Returns TensorFlow Dataset comprising given record file."""
128164
buffer_size = buffer_size or cls.BUFFER_SIZE
@@ -133,7 +169,7 @@ def write_examples(
133169
cls,
134170
path: epath.PathLike,
135171
iterator: Iterable[type_utils.KeySerializedExample],
136-
) -> Optional[ExamplePositions]:
172+
) -> ExamplePositions | None:
137173
"""Write examples from given iterator in given path.
138174
139175
Args:
@@ -154,12 +190,13 @@ class RiegeliFileAdapter(FileAdapter):
154190

155191
FILE_SUFFIX = 'riegeli'
156192
SUPPORTS_RANDOM_ACCESS = False
193+
SUPPORTS_TF_DATA = True
157194

158195
@classmethod
159196
def make_tf_data(
160197
cls,
161198
filename: epath.PathLike,
162-
buffer_size: Optional[int] = None,
199+
buffer_size: int | None = None,
163200
) -> tf.data.Dataset:
164201
buffer_size = buffer_size or cls.BUFFER_SIZE
165202
from riegeli.tensorflow.ops import riegeli_dataset_ops as riegeli_tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
@@ -171,7 +208,7 @@ def write_examples(
171208
cls,
172209
path: epath.PathLike,
173210
iterator: Iterable[type_utils.KeySerializedExample],
174-
) -> Optional[ExamplePositions]:
211+
) -> ExamplePositions | None:
175212
"""Write examples from given iterator in given path.
176213
177214
Args:
@@ -197,12 +234,13 @@ class ArrayRecordFileAdapter(FileAdapter):
197234

198235
FILE_SUFFIX = 'array_record'
199236
SUPPORTS_RANDOM_ACCESS = True
237+
SUPPORTS_TF_DATA = False
200238

201239
@classmethod
202240
def make_tf_data(
203241
cls,
204242
filename: epath.PathLike,
205-
buffer_size: Optional[int] = None,
243+
buffer_size: int | None = None,
206244
) -> tf.data.Dataset:
207245
"""Returns TensorFlow Dataset comprising given array record file."""
208246
raise NotImplementedError(
@@ -215,7 +253,7 @@ def write_examples(
215253
cls,
216254
path: epath.PathLike,
217255
iterator: Iterable[type_utils.KeySerializedExample],
218-
) -> Optional[ExamplePositions]:
256+
) -> ExamplePositions | None:
219257
"""Write examples from given iterator in given path.
220258
221259
Args:
@@ -249,6 +287,7 @@ class ParquetFileAdapter(FileAdapter):
249287

250288
FILE_SUFFIX = 'parquet'
251289
SUPPORTS_RANDOM_ACCESS = True
290+
SUPPORTS_TF_DATA = True
252291
_PARQUET_FIELD = 'data'
253292
_BATCH_SIZE = 100
254293

@@ -319,11 +358,11 @@ def _to_bytes(key: type_utils.Key) -> bytes:
319358

320359

321360
# Create a mapping from FileFormat -> FileAdapter.
322-
ADAPTER_FOR_FORMAT: Dict[FileFormat, Type[FileAdapter]] = {
323-
FileFormat.RIEGELI: RiegeliFileAdapter,
324-
FileFormat.TFRECORD: TfRecordFileAdapter,
361+
ADAPTER_FOR_FORMAT: dict[FileFormat, Type[FileAdapter]] = {
325362
FileFormat.ARRAY_RECORD: ArrayRecordFileAdapter,
326363
FileFormat.PARQUET: ParquetFileAdapter,
364+
FileFormat.RIEGELI: RiegeliFileAdapter,
365+
FileFormat.TFRECORD: TfRecordFileAdapter,
327366
}
328367

329368
_FILE_SUFFIX_TO_FORMAT = {
@@ -350,7 +389,7 @@ def is_example_file(filename: str) -> bool:
350389
)
351390

352391

353-
def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[List[T]]:
392+
def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[list[T]]:
354393
"""Batches the result of an iterator into lists of length n.
355394
356395
This function is built-in the standard library from 3.12 (source:
@@ -371,3 +410,49 @@ def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[List[T]]:
371410
return
372411
yield batch
373412
i += n
413+
414+
415+
def convert_path_to_file_format(
416+
path: epath.PathLike, file_format: FileFormat
417+
) -> epath.Path:
418+
"""Returns the path to a specific shard for a different file format.
419+
420+
TFDS can store the file format in the filename as a suffix or as an infix. For
421+
example:
422+
423+
- `dataset-train.<FILE_FORMAT>-00000-of-00001`, a so-called infix format
424+
because the file format comes before the shard spec.
425+
- `dataset-train-00000-of-00001.<FILE_FORMAT>`, a so-called suffix format
426+
because the file format comes after the shard spec.
427+
428+
Args:
429+
path: The path of a specific to convert. Can be the path for different file
430+
formats.
431+
file_format: The file format to which the shard path should be converted.
432+
"""
433+
path = epath.Path(path)
434+
file_name: str = path.name
435+
if file_format.file_suffix in file_name:
436+
# Already has the right file format in the file name.
437+
return path
438+
439+
infix_formats = FileFormat.with_suffix_before_shard_spec()
440+
suffix_formats = FileFormat.with_suffix_after_shard_spec()
441+
442+
# Remove any existing file format from the file name.
443+
infix_format_concat = '|'.join(f.file_suffix for f in infix_formats)
444+
file_name = re.sub(rf'(\.({infix_format_concat}))', '', file_name)
445+
446+
suffix_formats_concat = '|'.join(f.file_suffix for f in suffix_formats)
447+
file_name = re.sub(rf'(\.({suffix_formats_concat}))$', '', file_name)
448+
449+
# Add back the proper file format.
450+
if file_format in suffix_formats:
451+
file_name = f'{file_name}.{file_format.file_suffix}'
452+
else:
453+
file_name = re.sub(
454+
r'-(\d+)-of-(\d+)',
455+
rf'.{file_format.file_suffix}-\1-of-\2',
456+
file_name,
457+
)
458+
return path.parent / file_name

0 commit comments

Comments
 (0)