Skip to content

Commit 62c9456

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add method to get the split info from a dataset info proto.
PiperOrigin-RevId: 664803703
1 parent fa4eda5 commit 62c9456

File tree

2 files changed

+125
-4
lines changed

2 files changed

+125
-4
lines changed

tensorflow_datasets/core/dataset_info.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,17 +1272,62 @@ def update_info_proto_with_features(
12721272
return completed_info_proto
12731273

12741274

1275+
def available_file_formats(
1276+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1277+
) -> set[str]:
1278+
"""Returns the available file formats for the given dataset."""
1279+
return set(
1280+
[dataset_info_proto.file_format]
1281+
+ list(dataset_info_proto.alternative_file_formats)
1282+
)
1283+
1284+
12751285
def supports_file_format(
12761286
dataset_info_proto: dataset_info_pb2.DatasetInfo,
12771287
file_format: str | file_adapters.FileFormat,
12781288
) -> bool:
12791289
"""Returns whether the given file format is supported by the dataset."""
12801290
if isinstance(file_format, file_adapters.FileFormat):
12811291
file_format = file_format.value
1282-
return (
1283-
file_format == dataset_info_proto.file_format
1284-
or file_format in dataset_info_proto.alternative_file_formats
1285-
)
1292+
return file_format in available_file_formats(dataset_info_proto)
1293+
1294+
1295+
def get_split_info_from_proto(
1296+
dataset_info_proto: dataset_info_pb2.DatasetInfo,
1297+
split_name: str,
1298+
data_dir: epath.PathLike,
1299+
file_format: file_adapters.FileFormat,
1300+
) -> splits_lib.SplitInfo | None:
1301+
"""Returns split info from the given dataset info proto.
1302+
1303+
Args:
1304+
dataset_info_proto: the proto with the dataset info.
1305+
split_name: the split for which to retrieve info for.
1306+
data_dir: the directory where the data is stored.
1307+
file_format: the file format for which to get the split info.
1308+
"""
1309+
if not supports_file_format(dataset_info_proto, file_format):
1310+
available_format = available_file_formats(dataset_info_proto)
1311+
raise ValueError(
1312+
f"File format {file_format.value} does not match available dataset file"
1313+
f" formats: {sorted(available_format)}."
1314+
)
1315+
for split_info in dataset_info_proto.splits:
1316+
if split_info.name == split_name:
1317+
filename_template = naming.ShardedFileTemplate(
1318+
dataset_name=dataset_info_proto.name,
1319+
data_dir=epath.Path(data_dir),
1320+
filetype_suffix=file_format.file_suffix,
1321+
)
1322+
# Override the default file name template if it was set.
1323+
if split_info.filepath_template:
1324+
filename_template = filename_template.replace(
1325+
template=split_info.filepath_template
1326+
)
1327+
return splits_lib.SplitInfo.from_proto(
1328+
proto=split_info, filename_template=filename_template
1329+
)
1330+
return None
12861331

12871332

12881333
class MetadataDict(Metadata, dict):

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import pathlib
21+
import re
2122
import tempfile
2223
import time
2324
from typing import Union
@@ -698,6 +699,81 @@ def test_supports_file_format():
698699
)
699700

700701

702+
class GetSplitInfoFromProtoTest(testing.TestCase):
703+
704+
def _dataset_info_proto_with_splits(self):
705+
return dataset_info_pb2.DatasetInfo(
706+
name="dataset",
707+
file_format="tfrecord",
708+
alternative_file_formats=["riegeli"],
709+
splits=[
710+
dataset_info_pb2.SplitInfo(
711+
name="train",
712+
shard_lengths=[1, 2, 3],
713+
num_bytes=42,
714+
),
715+
dataset_info_pb2.SplitInfo(
716+
name="test",
717+
shard_lengths=[1, 2, 3],
718+
num_bytes=42,
719+
filepath_template="{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}",
720+
),
721+
],
722+
)
723+
724+
def test_get_split_info_from_proto_undefined_filename_template(self):
725+
actual = dataset_info.get_split_info_from_proto(
726+
dataset_info_proto=self._dataset_info_proto_with_splits(),
727+
split_name="train",
728+
data_dir="/path/to/data",
729+
file_format=file_adapters.FileFormat.TFRECORD,
730+
)
731+
assert actual.name == "train"
732+
assert actual.shard_lengths == [1, 2, 3]
733+
assert actual.num_bytes == 42
734+
assert actual.filename_template.dataset_name == "dataset"
735+
assert actual.filename_template.template == naming.DEFAULT_FILENAME_TEMPLATE
736+
737+
def test_get_split_info_from_proto_defined_filename_template(self):
738+
actual = dataset_info.get_split_info_from_proto(
739+
dataset_info_proto=self._dataset_info_proto_with_splits(),
740+
split_name="test",
741+
data_dir="/path/to/data",
742+
file_format=file_adapters.FileFormat.TFRECORD,
743+
)
744+
assert actual.name == "test"
745+
assert actual.shard_lengths == [1, 2, 3]
746+
assert actual.filename_template.dataset_name == "dataset"
747+
assert (
748+
actual.filename_template.template
749+
== "{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}"
750+
)
751+
752+
def test_get_split_info_from_proto_non_existing_split(self):
753+
actual = dataset_info.get_split_info_from_proto(
754+
dataset_info_proto=self._dataset_info_proto_with_splits(),
755+
split_name="undefined",
756+
data_dir="/path/to/data",
757+
file_format=file_adapters.FileFormat.TFRECORD,
758+
)
759+
assert actual is None
760+
761+
def test_get_split_info_from_proto_unavailable_format(self):
762+
with pytest.raises(
763+
ValueError,
764+
match=re.escape(
765+
"File format parquet does not match available dataset file formats:"
766+
" ['riegeli', 'tfrecord']."
767+
),
768+
):
769+
dataset_info.get_split_info_from_proto(
770+
dataset_info_proto=self._dataset_info_proto_with_splits(),
771+
split_name="undefined",
772+
data_dir="/path/to/data",
773+
file_format=file_adapters.FileFormat.PARQUET,
774+
)
775+
776+
701777
# pylint: disable=g-inconsistent-quotes
702778
_INFO_STR = '''tfds.core.DatasetInfo(
703779
name='mnist',

0 commit comments

Comments
 (0)