18
18
from __future__ import annotations
19
19
20
20
import abc
21
- from collections .abc import Iterator
21
+ from collections .abc import Iterable , Iterator
22
22
import enum
23
23
import itertools
24
24
import os
25
- from typing import Any , ClassVar , Dict , Iterable , List , Optional , Type , TypeVar , Union
25
+ import re
26
+ from typing import Any , ClassVar , Type , TypeVar
26
27
27
28
from etils import epath
28
29
from tensorflow_datasets .core .utils import file_utils
32
33
from tensorflow_datasets .core .utils .lazy_imports_utils import pyarrow as pa
33
34
from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
34
35
35
- ExamplePositions = List [Any ]
36
+ ExamplePositions = list [Any ]
36
37
T = TypeVar ('T' )
37
38
38
39
@@ -61,7 +62,34 @@ def with_random_access(cls) -> set[FileFormat]:
61
62
}
62
63
63
64
@classmethod
64
- def from_value (cls , file_format : Union [str , 'FileFormat' ]) -> 'FileFormat' :
65
+ def with_tf_data (cls ) -> set [FileFormat ]:
66
+ """File formats with tf.data support."""
67
+ return {
68
+ file_format
69
+ for file_format , adapter in ADAPTER_FOR_FORMAT .items ()
70
+ if adapter .SUPPORTS_TF_DATA
71
+ }
72
+
73
+ @classmethod
74
+ def with_suffix_before_shard_spec (cls ) -> set [FileFormat ]:
75
+ """File formats with suffix before shard spec."""
76
+ return {
77
+ file_format
78
+ for file_format , adapter in ADAPTER_FOR_FORMAT .items ()
79
+ if adapter .SUFFIX_BEFORE_SHARD_SPEC
80
+ }
81
+
82
+ @classmethod
83
+ def with_suffix_after_shard_spec (cls ) -> set [FileFormat ]:
84
+ """File formats with suffix after shard spec."""
85
+ return {
86
+ file_format
87
+ for file_format , adapter in ADAPTER_FOR_FORMAT .items ()
88
+ if not adapter .SUFFIX_BEFORE_SHARD_SPEC
89
+ }
90
+
91
+ @classmethod
92
+ def from_value (cls , file_format : str | FileFormat ) -> FileFormat :
65
93
try :
66
94
return cls (file_format )
67
95
except ValueError as e :
@@ -79,15 +107,22 @@ class FileAdapter(abc.ABC):
79
107
"""Interface for Adapter objects which read and write examples in a format."""
80
108
81
109
FILE_SUFFIX : ClassVar [str ]
110
+
111
+ # Whether the file format suffix should go before the shard spec.
112
+ # For example, `dataset-train.tfrecord-00000-of-00001` if `True`,
113
+ # otherwise `dataset-train-00000-of-00001.tfrecord`.
114
+ SUFFIX_BEFORE_SHARD_SPEC : ClassVar [bool ] = True
115
+
82
116
SUPPORTS_RANDOM_ACCESS : ClassVar [bool ]
117
+ SUPPORTS_TF_DATA : ClassVar [bool ]
83
118
BUFFER_SIZE = 8 << 20 # 8 MiB per file.
84
119
85
120
@classmethod
86
121
@abc .abstractmethod
87
122
def make_tf_data (
88
123
cls ,
89
124
filename : epath .PathLike ,
90
- buffer_size : Optional [ int ] = None ,
125
+ buffer_size : int | None = None ,
91
126
) -> tf .data .Dataset :
92
127
"""Returns TensorFlow Dataset comprising given record file."""
93
128
raise NotImplementedError ()
@@ -98,7 +133,7 @@ def write_examples(
98
133
cls ,
99
134
path : epath .PathLike ,
100
135
iterator : Iterable [type_utils .KeySerializedExample ],
101
- ) -> Optional [ ExamplePositions ] :
136
+ ) -> ExamplePositions | None :
102
137
"""Write examples from given iterator in given path.
103
138
104
139
Args:
@@ -117,12 +152,13 @@ class TfRecordFileAdapter(FileAdapter):
117
152
118
153
FILE_SUFFIX = 'tfrecord'
119
154
SUPPORTS_RANDOM_ACCESS = False
155
+ SUPPORTS_TF_DATA = True
120
156
121
157
@classmethod
122
158
def make_tf_data (
123
159
cls ,
124
160
filename : epath .PathLike ,
125
- buffer_size : Optional [ int ] = None ,
161
+ buffer_size : int | None = None ,
126
162
) -> tf .data .Dataset :
127
163
"""Returns TensorFlow Dataset comprising given record file."""
128
164
buffer_size = buffer_size or cls .BUFFER_SIZE
@@ -133,7 +169,7 @@ def write_examples(
133
169
cls ,
134
170
path : epath .PathLike ,
135
171
iterator : Iterable [type_utils .KeySerializedExample ],
136
- ) -> Optional [ ExamplePositions ] :
172
+ ) -> ExamplePositions | None :
137
173
"""Write examples from given iterator in given path.
138
174
139
175
Args:
@@ -154,12 +190,13 @@ class RiegeliFileAdapter(FileAdapter):
154
190
155
191
FILE_SUFFIX = 'riegeli'
156
192
SUPPORTS_RANDOM_ACCESS = False
193
+ SUPPORTS_TF_DATA = True
157
194
158
195
@classmethod
159
196
def make_tf_data (
160
197
cls ,
161
198
filename : epath .PathLike ,
162
- buffer_size : Optional [ int ] = None ,
199
+ buffer_size : int | None = None ,
163
200
) -> tf .data .Dataset :
164
201
buffer_size = buffer_size or cls .BUFFER_SIZE
165
202
from riegeli .tensorflow .ops import riegeli_dataset_ops as riegeli_tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
@@ -171,7 +208,7 @@ def write_examples(
171
208
cls ,
172
209
path : epath .PathLike ,
173
210
iterator : Iterable [type_utils .KeySerializedExample ],
174
- ) -> Optional [ ExamplePositions ] :
211
+ ) -> ExamplePositions | None :
175
212
"""Write examples from given iterator in given path.
176
213
177
214
Args:
@@ -197,12 +234,13 @@ class ArrayRecordFileAdapter(FileAdapter):
197
234
198
235
FILE_SUFFIX = 'array_record'
199
236
SUPPORTS_RANDOM_ACCESS = True
237
+ SUPPORTS_TF_DATA = False
200
238
201
239
@classmethod
202
240
def make_tf_data (
203
241
cls ,
204
242
filename : epath .PathLike ,
205
- buffer_size : Optional [ int ] = None ,
243
+ buffer_size : int | None = None ,
206
244
) -> tf .data .Dataset :
207
245
"""Returns TensorFlow Dataset comprising given array record file."""
208
246
raise NotImplementedError (
@@ -215,7 +253,7 @@ def write_examples(
215
253
cls ,
216
254
path : epath .PathLike ,
217
255
iterator : Iterable [type_utils .KeySerializedExample ],
218
- ) -> Optional [ ExamplePositions ] :
256
+ ) -> ExamplePositions | None :
219
257
"""Write examples from given iterator in given path.
220
258
221
259
Args:
@@ -249,6 +287,7 @@ class ParquetFileAdapter(FileAdapter):
249
287
250
288
FILE_SUFFIX = 'parquet'
251
289
SUPPORTS_RANDOM_ACCESS = True
290
+ SUPPORTS_TF_DATA = True
252
291
_PARQUET_FIELD = 'data'
253
292
_BATCH_SIZE = 100
254
293
@@ -319,11 +358,11 @@ def _to_bytes(key: type_utils.Key) -> bytes:
319
358
320
359
321
360
# Create a mapping from FileFormat -> FileAdapter.
322
- ADAPTER_FOR_FORMAT : Dict [FileFormat , Type [FileAdapter ]] = {
323
- FileFormat .RIEGELI : RiegeliFileAdapter ,
324
- FileFormat .TFRECORD : TfRecordFileAdapter ,
361
+ ADAPTER_FOR_FORMAT : dict [FileFormat , Type [FileAdapter ]] = {
325
362
FileFormat .ARRAY_RECORD : ArrayRecordFileAdapter ,
326
363
FileFormat .PARQUET : ParquetFileAdapter ,
364
+ FileFormat .RIEGELI : RiegeliFileAdapter ,
365
+ FileFormat .TFRECORD : TfRecordFileAdapter ,
327
366
}
328
367
329
368
_FILE_SUFFIX_TO_FORMAT = {
@@ -350,7 +389,7 @@ def is_example_file(filename: str) -> bool:
350
389
)
351
390
352
391
353
- def _batched (iterator : Iterator [T ] | Iterable [T ], n : int ) -> Iterator [List [T ]]:
392
+ def _batched (iterator : Iterator [T ] | Iterable [T ], n : int ) -> Iterator [list [T ]]:
354
393
"""Batches the result of an iterator into lists of length n.
355
394
356
395
This function is built-in the standard library from 3.12 (source:
@@ -371,3 +410,49 @@ def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[List[T]]:
371
410
return
372
411
yield batch
373
412
i += n
413
+
414
+
415
+ def convert_path_to_file_format (
416
+ path : epath .PathLike , file_format : FileFormat
417
+ ) -> epath .Path :
418
+ """Returns the path to a specific shard for a different file format.
419
+
420
+ TFDS can store the file format in the filename as a suffix or as an infix. For
421
+ example:
422
+
423
+ - `dataset-train.<FILE_FORMAT>-00000-of-00001`, a so-called infix format
424
+ because the file format comes before the shard spec.
425
+ - `dataset-train-00000-of-00001.<FILE_FORMAT>`, a so-called suffix format
426
+ because the file format comes after the shard spec.
427
+
428
+ Args:
429
+ path: The path of a specific to convert. Can be the path for different file
430
+ formats.
431
+ file_format: The file format to which the shard path should be converted.
432
+ """
433
+ path = epath .Path (path )
434
+ file_name : str = path .name
435
+ if file_format .file_suffix in file_name :
436
+ # Already has the right file format in the file name.
437
+ return path
438
+
439
+ infix_formats = FileFormat .with_suffix_before_shard_spec ()
440
+ suffix_formats = FileFormat .with_suffix_after_shard_spec ()
441
+
442
+ # Remove any existing file format from the file name.
443
+ infix_format_concat = '|' .join (f .file_suffix for f in infix_formats )
444
+ file_name = re .sub (rf'(\.({ infix_format_concat } ))' , '' , file_name )
445
+
446
+ suffix_formats_concat = '|' .join (f .file_suffix for f in suffix_formats )
447
+ file_name = re .sub (rf'(\.({ suffix_formats_concat } ))$' , '' , file_name )
448
+
449
+ # Add back the proper file format.
450
+ if file_format in suffix_formats :
451
+ file_name = f'{ file_name } .{ file_format .file_suffix } '
452
+ else :
453
+ file_name = re .sub (
454
+ r'-(\d+)-of-(\d+)' ,
455
+ rf'.{ file_format .file_suffix } -\1-of-\2' ,
456
+ file_name ,
457
+ )
458
+ return path .parent / file_name
0 commit comments