Skip to content

Commit 9cd4051

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Do not fail the convert format pipeline when some dataset variants cannot be loaded
Also use the dataset info proto instead of the class to avoid problems with datasets that have legacy features. PiperOrigin-RevId: 688512744
1 parent 354315c commit 9cd4051

File tree

5 files changed

+168
-68
lines changed

5 files changed

+168
-68
lines changed

tensorflow_datasets/core/dataset_info.py

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def initialized(self) -> bool:
637637

638638
@property
639639
def as_json(self) -> str:
640-
return json_format.MessageToJson(self.as_proto, sort_keys=True)
640+
return get_dataset_info_json(self.as_proto)
641641

642642
def write_to_directory(
643643
self, dataset_info_dir: epath.PathLike, all_metadata=True
@@ -671,7 +671,7 @@ def write_to_directory(
671671

672672
def write_dataset_info_json(self, dataset_info_dir: epath.PathLike) -> None:
673673
"""Writes only the dataset_info.json file to the given directory."""
674-
dataset_info_path(dataset_info_dir).write_text(self.as_json)
674+
write_dataset_info_proto(self.as_proto, dataset_info_dir=dataset_info_dir)
675675

676676
def read_from_directory(self, dataset_info_dir: epath.PathLike) -> None:
677677
"""Update DatasetInfo from the metadata files in `dataset_info_dir`.
@@ -852,18 +852,10 @@ def add_tfds_data_source_access(
852852
dataset_reference:
853853
url: a URL referring to the TFDS dataset.
854854
"""
855-
self._info_proto.data_source_accesses.append(
856-
dataset_info_pb2.DataSourceAccess(
857-
access_timestamp_ms=_now_in_milliseconds(),
858-
tfds_dataset=dataset_info_pb2.TfdsDatasetReference(
859-
name=dataset_reference.dataset_name,
860-
config=dataset_reference.config,
861-
version=str(dataset_reference.version),
862-
data_dir=os.fspath(dataset_reference.data_dir),
863-
ds_namespace=dataset_reference.namespace,
864-
),
865-
url=dataset_info_pb2.Url(url=url),
866-
)
855+
add_tfds_data_source_access(
856+
dataset_info_proto=self._info_proto,
857+
dataset_reference=dataset_reference,
858+
url=url,
867859
)
868860

869861
def initialize_from_bucket(self) -> None:
@@ -1130,6 +1122,22 @@ def get_dataset_feature_statistics(builder, split):
11301122
return statistics.datasets[0], schema
11311123

11321124

1125+
def get_dataset_info_json(
1126+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1127+
) -> str:
1128+
return json_format.MessageToJson(dataset_info_proto, sort_keys=True)
1129+
1130+
1131+
def write_dataset_info_proto(
1132+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1133+
dataset_info_dir: epath.PathLike,
1134+
) -> None:
1135+
"""Writes the dataset info proto to the given path."""
1136+
dataset_info_dir = epath.Path(dataset_info_dir)
1137+
json_str = get_dataset_info_json(dataset_info_proto)
1138+
dataset_info_path(dataset_info_dir).write_text(json_str)
1139+
1140+
11331141
def read_from_json(path: epath.PathLike) -> dataset_info_pb2.DatasetInfo:
11341142
"""Read JSON-formatted proto into DatasetInfo proto.
11351143
@@ -1308,6 +1316,36 @@ def supports_file_format(
13081316
return file_format in available_file_formats(dataset_info_proto)
13091317

13101318

1319+
def get_split_dict_from_proto(
1320+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1321+
data_dir: epath.PathLike,
1322+
file_format: str | file_adapters.FileFormat | None = None,
1323+
) -> splits_lib.SplitDict:
1324+
"""Returns the split dict with all split infos from the given dataset.
1325+
1326+
Args:
1327+
dataset_info_proto: the proto with the dataset info and split infos.
1328+
data_dir: the directory where the data is stored.
1329+
file_format: the file format for which to get the split dict. If the file
1330+
format is not specified, the file format from the dataset info proto is
1331+
used.
1332+
"""
1333+
if file_format:
1334+
file_format = file_adapters.FileFormat(file_format)
1335+
else:
1336+
file_format = file_adapters.FileFormat(dataset_info_proto.file_format)
1337+
1338+
filename_template = naming.ShardedFileTemplate(
1339+
dataset_name=dataset_info_proto.name,
1340+
data_dir=epath.Path(data_dir),
1341+
filetype_suffix=file_format.file_suffix,
1342+
)
1343+
return splits_lib.SplitDict.from_proto(
1344+
repeated_split_infos=dataset_info_proto.splits,
1345+
filename_template=filename_template,
1346+
)
1347+
1348+
13111349
def get_split_info_from_proto(
13121350
dataset_info_proto: dataset_info_pb2.DatasetInfo,
13131351
split_name: str,
@@ -1328,22 +1366,40 @@ def get_split_info_from_proto(
13281366
f"File format {file_format.value} does not match available dataset file"
13291367
f" formats: {sorted(available_format)}."
13301368
)
1331-
for split_info in dataset_info_proto.splits:
1332-
if split_info.name == split_name:
1333-
filename_template = naming.ShardedFileTemplate(
1334-
dataset_name=dataset_info_proto.name,
1335-
data_dir=epath.Path(data_dir),
1336-
filetype_suffix=file_format.file_suffix,
1337-
)
1338-
# Override the default file name template if it was set.
1339-
if split_info.filepath_template:
1340-
filename_template = filename_template.replace(
1341-
template=split_info.filepath_template
1342-
)
1343-
return splits_lib.SplitInfo.from_proto(
1344-
proto=split_info, filename_template=filename_template
1369+
1370+
splits_dict = get_split_dict_from_proto(
1371+
dataset_info_proto=dataset_info_proto,
1372+
data_dir=data_dir,
1373+
file_format=file_format,
1374+
)
1375+
return splits_dict.get(split_name)
1376+
1377+
1378+
def add_tfds_data_source_access(
1379+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1380+
dataset_reference: naming.DatasetReference,
1381+
url: str | None = None,
1382+
) -> None:
1383+
"""Records that the given query was used to generate this dataset.
1384+
1385+
Args:
1386+
dataset_info_proto: the proto with the dataset info to update.
1387+
dataset_reference: the dataset reference to record.
1388+
url: a URL referring to the TFDS dataset.
1389+
"""
1390+
dataset_info_proto.data_source_accesses.append(
1391+
dataset_info_pb2.DataSourceAccess(
1392+
access_timestamp_ms=_now_in_milliseconds(),
1393+
tfds_dataset=dataset_info_pb2.TfdsDatasetReference(
1394+
name=dataset_reference.dataset_name,
1395+
config=dataset_reference.config,
1396+
version=str(dataset_reference.version),
1397+
data_dir=os.fspath(dataset_reference.data_dir),
1398+
ds_namespace=dataset_reference.namespace,
1399+
),
1400+
url=dataset_info_pb2.Url(url=url),
13451401
)
1346-
return None
1402+
)
13471403

13481404

13491405
class MetadataDict(Metadata, dict):

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,32 @@ def _dataset_info_proto_with_splits(self):
721721
],
722722
)
723723

724+
def test_get_split_dict_from_proto(self):
725+
actual = dataset_info.get_split_dict_from_proto(
726+
dataset_info_proto=self._dataset_info_proto_with_splits(),
727+
data_dir="/data",
728+
file_format=file_adapters.FileFormat.PARQUET,
729+
)
730+
assert set(["train", "test"]) == set(actual.keys())
731+
732+
train = actual["train"]
733+
assert train.name == "train"
734+
assert train.shard_lengths == [1, 2, 3]
735+
assert train.num_bytes == 42
736+
assert train.filename_template.dataset_name == "dataset"
737+
assert train.filename_template.template == naming.DEFAULT_FILENAME_TEMPLATE
738+
assert train.filename_template.filetype_suffix == "parquet"
739+
740+
test = actual["test"]
741+
assert test.name == "test"
742+
assert test.shard_lengths == [1, 2, 3]
743+
assert test.num_bytes == 42
744+
assert test.filename_template.dataset_name == "dataset"
745+
assert (
746+
test.filename_template.template == "{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}"
747+
)
748+
assert test.filename_template.filetype_suffix == "parquet"
749+
724750
def test_get_split_info_from_proto_undefined_filename_template(self):
725751
actual = dataset_info.get_split_info_from_proto(
726752
dataset_info_proto=self._dataset_info_proto_with_splits(),

tensorflow_datasets/core/splits.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def filepaths(self) -> list[epath.Path]:
336336
result.extend(split_info.filepaths)
337337
return result
338338

339-
def replace(self, **kwargs: Any) -> 'MultiSplitInfo':
339+
def replace(self, **kwargs: Any) -> MultiSplitInfo:
340340
raise RuntimeError('replace is not supported on MultiSplitInfo')
341341

342342

@@ -471,7 +471,7 @@ def from_proto(
471471
cls,
472472
repeated_split_infos: Iterable[proto_lib.SplitInfo],
473473
filename_template: naming.ShardedFileTemplate,
474-
) -> 'SplitDict':
474+
) -> SplitDict:
475475
"""Returns a new SplitDict initialized from the `repeated_split_infos`."""
476476
split_infos = [
477477
SplitInfo.from_proto(
@@ -491,7 +491,7 @@ def total_num_examples(self):
491491
return sum(s.num_examples for s in self.values())
492492

493493
@classmethod
494-
def merge_multiple(cls, split_dicts: list['SplitDict']) -> 'SplitDict':
494+
def merge_multiple(cls, split_dicts: list[SplitDict]) -> SplitDict:
495495
info_per_split = []
496496
for split in set(itertools.chain(*split_dicts)):
497497
infos_of_split = []

0 commit comments

Comments
 (0)