Skip to content

Commit 58ae072

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add support for storing a dataset in multiple formats
Add alternative file formats to dataset info and make `tfds.data_source` pick a supported file format when the default one doesn't support random access. PiperOrigin-RevId: 645307409
1 parent f2f0868 commit 58ae072

File tree

6 files changed

+172
-38
lines changed

6 files changed

+172
-38
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -783,33 +783,56 @@ def as_data_source(
783783
if split is None:
784784
split = {s: s for s in self.info.splits}
785785

786-
# Create a dataset for each of the given splits
787-
def build_single_data_source(
788-
split: str,
789-
) -> Sequence[Any]:
790-
file_format = self.info.file_format
791-
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
792-
return array_record.ArrayRecordDataSource(
793-
self.info,
794-
split=split,
795-
decoders=decoders,
786+
info = self.info
787+
788+
random_access_formats = file_adapters.FileFormat.with_random_access()
789+
random_access_formats_msg = " or ".join(
790+
[f.value for f in random_access_formats]
791+
)
792+
unsupported_format_msg = (
793+
f"Random access data source for file format {info.file_format} is"
794+
" not supported. Can you try to run download_and_prepare with"
795+
f" file_format set to one of: {random_access_formats_msg}?"
796+
)
797+
798+
if info.file_format is None and not info.alternative_file_formats:
799+
raise ValueError(
800+
"Dataset info file format is not set! For random access, one of the"
801+
f" following formats is required: {random_access_formats_msg}"
802+
)
803+
804+
if (
805+
info.file_format is None
806+
or info.file_format not in random_access_formats
807+
):
808+
available_formats = set(info.alternative_file_formats)
809+
suitable_formats = available_formats.intersection(random_access_formats)
810+
if suitable_formats:
811+
chosen_format = suitable_formats.pop()
812+
logging.info(
813+
"Found random access formats: %s. Chose to use %s. Overriding file"
814+
" format in the dataset info.",
815+
", ".join([f.name for f in suitable_formats]),
816+
chosen_format,
796817
)
797-
elif file_format == file_adapters.FileFormat.PARQUET:
798-
return parquet.ParquetDataSource(
799-
self.info,
800-
split=split,
801-
decoders=decoders,
818+
# Change the dataset info to read from a random access format.
819+
info.set_file_format(
820+
chosen_format, override=True, override_if_initialized=True
802821
)
803822
else:
804-
args = [
805-
f"`file_format='{file_format.value}'`"
806-
for file_format in file_adapters.FileFormat.with_random_access()
807-
]
808-
raise NotImplementedError(
809-
f"Random access data source for file format {file_format} is not"
810-
" supported. Can you try to run download_and_prepare with"
811-
f" {' or '.join(args)}?"
812-
)
823+
raise NotImplementedError(unsupported_format_msg)
824+
825+
# Create a dataset for each of the given splits
826+
def build_single_data_source(split: str) -> Sequence[Any]:
827+
match info.file_format:
828+
case file_adapters.FileFormat.ARRAY_RECORD:
829+
return array_record.ArrayRecordDataSource(
830+
info, split=split, decoders=decoders
831+
)
832+
case file_adapters.FileFormat.PARQUET:
833+
return parquet.ParquetDataSource(info, split=split, decoders=decoders)
834+
case _:
835+
raise NotImplementedError(unsupported_format_msg)
813836

814837
all_ds = tree.map_structure(build_single_data_source, split)
815838
return all_ds

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,32 @@ def test_load_as_data_source(self):
578578
assert len(data_source) == 10
579579
assert data_source[0]["x"] == 28
580580

581+
def test_load_as_data_source_alternative_file_format(self):
582+
data_dir = self.get_temp_dir()
583+
builder = DummyDatasetWithConfigs(
584+
data_dir=data_dir,
585+
config="plus1",
586+
file_format=file_adapters.FileFormat.ARRAY_RECORD,
587+
)
588+
builder.download_and_prepare()
589+
# Change the default file format and add alternative file format.
590+
builder.info.as_proto.file_format = "tfrecord"
591+
builder.info.add_alternative_file_format("array_record")
592+
593+
data_source = builder.as_data_source()
594+
assert isinstance(data_source, dict)
595+
assert isinstance(data_source["train"], array_record.ArrayRecordDataSource)
596+
assert isinstance(data_source["test"], array_record.ArrayRecordDataSource)
597+
assert len(data_source["test"]) == 10
598+
assert data_source["test"][0]["x"] == 28
599+
assert len(data_source["train"]) == 20
600+
assert data_source["train"][0]["x"] == 7
601+
602+
data_source = builder.as_data_source(split="test")
603+
assert isinstance(data_source, array_record.ArrayRecordDataSource)
604+
assert len(data_source) == 10
605+
assert data_source[0]["x"] == 28
606+
581607
@parameterized.named_parameters(
582608
*[
583609
{"file_format": file_format, "testcase_name": file_format.value}

tensorflow_datasets/core/dataset_info.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from __future__ import annotations
3434

3535
import abc
36-
from collections.abc import Iterable
36+
from collections.abc import Iterable, Sequence
3737
import dataclasses
3838
import json
3939
import os
@@ -194,6 +194,9 @@ def __init__(
194194
license: str | None = None, # pylint: disable=redefined-builtin
195195
redistribution_info: Optional[dict[str, str]] = None,
196196
split_dict: Optional[splits_lib.SplitDict] = None,
197+
alternative_file_formats: (
198+
Sequence[str | file_adapters.FileFormat] | None
199+
) = None,
197200
# LINT.ThenChange(:setstate)
198201
):
199202
# pyformat: disable
@@ -238,6 +241,8 @@ def __init__(
238241
subfield will automatically be written to a LICENSE file stored with the
239242
dataset.
240243
split_dict: information about the splits in this dataset.
244+
alternative_file_formats: alternative file formats that are availablefor
245+
this dataset.
241246
"""
242247
# pyformat: enable
243248
self._builder_or_identity = builder
@@ -246,6 +251,13 @@ def __init__(
246251
else:
247252
self._identity = DatasetIdentity.from_builder(builder)
248253

254+
self._alternative_file_formats: list[file_adapters.FileFormat] = []
255+
if alternative_file_formats:
256+
for file_format in alternative_file_formats:
257+
if isinstance(file_format, str):
258+
file_format = file_adapters.FileFormat.from_value(file_format)
259+
self.add_alternative_file_format(file_format)
260+
249261
self._info_proto = dataset_info_pb2.DatasetInfo(
250262
name=self._identity.name,
251263
description=utils.dedent(description),
@@ -260,6 +272,9 @@ def __init__(
260272
redistribution_info=_create_redistribution_info_proto(
261273
license=license, redistribution_info=redistribution_info
262274
),
275+
alternative_file_formats=[
276+
f.value for f in self._alternative_file_formats
277+
],
263278
)
264279

265280
if homepage:
@@ -328,6 +343,7 @@ def from_proto(
328343
repeated_split_infos=proto.splits,
329344
filename_template=filename_template,
330345
),
346+
alternative_file_formats=proto.alternative_file_formats,
331347
)
332348

333349
@property
@@ -415,6 +431,10 @@ def download_size(self, size):
415431
def features(self):
416432
return self._features
417433

434+
@property
435+
def alternative_file_formats(self) -> Sequence[file_adapters.FileFormat]:
436+
return self._alternative_file_formats
437+
418438
@property
419439
def metadata(self) -> Metadata | None:
420440
return self._metadata
@@ -444,6 +464,7 @@ def set_file_format(
444464
self,
445465
file_format: None | str | file_adapters.FileFormat,
446466
override: bool = False,
467+
override_if_initialized: bool = False,
447468
) -> None:
448469
"""Internal function to define the file format.
449470
@@ -454,6 +475,8 @@ def set_file_format(
454475
file_format: The file format.
455476
override: Whether the file format should be overridden if it is already
456477
set.
478+
override_if_initialized: Whether the file format should be overridden if
479+
the DatasetInfo is already fully initialized.
457480
458481
Raises:
459482
ValueError: if the file format was already set and the `override`
@@ -474,12 +497,39 @@ def set_file_format(
474497
raise ValueError(
475498
f"File format is already set to {self.file_format}. Got {file_format}"
476499
)
477-
if override and self._fully_initialized:
500+
if override and self._fully_initialized and not override_if_initialized:
478501
raise RuntimeError(
479-
"Cannot override the file format "
480-
"when the DatasetInfo is already fully initialized!"
502+
"Cannot override the file format when the DatasetInfo is already"
503+
" fully initialized!"
481504
)
482505
self._info_proto.file_format = file_format.value
506+
if override_if_initialized:
507+
# Update the splits to point to the new file format.
508+
updated_split_infos = []
509+
for split_info in self.splits.values():
510+
if split_info.filename_template is None:
511+
continue
512+
updated_split_info = split_info.replace(
513+
filename_template=split_info.filename_template.replace(
514+
filetype_suffix=file_format.value
515+
)
516+
)
517+
updated_split_infos.append(updated_split_info)
518+
self._splits = splits_lib.SplitDict(updated_split_infos)
519+
520+
def add_alternative_file_format(
521+
self,
522+
file_format: str | file_adapters.FileFormat,
523+
) -> None:
524+
"""Adds an alternative file format to the dataset info."""
525+
if isinstance(file_format, str):
526+
file_format = file_adapters.FileFormat.from_value(file_format)
527+
if file_format in self.alternative_file_formats:
528+
raise ValueError(
529+
f"Alternative file format {file_format} is already present."
530+
)
531+
self._alternative_file_formats.append(file_format)
532+
self.as_proto.alternative_file_formats.append(file_format.value)
483533

484534
@property
485535
def splits(self) -> splits_lib.SplitDict:
@@ -882,6 +932,7 @@ def __getstate__(self):
882932
"metadata": self.metadata,
883933
"license": self.redistribution_info.license,
884934
"split_dict": self.splits,
935+
"alternative_file_formats": self.alternative_file_formats,
885936
}
886937
def __setstate__(self, state):
887938
# LINT.IfChange(setstate)
@@ -896,6 +947,7 @@ def __setstate__(self, state):
896947
metadata=state["metadata"],
897948
license=state["license"],
898949
split_dict=state["split_dict"],
950+
alternative_file_formats=state["alternative_file_formats"],
899951
)
900952
# LINT.ThenChange(:dataset_info_args)
901953

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,32 @@ def test_set_file_format_override(self):
422422
info.set_file_format(file_adapters.FileFormat.RIEGELI, override=True)
423423
self.assertEqual(info.file_format, file_adapters.FileFormat.RIEGELI)
424424

425+
def test_set_file_format_override_failes_when_fully_initialized(self):
426+
info = dataset_info.DatasetInfo(builder=self._builder)
427+
info.set_file_format(file_adapters.FileFormat.TFRECORD)
428+
info._fully_initialized = True
429+
self.assertEqual(info.file_format, file_adapters.FileFormat.TFRECORD)
430+
with pytest.raises(
431+
ValueError,
432+
match=(
433+
"File format is already set to FileFormat.TFRECORD. Got"
434+
" FileFormat.RIEGELI"
435+
),
436+
):
437+
info.set_file_format(file_adapters.FileFormat.RIEGELI)
438+
439+
def test_set_file_format_override_fully_initialized(self):
440+
info = dataset_info.DatasetInfo(builder=self._builder)
441+
info.set_file_format(file_adapters.FileFormat.TFRECORD)
442+
info._fully_initialized = True
443+
self.assertEqual(info.file_format, file_adapters.FileFormat.TFRECORD)
444+
info.set_file_format(
445+
file_adapters.FileFormat.RIEGELI,
446+
override=True,
447+
override_if_initialized=True,
448+
)
449+
self.assertEqual(info.file_format, file_adapters.FileFormat.RIEGELI)
450+
425451
def test_update_info_proto_with_features(self):
426452
info_proto = dataset_info.DatasetInfo(builder=self._builder).as_proto
427453
new_features = features.FeaturesDict({"text": features.Text()})

tensorflow_datasets/core/proto/dataset_info.proto

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,16 @@ message DatasetInfo {
222222
// Specifies whether examples should be shuffled.
223223
bool disable_shuffling = 16;
224224

225-
// File format used.
226-
// Use string to allow format extension without regenerating the proto.
225+
// Default file format to use. Note that alternative file formats may be
226+
// available too and that depending on how the dataset is loaded, the default
227+
// file format may be ignored.
227228
string file_format = 17;
228229

230+
// Alternative file formats available for this dataset. Note that the number
231+
// of shards and the number of examples per shard must be the same for all
232+
// file formats.
233+
repeated string alternative_file_formats = 22;
234+
229235
// The data that was used to generate this dataset.
230236
repeated DataSourceAccess data_source_accesses = 20;
231237

tensorflow_datasets/core/proto/dataset_info_generated_pb2.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
b' \x01(\t\x12\x10\n\x08\x64\x61ta_dir\x18\x04'
6868
b' \x01(\t\x12\x14\n\x0c\x64s_namespace\x18\x05'
6969
b' \x01(\t\x12\r\n\x05split\x18\x06'
70-
b' \x01(\t"\xb4\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01'
70+
b' \x01(\t"\xd6\x07\n\x0b\x44\x61tasetInfo\x12\x0c\n\x04name\x18\x01'
7171
b' \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02'
7272
b' \x01(\t\x12\x0f\n\x07version\x18\t \x01(\t\x12I\n\rrelease_notes\x18\x12'
7373
b' \x03(\x0b\x32\x32.tensorflow_datasets.DatasetInfo.ReleaseNotesEntry\x12\x13\n\x0b\x63onfig_name\x18\r'
@@ -85,8 +85,9 @@
8585
b' \x01(\x0b\x32#.tensorflow_datasets.SupervisedKeys\x12\x44\n\x13redistribution_info\x18\x0b'
8686
b" \x01(\x0b\x32'.tensorflow_datasets.RedistributionInfo\x12\x13\n\x0bmodule_name\x18\x0f"
8787
b' \x01(\t\x12\x19\n\x11\x64isable_shuffling\x18\x10'
88-
b' \x01(\x08\x12\x13\n\x0b\x66ile_format\x18\x11'
89-
b' \x01(\t\x12\x43\n\x14\x64\x61ta_source_accesses\x18\x14'
88+
b' \x01(\x08\x12\x13\n\x0b\x66ile_format\x18\x11 \x01(\t\x12'
89+
b' \n\x18\x61lternative_file_formats\x18\x16'
90+
b' \x03(\t\x12\x43\n\x14\x64\x61ta_source_accesses\x18\x14'
9091
b' \x03(\x0b\x32%.tensorflow_datasets.DataSourceAccess\x1a\x33\n\x11ReleaseNotesEntry\x12\x0b\n\x03key\x18\x01'
9192
b' \x01(\t\x12\r\n\x05value\x18\x02'
9293
b' \x01(\t:\x02\x38\x01\x1a\x38\n\x16\x44ownloadChecksumsEntry\x12\x0b\n\x03key\x18\x01'
@@ -145,9 +146,9 @@
145146
_TFDSDATASETREFERENCE._serialized_start = 1280
146147
_TFDSDATASETREFERENCE._serialized_end = 1404
147148
_DATASETINFO._serialized_start = 1407
148-
_DATASETINFO._serialized_end = 2355
149-
_DATASETINFO_RELEASENOTESENTRY._serialized_start = 2246
150-
_DATASETINFO_RELEASENOTESENTRY._serialized_end = 2297
151-
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2299
152-
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2355
149+
_DATASETINFO._serialized_end = 2389
150+
_DATASETINFO_RELEASENOTESENTRY._serialized_start = 2280
151+
_DATASETINFO_RELEASENOTESENTRY._serialized_end = 2331
152+
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_start = 2333
153+
_DATASETINFO_DOWNLOADCHECKSUMSENTRY._serialized_end = 2389
153154
# @@protoc_insertion_point(module_scope)

0 commit comments

Comments
 (0)