Skip to content

Commit c815f93

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add helper function to check whether data has been generated in a specific file format
PiperOrigin-RevId: 664771687
1 parent cb24f9a commit c815f93

File tree

3 files changed

+53
-21
lines changed

3 files changed

+53
-21
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -836,32 +836,28 @@ def as_data_source(
836836
" choose another data_dir or delete the data."
837837
)
838838

839-
if info.file_format is None and not info.alternative_file_formats:
839+
available_formats = info.available_file_formats()
840+
if not available_formats:
840841
raise ValueError(
841842
"Dataset info file format is not set! For random access, one of the"
842843
f" following formats is required: {random_access_formats_msg}"
843844
)
844845

845-
if (
846-
info.file_format is None
847-
or info.file_format not in random_access_formats
848-
):
849-
available_formats = set(info.alternative_file_formats)
850-
suitable_formats = available_formats.intersection(random_access_formats)
851-
if suitable_formats:
852-
chosen_format = suitable_formats.pop()
853-
logging.info(
854-
"Found random access formats: %s. Chose to use %s. Overriding file"
855-
" format in the dataset info.",
856-
", ".join([f.name for f in suitable_formats]),
857-
chosen_format,
858-
)
859-
# Change the dataset info to read from a random access format.
860-
info.set_file_format(
861-
chosen_format, override=True, override_if_initialized=True
862-
)
863-
else:
864-
raise NotImplementedError(unsupported_format_msg)
846+
suitable_formats = available_formats.intersection(random_access_formats)
847+
if suitable_formats:
848+
chosen_format = suitable_formats.pop()
849+
logging.info(
850+
"Found random access formats: %s. Chose to use %s. Overriding file"
851+
" format in the dataset info.",
852+
", ".join([f.name for f in suitable_formats]),
853+
chosen_format,
854+
)
855+
# Change the dataset info to read from a random access format.
856+
info.set_file_format(
857+
chosen_format, override=True, override_if_initialized=True
858+
)
859+
else:
860+
raise NotImplementedError(unsupported_format_msg)
865861

866862
# Create a dataset for each of the given splits
867863
def build_single_data_source(split: str) -> Sequence[Any]:

tensorflow_datasets/core/dataset_info.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,13 @@ def add_alternative_file_format(
532532
self._alternative_file_formats.append(file_format)
533533
self.as_proto.alternative_file_formats.append(file_format.value)
534534

535+
def available_file_formats(self) -> set[file_adapters.FileFormat]:
536+
formats = set()
537+
if self.file_format:
538+
formats.add(self.file_format)
539+
formats.update(self.alternative_file_formats)
540+
return formats
541+
535542
@property
536543
def splits(self) -> splits_lib.SplitDict:
537544
return self._splits
@@ -1265,6 +1272,19 @@ def update_info_proto_with_features(
12651272
return completed_info_proto
12661273

12671274

1275+
def supports_file_format(
1276+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1277+
file_format: str | file_adapters.FileFormat,
1278+
) -> bool:
1279+
"""Returns whether the given file format is supported by the dataset."""
1280+
if isinstance(file_format, file_adapters.FileFormat):
1281+
file_format = file_format.value
1282+
return (
1283+
file_format == dataset_info_proto.file_format
1284+
or file_format in dataset_info_proto.alternative_file_formats
1285+
)
1286+
1287+
12681288
class MetadataDict(Metadata, dict):
12691289
"""A `tfds.core.Metadata` object that acts as a `dict`.
12701290

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,22 @@ def test_create_redistribution_info_proto_unsupported_fields():
682682
)
683683

684684

685+
def test_supports_file_format():
686+
dataset_info_proto = dataset_info_pb2.DatasetInfo(
687+
file_format=file_adapters.FileFormat.TFRECORD.value,
688+
alternative_file_formats=[file_adapters.FileFormat.RIEGELI.value],
689+
)
690+
assert dataset_info.supports_file_format(
691+
dataset_info_proto, file_format=file_adapters.FileFormat.TFRECORD
692+
)
693+
assert dataset_info.supports_file_format(
694+
dataset_info_proto, file_format=file_adapters.FileFormat.RIEGELI
695+
)
696+
assert not dataset_info.supports_file_format(
697+
dataset_info_proto, file_format=file_adapters.FileFormat.PARQUET
698+
)
699+
700+
685701
# pylint: disable=g-inconsistent-quotes
686702
_INFO_STR = '''tfds.core.DatasetInfo(
687703
name='mnist',

0 commit comments

Comments
 (0)